Coverage for mcpgateway / services / tool_service.py: 98%
2225 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/tool_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Tool Service Implementation.
8This module implements tool management and invocation according to the MCP specification.
9It handles:
10- Tool registration and validation
11- Tool invocation with schema validation
12- Tool federation across gateways
13- Event notifications for tool changes
14- Active/inactive tool management
15"""
17# Standard
18import asyncio
19import base64
20import binascii
21from datetime import datetime, timezone
22from functools import lru_cache
23import json # NOTE: httpx uses stdlib json, not orjson, so response.json() raises json.JSONDecodeError
24import os
25import re
26import ssl
27import time
28from types import SimpleNamespace
29from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
30from urllib.parse import parse_qs, urlparse
31import uuid
33# Third-Party
34import anyio
35import httpx
36import jq
37import jsonschema
38from jsonschema import Draft4Validator, Draft6Validator, Draft7Validator, validators
39from mcp import ClientSession, types
40from mcp.client.sse import sse_client
41from mcp.client.streamable_http import streamablehttp_client
42import orjson
43from pydantic import ValidationError
44from sqlalchemy import and_, delete, desc, or_, select
45from sqlalchemy.exc import IntegrityError, OperationalError
46from sqlalchemy.orm import joinedload, selectinload, Session
48# First-Party
49from mcpgateway.cache.global_config_cache import global_config_cache
50from mcpgateway.common.models import Gateway as PydanticGateway
51from mcpgateway.common.models import TextContent
52from mcpgateway.common.models import Tool as PydanticTool
53from mcpgateway.common.models import ToolResult
54from mcpgateway.common.validators import SecurityValidator
55from mcpgateway.config import settings
56from mcpgateway.db import A2AAgent as DbA2AAgent
57from mcpgateway.db import fresh_db_session
58from mcpgateway.db import Gateway as DbGateway
59from mcpgateway.db import get_for_update, server_tool_association
60from mcpgateway.db import Tool as DbTool
61from mcpgateway.db import ToolMetric, ToolMetricsHourly
62from mcpgateway.observability import create_child_span, create_span, inject_trace_context_headers, set_span_attribute, set_span_error
63from mcpgateway.plugins.framework import (
64 get_plugin_manager,
65 GlobalContext,
66 HttpHeaderPayload,
67 PluginContextTable,
68 PluginError,
69 PluginManager,
70 PluginViolationError,
71 ToolHookType,
72 ToolPostInvokePayload,
73 ToolPreInvokePayload,
74)
75from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA
76from mcpgateway.schemas import AuthenticationValues, ToolCreate, ToolMetrics, ToolRead, ToolUpdate, TopPerformer
77from mcpgateway.services.audit_trail_service import get_audit_trail_service
78from mcpgateway.services.base_service import BaseService
79from mcpgateway.services.event_service import EventService
80from mcpgateway.services.logging_service import LoggingService
81from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType
82from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service
83from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge
84from mcpgateway.services.metrics_query_service import get_top_performers_combined
85from mcpgateway.services.oauth_manager import OAuthManager
86from mcpgateway.services.observability_service import current_trace_id, ObservabilityService
87from mcpgateway.services.performance_tracker import get_performance_tracker
88from mcpgateway.services.structured_logger import get_structured_logger
89from mcpgateway.services.team_management_service import TeamManagementService
90from mcpgateway.utils.correlation_id import get_correlation_id
91from mcpgateway.utils.create_slug import slugify
92from mcpgateway.utils.display_name import generate_display_name
93from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers
94from mcpgateway.utils.metrics_common import build_top_performers
95from mcpgateway.utils.pagination import decode_cursor, encode_cursor, unified_paginate
96from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
97from mcpgateway.utils.retry_manager import ResilientHttpClient
98from mcpgateway.utils.services_auth import decode_auth, encode_auth
99from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
100from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context
101from mcpgateway.utils.trace_context import format_trace_team_scope
102from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload
103from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging
104from mcpgateway.utils.validate_signature import validate_signature
106# Cache import (lazy to avoid circular dependencies)
107_REGISTRY_CACHE = None
108_TOOL_LOOKUP_CACHE = None
111def _get_registry_cache():
112 """Get registry cache singleton lazily.
114 Returns:
115 RegistryCache instance.
116 """
117 global _REGISTRY_CACHE # pylint: disable=global-statement
118 if _REGISTRY_CACHE is None:
119 # First-Party
120 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
122 _REGISTRY_CACHE = registry_cache
123 return _REGISTRY_CACHE
126def _get_tool_lookup_cache():
127 """Get tool lookup cache singleton lazily.
129 Returns:
130 ToolLookupCache instance.
131 """
132 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement
133 if _TOOL_LOOKUP_CACHE is None:
134 # First-Party
135 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
137 _TOOL_LOOKUP_CACHE = tool_lookup_cache
138 return _TOOL_LOOKUP_CACHE
141# Initialize logging service first
142logging_service = LoggingService()
143logger = logging_service.get_logger(__name__)
145# Initialize performance tracker, structured logger, audit trail, and metrics buffer for tool operations
146perf_tracker = get_performance_tracker()
147structured_logger = get_structured_logger("tool_service")
148audit_trail = get_audit_trail_service()
149metrics_buffer = get_metrics_buffer_service()
151_ENCRYPTED_TOOL_HEADER_VALUE_KEY = "_mcpgateway_encrypted_header_value_v1"
152_TOOL_HEADER_DATA_KEY = "data"
153_TOOL_HEADER_LEGACY_VALUE_KEY = "value"
154_SENSITIVE_TOOL_HEADER_PATTERNS = (
155 re.compile(r"^authorization$", re.IGNORECASE),
156 re.compile(r"^proxy-authorization$", re.IGNORECASE),
157 re.compile(r"^x-api-key$", re.IGNORECASE),
158 re.compile(r"^api-key$", re.IGNORECASE),
159 re.compile(r"^apikey$", re.IGNORECASE),
160 # Keep broad-enough auth matching while avoiding operational noise from
161 # non-secret tracing/idempotency headers (e.g. X-Correlation-Token).
162 re.compile(r"^x-(?:auth|api|access|refresh|client|bearer|session|security)[-_]?(?:token|secret|key)$", re.IGNORECASE),
163 re.compile(r"^(?:auth|api|access|refresh|client|bearer|session|security)[-_]?(?:token|secret|key)$", re.IGNORECASE),
164)
167def _is_sensitive_tool_header_name(name: str) -> bool:
168 """Return whether a tool header name should be treated as sensitive.
170 Args:
171 name: Header name to evaluate.
173 Returns:
174 ``True`` when header value should be protected.
175 """
176 normalized_name = str(name).strip().lower()
177 return any(pattern.match(normalized_name) for pattern in _SENSITIVE_TOOL_HEADER_PATTERNS)
180def _is_encrypted_tool_header_value(value: Any) -> bool:
181 """Return whether a header value uses encrypted envelope format.
183 Args:
184 value: Header value candidate.
186 Returns:
187 ``True`` when value is an encrypted envelope mapping.
188 """
189 return isinstance(value, dict) and isinstance(value.get(_ENCRYPTED_TOOL_HEADER_VALUE_KEY), str)
192def _encrypt_tool_header_value(value: Any, existing_value: Any = None) -> Any:
193 """Encrypt a single sensitive tool header value.
195 Args:
196 value: Incoming header value from payload.
197 existing_value: Existing stored value used for masked-value merges.
199 Returns:
200 Encrypted envelope, preserved existing value, or ``None`` when cleared.
201 """
202 if value is None or value == "":
203 return value
205 if value == settings.masked_auth_value:
206 if _is_encrypted_tool_header_value(existing_value):
207 return existing_value
208 if existing_value in (None, ""):
209 return None
210 return _encrypt_tool_header_value(existing_value, None)
212 if _is_encrypted_tool_header_value(value):
213 return value
215 encrypted = encode_auth({_TOOL_HEADER_DATA_KEY: str(value)})
216 return {_ENCRYPTED_TOOL_HEADER_VALUE_KEY: encrypted}
219def _protect_tool_headers_for_storage(headers: Optional[Dict[str, Any]], existing_headers: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
220 """Encrypt sensitive tool header values before persistence.
222 Args:
223 headers: Incoming tool headers payload.
224 existing_headers: Existing stored headers used for masked-value merges.
226 Returns:
227 Header mapping with sensitive values protected for storage, or ``None``.
228 """
229 if headers is None:
230 return None
231 if not isinstance(headers, dict):
232 return None
234 existing_by_lower: Dict[str, Any] = {}
235 if isinstance(existing_headers, dict):
236 for key, existing_value in existing_headers.items():
237 existing_by_lower[str(key).strip().lower()] = existing_value
239 protected: Dict[str, Any] = {}
240 for key, value in headers.items():
241 if _is_sensitive_tool_header_name(key):
242 existing_value = existing_by_lower.get(str(key).strip().lower())
243 protected[key] = _encrypt_tool_header_value(value, existing_value)
244 else:
245 protected[key] = value
246 return protected
249def _decrypt_tool_header_value(value: Any) -> Any:
250 """Decrypt a single tool header envelope when possible.
252 Args:
253 value: Stored header value, possibly encrypted.
255 Returns:
256 Decrypted plain value when envelope is valid, else original value.
257 """
258 if not _is_encrypted_tool_header_value(value):
259 return value
261 encrypted_payload = value.get(_ENCRYPTED_TOOL_HEADER_VALUE_KEY)
262 if not encrypted_payload:
263 return value
265 try:
266 decoded = decode_auth(encrypted_payload)
267 if isinstance(decoded, dict):
268 if _TOOL_HEADER_DATA_KEY in decoded:
269 return decoded[_TOOL_HEADER_DATA_KEY]
270 if _TOOL_HEADER_LEGACY_VALUE_KEY in decoded:
271 return decoded[_TOOL_HEADER_LEGACY_VALUE_KEY]
272 except Exception as exc:
273 logger.warning("Failed to decrypt tool header value: %s", exc)
274 return value
277def _decrypt_tool_headers_for_runtime(headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
278 """Decrypt tool header map for runtime outbound requests.
280 Args:
281 headers: Stored header mapping.
283 Returns:
284 Header mapping with encrypted values decrypted where possible.
285 """
286 if not isinstance(headers, dict):
287 return {}
288 return {key: _decrypt_tool_header_value(value) for key, value in headers.items()}
291def _handle_json_parse_error(response, error, is_error_response: bool = False) -> dict:
292 """Handle JSON parsing failures with graceful fallback to raw text.
294 Args:
295 response: The HTTP response object with .text attribute
296 error: The exception that was raised during JSON parsing
297 is_error_response: If True, logs as "error response", else "response"
299 Returns:
300 Dictionary with response_text key containing the raw response text
301 (truncated to REST_RESPONSE_TEXT_MAX_LENGTH if longer to avoid exposing sensitive data),
302 or error details if response body is empty/None
303 """
304 msg = "error response" if is_error_response else "response"
305 if not response.text:
306 logger.warning(f"Failed to parse JSON {msg}: {error}. Response body was empty.")
307 return {"error": "Empty response body"}
309 max_length = settings.rest_response_text_max_length
310 text = response.text[:max_length] if len(response.text) > max_length else response.text
311 if len(response.text) > max_length:
312 logger.warning(f"Failed to parse JSON {msg}: {error}. Response truncated from {len(response.text)} to {max_length} characters.")
313 else:
314 logger.warning(f"Failed to parse JSON {msg}: {error}")
315 return {"response_text": text}
318@lru_cache(maxsize=256)
319def _compile_jq_filter(jq_filter: str):
320 """Cache compiled jq filter program.
322 Args:
323 jq_filter: The jq filter string to compile.
325 Returns:
326 Compiled jq program object.
328 Raises:
329 ValueError: If the jq filter is invalid.
330 """
331 # pylint: disable=c-extension-no-member
332 return jq.compile(jq_filter)
335@lru_cache(maxsize=128)
336def _get_validator_class_and_check(schema_json: str) -> Tuple[type, dict]:
337 """Cache schema validation and validator class selection.
339 This caches the expensive operations:
340 1. Deserializing the schema
341 2. Selecting the appropriate validator class based on $schema
342 3. Checking the schema is valid
344 Supports multiple JSON Schema drafts by using fallback validators when the
345 auto-detected validator fails. This handles schemas using older draft features
346 (e.g., Draft 4 style exclusiveMinimum: true) that are invalid in newer drafts.
348 Args:
349 schema_json: Canonical JSON string of the schema (used as cache key).
351 Returns:
352 Tuple of (validator_class, schema_dict) ready for instantiation.
353 """
354 schema = orjson.loads(schema_json)
356 # First try auto-detection based on $schema
357 validator_cls = validators.validator_for(schema)
358 try:
359 validator_cls.check_schema(schema)
360 return validator_cls, schema
361 except jsonschema.exceptions.SchemaError:
362 pass
364 # Fallback: try older drafts that may accept schemas with legacy features
365 # (e.g., Draft 4/6 style boolean exclusiveMinimum/exclusiveMaximum)
366 for fallback_cls in [Draft7Validator, Draft6Validator, Draft4Validator]:
367 try:
368 fallback_cls.check_schema(schema)
369 return fallback_cls, schema
370 except jsonschema.exceptions.SchemaError:
371 continue
373 # If no validator accepts the schema, use the original and let it fail
374 # with a clear error message during validation
375 validator_cls.check_schema(schema)
376 return validator_cls, schema
379def _canonicalize_schema(schema: dict) -> str:
380 """Create a canonical JSON string of a schema for use as a cache key.
382 Args:
383 schema: The JSON Schema dictionary.
385 Returns:
386 Canonical JSON string with sorted keys.
387 """
388 return orjson.dumps(schema, option=orjson.OPT_SORT_KEYS).decode()
391def _validate_with_cached_schema(instance: Any, schema: dict) -> None:
392 """Validate instance against schema using cached validator class.
394 Creates a fresh validator instance for thread safety, but reuses
395 the cached validator class and schema check. Uses best_match to
396 preserve jsonschema.validate() error selection semantics.
398 Args:
399 instance: The data to validate.
400 schema: The JSON Schema to validate against.
402 Raises:
403 error: The best matching ValidationError from jsonschema validation.
404 jsonschema.exceptions.ValidationError: If validation fails.
405 jsonschema.exceptions.SchemaError: If the schema itself is invalid.
406 """
407 schema_json = _canonicalize_schema(schema)
408 validator_cls, checked_schema = _get_validator_class_and_check(schema_json)
409 # Create fresh validator instance for thread safety
410 validator = validator_cls(checked_schema)
411 # Use best_match to match jsonschema.validate() error selection behavior
412 error = jsonschema.exceptions.best_match(validator.iter_errors(instance))
413 if error is not None:
414 raise error
417def extract_using_jq(data, jq_filter=""):
418 """
419 Extracts data from a given input (string, dict, or list) using a jq filter string.
421 Uses cached compiled jq programs for performance.
423 Args:
424 data (str, dict, list): The input JSON data. Can be a string, dict, or list.
425 jq_filter (str): The jq filter string to extract the desired data.
427 Returns:
428 The result of applying the jq filter to the input data.
430 Examples:
431 >>> extract_using_jq('{"a": 1, "b": 2}', '.a')
432 [1]
433 >>> extract_using_jq({'a': 1, 'b': 2}, '.b')
434 [2]
435 >>> extract_using_jq('[{"a": 1}, {"a": 2}]', '.[].a')
436 [1, 2]
437 >>> extract_using_jq('not a json', '.a')
438 ['Invalid JSON string provided.']
439 >>> extract_using_jq({'a': 1}, '')
440 {'a': 1}
441 """
442 if not jq_filter or jq_filter == "":
443 return data
445 # Validate that jq_filter looks like a valid jq expression
446 jq_filter_str = str(jq_filter).strip()
447 if not jq_filter_str:
448 return data
450 # Check if it looks like an email address (common mistake when jsonpath_filter
451 # field contains corrupted data). Intentionally simple regex to avoid false
452 # positives with valid jq expressions like .foo|.bar
453 if re.match(r"^[^.\[\]|]+@[^.\[\]|]+\.[^.\[\]|]+$", jq_filter_str):
454 logger.warning(f"Invalid jq filter (email address): {jq_filter_str}. Treating as empty filter.")
455 return data
457 # Track if input was originally a string (for error handling)
458 was_string = isinstance(data, str)
460 if was_string:
461 # If the input is a string, parse it as JSON
462 try:
463 data = orjson.loads(data)
464 except orjson.JSONDecodeError:
465 return ["Invalid JSON string provided."]
466 elif not isinstance(data, (dict, list)):
467 # If the input is not a string, dict, or list, raise an error
468 return ["Input data must be a JSON string, dictionary, or list."]
470 # Apply the jq filter to the data using cached compiled program
471 try:
472 program = _compile_jq_filter(jq_filter)
473 result = program.input(data).all()
474 if result == [None]:
475 return [TextContent(type="text", text="Error applying jsonpath filter")]
476 except Exception as e:
477 message = "Error applying jsonpath filter: " + str(e)
478 return [TextContent(type="text", text=message)]
480 return result
483class ToolError(Exception):
484 """Base class for tool-related errors.
486 Examples:
487 >>> from mcpgateway.services.tool_service import ToolError
488 >>> err = ToolError("Something went wrong")
489 >>> str(err)
490 'Something went wrong'
491 """
494class ToolNotFoundError(ToolError):
495 """Raised when a requested tool is not found.
497 Examples:
498 >>> from mcpgateway.services.tool_service import ToolNotFoundError
499 >>> err = ToolNotFoundError("Tool xyz not found")
500 >>> str(err)
501 'Tool xyz not found'
502 >>> isinstance(err, ToolError)
503 True
504 """
507class ToolNameConflictError(ToolError):
508 """Raised when a tool name conflicts with existing (active or inactive) tool."""
510 def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = None, visibility: str = "public"):
511 """Initialize the error with tool information.
513 Args:
514 name: The conflicting tool name.
515 enabled: Whether the existing tool is enabled or not.
516 tool_id: ID of the existing tool if available.
517 visibility: The visibility of the tool ("public" or "team").
519 Examples:
520 >>> from mcpgateway.services.tool_service import ToolNameConflictError
521 >>> err = ToolNameConflictError('test_tool', enabled=False, tool_id=123)
522 >>> str(err)
523 'Public Tool already exists with name: test_tool (currently inactive, ID: 123)'
524 >>> err.name
525 'test_tool'
526 >>> err.enabled
527 False
528 >>> err.tool_id
529 123
530 """
531 self.name = name
532 self.enabled = enabled
533 self.tool_id = tool_id
534 if visibility == "team":
535 vis_label = "Team-level"
536 elif visibility == "private":
537 vis_label = "Private"
538 else:
539 vis_label = "Public"
540 message = f"{vis_label} Tool already exists with name: {name}"
541 if not enabled:
542 message += f" (currently inactive, ID: {tool_id})"
543 super().__init__(message)
546class ToolLockConflictError(ToolError):
547 """Raised when a tool row is locked by another transaction."""
550class ToolValidationError(ToolError):
551 """Raised when tool validation fails.
553 Examples:
554 >>> from mcpgateway.services.tool_service import ToolValidationError
555 >>> err = ToolValidationError("Invalid tool configuration")
556 >>> str(err)
557 'Invalid tool configuration'
558 >>> isinstance(err, ToolError)
559 True
560 """
563class ToolInvocationError(ToolError):
564 """Raised when tool invocation fails.
566 Examples:
567 >>> from mcpgateway.services.tool_service import ToolInvocationError
568 >>> err = ToolInvocationError("Tool execution failed")
569 >>> str(err)
570 'Tool execution failed'
571 >>> isinstance(err, ToolError)
572 True
573 >>> # Test with detailed error
574 >>> detailed_err = ToolInvocationError("Network timeout after 30 seconds")
575 >>> "timeout" in str(detailed_err)
576 True
577 >>> isinstance(err, ToolError)
578 True
579 """
582class ToolTimeoutError(ToolInvocationError):
583 """Raised when tool invocation times out.
585 This subclass is used to distinguish timeout errors from other invocation errors.
586 Timeout handlers call tool_post_invoke before raising this, so the generic exception
587 handler should skip calling post_invoke again to avoid double-counting failures.
589 Attributes:
590 retry_delay_ms: Delay in milliseconds requested by the retry plugin.
591 0 (default) means no retry. Set by the timeout handler after
592 invoking the post-invoke hook so the outer catch block can honour
593 the signal without calling post_invoke a second time.
594 """
596 def __init__(self, message: str, retry_delay_ms: int = 0) -> None:
597 """Initialise with an optional retry delay from the post-invoke hook.
599 Args:
600 message: Human-readable error description.
601 retry_delay_ms: Milliseconds the gateway should wait before retrying.
602 """
603 super().__init__(message)
604 self.retry_delay_ms = retry_delay_ms
607class ToolService(BaseService):
608 """Service for managing and invoking tools.
610 Handles:
611 - Tool registration and deregistration.
612 - Tool invocation and validation.
613 - Tool federation.
614 - Event notifications.
615 - Active/inactive tool management.
616 """
618 _visibility_model_cls = DbTool
620 def __init__(self) -> None:
621 """Initialize the tool service.
623 Examples:
624 >>> from mcpgateway.services.tool_service import ToolService
625 >>> service = ToolService()
626 >>> isinstance(service._event_service, EventService)
627 True
628 >>> hasattr(service, '_http_client')
629 True
630 """
631 self._event_service = EventService(channel_name="mcpgateway:tool_events")
632 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
633 self._plugin_manager: PluginManager | None = get_plugin_manager()
634 self.oauth_manager = OAuthManager(
635 request_timeout=int(settings.oauth_request_timeout if hasattr(settings, "oauth_request_timeout") else 30),
636 max_retries=int(settings.oauth_max_retries if hasattr(settings, "oauth_max_retries") else 3),
637 )
639 async def initialize(self) -> None:
640 """Initialize the service.
642 Examples:
643 >>> from mcpgateway.services.tool_service import ToolService
644 >>> service = ToolService()
645 >>> import asyncio
646 >>> asyncio.run(service.initialize()) # Should log "Initializing tool service"
647 """
648 logger.info("Initializing tool service")
649 await self._event_service.initialize()
651 async def shutdown(self) -> None:
652 """Shutdown the service.
654 Examples:
655 >>> from mcpgateway.services.tool_service import ToolService
656 >>> service = ToolService()
657 >>> import asyncio
658 >>> asyncio.run(service.shutdown()) # Should log "Tool service shutdown complete"
659 """
660 await self._http_client.aclose()
661 await self._event_service.shutdown()
662 logger.info("Tool service shutdown complete")
664 async def get_top_tools(self, db: Session, limit: Optional[int] = 5, include_deleted: bool = False) -> List[TopPerformer]:
665 """Retrieve the top-performing tools based on execution count.
667 Queries the database to get tools with their metrics, ordered by the number of executions
668 in descending order. Returns a list of TopPerformer objects containing tool details and
669 performance metrics. Results are cached for performance.
671 Args:
672 db (Session): Database session for querying tool metrics.
673 limit (Optional[int]): Maximum number of tools to return. Defaults to 5.
674 include_deleted (bool): Whether to include deleted tools from rollups.
676 Returns:
677 List[TopPerformer]: A list of TopPerformer objects, each containing:
678 - id: Tool ID.
679 - name: Tool name.
680 - execution_count: Total number of executions.
681 - avg_response_time: Average response time in seconds, or None if no metrics.
682 - success_rate: Success rate percentage, or None if no metrics.
683 - last_execution: Timestamp of the last execution, or None if no metrics.
684 """
685 # Check cache first (if enabled)
686 # First-Party
687 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
689 effective_limit = limit or 5
690 cache_key = f"top_tools:{effective_limit}:include_deleted={include_deleted}"
692 if is_cache_enabled():
693 cached = metrics_cache.get(cache_key)
694 if cached is not None:
695 return cached
697 # Use combined query that includes both raw metrics and rollup data
698 results = get_top_performers_combined(
699 db=db,
700 metric_type="tool",
701 entity_model=DbTool,
702 limit=effective_limit,
703 include_deleted=include_deleted,
704 )
705 top_performers = build_top_performers(results)
707 # Cache the result (if enabled)
708 if is_cache_enabled():
709 metrics_cache.set(cache_key, top_performers)
711 return top_performers
713 def _build_tool_cache_payload(self, tool: DbTool, gateway: Optional[DbGateway]) -> Dict[str, Any]:
714 """Build cache payload for tool lookup by name.
716 Args:
717 tool: Tool ORM instance.
718 gateway: Optional gateway ORM instance.
720 Returns:
721 Cache payload dict for tool lookup.
722 """
723 tool_payload = {
724 "id": str(tool.id),
725 "name": tool.name,
726 "original_name": tool.original_name,
727 "url": tool.url,
728 "description": tool.description,
729 "original_description": tool.original_description,
730 "integration_type": tool.integration_type,
731 "request_type": tool.request_type,
732 "headers": tool.headers or {},
733 "input_schema": tool.input_schema or {"type": "object", "properties": {}},
734 "output_schema": tool.output_schema,
735 "annotations": tool.annotations or {},
736 "auth_type": tool.auth_type,
737 "jsonpath_filter": tool.jsonpath_filter,
738 "custom_name": tool.custom_name,
739 "custom_name_slug": tool.custom_name_slug,
740 "display_name": tool.display_name,
741 "gateway_id": str(tool.gateway_id) if tool.gateway_id else None,
742 "enabled": bool(tool.enabled),
743 "reachable": bool(tool.reachable),
744 "tags": tool.tags or [],
745 "team_id": tool.team_id,
746 "owner_email": tool.owner_email,
747 "visibility": tool.visibility,
748 }
750 gateway_payload = None
751 if gateway:
752 gateway_payload = {
753 "id": str(gateway.id),
754 "name": gateway.name,
755 "url": gateway.url,
756 "description": gateway.description,
757 "slug": gateway.slug,
758 "transport": gateway.transport,
759 "capabilities": gateway.capabilities or {},
760 "passthrough_headers": gateway.passthrough_headers or [],
761 "auth_type": gateway.auth_type,
762 "ca_certificate": getattr(gateway, "ca_certificate", None),
763 "ca_certificate_sig": getattr(gateway, "ca_certificate_sig", None),
764 "enabled": bool(gateway.enabled),
765 "reachable": bool(gateway.reachable),
766 "team_id": gateway.team_id,
767 "owner_email": gateway.owner_email,
768 "visibility": gateway.visibility,
769 "tags": gateway.tags or [],
770 "gateway_mode": getattr(gateway, "gateway_mode", "cache"), # Gateway mode for direct proxy support
771 "client_cert": getattr(gateway, "client_cert", None),
772 "client_key": getattr(gateway, "client_key", None),
773 }
775 return {"status": "active", "tool": tool_payload, "gateway": gateway_payload}
777 def _pydantic_tool_from_payload(self, tool_payload: Dict[str, Any]) -> Optional[PydanticTool]:
778 """Build Pydantic tool metadata from cache payload.
780 Args:
781 tool_payload: Cached tool payload dict.
783 Returns:
784 Pydantic tool metadata or None if validation fails.
785 """
786 try:
787 return PydanticTool.model_validate(tool_payload)
788 except Exception as exc:
789 logger.debug("Failed to build PydanticTool from cache payload: %s", exc)
790 return None
792 def _pydantic_gateway_from_payload(self, gateway_payload: Dict[str, Any]) -> Optional[PydanticGateway]:
793 """Build Pydantic gateway metadata from cache payload.
795 Args:
796 gateway_payload: Cached gateway payload dict.
798 Returns:
799 Pydantic gateway metadata or None if validation fails.
800 """
801 try:
802 return PydanticGateway.model_validate(gateway_payload)
803 except Exception as exc:
804 logger.debug("Failed to build PydanticGateway from cache payload: %s", exc)
805 return None
807 async def _check_tool_access(
808 self,
809 db: Session,
810 tool_payload: Dict[str, Any],
811 user_email: Optional[str],
812 token_teams: Optional[List[str]],
813 ) -> bool:
814 """Check if user has access to a tool based on visibility rules.
816 Implements the same access control logic as list_tools() for consistency.
818 Access Rules:
819 - Public tools: Accessible by all authenticated users
820 - Team tools: Accessible by team members (team_id in user's teams)
821 - Private tools: Accessible only by owner (owner_email matches)
823 Args:
824 db: Database session for team membership lookup if needed.
825 tool_payload: Tool data dict with visibility, team_id, owner_email.
826 user_email: Email of the requesting user (None = unauthenticated).
827 token_teams: List of team IDs from token.
828 - None = unrestricted admin access
829 - [] = public-only token
830 - [...] = team-scoped token
832 Returns:
833 True if access is allowed, False otherwise.
834 """
835 visibility = tool_payload.get("visibility", "public")
836 tool_team_id = tool_payload.get("team_id")
837 tool_owner_email = tool_payload.get("owner_email")
839 # Public tools are accessible by everyone
840 if visibility == "public":
841 return True
843 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin
844 # This happens when is_admin=True and no team scoping in token
845 if token_teams is None and user_email is None:
846 return True
848 # No user context (but not admin) = deny access to non-public tools
849 if not user_email:
850 return False
852 # Public-only tokens (empty teams array) can ONLY access public tools
853 is_public_only_token = token_teams is not None and len(token_teams) == 0
854 if is_public_only_token:
855 return False # Already checked public above
857 # Owner can access their own private tools
858 if visibility == "private" and tool_owner_email and tool_owner_email == user_email:
859 return True
861 # Team tools: check team membership (matches list_tools behavior)
862 if tool_team_id:
863 # Use token_teams if provided, otherwise look up from DB
864 if token_teams is not None:
865 team_ids = token_teams
866 else:
867 team_service = TeamManagementService(db)
868 user_teams = await team_service.get_user_teams(user_email)
869 team_ids = [team.id for team in user_teams]
871 # Team/public visibility allows access if user is in the team
872 if visibility in ["team", "public"] and tool_team_id in team_ids:
873 return True
875 return False
877 def convert_tool_to_read(
878 self,
879 tool: DbTool,
880 include_metrics: bool = False,
881 include_auth: bool = True,
882 requesting_user_email: Optional[str] = None,
883 requesting_user_is_admin: bool = False,
884 requesting_user_team_roles: Optional[Dict[str, str]] = None,
885 ) -> ToolRead:
886 """Converts a DbTool instance into a ToolRead model, including aggregated metrics and
887 new API gateway fields: request_type and authentication credentials (masked).
889 Args:
890 tool (DbTool): The ORM instance of the tool.
891 include_metrics (bool): Whether to include metrics in the result. Defaults to False.
892 include_auth (bool): Whether to decode and include auth details. Defaults to True.
893 When False, skips expensive AES-GCM decryption and returns minimal auth info.
894 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
895 requesting_user_is_admin (bool): Whether the requester is an admin.
896 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
898 Returns:
899 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields.
900 """
901 # NOTE: This serves two purposes:
902 # 1. It determines whether to decode auth (used later)
903 # 2. It forces the tool object to lazily evaluate (required before copy)
904 has_encrypted_auth = tool.auth_type and tool.auth_value
906 # Copy the dict from the tool
907 tool_dict = tool.__dict__.copy()
908 tool_dict.pop("_sa_instance_state", None)
910 # Compute metrics in a single pass (matches server/resource/prompt service pattern)
911 if include_metrics:
912 metrics = tool.metrics_summary # Single-pass computation
913 tool_dict["metrics"] = metrics
914 tool_dict["execution_count"] = metrics["total_executions"]
915 else:
916 tool_dict["metrics"] = None
917 tool_dict["execution_count"] = None
919 tool_dict["request_type"] = tool.request_type
920 tool_dict["annotations"] = tool.annotations or {}
922 # Only decode auth if include_auth=True AND we have encrypted credentials
923 if include_auth and has_encrypted_auth:
924 decoded_auth_value = decode_auth(tool.auth_value)
925 if tool.auth_type == "basic":
926 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1])
927 username, password = decoded_bytes.decode("utf-8").split(":")
928 tool_dict["auth"] = {
929 "auth_type": "basic",
930 "username": username,
931 "password": settings.masked_auth_value if password else None,
932 }
933 elif tool.auth_type == "bearer":
934 tool_dict["auth"] = {
935 "auth_type": "bearer",
936 "token": settings.masked_auth_value if decoded_auth_value["Authorization"] else None,
937 }
938 elif tool.auth_type == "authheaders":
939 # Support multi-header format (list of {key, value} dicts)
940 if decoded_auth_value:
941 # Convert decoded dict to list format for frontend
942 auth_headers = [
943 {
944 "key": key,
945 "value": settings.masked_auth_value if value else None,
946 }
947 for key, value in decoded_auth_value.items()
948 ]
949 # Also include legacy single-header fields for backward compatibility
950 first_key = next(iter(decoded_auth_value))
951 tool_dict["auth"] = {
952 "auth_type": "authheaders",
953 "authHeaders": auth_headers, # Multi-header format (masked)
954 "auth_header_key": first_key, # Legacy format
955 "auth_header_value": settings.masked_auth_value if decoded_auth_value[first_key] else None, # Legacy format
956 }
957 else:
958 tool_dict["auth"] = None
959 else:
960 tool_dict["auth"] = None
961 elif not include_auth and has_encrypted_auth:
962 # LIST VIEW: Minimal auth info without decryption
963 # Only show auth_type for tools that have encrypted credentials
964 tool_dict["auth"] = {"auth_type": tool.auth_type}
965 else:
966 # No encrypted auth (includes OAuth tools where auth_value=None)
967 # Behavior unchanged from current implementation
968 tool_dict["auth"] = None
970 tool_dict["name"] = tool.name
971 # Handle displayName with fallback and None checks
972 display_name = getattr(tool, "display_name", None)
973 custom_name = getattr(tool, "custom_name", tool.original_name)
974 tool_dict["displayName"] = display_name or custom_name
975 tool_dict["custom_name"] = custom_name
976 tool_dict["gateway_slug"] = getattr(tool, "gateway_slug", "") or ""
977 tool_dict["custom_name_slug"] = getattr(tool, "custom_name_slug", "") or ""
978 tool_dict["tags"] = getattr(tool, "tags", []) or []
979 tool_dict["team"] = getattr(tool, "team", None)
981 # Mask custom headers unless the requester is allowed to modify this tool.
982 # Safe default: if no requester context is provided, mask everything.
983 headers = tool_dict.get("headers")
984 if headers:
985 tool_dict["headers"] = _decrypt_tool_headers_for_runtime(headers)
986 headers = tool_dict["headers"]
987 can_view = requesting_user_is_admin
988 if not can_view and getattr(tool, "owner_email", None) == requesting_user_email:
989 can_view = True
990 if (
991 not can_view
992 and getattr(tool, "visibility", None) == "team"
993 and getattr(tool, "team_id", None) is not None
994 and requesting_user_team_roles
995 and requesting_user_team_roles.get(str(tool.team_id)) == "owner"
996 ):
997 can_view = True
998 if not can_view:
999 tool_dict["headers"] = {k: settings.masked_auth_value for k in headers}
1001 return ToolRead.model_validate(tool_dict)
1003 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None:
1004 """
1005 Records a metric for a tool invocation.
1007 This function calculates the response time using the provided start time and records
1008 the metric details (including whether the invocation was successful and any error message)
1009 into the database. The metric is then committed to the database.
1011 Args:
1012 db (Session): The SQLAlchemy database session.
1013 tool (DbTool): The tool that was invoked.
1014 start_time (float): The monotonic start time of the invocation.
1015 success (bool): True if the invocation succeeded; otherwise, False.
1016 error_message (Optional[str]): The error message if the invocation failed, otherwise None.
1017 """
1018 end_time = time.monotonic()
1019 response_time = end_time - start_time
1020 metric = ToolMetric(
1021 tool_id=tool.id,
1022 response_time=response_time,
1023 is_success=success,
1024 error_message=error_message,
1025 )
1026 db.add(metric)
1027 db.commit()
1029 def _record_tool_metric_by_id(
1030 self,
1031 db: Session,
1032 tool_id: str,
1033 start_time: float,
1034 success: bool,
1035 error_message: Optional[str],
1036 ) -> None:
1037 """Record tool metric using tool ID instead of ORM object.
1039 This method is designed to be used with a fresh database session after the main
1040 request session has been released. It avoids requiring the ORM tool object,
1041 which may have been detached from the session.
1043 Args:
1044 db: A fresh database session (not the request session).
1045 tool_id: The UUID string of the tool.
1046 start_time: The monotonic start time of the invocation.
1047 success: True if the invocation succeeded; otherwise, False.
1048 error_message: The error message if the invocation failed, otherwise None.
1049 """
1050 end_time = time.monotonic()
1051 response_time = end_time - start_time
1052 metric = ToolMetric(
1053 tool_id=tool_id,
1054 response_time=response_time,
1055 is_success=success,
1056 error_message=error_message,
1057 )
1058 db.add(metric)
1059 db.commit()
1061 def _record_tool_metric_sync(
1062 self,
1063 tool_id: str,
1064 start_time: float,
1065 success: bool,
1066 error_message: Optional[str],
1067 ) -> None:
1068 """Synchronous helper to record tool metrics with its own session.
1070 This method creates a fresh database session, records the metric, and closes
1071 the session. Designed to be called via asyncio.to_thread() to avoid blocking
1072 the event loop.
1074 Args:
1075 tool_id: The UUID string of the tool.
1076 start_time: The monotonic start time of the invocation.
1077 success: True if the invocation succeeded; otherwise, False.
1078 error_message: The error message if the invocation failed, otherwise None.
1079 """
1080 with fresh_db_session() as db_metrics:
1081 self._record_tool_metric_by_id(
1082 db_metrics,
1083 tool_id=tool_id,
1084 start_time=start_time,
1085 success=success,
1086 error_message=error_message,
1087 )
1089 def _extract_and_validate_structured_content(self, tool: DbTool, tool_result: "ToolResult", candidate: Optional[Any] = None) -> bool:
1090 """
1091 Extract structured content (if any) and validate it against ``tool.output_schema``.
1093 Args:
1094 tool: The tool with an optional output schema to validate against.
1095 tool_result: The tool result containing content to validate.
1096 candidate: Optional structured payload to validate. If not provided, will attempt
1097 to parse the first TextContent item as JSON.
1099 Behavior:
1100 - If ``candidate`` is provided it is used as the structured payload to validate.
1101 - Otherwise the method will try to parse the first ``TextContent`` item in
1102 ``tool_result.content`` as JSON and use that as the candidate.
1103 - If no output schema is declared on the tool the method returns True (nothing to validate).
1104 - On successful validation the parsed value is attached to ``tool_result.structured_content``.
1105 When structured content is present and valid callers may drop textual ``content`` in favour
1106 of the structured payload.
1107 - On validation failure the method sets ``tool_result.content`` to a single ``TextContent``
1108 containing a compact JSON object describing the validation error, sets
1109 ``tool_result.is_error = True`` and returns False.
1111 Returns:
1112 True when the structured content is valid or when no schema is declared.
1113 False when validation fails.
1115 Examples:
1116 >>> from mcpgateway.services.tool_service import ToolService
1117 >>> from mcpgateway.common.models import TextContent, ToolResult
1118 >>> import json
1119 >>> service = ToolService()
1120 >>> # No schema declared -> nothing to validate
1121 >>> tool = type("T", (object,), {"output_schema": None})()
1122 >>> r = ToolResult(content=[TextContent(type="text", text='{"a":1}')])
1123 >>> service._extract_and_validate_structured_content(tool, r)
1124 True
1126 >>> # Valid candidate provided -> attaches structured_content and returns True
1127 >>> tool = type(
1128 ... "T",
1129 ... (object,),
1130 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}},
1131 ... )()
1132 >>> r = ToolResult(content=[])
1133 >>> service._extract_and_validate_structured_content(tool, r, candidate={"foo": "bar"})
1134 True
1135 >>> r.structured_content == {"foo": "bar"}
1136 True
1138 >>> # Invalid candidate -> returns False, marks result as error and emits details
1139 >>> tool = type(
1140 ... "T",
1141 ... (object,),
1142 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}},
1143 ... )()
1144 >>> r = ToolResult(content=[])
1145 >>> ok = service._extract_and_validate_structured_content(tool, r, candidate={"foo": 123})
1146 >>> ok
1147 False
1148 >>> r.is_error
1149 True
1150 >>> details = orjson.loads(r.content[0].text)
1151 >>> "received" in details
1152 True
1153 """
1154 try:
1155 output_schema = getattr(tool, "output_schema", None)
1156 # Nothing to do if the tool doesn't declare a schema
1157 if not output_schema:
1158 return True
1160 structured: Optional[Any] = None
1161 # Prefer explicit candidate
1162 if candidate is not None:
1163 structured = candidate
1164 else:
1165 # Try to parse first TextContent text payload as JSON
1166 for c in getattr(tool_result, "content", []) or []:
1167 try:
1168 if isinstance(c, dict) and "type" in c and c.get("type") == "text" and "text" in c:
1169 structured = orjson.loads(c.get("text") or "null")
1170 break
1171 except (orjson.JSONDecodeError, TypeError, ValueError):
1172 # ignore JSON parse errors and continue
1173 continue
1175 # If no structured data found, treat as valid (nothing to validate)
1176 if structured is None:
1177 return True
1179 # Try to normalize common wrapper shapes to match schema expectations
1180 schema_type = None
1181 try:
1182 if isinstance(output_schema, dict):
1183 schema_type = output_schema.get("type")
1184 except Exception:
1185 schema_type = None
1187 # Unwrap single-element list wrappers when schema expects object
1188 if isinstance(structured, list) and len(structured) == 1 and schema_type == "object":
1189 inner = structured[0]
1190 # If inner is a TextContent-like dict with 'text' JSON string, parse it
1191 if isinstance(inner, dict) and "text" in inner and "type" in inner and inner.get("type") == "text":
1192 try:
1193 structured = orjson.loads(inner.get("text") or "null")
1194 except Exception:
1195 # leave as-is if parsing fails
1196 structured = inner
1197 else:
1198 structured = inner
1200 # Attach structured content
1201 try:
1202 setattr(tool_result, "structured_content", structured)
1203 except Exception:
1204 logger.debug("Failed to set structured_content on ToolResult")
1206 # Validate using cached schema validator
1207 try:
1208 _validate_with_cached_schema(structured, output_schema)
1209 return True
1210 except jsonschema.exceptions.ValidationError as e:
1211 details = {
1212 "code": getattr(e, "validator", "validation_error"),
1213 "expected": e.schema.get("type") if isinstance(e.schema, dict) and "type" in e.schema else None,
1214 "received": type(e.instance).__name__.lower() if e.instance is not None else None,
1215 "path": list(e.absolute_path) if hasattr(e, "absolute_path") else list(e.path or []),
1216 "message": e.message,
1217 }
1218 try:
1219 tool_result.content = [TextContent(type="text", text=orjson.dumps(details).decode())]
1220 except Exception:
1221 tool_result.content = [TextContent(type="text", text=str(details))]
1222 tool_result.is_error = True
1223 logger.debug(f"structured_content validation failed for tool {getattr(tool, 'name', '<unknown>')}: {details}")
1224 return False
1225 except Exception as exc: # pragma: no cover - defensive
1226 logger.error(f"Error extracting/validating structured_content: {exc}")
1227 return False
1229 async def register_tool(
1230 self,
1231 db: Session,
1232 tool: ToolCreate,
1233 created_by: Optional[str] = None,
1234 created_from_ip: Optional[str] = None,
1235 created_via: Optional[str] = None,
1236 created_user_agent: Optional[str] = None,
1237 import_batch_id: Optional[str] = None,
1238 federation_source: Optional[str] = None,
1239 team_id: Optional[str] = None,
1240 owner_email: Optional[str] = None,
1241 visibility: str = None,
1242 ) -> ToolRead:
1243 """Register a new tool with team support.
1245 Args:
1246 db: Database session.
1247 tool: Tool creation schema.
1248 created_by: Username who created this tool.
1249 created_from_ip: IP address of creator.
1250 created_via: Creation method (ui, api, import, federation).
1251 created_user_agent: User agent of creation request.
1252 import_batch_id: UUID for bulk import operations.
1253 federation_source: Source gateway for federated tools.
1254 team_id: Optional team ID to assign tool to.
1255 owner_email: Optional owner email for tool ownership.
1256 visibility: Tool visibility (private, team, public).
1258 Returns:
1259 Created tool information.
1261 Raises:
1262 IntegrityError: If there is a database integrity error.
1263 ToolNameConflictError: If a tool with the same name and visibility public exists.
1264 ToolError: For other tool registration errors.
1266 Examples:
1267 >>> from mcpgateway.services.tool_service import ToolService
1268 >>> from unittest.mock import MagicMock, AsyncMock
1269 >>> from mcpgateway.schemas import ToolRead
1270 >>> service = ToolService()
1271 >>> db = MagicMock()
1272 >>> tool = MagicMock()
1273 >>> tool.name = 'test'
1274 >>> db.execute.return_value.scalar_one_or_none.return_value = None
1275 >>> mock_gateway = MagicMock()
1276 >>> mock_gateway.name = 'test_gateway'
1277 >>> db.add = MagicMock()
1278 >>> db.commit = MagicMock()
1279 >>> def mock_refresh(obj):
1280 ... obj.gateway = mock_gateway
1281 >>> db.refresh = MagicMock(side_effect=mock_refresh)
1282 >>> service._notify_tool_added = AsyncMock()
1283 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
1284 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
1285 >>> import asyncio
1286 >>> asyncio.run(service.register_tool(db, tool))
1287 'tool_read'
1288 """
1289 try:
1290 if tool.auth is None:
1291 auth_type = None
1292 auth_value = None
1293 else:
1294 auth_type = tool.auth.auth_type
1295 auth_value = tool.auth.auth_value
1297 if team_id is None:
1298 team_id = tool.team_id
1300 if owner_email is None:
1301 owner_email = tool.owner_email
1303 if visibility is None:
1304 visibility = tool.visibility or "public"
1305 # Check for existing tool with the same name and visibility
1306 if visibility.lower() == "public":
1307 # Check for existing public tool with the same name
1308 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "public")).scalar_one_or_none() # pylint: disable=comparison-with-callable
1309 if existing_tool:
1310 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
1311 elif visibility.lower() == "team" and team_id:
1312 # Check for existing team tool with the same name, team_id
1313 existing_tool = db.execute(
1314 select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "team", DbTool.team_id == team_id) # pylint: disable=comparison-with-callable
1315 ).scalar_one_or_none()
1316 if existing_tool:
1317 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
1319 db_tool = DbTool(
1320 original_name=tool.name,
1321 custom_name=tool.name,
1322 custom_name_slug=slugify(tool.name),
1323 display_name=tool.displayName or tool.name,
1324 title=tool.title,
1325 url=str(tool.url),
1326 description=tool.description,
1327 original_description=tool.description,
1328 integration_type=tool.integration_type,
1329 request_type=tool.request_type,
1330 headers=_protect_tool_headers_for_storage(tool.headers),
1331 input_schema=tool.input_schema,
1332 output_schema=tool.output_schema,
1333 annotations=tool.annotations,
1334 jsonpath_filter=tool.jsonpath_filter,
1335 auth_type=auth_type,
1336 auth_value=auth_value,
1337 gateway_id=tool.gateway_id,
1338 tags=tool.tags or [],
1339 # Metadata fields
1340 created_by=created_by,
1341 created_from_ip=created_from_ip,
1342 created_via=created_via,
1343 created_user_agent=created_user_agent,
1344 import_batch_id=import_batch_id,
1345 federation_source=federation_source,
1346 version=1,
1347 # Team scoping fields
1348 team_id=team_id,
1349 owner_email=owner_email or created_by,
1350 visibility=visibility,
1351 # passthrough REST tools fields
1352 base_url=tool.base_url if tool.integration_type == "REST" else None,
1353 path_template=tool.path_template if tool.integration_type == "REST" else None,
1354 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None,
1355 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None,
1356 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None,
1357 expose_passthrough=(tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None,
1358 allowlist=tool.allowlist if tool.integration_type == "REST" else None,
1359 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None,
1360 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None,
1361 )
1362 db.add(db_tool)
1363 db.commit()
1364 db.refresh(db_tool)
1365 await self._notify_tool_added(db_tool)
1367 # Structured logging: Audit trail for tool creation
1368 audit_trail.log_action(
1369 user_id=created_by or "system",
1370 action="create_tool",
1371 resource_type="tool",
1372 resource_id=db_tool.id,
1373 resource_name=db_tool.name,
1374 user_email=owner_email,
1375 team_id=team_id,
1376 client_ip=created_from_ip,
1377 user_agent=created_user_agent,
1378 new_values={
1379 "name": db_tool.name,
1380 "display_name": db_tool.display_name,
1381 "visibility": visibility,
1382 "integration_type": db_tool.integration_type,
1383 },
1384 context={
1385 "created_via": created_via,
1386 "import_batch_id": import_batch_id,
1387 "federation_source": federation_source,
1388 },
1389 db=db,
1390 )
1392 # Structured logging: Log successful tool creation
1393 structured_logger.log(
1394 level="INFO",
1395 message="Tool created successfully",
1396 event_type="tool_created",
1397 component="tool_service",
1398 user_id=created_by,
1399 user_email=owner_email,
1400 team_id=team_id,
1401 resource_type="tool",
1402 resource_id=db_tool.id,
1403 custom_fields={
1404 "tool_name": db_tool.name,
1405 "visibility": visibility,
1406 "integration_type": db_tool.integration_type,
1407 },
1408 )
1410 # Refresh db_tool after logging commits (they expire the session objects)
1411 db.refresh(db_tool)
1413 # Invalidate cache after successful creation
1414 cache = _get_registry_cache()
1415 await cache.invalidate_tools()
1416 tool_lookup_cache = _get_tool_lookup_cache()
1417 await tool_lookup_cache.invalidate(db_tool.name, gateway_id=str(db_tool.gateway_id) if db_tool.gateway_id else None)
1418 # Also invalidate tags cache since tool tags may have changed
1419 # First-Party
1420 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1422 await admin_stats_cache.invalidate_tags()
1424 return self.convert_tool_to_read(db_tool, requesting_user_email=getattr(db_tool, "owner_email", None))
1425 except IntegrityError as ie:
1426 db.rollback()
1427 logger.error(f"IntegrityError during tool registration: {ie}")
1429 # Structured logging: Log database integrity error
1430 structured_logger.log(
1431 level="ERROR",
1432 message="Tool creation failed due to database integrity error",
1433 event_type="tool_creation_failed",
1434 component="tool_service",
1435 user_id=created_by,
1436 user_email=owner_email,
1437 error=ie,
1438 custom_fields={
1439 "tool_name": tool.name,
1440 },
1441 )
1442 raise ie
1443 except ToolNameConflictError as tnce:
1444 db.rollback()
1445 logger.error(f"ToolNameConflictError during tool registration: {tnce}")
1447 # Structured logging: Log name conflict error
1448 structured_logger.log(
1449 level="WARNING",
1450 message="Tool creation failed due to name conflict",
1451 event_type="tool_name_conflict",
1452 component="tool_service",
1453 user_id=created_by,
1454 user_email=owner_email,
1455 custom_fields={
1456 "tool_name": tool.name,
1457 "visibility": visibility,
1458 },
1459 )
1460 raise tnce
1461 except Exception as e:
1462 db.rollback()
1464 # Structured logging: Log generic tool creation failure
1465 structured_logger.log(
1466 level="ERROR",
1467 message="Tool creation failed",
1468 event_type="tool_creation_failed",
1469 component="tool_service",
1470 user_id=created_by,
1471 user_email=owner_email,
1472 error=e,
1473 custom_fields={
1474 "tool_name": tool.name,
1475 },
1476 )
1477 raise ToolError(f"Failed to register tool: {str(e)}")
1479 async def register_tools_bulk(
1480 self,
1481 db: Session,
1482 tools: List[ToolCreate],
1483 created_by: Optional[str] = None,
1484 created_from_ip: Optional[str] = None,
1485 created_via: Optional[str] = None,
1486 created_user_agent: Optional[str] = None,
1487 import_batch_id: Optional[str] = None,
1488 federation_source: Optional[str] = None,
1489 team_id: Optional[str] = None,
1490 owner_email: Optional[str] = None,
1491 visibility: Optional[str] = "public",
1492 conflict_strategy: str = "skip",
1493 ) -> Dict[str, Any]:
1494 """Register multiple tools in bulk with a single commit.
1496 This method provides significant performance improvements over individual
1497 tool registration by:
1498 - Using db.add_all() instead of individual db.add() calls
1499 - Performing a single commit for all tools
1500 - Batch conflict detection
1501 - Chunking for very large imports (>500 items)
1503 Args:
1504 db: Database session
1505 tools: List of tool creation schemas
1506 created_by: Username who created these tools
1507 created_from_ip: IP address of creator
1508 created_via: Creation method (ui, api, import, federation)
1509 created_user_agent: User agent of creation request
1510 import_batch_id: UUID for bulk import operations
1511 federation_source: Source gateway for federated tools
1512 team_id: Team ID to assign the tools to
1513 owner_email: Email of the user who owns these tools
1514 visibility: Tool visibility level (private, team, public)
1515 conflict_strategy: How to handle conflicts (skip, update, rename, fail)
1517 Returns:
1518 Dict with statistics:
1519 - created: Number of tools created
1520 - updated: Number of tools updated
1521 - skipped: Number of tools skipped
1522 - failed: Number of tools that failed
1523 - errors: List of error messages
1525 Raises:
1526 ToolError: If bulk registration fails critically
1528 Examples:
1529 >>> from mcpgateway.services.tool_service import ToolService
1530 >>> from unittest.mock import MagicMock
1531 >>> service = ToolService()
1532 >>> db = MagicMock()
1533 >>> tools = [MagicMock(), MagicMock()]
1534 >>> import asyncio
1535 >>> try:
1536 ... result = asyncio.run(service.register_tools_bulk(db, tools))
1537 ... except Exception:
1538 ... pass
1539 """
1540 if not tools:
1541 return {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1543 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1545 # Process in chunks to avoid memory issues and SQLite parameter limits
1546 chunk_size = 500
1548 for chunk_start in range(0, len(tools), chunk_size):
1549 chunk = tools[chunk_start : chunk_start + chunk_size]
1550 chunk_stats = self._process_tool_chunk(
1551 db=db,
1552 chunk=chunk,
1553 conflict_strategy=conflict_strategy,
1554 visibility=visibility,
1555 team_id=team_id,
1556 owner_email=owner_email,
1557 created_by=created_by,
1558 created_from_ip=created_from_ip,
1559 created_via=created_via,
1560 created_user_agent=created_user_agent,
1561 import_batch_id=import_batch_id,
1562 federation_source=federation_source,
1563 )
1565 # Aggregate stats
1566 for key, value in chunk_stats.items():
1567 if key == "errors":
1568 stats[key].extend(value)
1569 else:
1570 stats[key] += value
1572 if chunk_stats["created"] or chunk_stats["updated"]:
1573 cache = _get_registry_cache()
1574 await cache.invalidate_tools()
1575 tool_lookup_cache = _get_tool_lookup_cache()
1576 tool_name_map: Dict[str, Optional[str]] = {}
1577 for tool in chunk:
1578 name = getattr(tool, "name", None)
1579 if not name:
1580 continue
1581 gateway_id = getattr(tool, "gateway_id", None)
1582 tool_name_map[name] = str(gateway_id) if gateway_id else tool_name_map.get(name)
1583 for tool_name, gateway_id in tool_name_map.items():
1584 await tool_lookup_cache.invalidate(tool_name, gateway_id=gateway_id)
1585 # Also invalidate tags cache since tool tags may have changed
1586 # First-Party
1587 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1589 await admin_stats_cache.invalidate_tags()
1591 return stats
1593 def _process_tool_chunk(
1594 self,
1595 db: Session,
1596 chunk: List[ToolCreate],
1597 conflict_strategy: str,
1598 visibility: str,
1599 team_id: Optional[int],
1600 owner_email: Optional[str],
1601 created_by: str,
1602 created_from_ip: Optional[str],
1603 created_via: Optional[str],
1604 created_user_agent: Optional[str],
1605 import_batch_id: Optional[str],
1606 federation_source: Optional[str],
1607 ) -> dict:
1608 """Process a chunk of tools for bulk import.
1610 Args:
1611 db: The SQLAlchemy database session.
1612 chunk: List of ToolCreate objects to process.
1613 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail").
1614 visibility: Tool visibility level ("public", "team", or "private").
1615 team_id: Team ID for team-scoped tools.
1616 owner_email: Email of the tool owner.
1617 created_by: Email of the user creating the tools.
1618 created_from_ip: IP address of the request origin.
1619 created_via: Source of the creation (e.g., "api", "ui").
1620 created_user_agent: User agent string from the request.
1621 import_batch_id: Batch identifier for bulk imports.
1622 federation_source: Source identifier for federated tools.
1624 Returns:
1625 dict: Statistics dictionary with keys "created", "updated", "skipped", "failed", and "errors".
1626 """
1627 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1629 try:
1630 # Batch check for existing tools to detect conflicts
1631 tool_names = [tool.name for tool in chunk]
1633 if visibility.lower() == "public":
1634 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "public")
1635 elif visibility.lower() == "team" and team_id:
1636 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "team", DbTool.team_id == team_id)
1637 else:
1638 # Private tools - check by owner
1639 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "private", DbTool.owner_email == (owner_email or created_by))
1641 existing_tools = db.execute(existing_tools_query).scalars().all()
1642 existing_tools_map = {tool.name: tool for tool in existing_tools}
1644 tools_to_add = []
1645 tools_to_update = []
1647 for tool in chunk:
1648 result = self._process_single_tool_for_bulk(
1649 tool=tool,
1650 existing_tools_map=existing_tools_map,
1651 conflict_strategy=conflict_strategy,
1652 visibility=visibility,
1653 team_id=team_id,
1654 owner_email=owner_email,
1655 created_by=created_by,
1656 created_from_ip=created_from_ip,
1657 created_via=created_via,
1658 created_user_agent=created_user_agent,
1659 import_batch_id=import_batch_id,
1660 federation_source=federation_source,
1661 )
1663 if result["status"] == "add":
1664 tools_to_add.append(result["tool"])
1665 stats["created"] += 1
1666 elif result["status"] == "update":
1667 tools_to_update.append(result["tool"])
1668 stats["updated"] += 1
1669 elif result["status"] == "skip":
1670 stats["skipped"] += 1
1671 elif result["status"] == "fail":
1672 stats["failed"] += 1
1673 stats["errors"].append(result["error"])
1675 # Bulk add new tools
1676 if tools_to_add:
1677 db.add_all(tools_to_add)
1679 # Commit the chunk
1680 db.commit()
1682 # Refresh tools for notifications and audit trail
1683 for db_tool in tools_to_add:
1684 db.refresh(db_tool)
1685 # Notify subscribers (sync call in async context handled by caller)
1687 # Log bulk audit trail entry
1688 if tools_to_add or tools_to_update:
1689 audit_trail.log_action(
1690 user_id=created_by or "system",
1691 action="bulk_create_tools" if tools_to_add else "bulk_update_tools",
1692 resource_type="tool",
1693 resource_id=None,
1694 details={"count": len(tools_to_add) + len(tools_to_update), "import_batch_id": import_batch_id},
1695 db=db,
1696 )
1698 except Exception as e:
1699 db.rollback()
1700 logger.error(f"Failed to process tool chunk: {str(e)}")
1701 stats["failed"] += len(chunk)
1702 stats["errors"].append(f"Chunk processing failed: {str(e)}")
1704 return stats
1706 def _process_single_tool_for_bulk(
1707 self,
1708 tool: ToolCreate,
1709 existing_tools_map: dict,
1710 conflict_strategy: str,
1711 visibility: str,
1712 team_id: Optional[int],
1713 owner_email: Optional[str],
1714 created_by: str,
1715 created_from_ip: Optional[str],
1716 created_via: Optional[str],
1717 created_user_agent: Optional[str],
1718 import_batch_id: Optional[str],
1719 federation_source: Optional[str],
1720 ) -> dict:
1721 """Process a single tool for bulk import.
1723 Args:
1724 tool: ToolCreate object to process.
1725 existing_tools_map: Dictionary mapping tool names to existing DbTool objects.
1726 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail").
1727 visibility: Tool visibility level ("public", "team", or "private").
1728 team_id: Team ID for team-scoped tools.
1729 owner_email: Email of the tool owner.
1730 created_by: Email of the user creating the tool.
1731 created_from_ip: IP address of the request origin.
1732 created_via: Source of the creation (e.g., "api", "ui").
1733 created_user_agent: User agent string from the request.
1734 import_batch_id: Batch identifier for bulk imports.
1735 federation_source: Source identifier for federated tools.
1737 Returns:
1738 dict: Result dictionary with "status" key ("add", "update", "skip", or "fail")
1739 and either "tool" (DbTool object) or "error" (error message).
1740 """
1741 try:
1742 # Extract auth information
1743 if tool.auth is None:
1744 auth_type = None
1745 auth_value = None
1746 else:
1747 auth_type = tool.auth.auth_type
1748 auth_value = tool.auth.auth_value
1750 # Use provided parameters or schema values
1751 tool_team_id = team_id if team_id is not None else getattr(tool, "team_id", None)
1752 tool_owner_email = owner_email or getattr(tool, "owner_email", None) or created_by
1753 tool_visibility = visibility if visibility is not None else (getattr(tool, "visibility", None) or "public")
1755 existing_tool = existing_tools_map.get(tool.name)
1757 if existing_tool:
1758 # Handle conflict based on strategy
1759 if conflict_strategy == "skip":
1760 return {"status": "skip"}
1761 if conflict_strategy == "update":
1762 # Update existing tool
1763 existing_tool.display_name = tool.displayName or tool.name
1764 existing_tool.title = tool.title
1765 existing_tool.url = str(tool.url)
1766 existing_tool.description = tool.description
1767 if getattr(existing_tool, "original_description", None) is None:
1768 existing_tool.original_description = tool.description
1769 existing_tool.integration_type = tool.integration_type
1770 existing_tool.request_type = tool.request_type
1771 existing_tool.headers = _protect_tool_headers_for_storage(tool.headers, existing_headers=existing_tool.headers)
1772 existing_tool.input_schema = tool.input_schema
1773 existing_tool.output_schema = tool.output_schema
1774 existing_tool.annotations = tool.annotations
1775 existing_tool.jsonpath_filter = tool.jsonpath_filter
1776 existing_tool.auth_type = auth_type
1777 existing_tool.auth_value = auth_value
1778 existing_tool.tags = tool.tags or []
1779 existing_tool.modified_by = created_by
1780 existing_tool.modified_from_ip = created_from_ip
1781 existing_tool.modified_via = created_via
1782 existing_tool.modified_user_agent = created_user_agent
1783 existing_tool.updated_at = datetime.now(timezone.utc)
1784 existing_tool.version = (existing_tool.version or 1) + 1
1786 # Update REST-specific fields if applicable
1787 if tool.integration_type == "REST":
1788 existing_tool.base_url = tool.base_url
1789 existing_tool.path_template = tool.path_template
1790 existing_tool.query_mapping = tool.query_mapping
1791 existing_tool.header_mapping = tool.header_mapping
1792 existing_tool.timeout_ms = tool.timeout_ms
1793 existing_tool.expose_passthrough = tool.expose_passthrough if tool.expose_passthrough is not None else True
1794 existing_tool.allowlist = tool.allowlist
1795 existing_tool.plugin_chain_pre = tool.plugin_chain_pre
1796 existing_tool.plugin_chain_post = tool.plugin_chain_post
1798 return {"status": "update", "tool": existing_tool}
1800 if conflict_strategy == "rename":
1801 # Create with renamed tool
1802 new_name = f"{tool.name}_imported_{int(datetime.now().timestamp())}"
1803 db_tool = self._create_tool_object(
1804 tool,
1805 new_name,
1806 auth_type,
1807 auth_value,
1808 tool_team_id,
1809 tool_owner_email,
1810 tool_visibility,
1811 created_by,
1812 created_from_ip,
1813 created_via,
1814 created_user_agent,
1815 import_batch_id,
1816 federation_source,
1817 )
1818 return {"status": "add", "tool": db_tool}
1820 if conflict_strategy == "fail":
1821 return {"status": "fail", "error": f"Tool name conflict: {tool.name}"}
1823 # Create new tool
1824 db_tool = self._create_tool_object(
1825 tool,
1826 tool.name,
1827 auth_type,
1828 auth_value,
1829 tool_team_id,
1830 tool_owner_email,
1831 tool_visibility,
1832 created_by,
1833 created_from_ip,
1834 created_via,
1835 created_user_agent,
1836 import_batch_id,
1837 federation_source,
1838 )
1839 return {"status": "add", "tool": db_tool}
1841 except Exception as e:
1842 logger.warning(f"Failed to process tool {tool.name} in bulk operation: {str(e)}")
1843 return {"status": "fail", "error": f"Failed to process tool {tool.name}: {str(e)}"}
1845 def _create_tool_object(
1846 self,
1847 tool: ToolCreate,
1848 name: str,
1849 auth_type: Optional[str],
1850 auth_value: Optional[str],
1851 tool_team_id: Optional[int],
1852 tool_owner_email: Optional[str],
1853 tool_visibility: str,
1854 created_by: str,
1855 created_from_ip: Optional[str],
1856 created_via: Optional[str],
1857 created_user_agent: Optional[str],
1858 import_batch_id: Optional[str],
1859 federation_source: Optional[str],
1860 ) -> DbTool:
1861 """Create a DbTool object from ToolCreate schema.
1863 Args:
1864 tool: ToolCreate schema object containing tool data.
1865 name: Name of the tool.
1866 auth_type: Authentication type for the tool.
1867 auth_value: Authentication value/credentials for the tool.
1868 tool_team_id: Team ID for team-scoped tools.
1869 tool_owner_email: Email of the tool owner.
1870 tool_visibility: Tool visibility level ("public", "team", or "private").
1871 created_by: Email of the user creating the tool.
1872 created_from_ip: IP address of the request origin.
1873 created_via: Source of the creation (e.g., "api", "ui").
1874 created_user_agent: User agent string from the request.
1875 import_batch_id: Batch identifier for bulk imports.
1876 federation_source: Source identifier for federated tools.
1878 Returns:
1879 DbTool: Database model instance ready to be added to the session.
1880 """
1881 return DbTool(
1882 original_name=name,
1883 custom_name=name,
1884 custom_name_slug=slugify(name),
1885 display_name=tool.displayName or name,
1886 title=tool.title,
1887 url=str(tool.url),
1888 description=tool.description,
1889 original_description=tool.description,
1890 integration_type=tool.integration_type,
1891 request_type=tool.request_type,
1892 headers=_protect_tool_headers_for_storage(tool.headers),
1893 input_schema=tool.input_schema,
1894 output_schema=tool.output_schema,
1895 annotations=tool.annotations,
1896 jsonpath_filter=tool.jsonpath_filter,
1897 auth_type=auth_type,
1898 auth_value=auth_value,
1899 gateway_id=tool.gateway_id,
1900 tags=tool.tags or [],
1901 created_by=created_by,
1902 created_from_ip=created_from_ip,
1903 created_via=created_via,
1904 created_user_agent=created_user_agent,
1905 import_batch_id=import_batch_id,
1906 federation_source=federation_source,
1907 version=1,
1908 team_id=tool_team_id,
1909 owner_email=tool_owner_email,
1910 visibility=tool_visibility,
1911 base_url=tool.base_url if tool.integration_type == "REST" else None,
1912 path_template=tool.path_template if tool.integration_type == "REST" else None,
1913 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None,
1914 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None,
1915 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None,
1916 expose_passthrough=((tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None),
1917 allowlist=tool.allowlist if tool.integration_type == "REST" else None,
1918 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None,
1919 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None,
1920 )
1922 async def list_tools(
1923 self,
1924 db: Session,
1925 include_inactive: bool = False,
1926 cursor: Optional[str] = None,
1927 tags: Optional[List[str]] = None,
1928 gateway_id: Optional[str] = None,
1929 limit: Optional[int] = None,
1930 page: Optional[int] = None,
1931 per_page: Optional[int] = None,
1932 user_email: Optional[str] = None,
1933 team_id: Optional[str] = None,
1934 visibility: Optional[str] = None,
1935 token_teams: Optional[List[str]] = None,
1936 _request_headers: Optional[Dict[str, str]] = None,
1937 requesting_user_email: Optional[str] = None,
1938 requesting_user_is_admin: bool = False,
1939 requesting_user_team_roles: Optional[Dict[str, str]] = None,
1940 ) -> Union[tuple[List[ToolRead], Optional[str]], Dict[str, Any]]:
1941 """
1942 Retrieve a list of registered tools from the database with pagination support.
1944 Args:
1945 db (Session): The SQLAlchemy database session.
1946 include_inactive (bool): If True, include inactive tools in the result.
1947 Defaults to False.
1948 cursor (Optional[str], optional): An opaque cursor token for pagination.
1949 Opaque base64-encoded string containing last item's ID.
1950 tags (Optional[List[str]]): Filter tools by tags. If provided, only tools with at least one matching tag will be returned.
1951 gateway_id (Optional[str]): Filter tools by gateway ID. Accepts the literal value 'null' to match NULL gateway_id.
1952 limit (Optional[int]): Maximum number of tools to return. Use 0 for all tools (no limit).
1953 If not specified, uses pagination_default_page_size.
1954 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1955 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1956 user_email (Optional[str]): User email for team-based access control. If None, no access control is applied.
1957 team_id (Optional[str]): Filter by specific team ID. Requires user_email for access validation.
1958 visibility (Optional[str]): Filter by visibility (private, team, public).
1959 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API token access
1960 where the token scope should be respected instead of the user's full team memberships.
1961 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
1962 Currently unused but kept for API consistency. Defaults to None.
1963 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
1964 requesting_user_is_admin (bool): Whether the requester is an admin.
1965 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
1967 Returns:
1968 tuple[List[ToolRead], Optional[str]]: Tuple containing:
1969 - List of tools for current page
1970 - Next cursor token if more results exist, None otherwise
1972 Examples:
1973 >>> from mcpgateway.services.tool_service import ToolService
1974 >>> from unittest.mock import MagicMock
1975 >>> service = ToolService()
1976 >>> db = MagicMock()
1977 >>> tool_read = MagicMock()
1978 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read)
1979 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
1980 >>> import asyncio
1981 >>> tools, next_cursor = asyncio.run(service.list_tools(db))
1982 >>> isinstance(tools, list)
1983 True
1984 """
1985 with create_span(
1986 "tool.list",
1987 {
1988 "include_inactive": include_inactive,
1989 "tags.count": len(tags) if tags else 0,
1990 "gateway_id": gateway_id,
1991 "limit": limit,
1992 "page": page,
1993 "per_page": per_page,
1994 "user.email": user_email,
1995 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None,
1996 "team.filter": team_id,
1997 "visibility": visibility,
1998 },
1999 ):
2000 # Check cache for first page only (cursor=None)
2001 # Skip caching when:
2002 # - user_email is provided (team-filtered results are user-specific)
2003 # - token_teams is set (scoped access, e.g., public-only or team-scoped tokens)
2004 # - page-based pagination is used
2005 # This prevents cache poisoning where admin results could leak to public-only requests
2006 cache = _get_registry_cache()
2007 filters_hash = None
2008 # Only use the cache when using the real converter. In unit tests we often patch
2009 # convert_tool_to_read() to exercise error handling, and a warm cache would bypass it.
2010 try:
2011 converter_is_default = self.convert_tool_to_read.__func__ is ToolService.convert_tool_to_read # type: ignore[attr-defined]
2012 except Exception:
2013 converter_is_default = False
2015 if cursor is None and user_email is None and token_teams is None and page is None and converter_is_default:
2016 # Include visibility in the cache hash so admin requests that include
2017 # an explicit visibility filter don't get served stale results from
2018 # a previously cached unfiltered admin request.
2019 filters_hash = cache.hash_filters(
2020 include_inactive=include_inactive,
2021 tags=sorted(tags) if tags else None,
2022 gateway_id=gateway_id,
2023 limit=limit,
2024 visibility=visibility,
2025 )
2026 cached = await cache.get("tools", filters_hash)
2027 if cached is not None:
2028 # Reconstruct ToolRead objects from cached dicts
2029 cached_tools = [ToolRead.model_validate(t) for t in cached["tools"]]
2030 return (cached_tools, cached.get("next_cursor"))
2032 # Build base query with ordering and eager load gateway + email_team to avoid N+1
2033 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)).order_by(desc(DbTool.created_at), desc(DbTool.id))
2035 # Apply active/inactive filter
2036 if not include_inactive:
2037 query = query.where(DbTool.enabled)
2038 query = await self._apply_access_control(query, db, user_email, token_teams, team_id)
2040 if visibility:
2041 query = query.where(DbTool.visibility == visibility)
2043 # Add gateway_id filtering if provided
2044 if gateway_id:
2045 if gateway_id.lower() == "null":
2046 query = query.where(DbTool.gateway_id.is_(None))
2047 else:
2048 query = query.where(DbTool.gateway_id == gateway_id)
2050 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
2051 if tags:
2052 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True))
2054 # Use unified pagination helper - handles both page and cursor pagination
2055 pag_result = await unified_paginate(
2056 db=db,
2057 query=query,
2058 page=page,
2059 per_page=per_page,
2060 cursor=cursor,
2061 limit=limit,
2062 base_url="/admin/tools", # Used for page-based links
2063 query_params={"include_inactive": include_inactive} if include_inactive else {},
2064 )
2066 next_cursor = None
2067 # Extract servers based on pagination type
2068 if page is not None:
2069 # Page-based: pag_result is a dict
2070 tools_db = pag_result["data"]
2071 else:
2072 # Cursor-based: pag_result is a tuple
2073 tools_db, next_cursor = pag_result
2075 db.commit() # Release transaction to avoid idle-in-transaction
2077 # Convert to ToolRead (common for both pagination types)
2078 # Team names are loaded via joinedload(DbTool.email_team)
2079 result = []
2080 for s in tools_db:
2081 try:
2082 result.append(
2083 self.convert_tool_to_read(
2084 s,
2085 include_metrics=False,
2086 include_auth=False,
2087 requesting_user_email=requesting_user_email,
2088 requesting_user_is_admin=requesting_user_is_admin,
2089 requesting_user_team_roles=requesting_user_team_roles,
2090 )
2091 )
2092 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2093 logger.exception(f"Failed to convert tool {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
2094 # Continue with remaining tools instead of failing completely
2096 # Return appropriate format based on pagination type
2097 if page is not None:
2098 # Page-based format
2099 return {
2100 "data": result,
2101 "pagination": pag_result["pagination"],
2102 "links": pag_result["links"],
2103 }
2105 # Cursor-based format
2107 # Cache first page results - only for non-user-specific/non-scoped queries
2108 # Must match the same conditions as cache lookup to prevent cache poisoning
2109 if filters_hash is not None and cursor is None and user_email is None and token_teams is None and page is None and converter_is_default:
2110 try:
2111 cache_data = {"tools": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
2112 await cache.set("tools", cache_data, filters_hash)
2113 except AttributeError:
2114 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
2116 return (result, next_cursor)
2118 async def list_server_tools(
2119 self,
2120 db: Session,
2121 server_id: str,
2122 include_inactive: bool = False,
2123 include_metrics: bool = False,
2124 cursor: Optional[str] = None,
2125 user_email: Optional[str] = None,
2126 token_teams: Optional[List[str]] = None,
2127 _request_headers: Optional[Dict[str, str]] = None,
2128 requesting_user_email: Optional[str] = None,
2129 requesting_user_is_admin: bool = False,
2130 requesting_user_team_roles: Optional[Dict[str, str]] = None,
2131 ) -> List[ToolRead]:
2132 """
2133 Retrieve a list of registered tools from the database.
2135 Args:
2136 db (Session): The SQLAlchemy database session.
2137 server_id (str): Server ID
2138 include_inactive (bool): If True, include inactive tools in the result.
2139 Defaults to False.
2140 include_metrics (bool): If True, all tool metrics included in result otherwise null.
2141 Defaults to False.
2142 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
2143 this parameter is ignored. Defaults to None.
2144 user_email (Optional[str]): User email for visibility filtering. If None, no filtering applied.
2145 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API
2146 token access where the token scope should be respected.
2147 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
2148 Currently unused but kept for API consistency. Defaults to None.
2149 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
2150 requesting_user_is_admin (bool): Whether the requester is an admin.
2151 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
2153 Returns:
2154 List[ToolRead]: A list of registered tools represented as ToolRead objects.
2156 Examples:
2157 >>> from mcpgateway.services.tool_service import ToolService
2158 >>> from unittest.mock import MagicMock
2159 >>> service = ToolService()
2160 >>> db = MagicMock()
2161 >>> tool_read = MagicMock()
2162 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read)
2163 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
2164 >>> import asyncio
2165 >>> result = asyncio.run(service.list_server_tools(db, 'server1'))
2166 >>> isinstance(result, list)
2167 True
2168 """
2170 with create_span(
2171 "tool.list",
2172 {
2173 "server_id": server_id,
2174 "include_inactive": include_inactive,
2175 "include_metrics": include_metrics,
2176 "user.email": user_email,
2177 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None,
2178 },
2179 ):
2180 if include_metrics:
2181 query = (
2182 select(DbTool)
2183 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2184 .options(selectinload(DbTool.metrics))
2185 .options(selectinload(DbTool.metrics_hourly))
2186 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
2187 .where(server_tool_association.c.server_id == server_id)
2188 )
2189 else:
2190 query = (
2191 select(DbTool)
2192 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2193 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
2194 .where(server_tool_association.c.server_id == server_id)
2195 )
2197 cursor = None # Placeholder for pagination; ignore for now
2198 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}")
2200 if not include_inactive:
2201 query = query.where(DbTool.enabled)
2203 # Add visibility filtering if user context OR token_teams provided
2204 # This ensures unauthenticated requests with token_teams=[] only see public tools
2205 if user_email is not None or token_teams is not None: # empty-string user_email -> public-only filtering (secure default)
2206 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB
2207 if token_teams is not None:
2208 team_ids = token_teams
2209 elif user_email:
2210 team_service = TeamManagementService(db)
2211 user_teams = await team_service.get_user_teams(user_email)
2212 team_ids = [team.id for team in user_teams]
2213 else:
2214 team_ids = []
2216 # Check if this is a public-only token (empty teams array)
2217 # Public-only tokens can ONLY see public resources - no owner access
2218 is_public_only_token = token_teams is not None and len(token_teams) == 0
2220 access_conditions = [
2221 DbTool.visibility == "public",
2222 ]
2223 # Only include owner access for non-public-only tokens with user_email
2224 if not is_public_only_token and user_email:
2225 access_conditions.append(DbTool.owner_email == user_email)
2226 if team_ids:
2227 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2228 query = query.where(or_(*access_conditions))
2230 # Execute the query - team names are loaded via joinedload(DbTool.email_team)
2231 tools = db.execute(query).scalars().all()
2233 db.commit() # Release transaction to avoid idle-in-transaction
2235 result = []
2236 for tool in tools:
2237 try:
2238 result.append(
2239 self.convert_tool_to_read(
2240 tool,
2241 include_metrics=include_metrics,
2242 include_auth=False,
2243 requesting_user_email=requesting_user_email,
2244 requesting_user_is_admin=requesting_user_is_admin,
2245 requesting_user_team_roles=requesting_user_team_roles,
2246 )
2247 )
2248 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2249 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}")
2250 # Continue with remaining tools instead of failing completely
2252 return result
2254 async def list_server_mcp_tool_definitions(
2255 self,
2256 db: Session,
2257 server_id: str,
2258 *,
2259 include_inactive: bool = False,
2260 user_email: Optional[str] = None,
2261 token_teams: Optional[List[str]] = None,
2262 ) -> List[Dict[str, Any]]:
2263 """Return server-scoped MCP tool definitions without building full ToolRead models.
2265 This is a hot-path helper for the internal Rust -> Python seam. It keeps
2266 auth and visibility semantics aligned with ``list_server_tools`` while
2267 avoiding the heavier ``ToolRead`` conversion that is only needed for the
2268 admin/API surfaces.
2270 Args:
2271 db: Active database session.
2272 server_id: Virtual server identifier used to scope the tool listing.
2273 include_inactive: Whether disabled tools should be included.
2274 user_email: Requester email for owner-scoped visibility checks.
2275 token_teams: Normalized team scope from the caller token.
2277 Returns:
2278 A list of MCP-compatible tool definition dictionaries.
2279 """
2280 with create_span(
2281 "tool.list",
2282 {
2283 "server_id": server_id,
2284 "include_inactive": include_inactive,
2285 "user.email": user_email,
2286 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None,
2287 "mcp.definition_mode": True,
2288 },
2289 ):
2290 name_column = DbTool.__table__.c.name
2291 query = (
2292 select(
2293 name_column.label("name"),
2294 DbTool.description.label("description"),
2295 DbTool.input_schema.label("input_schema"),
2296 DbTool.output_schema.label("output_schema"),
2297 DbTool.annotations.label("annotations"),
2298 DbTool.owner_email.label("owner_email"),
2299 DbTool.team_id.label("team_id"),
2300 DbTool.visibility.label("visibility"),
2301 )
2302 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
2303 .where(server_tool_association.c.server_id == server_id)
2304 )
2306 if not include_inactive:
2307 query = query.where(DbTool.enabled)
2309 if user_email is not None or token_teams is not None:
2310 team_ids = token_teams if token_teams is not None else []
2311 is_public_only_token = token_teams is not None and len(token_teams) == 0
2313 access_conditions = [DbTool.visibility == "public"]
2314 if not is_public_only_token and user_email:
2315 access_conditions.append(DbTool.owner_email == user_email)
2316 if team_ids:
2317 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2318 query = query.where(or_(*access_conditions))
2320 rows = db.execute(query).mappings().all()
2321 db.commit()
2323 result: List[Dict[str, Any]] = []
2324 for row in rows:
2325 payload: Dict[str, Any] = {
2326 "name": row["name"],
2327 "description": row["description"],
2328 "inputSchema": row["input_schema"] or {"type": "object", "properties": {}},
2329 "annotations": row["annotations"] or {},
2330 }
2331 if row["output_schema"] is not None:
2332 payload["outputSchema"] = row["output_schema"]
2333 result.append(payload)
2335 return result
2337 async def list_tools_for_user(
2338 self,
2339 db: Session,
2340 user_email: str,
2341 team_id: Optional[str] = None,
2342 visibility: Optional[str] = None,
2343 include_inactive: bool = False,
2344 _skip: int = 0,
2345 _limit: int = 100,
2346 *,
2347 cursor: Optional[str] = None,
2348 gateway_id: Optional[str] = None,
2349 tags: Optional[List[str]] = None,
2350 limit: Optional[int] = None,
2351 ) -> tuple[List[ToolRead], Optional[str]]:
2352 """
2353 DEPRECATED: Use list_tools() with user_email parameter instead.
2355 List tools user has access to with team filtering and cursor pagination.
2357 This method is maintained for backward compatibility but is no longer used.
2358 New code should call list_tools() with user_email, team_id, and visibility parameters.
2360 Args:
2361 db: Database session
2362 user_email: Email of the user requesting tools
2363 team_id: Optional team ID to filter by specific team
2364 visibility: Optional visibility filter (private, team, public)
2365 include_inactive: Whether to include inactive tools
2366 _skip: Number of tools to skip for pagination (deprecated)
2367 _limit: Maximum number of tools to return (deprecated)
2368 cursor: Opaque cursor token for pagination
2369 gateway_id: Filter tools by gateway ID. Accepts literal 'null' for NULL gateway_id.
2370 tags: Filter tools by tags (match any)
2371 limit: Maximum number of tools to return. Use 0 for all tools (no limit).
2372 If not specified, uses pagination_default_page_size.
2374 Returns:
2375 tuple[List[ToolRead], Optional[str]]: Tools the user has access to and optional next_cursor
2376 """
2377 # Determine page size based on limit parameter
2378 # limit=None: use default, limit=0: no limit (all), limit>0: use specified (capped)
2379 if limit is None:
2380 page_size = settings.pagination_default_page_size
2381 elif limit == 0:
2382 page_size = None # No limit - fetch all
2383 else:
2384 page_size = min(limit, settings.pagination_max_page_size)
2386 # Decode cursor to get last_id if provided
2387 last_id = None
2388 if cursor:
2389 try:
2390 cursor_data = decode_cursor(cursor)
2391 last_id = cursor_data.get("id")
2392 logger.debug(f"Decoded cursor: last_id={last_id}")
2393 except ValueError as e:
2394 logger.warning(f"Invalid cursor, ignoring: {e}")
2396 # Build query following existing patterns from list_tools()
2397 team_service = TeamManagementService(db)
2398 user_teams = await team_service.get_user_teams(user_email)
2399 team_ids = [team.id for team in user_teams]
2401 # Eager load gateway and email_team to avoid N+1 when accessing gateway_slug and team name
2402 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2404 # Apply active/inactive filter
2405 if not include_inactive:
2406 query = query.where(DbTool.enabled.is_(True))
2408 if team_id:
2409 if team_id not in team_ids:
2410 return ([], None) # No access to team
2412 access_conditions = [
2413 and_(DbTool.team_id == team_id, DbTool.visibility.in_(["team", "public"])),
2414 and_(DbTool.team_id == team_id, DbTool.owner_email == user_email),
2415 ]
2416 query = query.where(or_(*access_conditions))
2417 else:
2418 access_conditions = [
2419 DbTool.owner_email == user_email,
2420 DbTool.visibility == "public",
2421 ]
2422 if team_ids:
2423 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2425 query = query.where(or_(*access_conditions))
2427 # Apply visibility filter if specified
2428 if visibility:
2429 query = query.where(DbTool.visibility == visibility)
2431 if gateway_id:
2432 if gateway_id.lower() == "null":
2433 query = query.where(DbTool.gateway_id.is_(None))
2434 else:
2435 query = query.where(DbTool.gateway_id == gateway_id)
2437 if tags:
2438 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True))
2440 # Apply cursor filter (WHERE id > last_id)
2441 if last_id:
2442 query = query.where(DbTool.id > last_id)
2444 # Execute query - team names are loaded via joinedload(DbTool.email_team)
2445 if page_size is not None:
2446 tools = db.execute(query.limit(page_size + 1)).scalars().all()
2447 else:
2448 tools = db.execute(query).scalars().all()
2450 db.commit() # Release transaction to avoid idle-in-transaction
2452 # Check if there are more results (only when paginating)
2453 has_more = page_size is not None and len(tools) > page_size
2454 if has_more:
2455 tools = tools[:page_size]
2457 # Convert to ToolRead objects
2458 result = []
2459 for tool in tools:
2460 try:
2461 result.append(self.convert_tool_to_read(tool, include_metrics=False, include_auth=False, requesting_user_email=user_email, requesting_user_is_admin=False))
2462 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2463 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}")
2464 # Continue with remaining tools instead of failing completely
2466 next_cursor = None
2467 # Generate cursor if there are more results (cursor-based pagination)
2468 if has_more and tools:
2469 last_tool = tools[-1]
2470 next_cursor = encode_cursor({"created_at": last_tool.created_at.isoformat(), "id": last_tool.id})
2472 return (result, next_cursor)
2474 async def get_tool(
2475 self,
2476 db: Session,
2477 tool_id: str,
2478 requesting_user_email: Optional[str] = None,
2479 requesting_user_is_admin: bool = False,
2480 requesting_user_team_roles: Optional[Dict[str, str]] = None,
2481 ) -> ToolRead:
2482 """
2483 Retrieve a tool by its ID.
2485 Args:
2486 db (Session): The SQLAlchemy database session.
2487 tool_id (str): The unique identifier of the tool.
2488 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
2489 requesting_user_is_admin (bool): Whether the requester is an admin.
2490 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
2492 Returns:
2493 ToolRead: The tool object.
2495 Raises:
2496 ToolNotFoundError: If the tool is not found.
2498 Examples:
2499 >>> from mcpgateway.services.tool_service import ToolService
2500 >>> from unittest.mock import MagicMock
2501 >>> service = ToolService()
2502 >>> db = MagicMock()
2503 >>> tool = MagicMock()
2504 >>> db.get.return_value = tool
2505 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
2506 >>> import asyncio
2507 >>> asyncio.run(service.get_tool(db, 'tool_id'))
2508 'tool_read'
2509 """
2510 tool = db.get(DbTool, tool_id)
2511 if not tool:
2512 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2514 tool_read = self.convert_tool_to_read(
2515 tool,
2516 requesting_user_email=requesting_user_email,
2517 requesting_user_is_admin=requesting_user_is_admin,
2518 requesting_user_team_roles=requesting_user_team_roles,
2519 )
2521 structured_logger.log(
2522 level="INFO",
2523 message="Tool retrieved successfully",
2524 event_type="tool_viewed",
2525 component="tool_service",
2526 team_id=getattr(tool, "team_id", None),
2527 resource_type="tool",
2528 resource_id=str(tool.id),
2529 custom_fields={
2530 "tool_name": tool.name,
2531 "include_metrics": bool(getattr(tool_read, "metrics", {})),
2532 },
2533 )
2535 return tool_read
2537 async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
2538 """
2539 Delete a tool by its ID.
2541 Args:
2542 db (Session): The SQLAlchemy database session.
2543 tool_id (str): The unique identifier of the tool.
2544 user_email (Optional[str]): Email of user performing delete (for ownership check).
2545 purge_metrics (bool): If True, delete raw + rollup metrics for this tool.
2547 Raises:
2548 ToolNotFoundError: If the tool is not found.
2549 PermissionError: If user doesn't own the tool.
2550 ToolError: For other deletion errors.
2552 Examples:
2553 >>> from mcpgateway.services.tool_service import ToolService
2554 >>> from unittest.mock import MagicMock, AsyncMock
2555 >>> service = ToolService()
2556 >>> db = MagicMock()
2557 >>> tool = MagicMock()
2558 >>> db.get.return_value = tool
2559 >>> db.delete = MagicMock()
2560 >>> db.commit = MagicMock()
2561 >>> service._notify_tool_deleted = AsyncMock()
2562 >>> import asyncio
2563 >>> asyncio.run(service.delete_tool(db, 'tool_id'))
2564 """
2565 try:
2566 tool = db.get(DbTool, tool_id)
2567 if not tool:
2568 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2570 # Check ownership if user_email provided
2571 if user_email:
2572 # First-Party
2573 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2575 permission_service = PermissionService(db)
2576 if not await permission_service.check_resource_ownership(user_email, tool):
2577 raise PermissionError("Only the owner can delete this tool")
2579 tool_info = {"id": tool.id, "name": tool.name}
2580 tool_name = tool.name
2581 tool_team_id = tool.team_id
2583 if purge_metrics:
2584 with pause_rollup_during_purge(reason=f"purge_tool:{tool_id}"):
2585 delete_metrics_in_batches(db, ToolMetric, ToolMetric.tool_id, tool_id)
2586 delete_metrics_in_batches(db, ToolMetricsHourly, ToolMetricsHourly.tool_id, tool_id)
2588 # Use DELETE with rowcount check for database-agnostic atomic delete
2589 stmt = delete(DbTool).where(DbTool.id == tool_id)
2590 result = db.execute(stmt)
2591 if result.rowcount == 0:
2592 # Tool was already deleted by another concurrent request
2593 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2595 db.commit()
2596 await self._notify_tool_deleted(tool_info)
2597 logger.info(f"Permanently deleted tool: {tool_info['name']}")
2599 # Structured logging: Audit trail for tool deletion
2600 audit_trail.log_action(
2601 user_id=user_email or "system",
2602 action="delete_tool",
2603 resource_type="tool",
2604 resource_id=tool_info["id"],
2605 resource_name=tool_name,
2606 user_email=user_email,
2607 team_id=tool_team_id,
2608 old_values={
2609 "name": tool_name,
2610 },
2611 db=db,
2612 )
2614 # Structured logging: Log successful tool deletion
2615 structured_logger.log(
2616 level="INFO",
2617 message="Tool deleted successfully",
2618 event_type="tool_deleted",
2619 component="tool_service",
2620 user_email=user_email,
2621 team_id=tool_team_id,
2622 resource_type="tool",
2623 resource_id=tool_info["id"],
2624 custom_fields={
2625 "tool_name": tool_name,
2626 "purge_metrics": purge_metrics,
2627 },
2628 )
2630 # Invalidate cache after successful deletion
2631 cache = _get_registry_cache()
2632 await cache.invalidate_tools()
2633 tool_lookup_cache = _get_tool_lookup_cache()
2634 await tool_lookup_cache.invalidate(tool_name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
2635 # Also invalidate tags cache since tool tags may have changed
2636 # First-Party
2637 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2639 await admin_stats_cache.invalidate_tags()
2640 # Invalidate top performers cache
2641 # First-Party
2642 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
2644 metrics_cache.invalidate_prefix("top_tools:")
2645 metrics_cache.invalidate("tools")
2646 except PermissionError as pe:
2647 db.rollback()
2649 # Structured logging: Log permission error
2650 structured_logger.log(
2651 level="WARNING",
2652 message="Tool deletion failed due to permission error",
2653 event_type="tool_delete_permission_denied",
2654 component="tool_service",
2655 user_email=user_email,
2656 resource_type="tool",
2657 resource_id=tool_id,
2658 error=pe,
2659 )
2660 raise
2661 except Exception as e:
2662 db.rollback()
2664 # Structured logging: Log generic tool deletion failure
2665 structured_logger.log(
2666 level="ERROR",
2667 message="Tool deletion failed",
2668 event_type="tool_deletion_failed",
2669 component="tool_service",
2670 user_email=user_email,
2671 resource_type="tool",
2672 resource_id=tool_id,
2673 error=e,
2674 )
2675 raise ToolError(f"Failed to delete tool: {str(e)}")
2677 async def set_tool_state(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead:
2678 """
2679 Set the activation status of a tool.
2681 Args:
2682 db (Session): The SQLAlchemy database session.
2683 tool_id (str): The unique identifier of the tool.
2684 activate (bool): True to activate, False to deactivate.
2685 reachable (bool): True if the tool is reachable.
2686 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2687 skip_cache_invalidation: If True, skip cache invalidation (used for batch operations).
2689 Returns:
2690 ToolRead: The updated tool object.
2692 Raises:
2693 ToolNotFoundError: If the tool is not found.
2694 ToolLockConflictError: If the tool row is locked by another transaction.
2695 ToolError: For other errors.
2696 PermissionError: If user doesn't own the agent.
2698 Examples:
2699 >>> from mcpgateway.services.tool_service import ToolService
2700 >>> from unittest.mock import MagicMock, AsyncMock
2701 >>> from mcpgateway.schemas import ToolRead
2702 >>> service = ToolService()
2703 >>> db = MagicMock()
2704 >>> tool = MagicMock()
2705 >>> db.get.return_value = tool
2706 >>> db.commit = MagicMock()
2707 >>> db.refresh = MagicMock()
2708 >>> service._notify_tool_activated = AsyncMock()
2709 >>> service._notify_tool_deactivated = AsyncMock()
2710 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
2711 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
2712 >>> import asyncio
2713 >>> asyncio.run(service.set_tool_state(db, 'tool_id', True, True))
2714 'tool_read'
2715 """
2716 try:
2717 # Use nowait=True to fail fast if row is locked, preventing lock contention under high load
2718 try:
2719 tool = get_for_update(db, DbTool, tool_id, nowait=True)
2720 except OperationalError as lock_err:
2721 # Row is locked by another transaction - fail fast with 409
2722 db.rollback()
2723 raise ToolLockConflictError(f"Tool {tool_id} is currently being modified by another request") from lock_err
2724 if not tool:
2725 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2727 if user_email:
2728 # First-Party
2729 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2731 permission_service = PermissionService(db)
2732 if not await permission_service.check_resource_ownership(user_email, tool):
2733 raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool")
2735 is_activated = is_reachable = False
2736 if tool.enabled != activate:
2737 tool.enabled = activate
2738 is_activated = True
2740 if tool.reachable != reachable:
2741 tool.reachable = reachable
2742 is_reachable = True
2744 if is_activated or is_reachable:
2745 tool.updated_at = datetime.now(timezone.utc)
2747 db.commit()
2748 db.refresh(tool)
2750 # Invalidate cache after status change (skip for batch operations)
2751 if not skip_cache_invalidation:
2752 cache = _get_registry_cache()
2753 await cache.invalidate_tools()
2754 tool_lookup_cache = _get_tool_lookup_cache()
2755 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
2757 if not tool.enabled:
2758 # Inactive
2759 await self._notify_tool_deactivated(tool)
2760 elif tool.enabled and not tool.reachable:
2761 # Offline
2762 await self._notify_tool_offline(tool)
2763 else:
2764 # Active
2765 await self._notify_tool_activated(tool)
2767 logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}")
2769 # Structured logging: Audit trail for tool state change
2770 audit_trail.log_action(
2771 user_id=user_email or "system",
2772 action="set_tool_state",
2773 resource_type="tool",
2774 resource_id=tool.id,
2775 resource_name=tool.name,
2776 user_email=user_email,
2777 team_id=tool.team_id,
2778 new_values={
2779 "enabled": tool.enabled,
2780 "reachable": tool.reachable,
2781 },
2782 context={
2783 "action": "activate" if activate else "deactivate",
2784 },
2785 db=db,
2786 )
2788 # Structured logging: Log successful tool state change
2789 structured_logger.log(
2790 level="INFO",
2791 message=f"Tool {'activated' if activate else 'deactivated'} successfully",
2792 event_type="tool_state_changed",
2793 component="tool_service",
2794 user_email=user_email,
2795 team_id=tool.team_id,
2796 resource_type="tool",
2797 resource_id=tool.id,
2798 custom_fields={
2799 "tool_name": tool.name,
2800 "enabled": tool.enabled,
2801 "reachable": tool.reachable,
2802 },
2803 )
2805 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None))
2806 except PermissionError as e:
2807 # Structured logging: Log permission error
2808 structured_logger.log(
2809 level="WARNING",
2810 message="Tool state change failed due to permission error",
2811 event_type="tool_state_change_permission_denied",
2812 component="tool_service",
2813 user_email=user_email,
2814 resource_type="tool",
2815 resource_id=tool_id,
2816 error=e,
2817 )
2818 raise e
2819 except ToolLockConflictError:
2820 # Re-raise lock conflicts without wrapping - allows 409 response
2821 raise
2822 except ToolNotFoundError:
2823 # Re-raise not found without wrapping - allows 404 response
2824 raise
2825 except Exception as e:
2826 db.rollback()
2828 # Structured logging: Log generic tool state change failure
2829 structured_logger.log(
2830 level="ERROR",
2831 message="Tool state change failed",
2832 event_type="tool_state_change_failed",
2833 component="tool_service",
2834 user_email=user_email,
2835 resource_type="tool",
2836 resource_id=tool_id,
2837 error=e,
2838 )
2839 raise ToolError(f"Failed to set tool state: {str(e)}")
2841 async def invoke_tool_direct(
2842 self,
2843 gateway_id: str,
2844 name: str,
2845 arguments: Dict[str, Any],
2846 request_headers: Optional[Dict[str, str]] = None,
2847 meta_data: Optional[Dict[str, Any]] = None,
2848 user_email: Optional[str] = None,
2849 token_teams: Optional[List[str]] = None,
2850 ) -> types.CallToolResult:
2851 """
2852 Invoke a tool directly on a remote MCP gateway in direct_proxy mode.
2854 This bypasses all gateway processing (caching, plugins, validation) and forwards
2855 the tool call directly to the remote MCP server, returning the raw result.
2857 Args:
2858 gateway_id: Gateway ID to invoke the tool on.
2859 name: Name of tool to invoke.
2860 arguments: Tool arguments.
2861 request_headers: Headers from the request to pass through.
2862 meta_data: Optional metadata dictionary for additional context (e.g., request ID).
2863 user_email: Email of the requesting user for access control.
2864 token_teams: Team IDs from the user's token for access control.
2866 Returns:
2867 CallToolResult from the remote MCP server (as-is, no normalization).
2869 Raises:
2870 ToolNotFoundError: If gateway not found or access denied.
2871 ToolInvocationError: If invocation fails.
2872 """
2873 logger.info(f"Direct proxy tool invocation: {name} via gateway {SecurityValidator.sanitize_log_message(gateway_id)}")
2874 # Look up gateway
2875 # Use a fresh session for this lookup
2876 with fresh_db_session() as db:
2877 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
2878 if not gateway:
2879 raise ToolNotFoundError(f"Gateway {gateway_id} not found")
2881 if getattr(gateway, "gateway_mode", "cache") != "direct_proxy" or not settings.mcpgateway_direct_proxy_enabled:
2882 raise ToolInvocationError(f"Gateway {gateway_id} is not in direct_proxy mode")
2884 # SECURITY: Defensive access check — callers should also check,
2885 # but enforce here to prevent RBAC bypass if called from a new context.
2886 if not await check_gateway_access(db, gateway, user_email, token_teams):
2887 raise ToolNotFoundError(f"Tool not found: {name}")
2889 # Prepare headers with gateway auth
2890 headers = build_gateway_auth_headers(gateway)
2892 # Forward passthrough headers if configured
2893 if gateway.passthrough_headers and request_headers:
2894 for header_name in gateway.passthrough_headers:
2895 header_value = request_headers.get(header_name.lower()) or request_headers.get(header_name)
2896 if header_value:
2897 headers[header_name] = header_value
2899 gateway_url = gateway.url
2901 # Resolve the original (unprefixed) tool name for the remote server.
2902 # Tools registered via gateways are stored as "{gateway_slug}{separator}{slugified_name}",
2903 # but the remote server only knows the original name (e.g. "get_system_time" not "get-system-time").
2904 # Look up the tool's original_name from the DB; fall back to the prefixed name if not found
2905 # (e.g. when calling a tool that exists on the remote but hasn't been cached locally).
2906 remote_name = name
2907 tool_row = db.execute(select(DbTool).where(DbTool.name == name, DbTool.gateway_id == gateway_id)).scalar_one_or_none()
2908 if tool_row and tool_row.original_name:
2909 remote_name = tool_row.original_name
2910 else:
2911 # Fallback: strip the slug prefix (best-effort for tools not yet in DB)
2912 gateway_slug = getattr(gateway, "slug", None) or ""
2913 if gateway_slug:
2914 prefix = f"{gateway_slug}{settings.gateway_tool_name_separator}"
2915 if name.startswith(prefix):
2916 remote_name = name[len(prefix) :]
2918 # Use MCP SDK to connect and call tool
2919 try:
2920 with create_span(
2921 "mcp.client.call",
2922 {
2923 "mcp.tool.name": remote_name,
2924 "contextforge.gateway_id": str(gateway.id),
2925 "contextforge.runtime": "python",
2926 "contextforge.transport": "streamablehttp",
2927 "network.protocol.name": "mcp",
2928 "server.address": urlparse(gateway_url).hostname,
2929 "server.port": urlparse(gateway_url).port,
2930 "url.path": urlparse(gateway_url).path or "/",
2931 "url.full": sanitize_url_for_logging(gateway_url, {}),
2932 },
2933 ):
2934 traced_headers = inject_trace_context_headers(headers)
2935 async with streamablehttp_client(url=gateway_url, headers=traced_headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
2936 async with ClientSession(read_stream, write_stream) as session:
2937 with create_span("mcp.client.initialize", {"contextforge.transport": "streamablehttp", "contextforge.runtime": "python"}):
2938 await session.initialize()
2940 with create_span(
2941 "mcp.client.request",
2942 {
2943 "mcp.tool.name": remote_name,
2944 "contextforge.gateway_id": str(gateway.id),
2945 "contextforge.runtime": "python",
2946 },
2947 ):
2948 # Call tool with meta if provided
2949 if meta_data:
2950 logger.debug(f"Forwarding _meta to remote gateway: {meta_data}")
2951 tool_result = await session.call_tool(name=remote_name, arguments=arguments, meta=meta_data)
2952 else:
2953 tool_result = await session.call_tool(name=remote_name, arguments=arguments)
2954 with create_span(
2955 "mcp.client.response",
2956 {
2957 "mcp.tool.name": remote_name,
2958 "contextforge.gateway_id": str(gateway.id),
2959 "contextforge.runtime": "python",
2960 "upstream.response.success": not getattr(tool_result, "is_error", False) and not getattr(tool_result, "isError", False),
2961 },
2962 ):
2963 pass
2965 logger.info(
2966 f"[INVOKE TOOL] Using direct_proxy mode for gateway {SecurityValidator.sanitize_log_message(gateway.id)} (from X-Context-Forge-Gateway-Id header). Meta Attached: {meta_data is not None}"
2967 )
2968 return tool_result
2969 except Exception as e:
2970 logger.exception(f"Direct proxy tool invocation failed for {name}: {e}")
2971 raise ToolInvocationError(f"Direct proxy tool invocation failed: {str(e)}")
2973 async def prepare_rust_mcp_tool_execution(
2974 self,
2975 db: Session,
2976 name: str,
2977 arguments: Optional[Dict[str, Any]] = None,
2978 request_headers: Optional[Dict[str, str]] = None,
2979 app_user_email: Optional[str] = None,
2980 user_email: Optional[str] = None,
2981 token_teams: Optional[List[str]] = None,
2982 server_id: Optional[str] = None,
2983 plugin_global_context: Optional[GlobalContext] = None,
2984 plugin_context_table: Optional[PluginContextTable] = None,
2985 ) -> Dict[str, Any]:
2986 """Build a narrow MCP execution plan for the Rust runtime hot path.
2988 This reuses Python's existing auth, scoping, and secret-handling logic,
2989 but stops before the actual upstream MCP call. The Rust runtime can then
2990 execute the call directly for the simple streamable HTTP MCP cases that
2991 dominate load tests, while Python remains the authority for policy.
2993 When tool_pre_invoke hooks are registered, they are executed during plan
2994 resolution and their modifications (cleaned args, injected headers) are
2995 returned in the plan for the Rust runtime to apply.
2997 Args:
2998 db: Active database session.
2999 name: Tool name requested by the caller.
3000 arguments: Tool call arguments from the JSON-RPC params (passed to pre-invoke hooks).
3001 request_headers: Incoming request headers used for passthrough/auth decisions.
3002 app_user_email: OAuth application user email, when present.
3003 user_email: Effective requester email after auth normalization.
3004 token_teams: Normalized team scope from the caller token.
3005 server_id: Optional virtual server identifier restricting tool access.
3006 plugin_global_context: Optional global context from middleware for hook continuity.
3007 plugin_context_table: Optional context table from prior hooks for state sharing.
3009 Returns:
3010 A Rust execution plan dictionary, or a fallback descriptor when direct
3011 Rust execution is not eligible.
3013 Raises:
3014 ToolNotFoundError: If the requested tool is not visible or invocable.
3015 ToolInvocationError: If gateway auth preparation fails or the tool name is ambiguous.
3016 """
3017 has_pre_invoke = self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE)
3018 has_post_invoke = self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE)
3020 gateway_id_from_header = extract_gateway_id_from_headers(request_headers)
3021 is_direct_proxy = False
3022 tool = None
3023 gateway = None
3024 tool_selected_from_server_scope = False
3025 tool_payload: Dict[str, Any] = {}
3026 gateway_payload: Optional[Dict[str, Any]] = None
3027 if gateway_id_from_header:
3028 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none()
3029 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
3030 if not await check_gateway_access(db, gateway, user_email, token_teams):
3031 raise ToolNotFoundError(f"Tool not found: {name}")
3032 is_direct_proxy = True
3033 gateway_payload = {
3034 "id": str(gateway.id),
3035 "name": gateway.name,
3036 "url": gateway.url,
3037 "auth_type": gateway.auth_type,
3038 "auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
3039 "auth_query_params": gateway.auth_query_params,
3040 "oauth_config": gateway.oauth_config,
3041 "ca_certificate": gateway.ca_certificate,
3042 "ca_certificate_sig": gateway.ca_certificate_sig,
3043 "passthrough_headers": gateway.passthrough_headers,
3044 "gateway_mode": gateway.gateway_mode,
3045 }
3046 tool_payload = {
3047 "id": None,
3048 "name": name,
3049 "original_name": name,
3050 "enabled": True,
3051 "reachable": True,
3052 "integration_type": "MCP",
3053 "request_type": "streamablehttp",
3054 "gateway_id": str(gateway.id),
3055 }
3057 if not is_direct_proxy:
3058 tool_lookup_cache = _get_tool_lookup_cache()
3059 cached_payload = await tool_lookup_cache.get(name) if tool_lookup_cache.enabled else None
3061 if cached_payload:
3062 status = cached_payload.get("status", "active")
3063 if status == "missing":
3064 raise ToolNotFoundError(f"Tool not found: {name}")
3065 if status == "inactive":
3066 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
3067 if status == "offline":
3068 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
3069 tool_payload = cached_payload.get("tool") or {}
3070 gateway_payload = cached_payload.get("gateway")
3072 if not tool_payload:
3073 tools = self._load_invocable_tools(db, name, server_id=server_id)
3074 tool_selected_from_server_scope = bool(server_id)
3076 if not tools:
3077 raise ToolNotFoundError(f"Tool not found: {name}")
3079 multiple_found = len(tools) > 1
3080 if not multiple_found:
3081 tool = tools[0]
3082 else:
3083 visibility_priority = {"team": 0, "private": 1, "public": 2}
3084 accessible_tools: list[tuple[int, Any]] = []
3085 for candidate in tools:
3086 tool_dict = {"visibility": candidate.visibility, "team_id": candidate.team_id, "owner_email": candidate.owner_email}
3087 if await self._check_tool_access(db, tool_dict, user_email, token_teams):
3088 priority = visibility_priority.get(candidate.visibility, 99)
3089 accessible_tools.append((priority, candidate))
3091 if not accessible_tools:
3092 raise ToolNotFoundError(f"Tool not found: {name}")
3094 accessible_tools.sort(key=lambda item: item[0])
3095 best_priority = accessible_tools[0][0]
3096 best_tools = [candidate for priority, candidate in accessible_tools if priority == best_priority]
3097 if len(best_tools) > 1:
3098 raise ToolInvocationError(f"Multiple tools found with name '{name}' at same priority level. Tool name is ambiguous.")
3099 tool = best_tools[0]
3101 if not tool.enabled:
3102 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
3104 if not tool.reachable:
3105 await tool_lookup_cache.set_negative(name, "offline")
3106 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
3108 gateway = tool.gateway
3109 cache_payload = self._build_tool_cache_payload(tool, gateway)
3110 tool_payload = cache_payload.get("tool") or {}
3111 gateway_payload = cache_payload.get("gateway")
3112 if not multiple_found:
3113 await tool_lookup_cache.set(name, cache_payload, gateway_id=tool_payload.get("gateway_id"))
3115 if tool_payload.get("enabled") is False:
3116 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
3117 if tool_payload.get("reachable") is False:
3118 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
3120 if is_direct_proxy:
3121 return {"eligible": False, "fallbackReason": "direct-proxy"}
3123 if not await self._check_tool_access(db, tool_payload, user_email, token_teams):
3124 raise ToolNotFoundError(f"Tool not found: {name}")
3126 if server_id and not tool_selected_from_server_scope:
3127 tool_id_for_check = tool_payload.get("id")
3128 if not tool_id_for_check:
3129 raise ToolNotFoundError(f"Tool not found: {name}")
3130 server_match = db.execute(
3131 select(server_tool_association.c.tool_id).where(
3132 server_tool_association.c.server_id == server_id,
3133 server_tool_association.c.tool_id == tool_id_for_check,
3134 )
3135 ).first()
3136 if not server_match:
3137 raise ToolNotFoundError(f"Tool not found: {name}")
3139 tool_integration_type = tool_payload.get("integration_type")
3140 if tool_integration_type != "MCP":
3141 return {"eligible": False, "fallbackReason": f"unsupported-integration:{tool_integration_type or 'unknown'}"}
3143 tool_request_type = tool_payload.get("request_type")
3144 transport = tool_request_type.lower() if tool_request_type else "sse"
3145 if transport not in {"streamablehttp", "sse"}:
3146 return {"eligible": False, "fallbackReason": f"unsupported-transport:{transport}"}
3148 tool_jsonpath_filter = tool_payload.get("jsonpath_filter")
3149 if tool_jsonpath_filter:
3150 return {"eligible": False, "fallbackReason": "jsonpath-filter-configured"}
3152 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
3154 if tool is not None:
3155 gateway = tool.gateway
3157 tool_name_original = tool_payload.get("original_name") or tool_payload.get("name") or name
3158 tool_id = tool_payload.get("id")
3159 tool_gateway_id = tool_payload.get("gateway_id")
3160 tool_timeout_ms = tool_payload.get("timeout_ms")
3161 effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout
3163 has_gateway = gateway_payload is not None
3164 gateway_url = gateway_payload.get("url") if has_gateway else None
3165 gateway_name = gateway_payload.get("name") if has_gateway else None
3166 gateway_auth_type = gateway_payload.get("auth_type") if has_gateway else None
3167 gateway_auth_value = gateway_payload.get("auth_value") if has_gateway and isinstance(gateway_payload.get("auth_value"), str) else None
3168 gateway_auth_query_params = gateway_payload.get("auth_query_params") if has_gateway and isinstance(gateway_payload.get("auth_query_params"), dict) else None
3169 gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None
3170 if has_gateway and gateway is not None:
3171 runtime_gateway_auth_value = getattr(gateway, "auth_value", None)
3172 if isinstance(runtime_gateway_auth_value, dict):
3173 gateway_auth_value = encode_auth(runtime_gateway_auth_value)
3174 elif isinstance(runtime_gateway_auth_value, str):
3175 gateway_auth_value = runtime_gateway_auth_value
3176 runtime_gateway_query_params = getattr(gateway, "auth_query_params", None)
3177 if isinstance(runtime_gateway_query_params, dict):
3178 gateway_auth_query_params = runtime_gateway_query_params
3179 runtime_gateway_oauth_config = getattr(gateway, "oauth_config", None)
3180 if isinstance(runtime_gateway_oauth_config, dict):
3181 gateway_oauth_config = runtime_gateway_oauth_config
3182 gateway_ca_cert = gateway_payload.get("ca_certificate") if has_gateway else None
3183 gateway_id_str = gateway_payload.get("id") if has_gateway else None
3185 if tool is None and has_gateway:
3186 requires_gateway_auth_hydration = gateway_auth_type in {"basic", "bearer", "authheaders", "oauth", "query_param"}
3187 if requires_gateway_auth_hydration:
3188 tool_id_for_hydration = tool_payload.get("id")
3189 if tool_id_for_hydration:
3190 tool_auth_row = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.id == tool_id_for_hydration)).scalar_one_or_none()
3191 if tool_auth_row and tool_auth_row.gateway:
3192 hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None)
3193 if isinstance(hydrated_gateway_auth_value, dict):
3194 gateway_auth_value = encode_auth(hydrated_gateway_auth_value)
3195 elif isinstance(hydrated_gateway_auth_value, str):
3196 gateway_auth_value = hydrated_gateway_auth_value
3197 hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None)
3198 if isinstance(hydrated_gateway_query_params, dict):
3199 gateway_auth_query_params = hydrated_gateway_query_params
3200 hydrated_gateway_oauth_config = getattr(tool_auth_row.gateway, "oauth_config", None)
3201 if isinstance(hydrated_gateway_oauth_config, dict):
3202 gateway_oauth_config = hydrated_gateway_oauth_config
3204 gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None
3205 if gateway_auth_type == "query_param" and gateway_auth_query_params:
3206 gateway_auth_query_params_decrypted = {}
3207 for param_key, encrypted_value in gateway_auth_query_params.items():
3208 if encrypted_value:
3209 try:
3210 decrypted = decode_auth(encrypted_value)
3211 gateway_auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3212 except Exception: # noqa: S110
3213 logger.debug(f"Failed to decrypt query param '{param_key}' for Rust MCP tool execution plan")
3214 if gateway_auth_query_params_decrypted and gateway_url:
3215 gateway_url = apply_query_param_auth(gateway_url, gateway_auth_query_params_decrypted)
3217 if gateway_ca_cert:
3218 return {"eligible": False, "fallbackReason": "custom-ca-certificate"}
3220 if not gateway_url:
3221 return {"eligible": False, "fallbackReason": "missing-gateway-url"}
3223 if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config:
3224 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
3225 if grant_type == "authorization_code":
3226 try:
3227 # First-Party
3228 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3230 with fresh_db_session() as token_db:
3231 token_storage = TokenStorageService(token_db)
3232 if not app_user_email:
3233 raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.")
3234 access_token = await token_storage.get_user_token(gateway_id_str, app_user_email)
3236 if access_token:
3237 headers = {"Authorization": f"Bearer {access_token}"}
3238 else:
3239 raise ToolInvocationError(f"Please authorize {gateway_name} first. Visit /oauth/authorize/{gateway_id_str} to complete OAuth flow.")
3240 except Exception as e:
3241 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
3242 raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}")
3243 else:
3244 try:
3245 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
3246 headers = {"Authorization": f"Bearer {access_token}"}
3247 except Exception as e:
3248 logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}")
3249 raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}")
3250 else:
3251 headers = decode_auth(gateway_auth_value) if gateway_auth_value else {}
3253 if request_headers:
3254 headers = compute_passthrough_headers_cached(
3255 request_headers,
3256 headers,
3257 passthrough_allowed,
3258 gateway_auth_type=gateway_auth_type,
3259 gateway_passthrough_headers=gateway_payload.get("passthrough_headers") if has_gateway else None,
3260 )
3262 runtime_headers = {str(header_name): str(header_value) for header_name, header_value in headers.items() if header_name and header_value}
3264 hook_global_context = None
3265 if has_pre_invoke or has_post_invoke:
3266 hook_global_context = self._build_rust_tool_hook_global_context(
3267 app_user_email=app_user_email,
3268 server_id=server_id,
3269 tool_gateway_id=tool_gateway_id,
3270 plugin_global_context=plugin_global_context,
3271 tool_payload=tool_payload,
3272 gateway_payload=gateway_payload,
3273 )
3275 native_post_invoke_retry_policy = None
3276 if has_post_invoke:
3277 native_post_invoke_retry_policy, requires_python_fallback = self._build_rust_native_tool_post_invoke_retry_policy(name, hook_global_context)
3278 if requires_python_fallback:
3279 return {"eligible": False, "fallbackReason": "post-invoke-hooks-configured"}
3281 # Run tool_pre_invoke hooks so that plugins (e.g. wxo_connections) can
3282 # inject credentials and clean arguments before the Rust direct call.
3283 modified_args = arguments
3284 if has_pre_invoke and arguments is not None:
3285 pre_result, _ = await self._plugin_manager.invoke_hook(
3286 ToolHookType.TOOL_PRE_INVOKE,
3287 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=dict(runtime_headers))),
3288 global_context=hook_global_context,
3289 local_contexts=plugin_context_table,
3290 violations_as_exceptions=True,
3291 )
3292 if pre_result.modified_payload:
3293 modified_args = pre_result.modified_payload.args
3294 if pre_result.modified_payload.name and pre_result.modified_payload.name != name:
3295 tool_name_original = pre_result.modified_payload.name
3296 if pre_result.modified_payload.headers is not None:
3297 plugin_headers = pre_result.modified_payload.headers.root if hasattr(pre_result.modified_payload.headers, "root") else {}
3298 for hk, hv in plugin_headers.items():
3299 if hk and hv:
3300 runtime_headers[str(hk).lower()] = str(hv)
3302 runtime_headers = inject_trace_context_headers(runtime_headers)
3304 plan: Dict[str, Any] = {
3305 "eligible": True,
3306 "transport": transport,
3307 "serverUrl": gateway_url,
3308 "remoteToolName": tool_name_original,
3309 "headers": runtime_headers,
3310 "timeoutMs": int(effective_timeout * 1000),
3311 "gatewayId": tool_gateway_id,
3312 "toolName": name,
3313 "toolId": tool_id or None,
3314 "serverId": server_id,
3315 }
3316 if native_post_invoke_retry_policy is not None:
3317 plan["postInvokeRetryPolicy"] = native_post_invoke_retry_policy
3318 if has_pre_invoke:
3319 plan["hasPreInvokeHooks"] = True
3320 if modified_args is not None:
3321 plan["modifiedArgs"] = modified_args
3322 return plan
3324 def _build_rust_tool_hook_global_context(
3325 self,
3326 *,
3327 app_user_email: Optional[str],
3328 server_id: Optional[str],
3329 tool_gateway_id: Optional[str],
3330 plugin_global_context: Optional[GlobalContext],
3331 tool_payload: Optional[Dict[str, Any]],
3332 gateway_payload: Optional[Dict[str, Any]],
3333 ) -> GlobalContext:
3334 """Build plugin global context for Rust-direct tool plan resolution.
3336 Args:
3337 app_user_email: Effective authenticated user for plugin context.
3338 server_id: Explicit virtual server scope from the request.
3339 tool_gateway_id: Resolved tool gateway id.
3340 plugin_global_context: Existing middleware context if available.
3341 tool_payload: Resolved tool payload.
3342 gateway_payload: Resolved gateway payload.
3344 Returns:
3345 GlobalContext primed with the same metadata the Python invoke path exposes.
3346 """
3347 if plugin_global_context:
3348 hook_global_context = plugin_global_context
3349 if tool_gateway_id and isinstance(tool_gateway_id, str):
3350 hook_global_context.server_id = tool_gateway_id
3351 if not hook_global_context.user and app_user_email and isinstance(app_user_email, str):
3352 hook_global_context.user = app_user_email
3353 else:
3354 request_id = get_correlation_id() or uuid.uuid4().hex
3355 context_server_id = tool_gateway_id if tool_gateway_id and isinstance(tool_gateway_id, str) else server_id
3356 hook_global_context = GlobalContext(request_id=request_id, server_id=context_server_id, tenant_id=None, user=app_user_email)
3358 tool_metadata: Optional[PydanticTool] = self._pydantic_tool_from_payload(tool_payload) if tool_payload else None
3359 gateway_metadata: Optional[PydanticGateway] = self._pydantic_gateway_from_payload(gateway_payload) if gateway_payload else None
3360 if tool_metadata:
3361 hook_global_context.metadata[TOOL_METADATA] = tool_metadata
3362 if gateway_metadata:
3363 hook_global_context.metadata[GATEWAY_METADATA] = gateway_metadata
3364 return hook_global_context
3366 def _build_rust_native_tool_post_invoke_retry_policy(
3367 self,
3368 tool_name: str,
3369 hook_global_context: Optional[GlobalContext],
3370 ) -> Tuple[Optional[Dict[str, Any]], bool]:
3371 """Return a native Rust retry policy when the active post-invoke hooks allow it.
3373 The Rust runtime only supports native post-invoke execution for the
3374 default retry-with-backoff plugin. Any other active `tool_post_invoke`
3375 hook must still force the call back to Python to preserve plugin semantics.
3377 Args:
3378 tool_name: Requested tool name.
3379 hook_global_context: Resolved plugin context for condition matching.
3381 Returns:
3382 Tuple of `(policy, requires_python_fallback)`.
3383 """
3384 if not self._plugin_manager or not self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3385 return (None, False)
3387 # First-Party
3388 from mcpgateway.plugins.framework import PluginMode # pylint: disable=import-outside-toplevel
3389 from mcpgateway.plugins.framework.utils import payload_matches # pylint: disable=import-outside-toplevel
3391 # Third-Party/Local
3392 from plugins.retry_with_backoff.retry_with_backoff import RetryConfig # pylint: disable=import-outside-toplevel
3394 global_context = hook_global_context or GlobalContext(request_id=get_correlation_id() or uuid.uuid4().hex)
3395 payload = ToolPostInvokePayload(name=tool_name, result={})
3396 hook_refs = self._plugin_manager._registry.get_hook_refs_for_hook(hook_type=ToolHookType.TOOL_POST_INVOKE) # pylint: disable=protected-access
3398 active_hook_refs = []
3399 for hook_ref in hook_refs:
3400 if hook_ref.plugin_ref.mode == PluginMode.DISABLED:
3401 continue
3402 if hook_ref.plugin_ref.conditions and not payload_matches(payload, ToolHookType.TOOL_POST_INVOKE, hook_ref.plugin_ref.conditions, global_context):
3403 continue
3404 active_hook_refs.append(hook_ref)
3406 if not active_hook_refs:
3407 return (None, False)
3409 if len(active_hook_refs) != 1 or active_hook_refs[0].plugin_ref.name != "RetryWithBackoffPlugin":
3410 return (None, True)
3412 retry_hook = active_hook_refs[0]
3413 effective_cfg = RetryConfig(**(retry_hook.plugin_ref.plugin.config.config or {}))
3414 ceiling = settings.max_tool_retries
3415 if effective_cfg.max_retries > ceiling:
3416 effective_cfg = effective_cfg.model_copy(update={"max_retries": ceiling})
3418 overrides = effective_cfg.tool_overrides.get(tool_name)
3419 if overrides:
3420 merged_cfg = effective_cfg.model_dump()
3421 merged_cfg.update(overrides)
3422 merged_cfg.pop("tool_overrides", None)
3423 effective_cfg = RetryConfig(**merged_cfg)
3424 if effective_cfg.max_retries > ceiling:
3425 effective_cfg = effective_cfg.model_copy(update={"max_retries": ceiling})
3427 if effective_cfg.check_text_content:
3428 return (None, True)
3430 return (
3431 {
3432 "kind": "retry_with_backoff",
3433 "maxRetries": int(effective_cfg.max_retries),
3434 "backoffBaseMs": int(effective_cfg.backoff_base_ms),
3435 "maxBackoffMs": int(effective_cfg.max_backoff_ms),
3436 "retryOnStatus": list(effective_cfg.retry_on_status),
3437 "jitter": bool(effective_cfg.jitter),
3438 },
3439 False,
3440 )
3442 def _load_invocable_tools(self, db: Session, name: str, server_id: Optional[str] = None) -> List[DbTool]:
3443 """Load candidate tools for invocation, narrowing to a virtual server when possible.
3445 Args:
3446 db: Active database session.
3447 name: Tool name to resolve.
3448 server_id: Optional virtual server identifier used to constrain results.
3450 Returns:
3451 A list of candidate tool ORM rows matching the request.
3452 """
3453 query = select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.name == name)
3454 if server_id:
3455 query = query.join(server_tool_association, DbTool.id == server_tool_association.c.tool_id).where(server_tool_association.c.server_id == server_id)
3456 return db.execute(query).scalars().all()
3458 # ------------------------------------------------------------------
3459 # Retry helpers (used by invoke_tool)
3460 # ------------------------------------------------------------------
3462 async def _run_timeout_post_invoke(
3463 self,
3464 name: str,
3465 effective_timeout: float,
3466 global_context: Any,
3467 context_table: Any,
3468 ) -> None:
3469 """Invoke post-invoke plugins after a timeout and raise with retry signal if requested.
3471 Called from each transport-specific timeout handler so the retry plugin
3472 can record the failure and (optionally) request a retry. If the plugin
3473 sets ``retry_delay_ms > 0``, a ``ToolTimeoutError`` carrying the delay
3474 is raised immediately; otherwise control returns to the caller which
3475 raises a plain ``ToolTimeoutError``.
3477 Args:
3478 name: Tool name.
3479 effective_timeout: Timeout duration in seconds.
3480 global_context: Plugin global context for cross-hook state.
3481 context_table: Plugin local context table for per-plugin state.
3483 Raises:
3484 ToolTimeoutError: When the retry plugin requests a delayed retry.
3485 """
3486 if context_table:
3487 for ctx in context_table.values():
3488 ctx.set_state("cb_timeout_failure", True)
3490 if not self._plugin_manager:
3491 return
3493 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3494 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3495 timeout_post_result, _ = await self._plugin_manager.invoke_hook(
3496 ToolHookType.TOOL_POST_INVOKE,
3497 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3498 global_context=global_context,
3499 local_contexts=context_table,
3500 violations_as_exceptions=False,
3501 )
3502 if timeout_post_result and timeout_post_result.retry_delay_ms > 0:
3503 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s", retry_delay_ms=timeout_post_result.retry_delay_ms)
3505 async def _retry_tool_invocation(
3506 self,
3507 delay_ms: int,
3508 retry_attempt: int,
3509 name: str,
3510 arguments: Dict[str, Any],
3511 request_headers: Any,
3512 app_user_email: Optional[str],
3513 user_email: Optional[str],
3514 token_teams: Optional[List[str]],
3515 server_id: Optional[str],
3516 context_table: Any,
3517 global_context: Any,
3518 meta_data: Optional[Dict[str, Any]],
3519 skip_pre_invoke: bool,
3520 path_label: str,
3521 ) -> "ToolResult":
3522 """Sleep for the plugin-requested delay, then recursively re-invoke the tool.
3524 The sleep is cancellation-aware: if the calling task is cancelled (e.g.
3525 client disconnect) the ``CancelledError`` propagates immediately instead
3526 of wasting time on a retry that nobody will consume.
3528 Args:
3529 delay_ms: Backoff delay in milliseconds before retrying.
3530 retry_attempt: Current zero-based retry counter.
3531 name: Tool name to re-invoke.
3532 arguments: Tool arguments to forward.
3533 request_headers: Original request headers.
3534 app_user_email: ContextForge user email for OAuth.
3535 user_email: User email for authorization.
3536 token_teams: Team IDs from JWT token.
3537 server_id: Virtual server ID for scoping.
3538 context_table: Plugin local context table.
3539 global_context: Plugin global context.
3540 meta_data: Optional metadata dictionary.
3541 skip_pre_invoke: Whether to skip pre-invoke hooks.
3542 path_label: Label for log messages (success/timeout/exception).
3544 Returns:
3545 ToolResult from the retried invocation.
3546 """
3547 logger.debug(
3548 "tool_service: retry requested (%s) for tool=%s attempt=%d/%d delay_ms=%d",
3549 path_label,
3550 name,
3551 retry_attempt + 1,
3552 settings.max_tool_retries,
3553 delay_ms,
3554 )
3555 await asyncio.sleep(delay_ms / 1000)
3556 with fresh_db_session() as retry_db:
3557 return await self.invoke_tool(
3558 db=retry_db,
3559 name=name,
3560 arguments=arguments,
3561 request_headers=request_headers,
3562 app_user_email=app_user_email,
3563 user_email=user_email,
3564 token_teams=token_teams,
3565 server_id=server_id,
3566 plugin_context_table=context_table,
3567 plugin_global_context=global_context,
3568 meta_data=meta_data,
3569 skip_pre_invoke=skip_pre_invoke,
3570 retry_attempt=retry_attempt + 1,
3571 )
3573 async def invoke_tool(
3574 self,
3575 db: Session,
3576 name: str,
3577 arguments: Dict[str, Any],
3578 request_headers: Optional[Dict[str, str]] = None,
3579 app_user_email: Optional[str] = None,
3580 user_email: Optional[str] = None,
3581 token_teams: Optional[List[str]] = None,
3582 server_id: Optional[str] = None,
3583 plugin_context_table: Optional[PluginContextTable] = None,
3584 plugin_global_context: Optional[GlobalContext] = None,
3585 meta_data: Optional[Dict[str, Any]] = None,
3586 skip_pre_invoke: bool = False,
3587 retry_attempt: int = 0,
3588 ) -> ToolResult:
3589 """
3590 Invoke a registered tool and record execution metrics.
3592 Args:
3593 db: Database session.
3594 name: Name of tool to invoke.
3595 arguments: Tool arguments.
3596 request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
3597 Defaults to None.
3598 app_user_email (Optional[str], optional): ContextForge user email for OAuth token retrieval.
3599 Required for OAuth-protected gateways.
3600 user_email (Optional[str], optional): User email for authorization checks.
3601 None = unauthenticated request.
3602 token_teams (Optional[List[str]], optional): Team IDs from JWT token for authorization.
3603 None = unrestricted admin, [] = public-only, [...] = team-scoped.
3604 server_id (Optional[str], optional): Virtual server ID for server scoping enforcement.
3605 If provided, tool must be attached to this server.
3606 plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing.
3607 plugin_global_context: Optional global context from middleware for consistency across hooks.
3608 meta_data: Optional metadata dictionary for additional context (e.g., request ID).
3609 skip_pre_invoke: When True, skip TOOL_PRE_INVOKE hooks (used by trusted Rust fallback path).
3610 retry_attempt: Zero-based retry counter; 0 = original call. Incremented by the retry
3611 loop and compared against ``settings.max_tool_retries``.
3613 Returns:
3614 Tool invocation result.
3616 Raises:
3617 ToolNotFoundError: If tool not found or access denied.
3618 ToolInvocationError: If invocation fails or A2A authentication decryption fails.
3619 ToolTimeoutError: If tool invocation times out.
3620 PluginViolationError: If plugin blocks tool invocation.
3621 PluginError: If encounters issue with plugin.
3623 Examples:
3624 >>> # Note: This method requires extensive mocking of SQLAlchemy models,
3625 >>> # database relationships, and caching infrastructure, which is not
3626 >>> # suitable for doctests. See tests/unit/mcpgateway/services/test_tool_service.py
3627 >>> pass # doctest: +SKIP
3628 """
3629 # pylint: disable=comparison-with-callable
3630 logger.info(f"Invoking tool: {name} with arguments: {arguments.keys() if arguments else None} and headers: {request_headers.keys() if request_headers else None}, server_id={server_id}")
3631 # ═══════════════════════════════════════════════════════════════════════════
3632 # PHASE 1: Check for X-Context-Forge-Gateway-Id header for direct_proxy mode (no DB lookup)
3633 # ═══════════════════════════════════════════════════════════════════════════
3634 gateway_id_from_header = extract_gateway_id_from_headers(request_headers)
3636 # If X-Context-Forge-Gateway-Id header is present, check if gateway is in direct_proxy mode
3637 is_direct_proxy = False
3638 tool = None
3639 gateway = None
3640 tool_payload: Dict[str, Any] = {}
3641 gateway_payload: Optional[Dict[str, Any]] = None
3643 if gateway_id_from_header:
3644 # Look up gateway to check if it's in direct_proxy mode
3645 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none()
3646 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
3647 # SECURITY: Check gateway access before allowing direct proxy
3648 # This prevents RBAC bypass where any authenticated user could invoke tools
3649 # on any gateway just by knowing the gateway ID
3650 if not await check_gateway_access(db, gateway, user_email, token_teams):
3651 logger.warning(f"Access denied to gateway {gateway_id_from_header} in direct_proxy mode for user {SecurityValidator.sanitize_log_message(user_email)}")
3652 raise ToolNotFoundError(f"Tool not found: {name}")
3654 is_direct_proxy = True
3655 # Build minimal gateway payload for direct proxy (no tool lookup needed)
3656 gateway_payload = {
3657 "id": str(gateway.id),
3658 "name": gateway.name,
3659 "url": gateway.url,
3660 "auth_type": gateway.auth_type,
3661 # DbGateway.auth_value is JSON (dict); downstream code expects an encoded str.
3662 "auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
3663 "auth_query_params": gateway.auth_query_params,
3664 "oauth_config": gateway.oauth_config,
3665 "ca_certificate": gateway.ca_certificate,
3666 "ca_certificate_sig": gateway.ca_certificate_sig,
3667 "passthrough_headers": gateway.passthrough_headers,
3668 "gateway_mode": gateway.gateway_mode,
3669 }
3670 # Create minimal tool payload for direct proxy (no DB tool needed)
3671 tool_payload = {
3672 "id": None, # No tool ID in direct proxy mode
3673 "name": name,
3674 "original_name": name,
3675 "enabled": True,
3676 "reachable": True,
3677 "integration_type": "MCP",
3678 "request_type": "streamablehttp", # Default to streamablehttp
3679 "gateway_id": str(gateway.id),
3680 }
3681 logger.info(f"Direct proxy mode via X-Context-Forge-Gateway-Id header: passing tool '{name}' directly to remote MCP server at {gateway.url}")
3682 elif gateway:
3683 logger.debug(f"Gateway {gateway_id_from_header} found but not in direct_proxy mode (mode: {gateway.gateway_mode}), using normal lookup")
3684 else:
3685 logger.warning(f"Gateway {gateway_id_from_header} specified in X-Context-Forge-Gateway-Id header not found")
3687 # Normal mode: look up tool in database/cache
3688 if not is_direct_proxy:
3689 tool_lookup_cache = _get_tool_lookup_cache()
3690 cached_payload = await tool_lookup_cache.get(name) if tool_lookup_cache.enabled else None
3692 if cached_payload:
3693 status = cached_payload.get("status", "active")
3694 if status == "missing":
3695 raise ToolNotFoundError(f"Tool not found: {name}")
3696 if status == "inactive":
3697 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
3698 if status == "offline":
3699 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
3700 tool_payload = cached_payload.get("tool") or {}
3701 gateway_payload = cached_payload.get("gateway")
3703 if not tool_payload:
3704 # Eager load tool WITH gateway in single query to prevent lazy load N+1
3705 # Use a single query to avoid a race between separate enabled/inactive lookups.
3706 # Use scalars().all() instead of scalar_one_or_none() to handle duplicate
3707 # tool names across teams without crashing on MultipleResultsFound.
3708 tools = self._load_invocable_tools(db, name, server_id=server_id)
3710 if not tools:
3711 raise ToolNotFoundError(f"Tool not found: {name}")
3713 multiple_found = len(tools) > 1
3714 if not multiple_found:
3715 tool = tools[0]
3716 else:
3717 # Multiple tools found with same name — filter by access using
3718 # _check_tool_access (same rules as list_tools) and prioritize.
3719 # Priority (lower is better): team (0) > private (1) > public (2)
3720 visibility_priority = {"team": 0, "private": 1, "public": 2}
3721 accessible_tools: list[tuple[int, Any]] = []
3722 for t in tools:
3723 tool_dict = {"visibility": t.visibility, "team_id": t.team_id, "owner_email": t.owner_email}
3724 if await self._check_tool_access(db, tool_dict, user_email, token_teams):
3725 priority = visibility_priority.get(t.visibility, 99)
3726 accessible_tools.append((priority, t))
3728 if not accessible_tools:
3729 raise ToolNotFoundError(f"Tool not found: {name}")
3731 accessible_tools.sort(key=lambda x: x[0])
3733 # Check for ambiguity at the highest priority level
3734 best_priority = accessible_tools[0][0]
3735 best_tools = [t for p, t in accessible_tools if p == best_priority]
3737 if len(best_tools) > 1:
3738 raise ToolInvocationError(f"Multiple tools found with name '{name}' at same priority level. Tool name is ambiguous.")
3740 tool = best_tools[0]
3742 if not tool.enabled:
3743 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
3745 if not tool.reachable:
3746 await tool_lookup_cache.set_negative(name, "offline")
3747 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
3749 gateway = tool.gateway
3750 cache_payload = self._build_tool_cache_payload(tool, gateway)
3751 tool_payload = cache_payload.get("tool") or {}
3752 gateway_payload = cache_payload.get("gateway")
3753 # Skip caching when multiple tools share a name — resolution is
3754 # user-dependent, so a cached result could be wrong for other users.
3755 if not multiple_found:
3756 await tool_lookup_cache.set(name, cache_payload, gateway_id=tool_payload.get("gateway_id"))
3758 if tool_payload.get("enabled") is False:
3759 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
3760 if tool_payload.get("reachable") is False:
3761 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
3763 # ═══════════════════════════════════════════════════════════════════════════
3764 # SECURITY: Check tool access based on visibility and team membership
3765 # Skip these checks for direct_proxy mode (no tool in database)
3766 # ═══════════════════════════════════════════════════════════════════════════
3767 if not is_direct_proxy:
3768 if not await self._check_tool_access(db, tool_payload, user_email, token_teams):
3769 # Don't reveal tool existence - return generic "not found"
3770 raise ToolNotFoundError(f"Tool not found: {name}")
3772 # ═══════════════════════════════════════════════════════════════════════════
3773 # SECURITY: Enforce server scoping if server_id is provided
3774 # Tool must be attached to the specified virtual server
3775 # ═══════════════════════════════════════════════════════════════════════════
3776 if server_id:
3777 tool_id_for_check = tool_payload.get("id")
3778 if not tool_id_for_check:
3779 # Cannot verify server membership without tool ID - deny access
3780 # This should not happen with properly cached tools, but fail safe
3781 logger.warning(f"Tool '{name}' has no ID in payload, cannot verify server membership")
3782 raise ToolNotFoundError(f"Tool not found: {name}")
3784 server_match = db.execute(
3785 select(server_tool_association.c.tool_id).where(
3786 server_tool_association.c.server_id == server_id,
3787 server_tool_association.c.tool_id == tool_id_for_check,
3788 )
3789 ).first()
3790 if not server_match:
3791 raise ToolNotFoundError(f"Tool not found: {name}")
3793 # Extract A2A-related data from annotations (will be used after db.close() if A2A tool)
3794 tool_annotations = tool_payload.get("annotations") or {}
3795 tool_integration_type = tool_payload.get("integration_type")
3797 # Get passthrough headers from in-memory cache (Issue #1715)
3798 # This eliminates 42,000+ redundant DB queries under load
3799 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
3801 # Access gateway now (already eager-loaded) to prevent later lazy load
3802 if tool is not None:
3803 gateway = tool.gateway
3805 # ═══════════════════════════════════════════════════════════════════════════
3806 # PHASE 2: Extract all needed data to local variables before network I/O
3807 # This allows us to release the DB session before making HTTP calls
3808 # ═══════════════════════════════════════════════════════════════════════════
3809 tool_id = tool_payload.get("id") or (str(tool.id) if tool else "")
3810 tool_name_original = tool_payload.get("original_name") or tool_payload.get("name") or name
3811 tool_name_computed = tool_payload.get("name") or name
3812 tool_url = tool_payload.get("url")
3813 tool_integration_type = tool_payload.get("integration_type")
3814 tool_request_type = tool_payload.get("request_type")
3815 tool_headers = _decrypt_tool_headers_for_runtime(tool_payload.get("headers") or {})
3816 tool_auth_type = tool_payload.get("auth_type")
3817 tool_auth_value = tool_payload.get("auth_value")
3818 if tool is not None:
3819 runtime_tool_auth_value = getattr(tool, "auth_value", None)
3820 if isinstance(runtime_tool_auth_value, str):
3821 tool_auth_value = runtime_tool_auth_value
3822 if not isinstance(tool_auth_value, str):
3823 tool_auth_value = None
3824 tool_jsonpath_filter = tool_payload.get("jsonpath_filter")
3825 tool_output_schema = tool_payload.get("output_schema")
3826 tool_oauth_config = tool_payload.get("oauth_config") if isinstance(tool_payload.get("oauth_config"), dict) else None
3827 if tool is not None:
3828 runtime_tool_oauth_config = getattr(tool, "oauth_config", None)
3829 if isinstance(runtime_tool_oauth_config, dict):
3830 tool_oauth_config = runtime_tool_oauth_config
3831 tool_gateway_id = tool_payload.get("gateway_id")
3833 # Get effective timeout: per-tool timeout_ms (in seconds) or global fallback
3834 # timeout_ms is stored in milliseconds, convert to seconds
3835 tool_timeout_ms = tool_payload.get("timeout_ms")
3836 effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout
3838 # Save gateway existence as local boolean BEFORE db.close()
3839 # to avoid checking ORM object truthiness after session is closed
3840 has_gateway = gateway_payload is not None
3841 gateway_url = gateway_payload.get("url") if has_gateway else None
3842 gateway_name = gateway_payload.get("name") if has_gateway else None
3843 gateway_auth_type = gateway_payload.get("auth_type") if has_gateway else None
3844 gateway_auth_value = gateway_payload.get("auth_value") if has_gateway and isinstance(gateway_payload.get("auth_value"), str) else None
3845 gateway_auth_query_params = gateway_payload.get("auth_query_params") if has_gateway and isinstance(gateway_payload.get("auth_query_params"), dict) else None
3846 gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None
3847 if has_gateway and gateway is not None:
3848 runtime_gateway_auth_value = getattr(gateway, "auth_value", None)
3849 if isinstance(runtime_gateway_auth_value, dict):
3850 gateway_auth_value = encode_auth(runtime_gateway_auth_value)
3851 elif isinstance(runtime_gateway_auth_value, str):
3852 gateway_auth_value = runtime_gateway_auth_value
3853 runtime_gateway_query_params = getattr(gateway, "auth_query_params", None)
3854 if isinstance(runtime_gateway_query_params, dict):
3855 gateway_auth_query_params = runtime_gateway_query_params
3856 runtime_gateway_oauth_config = getattr(gateway, "oauth_config", None)
3857 if isinstance(runtime_gateway_oauth_config, dict):
3858 gateway_oauth_config = runtime_gateway_oauth_config
3859 gateway_ca_cert = gateway_payload.get("ca_certificate") if has_gateway else None
3860 gateway_ca_cert_sig = gateway_payload.get("ca_certificate_sig") if has_gateway else None
3861 gateway_passthrough = gateway_payload.get("passthrough_headers") if has_gateway else None
3862 gateway_id_str = gateway_payload.get("id") if has_gateway else None
3864 # Cache payload intentionally excludes sensitive auth material. For cache hits
3865 # (tool is None), hydrate auth-related fields from DB only when needed.
3866 if tool is None:
3867 requires_tool_auth_hydration = tool_auth_type in {"basic", "bearer", "authheaders", "oauth"}
3868 requires_gateway_auth_hydration = has_gateway and gateway_auth_type in {"basic", "bearer", "authheaders", "oauth", "query_param"}
3869 if requires_tool_auth_hydration or requires_gateway_auth_hydration:
3870 tool_id_for_hydration = tool_payload.get("id")
3871 if tool_id_for_hydration:
3872 tool_auth_row = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.id == tool_id_for_hydration)).scalar_one_or_none()
3873 if tool_auth_row:
3874 hydrated_tool_auth_value = getattr(tool_auth_row, "auth_value", None)
3875 if isinstance(hydrated_tool_auth_value, str):
3876 tool_auth_value = hydrated_tool_auth_value
3877 hydrated_tool_oauth_config = getattr(tool_auth_row, "oauth_config", None)
3878 if isinstance(hydrated_tool_oauth_config, dict):
3879 tool_oauth_config = hydrated_tool_oauth_config
3880 if has_gateway and tool_auth_row.gateway:
3881 hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None)
3882 if isinstance(hydrated_gateway_auth_value, dict):
3883 gateway_auth_value = encode_auth(hydrated_gateway_auth_value)
3884 elif isinstance(hydrated_gateway_auth_value, str):
3885 gateway_auth_value = hydrated_gateway_auth_value
3886 hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None)
3887 if isinstance(hydrated_gateway_query_params, dict):
3888 gateway_auth_query_params = hydrated_gateway_query_params
3889 hydrated_gateway_oauth_config = getattr(tool_auth_row.gateway, "oauth_config", None)
3890 if isinstance(hydrated_gateway_oauth_config, dict):
3891 gateway_oauth_config = hydrated_gateway_oauth_config
3893 # Decrypt and apply query param auth to URL if applicable
3894 gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None
3895 if gateway_auth_type == "query_param" and gateway_auth_query_params:
3896 # Decrypt the query param values
3897 gateway_auth_query_params_decrypted = {}
3898 for param_key, encrypted_value in gateway_auth_query_params.items():
3899 if encrypted_value:
3900 try:
3901 decrypted = decode_auth(encrypted_value)
3902 gateway_auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3903 except Exception: # noqa: S110 - intentionally skip failed decryptions
3904 # Silently skip params that fail decryption (may be corrupted or use old key)
3905 logger.debug(f"Failed to decrypt query param '{param_key}' for tool invocation")
3906 # Apply query params to gateway URL
3907 if gateway_auth_query_params_decrypted and gateway_url:
3908 gateway_url = apply_query_param_auth(gateway_url, gateway_auth_query_params_decrypted)
3910 # Create Pydantic models for plugins BEFORE HTTP calls (use ORM objects while still valid)
3911 # This prevents lazy loading during HTTP calls
3912 tool_metadata: Optional[PydanticTool] = None
3913 gateway_metadata: Optional[PydanticGateway] = None
3914 if self._plugin_manager:
3915 if tool is not None:
3916 tool_metadata = PydanticTool.model_validate(tool)
3917 if has_gateway and gateway is not None:
3918 gateway_metadata = PydanticGateway.model_validate(gateway)
3919 else:
3920 tool_metadata = self._pydantic_tool_from_payload(tool_payload)
3921 if has_gateway and gateway_payload:
3922 gateway_metadata = self._pydantic_gateway_from_payload(gateway_payload)
3924 tool_for_validation = tool if tool is not None else SimpleNamespace(output_schema=tool_output_schema, name=tool_name_computed)
3926 # ═══════════════════════════════════════════════════════════════════════════
3927 # A2A Agent Data Extraction (must happen before db.close())
3928 # Extract all A2A agent data to local variables so HTTP call can happen after db.close()
3929 # ═══════════════════════════════════════════════════════════════════════════
3930 a2a_agent_name: Optional[str] = None
3931 a2a_agent_endpoint_url: Optional[str] = None
3932 a2a_agent_type: Optional[str] = None
3933 a2a_agent_protocol_version: Optional[str] = None
3934 a2a_agent_auth_type: Optional[str] = None
3935 a2a_agent_auth_value: Optional[str] = None
3936 a2a_agent_auth_query_params: Optional[Dict[str, str]] = None
3938 if tool_integration_type == "A2A" and "a2a_agent_id" in tool_annotations:
3939 a2a_agent_id = tool_annotations.get("a2a_agent_id")
3940 if not a2a_agent_id:
3941 raise ToolNotFoundError(f"A2A tool '{name}' missing agent ID in annotations")
3943 # Query for the A2A agent
3944 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == a2a_agent_id)
3945 a2a_agent = db.execute(agent_query).scalar_one_or_none()
3947 if not a2a_agent:
3948 raise ToolNotFoundError(f"A2A agent not found for tool '{name}' (agent ID: {a2a_agent_id})")
3950 if not a2a_agent.enabled:
3951 raise ToolNotFoundError(f"A2A agent '{a2a_agent.name}' is disabled")
3953 # Extract all needed data to local variables before db.close()
3954 a2a_agent_name = a2a_agent.name
3955 a2a_agent_endpoint_url = a2a_agent.endpoint_url
3956 a2a_agent_type = a2a_agent.agent_type
3957 a2a_agent_protocol_version = a2a_agent.protocol_version
3958 a2a_agent_auth_type = a2a_agent.auth_type
3959 a2a_agent_auth_value = a2a_agent.auth_value
3960 a2a_agent_auth_query_params = a2a_agent.auth_query_params
3962 # ═══════════════════════════════════════════════════════════════════════════
3963 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
3964 # This prevents connection pool exhaustion during slow upstream requests.
3965 # All needed data has been extracted to local variables above.
3966 # The session will be closed again by FastAPI's get_db() finally block (safe no-op).
3967 # ═══════════════════════════════════════════════════════════════════════════
3968 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats)
3969 db.close()
3971 # Plugin hook: tool pre-invoke
3972 # Use existing context_table from previous hooks if available
3973 context_table = plugin_context_table
3975 # Reuse existing global_context from middleware or create new one
3976 # IMPORTANT: Use local variables (tool_gateway_id) instead of ORM object access
3977 if plugin_global_context:
3978 global_context = plugin_global_context
3979 # Update server_id using local variable (not ORM access)
3980 if tool_gateway_id and isinstance(tool_gateway_id, str):
3981 global_context.server_id = tool_gateway_id
3982 # Propagate user email to global context for plugin access
3983 if not plugin_global_context.user and app_user_email and isinstance(app_user_email, str):
3984 global_context.user = app_user_email
3985 else:
3986 # Create new context (fallback when middleware didn't run)
3987 # Use correlation ID from context if available, otherwise generate new one
3988 request_id = get_correlation_id() or uuid.uuid4().hex
3989 context_server_id = tool_gateway_id if tool_gateway_id and isinstance(tool_gateway_id, str) else "unknown"
3990 global_context = GlobalContext(request_id=request_id, server_id=context_server_id, tenant_id=None, user=app_user_email)
3992 start_time = time.monotonic()
3993 success = False
3994 error_message = None
3995 tool_result: Optional[ToolResult] = None
3996 tool_team_scope = format_trace_team_scope(token_teams)
3998 # Get trace_id from context for database span creation
3999 trace_id = current_trace_id.get()
4000 db_span_id = None
4001 db_span_ended = False
4002 observability_service = ObservabilityService() if trace_id else None
4004 # Create database span for observability_spans table
4005 if trace_id and observability_service:
4006 try:
4007 # Re-open database session for span creation (original was closed at line 2285)
4008 # Use commit=False since fresh_db_session() handles commits on exit
4009 with fresh_db_session() as span_db:
4010 db_span_id = observability_service.start_span(
4011 db=span_db,
4012 trace_id=trace_id,
4013 name="tool.invoke",
4014 kind="client",
4015 resource_type="tool",
4016 resource_name=name,
4017 resource_id=tool_id,
4018 attributes={
4019 "tool.name": name,
4020 "tool.id": tool_id,
4021 "tool.integration_type": tool_integration_type,
4022 "tool.gateway_id": tool_gateway_id,
4023 "arguments_count": len(arguments) if arguments else 0,
4024 "has_headers": bool(request_headers),
4025 },
4026 commit=False,
4027 )
4028 logger.debug(f"✓ Created tool.invoke span: {db_span_id} for tool: {name}")
4029 except Exception as e:
4030 logger.warning(f"Failed to start observability span for tool invocation: {e}")
4031 db_span_id = None
4033 # Create a trace span for OpenTelemetry export (Jaeger, Zipkin, etc.)
4034 span_attributes = {
4035 "tool.name": name,
4036 "tool.id": tool_id,
4037 "tool.integration_type": tool_integration_type,
4038 "tool.gateway_id": tool_gateway_id,
4039 "arguments_count": len(arguments) if arguments else 0,
4040 "has_headers": bool(request_headers),
4041 "user.email": user_email or app_user_email or "anonymous",
4042 "team.scope": tool_team_scope,
4043 "server_id": server_id,
4044 }
4045 if is_input_capture_enabled("tool.invoke"):
4046 span_attributes["langfuse.observation.input"] = serialize_trace_payload(arguments or {})
4048 with create_span("tool.invoke", span_attributes) as span:
4049 try:
4050 # Create a lightweight lookup child span so Langfuse shows the invoke breakdown.
4051 with create_child_span(
4052 "tool.lookup",
4053 {
4054 "tool.name": name,
4055 "tool.id": tool_id,
4056 "tool.integration_type": tool_integration_type,
4057 },
4058 ):
4059 headers = tool_headers.copy()
4060 if tool_integration_type == "REST":
4061 # Handle OAuth authentication for REST tools
4062 if tool_auth_type == "oauth" and isinstance(tool_oauth_config, dict) and tool_oauth_config:
4063 try:
4064 access_token = await self.oauth_manager.get_access_token(tool_oauth_config)
4065 headers["Authorization"] = f"Bearer {access_token}"
4066 except Exception as e:
4067 logger.error(f"Failed to obtain OAuth access token for tool {tool_name_computed}: {e}")
4068 raise ToolInvocationError(f"OAuth authentication failed: {str(e)}")
4069 else:
4070 credentials = decode_auth(tool_auth_value) if tool_auth_value else {}
4071 # Filter out empty header names/values to avoid "Illegal header name" errors
4072 filtered_credentials = {k: v for k, v in credentials.items() if k and v}
4073 headers.update(filtered_credentials)
4075 # Use cached passthrough headers (no DB query needed)
4076 if request_headers:
4077 headers = compute_passthrough_headers_cached(
4078 request_headers,
4079 headers,
4080 passthrough_allowed,
4081 gateway_auth_type=None,
4082 gateway_passthrough_headers=None, # REST tools don't use gateway auth here
4083 )
4084 # Read MCP-Session-Id from downstream client (MCP protocol header)
4085 # and normalize to x-mcp-session-id for our internal session affinity logic
4086 # The pool will strip this before sending to upstream
4087 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests)
4088 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
4089 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id")
4090 if mcp_session_id:
4091 headers["x-mcp-session-id"] = mcp_session_id
4093 worker_id = str(os.getpid())
4094 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
4095 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity")
4097 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke:
4098 # Use pre-created Pydantic model from Phase 2 (no ORM access)
4099 if tool_metadata:
4100 global_context.metadata[TOOL_METADATA] = tool_metadata
4101 pre_result, context_table = await self._plugin_manager.invoke_hook(
4102 ToolHookType.TOOL_PRE_INVOKE,
4103 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
4104 global_context=global_context,
4105 local_contexts=context_table, # Pass context from previous hooks
4106 violations_as_exceptions=True,
4107 )
4108 if pre_result.modified_payload:
4109 payload = pre_result.modified_payload
4110 name = payload.name
4111 arguments = payload.args
4112 if payload.headers is not None:
4113 headers = payload.headers.model_dump()
4115 # Build the payload based on integration type
4116 payload = arguments.copy()
4118 # Handle URL path parameter substitution (using local variable)
4119 final_url = tool_url
4120 if "{" in tool_url and "}" in tool_url:
4121 # Extract path parameters from URL template and arguments
4122 url_params = re.findall(r"\{(\w+)\}", tool_url)
4123 url_substitutions = {}
4125 for param in url_params:
4126 if param in payload:
4127 url_substitutions[param] = payload.pop(param) # Remove from payload
4128 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param]))
4129 else:
4130 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments")
4132 # Use the tool's request_type rather than defaulting to POST (using local variable)
4133 method = tool_request_type.upper() if tool_request_type else "POST"
4134 with create_child_span("tool.gateway_call", {"tool.name": name, "tool.id": tool_id, "tool.integration_type": "REST"}):
4135 rest_start_time = time.time()
4136 try:
4137 if method == "GET":
4138 # For GET: Extract query params from URL and merge with input arguments
4139 # URL query params take precedence over input arguments on conflicts
4140 # to ensure stable behavior when tools have default params in their URL.
4141 parsed = urlparse(final_url)
4142 final_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
4143 query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()}
4145 conflicts = set(payload.keys()) & set(query_params.keys())
4146 if conflicts:
4147 logger.warning(
4148 f"REST tool GET request has conflicting parameters between URL and input arguments. "
4149 f"URL query params will take precedence for: {', '.join(sorted(conflicts))}. "
4150 f"Tool: {name}"
4151 )
4153 payload.update(query_params)
4154 response = await asyncio.wait_for(self._http_client.get(final_url, params=payload, headers=headers), timeout=effective_timeout)
4155 else:
4156 # For POST/PUT/PATCH/DELETE: Preserve query params in URL, only send input args in body.
4157 # This is critical for signed URLs (Azure SAS, AWS presigned URLs, webhook signatures).
4158 response = await asyncio.wait_for(self._http_client.request(method, final_url, json=payload, headers=headers), timeout=effective_timeout)
4159 except (asyncio.TimeoutError, httpx.TimeoutException):
4160 rest_elapsed_ms = (time.time() - rest_start_time) * 1000
4161 structured_logger.log(
4162 level="WARNING",
4163 message=f"REST tool invocation timed out: {tool_name_computed}",
4164 component="tool_service",
4165 correlation_id=get_correlation_id(),
4166 duration_ms=rest_elapsed_ms,
4167 metadata={"event": "tool_timeout", "tool_name": tool_name_computed, "timeout_seconds": effective_timeout},
4168 )
4170 # Manually trigger circuit breaker (or other plugins) on timeout
4171 try:
4172 # First-Party
4173 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
4175 tool_timeout_counter.labels(tool_name=name).inc()
4176 except Exception as exc:
4177 logger.debug(
4178 "Failed to increment tool_timeout_counter for %s: %s",
4179 name,
4180 exc,
4181 exc_info=True,
4182 )
4183 if self._plugin_manager:
4184 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table)
4186 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
4187 try:
4188 response.raise_for_status()
4189 except httpx.HTTPStatusError:
4190 # Non-2xx response — parse body (may be HTML, plain text, XML, etc.)
4191 try:
4192 result = response.json()
4193 except (json.JSONDecodeError, orjson.JSONDecodeError, UnicodeDecodeError, AttributeError) as e:
4194 result = _handle_json_parse_error(response, e, is_error_response=True)
4195 if "error" in result:
4196 error_val = result["error"]
4197 elif "response_text" in result:
4198 error_val = f"HTTP {response.status_code}: {result['response_text']}"
4199 else:
4200 error_val = f"HTTP {response.status_code}"
4201 tool_result = ToolResult(
4202 content=[TextContent(type="text", text=error_val if isinstance(error_val, str) else orjson.dumps(error_val).decode())],
4203 is_error=True,
4204 structured_content={"status_code": response.status_code},
4205 )
4206 # Don't mark as successful — success remains False
4208 # Handle 204 No Content responses that have no body
4209 if tool_result is not None and tool_result.is_error:
4210 pass # Already handled by HTTPStatusError above
4211 elif response.status_code == 204:
4212 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")])
4213 success = True
4214 elif response.status_code not in [200, 201, 202, 206]:
4215 # Non-standard 2xx codes (203, 205, 207, etc.) treated as errors
4216 try:
4217 result = response.json()
4218 except (json.JSONDecodeError, orjson.JSONDecodeError, UnicodeDecodeError, AttributeError) as e:
4219 result = _handle_json_parse_error(response, e, is_error_response=True)
4220 error_val = result["error"] if "error" in result else "Tool error encountered"
4221 tool_result = ToolResult(
4222 content=[TextContent(type="text", text=error_val if isinstance(error_val, str) else orjson.dumps(error_val).decode())],
4223 is_error=True,
4224 )
4225 # Don't mark as successful for error responses - success remains False
4226 else:
4227 try:
4228 result = response.json()
4229 except (json.JSONDecodeError, orjson.JSONDecodeError, UnicodeDecodeError, AttributeError) as e:
4230 result = _handle_json_parse_error(response, e, is_error_response=False)
4231 logger.debug(f"REST API tool response: {result}")
4232 filtered_response = extract_using_jq(result, tool_jsonpath_filter)
4233 # Check if extract_using_jq returned an error (list of TextContent objects)
4234 if isinstance(filtered_response, list) and len(filtered_response) > 0 and isinstance(filtered_response[0], TextContent):
4235 # Error case - use the TextContent directly
4236 tool_result = ToolResult(content=filtered_response, is_error=True)
4237 success = False
4238 else:
4239 # Success case - serialize the filtered response
4240 serialized = orjson.dumps(filtered_response, option=orjson.OPT_INDENT_2)
4241 tool_result = ToolResult(content=[TextContent(type="text", text=serialized.decode())])
4242 success = True
4243 # If output schema is present, validate and attach structured content
4244 if tool_output_schema:
4245 valid = self._extract_and_validate_structured_content(tool_for_validation, tool_result, candidate=filtered_response)
4246 success = bool(valid)
4247 elif tool_integration_type == "MCP":
4248 transport = tool_request_type.lower() if tool_request_type else "sse"
4250 # Handle OAuth authentication for the gateway (using local variables)
4251 # NOTE: Use has_gateway instead of gateway to avoid accessing detached ORM object
4252 if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config:
4253 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
4255 if grant_type == "authorization_code":
4256 # For Authorization Code flow, try to get stored tokens
4257 # NOTE: Use fresh_db_session() since the original db was closed
4258 try:
4259 # First-Party
4260 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
4262 with fresh_db_session() as token_db:
4263 token_storage = TokenStorageService(token_db)
4265 # Get user-specific OAuth token
4266 if not app_user_email:
4267 raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.")
4269 access_token = await token_storage.get_user_token(gateway_id_str, app_user_email)
4271 if access_token:
4272 headers = {"Authorization": f"Bearer {access_token}"}
4273 else:
4274 # User hasn't authorized this gateway yet
4275 raise ToolInvocationError(f"Please authorize {gateway_name} first. Visit /oauth/authorize/{gateway_id_str} to complete OAuth flow.")
4276 except Exception as e:
4277 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
4278 raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}")
4279 else:
4280 # For Client Credentials flow, get token directly (no DB needed)
4281 try:
4282 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
4283 headers = {"Authorization": f"Bearer {access_token}"}
4284 except Exception as e:
4285 logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}")
4286 raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}")
4287 else:
4288 headers = decode_auth(gateway_auth_value) if gateway_auth_value else {}
4290 # Use cached passthrough headers (no DB query needed)
4291 if request_headers:
4292 headers = compute_passthrough_headers_cached(
4293 request_headers, headers, passthrough_allowed, gateway_auth_type=gateway_auth_type, gateway_passthrough_headers=gateway_passthrough
4294 )
4295 # Read MCP-Session-Id from downstream client (MCP protocol header)
4296 # and normalize to x-mcp-session-id for our internal session affinity logic
4297 # The pool will strip this before sending to upstream
4298 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests)
4299 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
4300 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id")
4301 if mcp_session_id:
4302 headers["x-mcp-session-id"] = mcp_session_id
4304 worker_id = str(os.getpid())
4305 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
4306 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity (MCP transport)")
4308 # mTLS client cert/key: resolve from payload, then override with runtime gateway if available
4309 client_cert_from_payload = gateway_payload.get("client_cert") if has_gateway else None
4310 client_key_from_payload = gateway_payload.get("client_key") if has_gateway else None
4312 # Resolve client cert/key: payload values take precedence, runtime values override if present
4313 gateway_client_cert = client_cert_from_payload
4314 gateway_client_key = client_key_from_payload
4315 if has_gateway and gateway is not None:
4316 runtime_gateway_client_cert = getattr(gateway, "client_cert", None)
4317 runtime_gateway_client_key = getattr(gateway, "client_key", None)
4318 if runtime_gateway_client_cert:
4319 gateway_client_cert = runtime_gateway_client_cert
4320 if runtime_gateway_client_key:
4321 gateway_client_key = runtime_gateway_client_key
4323 # Decrypt client_key if stored encrypted
4324 if gateway_client_key:
4325 try:
4326 # First-Party
4327 from mcpgateway.services.encryption_service import get_encryption_service # pylint: disable=import-outside-toplevel
4329 _enc = get_encryption_service(settings.auth_encryption_secret)
4330 gateway_client_key = _enc.decrypt_secret_or_plaintext(gateway_client_key)
4331 except Exception as _dec_exc:
4332 logger.debug("client_key decryption skipped, using as-is: %s", _dec_exc)
4334 def create_ssl_context(
4335 ca_certificate: str,
4336 client_cert: str | None = None,
4337 client_key: str | None = None,
4338 ) -> ssl.SSLContext:
4339 """Create an SSL context with the provided CA certificate and optional mTLS credentials.
4341 Uses caching to avoid repeated SSL context creation for the same certificate(s).
4343 Args:
4344 ca_certificate: CA certificate in PEM format
4345 client_cert: Optional client cert path or PEM for mTLS
4346 client_key: Optional client key path or PEM for mTLS
4348 Returns:
4349 ssl.SSLContext: Configured SSL context
4350 """
4351 return get_cached_ssl_context(ca_certificate, client_cert=client_cert, client_key=client_key)
4353 # Capture mTLS client cert/key values for passing to nested function
4354 _client_cert_value = gateway_client_cert
4355 _client_key_value = gateway_client_key
4357 def get_httpx_client_factory(
4358 headers: dict[str, str] | None = None,
4359 timeout: httpx.Timeout | None = None,
4360 auth: httpx.Auth | None = None,
4361 ) -> httpx.AsyncClient:
4362 """Factory function to create httpx.AsyncClient with optional CA certificate.
4364 Args:
4365 headers: Optional headers for the client
4366 timeout: Optional timeout for the client
4367 auth: Optional auth for the client
4369 Returns:
4370 httpx.AsyncClient: Configured HTTPX async client
4372 Raises:
4373 Exception: If CA certificate signature is invalid
4374 """
4375 # Use captured client cert/key values from closure
4376 client_cert_value = _client_cert_value
4377 client_key_value = _client_key_value
4378 # Use local variables instead of ORM objects (captured from outer scope)
4379 valid = False
4380 if gateway_ca_cert:
4381 if settings.enable_ed25519_signing:
4382 public_key_pem = settings.ed25519_public_key
4383 valid = validate_signature(gateway_ca_cert.encode(), gateway_ca_cert_sig, public_key_pem)
4384 else:
4385 valid = True
4386 # First-Party
4387 from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout # pylint: disable=import-outside-toplevel
4389 # For plain HTTP gateway URLs, skip SSL context entirely to avoid unnecessary SSL setup.
4390 if gateway_url and gateway_url.lower().startswith("http://"):
4391 ctx = None
4392 elif valid and gateway_ca_cert:
4393 ctx = create_ssl_context(
4394 gateway_ca_cert,
4395 client_cert=client_cert_value,
4396 client_key=client_key_value,
4397 )
4398 else:
4399 ctx = None
4401 # Use effective_timeout for read operations if not explicitly overridden by caller
4402 # This ensures the underlying client waits at least as long as the tool configuration requires
4403 factory_timeout = timeout if timeout else get_http_timeout(read_timeout=effective_timeout)
4405 return httpx.AsyncClient(
4406 verify=ctx if ctx else get_default_verify(),
4407 follow_redirects=True,
4408 headers=headers,
4409 timeout=factory_timeout,
4410 auth=auth,
4411 limits=httpx.Limits(
4412 max_connections=settings.httpx_max_connections,
4413 max_keepalive_connections=settings.httpx_max_keepalive_connections,
4414 keepalive_expiry=settings.httpx_keepalive_expiry,
4415 ),
4416 )
4418 async def connect_to_sse_server(server_url: str, headers: dict = headers):
4419 """Connect to an MCP server running with SSE transport.
4421 Args:
4422 server_url: MCP Server SSE URL
4423 headers: HTTP headers to include in the request
4425 Returns:
4426 ToolResult: Result of tool call
4428 Raises:
4429 ToolInvocationError: If the tool invocation fails during execution.
4430 ToolTimeoutError: If the tool invocation times out.
4431 BaseException: On connection or communication errors
4433 """
4434 # Get correlation ID for distributed tracing
4435 correlation_id = get_correlation_id()
4437 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions.
4438 # MCP SDK pins headers at transport creation, so adding per-request headers
4439 # would cause the first request's correlation ID to be reused for all
4440 # subsequent requests on the same pooled session. Correlation IDs are
4441 # still logged locally for tracing within the gateway.
4443 # Log MCP call start (using local variables)
4444 # Sanitize server_url to redact sensitive query params from logs
4445 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted)
4446 mcp_start_time = time.time()
4447 structured_logger.log(
4448 level="INFO",
4449 message=f"MCP tool call started: {tool_name_original}",
4450 component="tool_service",
4451 correlation_id=correlation_id,
4452 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "sse"},
4453 )
4455 try:
4456 # Use session pool if enabled for 10-20x latency improvement
4457 tool_call_result = None
4458 use_pool = False
4459 pool = None
4460 if settings.mcp_session_pool_enabled:
4461 try:
4462 pool = get_mcp_session_pool()
4463 use_pool = True
4464 except RuntimeError:
4465 # Pool not initialized (e.g., in tests), fall back to per-call sessions
4466 pass
4468 if use_pool and pool is not None:
4469 # Pooled path: do NOT inject per-request trace headers
4470 # The pool reuses transports with pinned headers, so injecting
4471 # traceparent/X-Correlation-ID would cause the first request's
4472 # trace context to be replayed on later unrelated requests,
4473 # corrupting distributed traces and leaking correlation IDs.
4474 # Trade-off: pooled sessions gain 10-20x latency improvement
4475 # but lose distributed trace propagation to upstream servers.
4476 async with pool.session(
4477 url=server_url,
4478 headers=headers,
4479 transport_type=TransportType.SSE,
4480 httpx_client_factory=get_httpx_client_factory,
4481 user_identity=app_user_email,
4482 gateway_id=gateway_id_str,
4483 ) as pooled:
4484 with anyio.fail_after(effective_timeout):
4485 tool_call_result = await pooled.session.call_tool(tool_name_original, arguments, meta=meta_data)
4486 else:
4487 # Fallback to per-call sessions when pool disabled or not initialized
4488 with create_span(
4489 "mcp.client.call",
4490 {
4491 "mcp.tool.name": tool_name_original,
4492 "contextforge.tool.id": tool_id,
4493 "contextforge.gateway_id": tool_gateway_id,
4494 "contextforge.runtime": "python",
4495 "contextforge.transport": "sse",
4496 "network.protocol.name": "mcp",
4497 "server.address": urlparse(server_url).hostname,
4498 "server.port": urlparse(server_url).port,
4499 "url.path": urlparse(server_url).path or "/",
4500 "url.full": server_url_sanitized,
4501 },
4502 ):
4503 # Non-pooled path: safe to add per-request headers.
4504 # Inject within the active client span so an upstream service
4505 # can attach beneath this span when it extracts traceparent.
4506 request_headers = inject_trace_context_headers(headers)
4507 if correlation_id and request_headers:
4508 request_headers["X-Correlation-ID"] = correlation_id
4509 async with sse_client(url=server_url, headers=request_headers, httpx_client_factory=get_httpx_client_factory) as streams:
4510 async with ClientSession(*streams) as session:
4511 with create_span("mcp.client.initialize", {"contextforge.transport": "sse", "contextforge.runtime": "python"}):
4512 await session.initialize()
4513 with create_span(
4514 "mcp.client.request",
4515 {
4516 "mcp.tool.name": tool_name_original,
4517 "contextforge.tool.id": tool_id,
4518 "contextforge.gateway_id": tool_gateway_id,
4519 "contextforge.runtime": "python",
4520 },
4521 ):
4522 with anyio.fail_after(effective_timeout):
4523 tool_call_result = await session.call_tool(tool_name_original, arguments, meta=meta_data)
4524 with create_span(
4525 "mcp.client.response",
4526 {
4527 "mcp.tool.name": tool_name_original,
4528 "contextforge.tool.id": tool_id,
4529 "contextforge.gateway_id": tool_gateway_id,
4530 "contextforge.runtime": "python",
4531 "upstream.response.success": not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False),
4532 },
4533 ):
4534 pass
4536 # Log successful MCP call
4537 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
4538 structured_logger.log(
4539 level="INFO",
4540 message=f"MCP tool call completed: {tool_name_original}",
4541 component="tool_service",
4542 correlation_id=correlation_id,
4543 duration_ms=mcp_duration_ms,
4544 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "success": True},
4545 )
4547 return tool_call_result
4548 except (asyncio.TimeoutError, httpx.TimeoutException):
4549 # Handle timeout specifically - log and raise ToolInvocationError
4550 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
4551 structured_logger.log(
4552 level="WARNING",
4553 message=f"MCP SSE tool invocation timed out: {tool_name_original}",
4554 component="tool_service",
4555 correlation_id=correlation_id,
4556 duration_ms=mcp_duration_ms,
4557 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "timeout_seconds": effective_timeout},
4558 )
4560 # Manually trigger circuit breaker (or other plugins) on timeout
4561 try:
4562 # First-Party
4563 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
4565 tool_timeout_counter.labels(tool_name=name).inc()
4566 except Exception as exc:
4567 logger.debug(
4568 "Failed to increment tool_timeout_counter for %s: %s",
4569 name,
4570 exc,
4571 exc_info=True,
4572 )
4574 if self._plugin_manager:
4575 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table)
4577 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
4578 except BaseException as e:
4579 # Extract root cause from ExceptionGroup (Python 3.11+)
4580 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
4581 root_cause = e
4582 if isinstance(e, BaseExceptionGroup):
4583 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
4584 root_cause = root_cause.exceptions[0]
4585 # Log failed MCP call (using local variables)
4586 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
4587 # Sanitize error message to prevent URL secrets from leaking in logs
4588 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted)
4589 structured_logger.log(
4590 level="ERROR",
4591 message=f"MCP tool call failed: {tool_name_original}",
4592 component="tool_service",
4593 correlation_id=correlation_id,
4594 duration_ms=mcp_duration_ms,
4595 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error},
4596 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse"},
4597 )
4598 raise
4600 async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers):
4601 """Connect to an MCP server running with Streamable HTTP transport.
4603 Args:
4604 server_url: MCP Server URL
4605 headers: HTTP headers to include in the request
4607 Returns:
4608 ToolResult: Result of tool call
4610 Raises:
4611 ToolInvocationError: If the tool invocation fails during execution.
4612 ToolTimeoutError: If the tool invocation times out.
4613 BaseException: On connection or communication errors
4614 """
4615 # Get correlation ID for distributed tracing
4616 correlation_id = get_correlation_id()
4618 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions.
4619 # MCP SDK pins headers at transport creation, so adding per-request headers
4620 # would cause the first request's correlation ID to be reused for all
4621 # subsequent requests on the same pooled session. Correlation IDs are
4622 # still logged locally for tracing within the gateway.
4624 # Log MCP call start (using local variables)
4625 # Sanitize server_url to redact sensitive query params from logs
4626 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted)
4627 mcp_start_time = time.time()
4628 structured_logger.log(
4629 level="INFO",
4630 message=f"MCP tool call started: {tool_name_original}",
4631 component="tool_service",
4632 correlation_id=correlation_id,
4633 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "streamablehttp"},
4634 )
4636 try:
4637 # Use session pool if enabled for 10-20x latency improvement
4638 tool_call_result = None
4639 use_pool = False
4640 pool = None
4641 if settings.mcp_session_pool_enabled:
4642 try:
4643 pool = get_mcp_session_pool()
4644 use_pool = True
4645 except RuntimeError:
4646 # Pool not initialized (e.g., in tests), fall back to per-call sessions
4647 pass
4649 if use_pool and pool is not None:
4650 # Pooled path: do NOT inject per-request trace headers
4651 # The pool reuses transports with pinned headers, so injecting
4652 # traceparent/X-Correlation-ID would cause the first request's
4653 # trace context to be replayed on later unrelated requests,
4654 # corrupting distributed traces and leaking correlation IDs.
4655 # Trade-off: pooled sessions gain 10-20x latency improvement
4656 # but lose distributed trace propagation to upstream servers.
4657 # Determine transport type based on current transport setting
4658 pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP
4659 async with pool.session(
4660 url=server_url,
4661 headers=headers,
4662 transport_type=pool_transport_type,
4663 httpx_client_factory=get_httpx_client_factory,
4664 user_identity=app_user_email,
4665 gateway_id=gateway_id_str,
4666 ) as pooled:
4667 with anyio.fail_after(effective_timeout):
4668 tool_call_result = await pooled.session.call_tool(tool_name_original, arguments, meta=meta_data)
4669 else:
4670 # Fallback to per-call sessions when pool disabled or not initialized
4671 with create_span(
4672 "mcp.client.call",
4673 {
4674 "mcp.tool.name": tool_name_original,
4675 "contextforge.tool.id": tool_id,
4676 "contextforge.gateway_id": tool_gateway_id,
4677 "contextforge.runtime": "python",
4678 "contextforge.transport": "streamablehttp",
4679 "network.protocol.name": "mcp",
4680 "server.address": urlparse(server_url).hostname,
4681 "server.port": urlparse(server_url).port,
4682 "url.path": urlparse(server_url).path or "/",
4683 "url.full": server_url_sanitized,
4684 },
4685 ):
4686 # Non-pooled path: safe to add per-request headers.
4687 # Inject within the active client span so an upstream service
4688 # can attach beneath this span when it extracts traceparent.
4689 request_headers = inject_trace_context_headers(headers)
4690 if correlation_id and request_headers:
4691 request_headers["X-Correlation-ID"] = correlation_id
4692 async with streamablehttp_client(url=server_url, headers=request_headers, httpx_client_factory=get_httpx_client_factory) as (
4693 read_stream,
4694 write_stream,
4695 _get_session_id,
4696 ):
4697 async with ClientSession(read_stream, write_stream) as session:
4698 with create_span("mcp.client.initialize", {"contextforge.transport": "streamablehttp", "contextforge.runtime": "python"}):
4699 await session.initialize()
4700 with create_span(
4701 "mcp.client.request",
4702 {
4703 "mcp.tool.name": tool_name_original,
4704 "contextforge.tool.id": tool_id,
4705 "contextforge.gateway_id": tool_gateway_id,
4706 "contextforge.runtime": "python",
4707 },
4708 ):
4709 with anyio.fail_after(effective_timeout):
4710 tool_call_result = await session.call_tool(tool_name_original, arguments, meta=meta_data)
4711 with create_span(
4712 "mcp.client.response",
4713 {
4714 "mcp.tool.name": tool_name_original,
4715 "contextforge.tool.id": tool_id,
4716 "contextforge.gateway_id": tool_gateway_id,
4717 "contextforge.runtime": "python",
4718 "upstream.response.success": not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False),
4719 },
4720 ):
4721 pass
4723 # Log successful MCP call
4724 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
4725 structured_logger.log(
4726 level="INFO",
4727 message=f"MCP tool call completed: {tool_name_original}",
4728 component="tool_service",
4729 correlation_id=correlation_id,
4730 duration_ms=mcp_duration_ms,
4731 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "success": True},
4732 )
4734 return tool_call_result
4735 except (asyncio.TimeoutError, httpx.TimeoutException):
4736 # Handle timeout specifically - log and raise ToolInvocationError
4737 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
4738 structured_logger.log(
4739 level="WARNING",
4740 message=f"MCP StreamableHTTP tool invocation timed out: {tool_name_original}",
4741 component="tool_service",
4742 correlation_id=correlation_id,
4743 duration_ms=mcp_duration_ms,
4744 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "timeout_seconds": effective_timeout},
4745 )
4747 # Manually trigger circuit breaker (or other plugins) on timeout
4748 try:
4749 # First-Party
4750 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
4752 tool_timeout_counter.labels(tool_name=name).inc()
4753 except Exception as exc:
4754 logger.debug(
4755 "Failed to increment tool_timeout_counter for %s: %s",
4756 name,
4757 exc,
4758 exc_info=True,
4759 )
4761 if self._plugin_manager:
4762 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table)
4764 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
4765 except BaseException as e:
4766 # Extract root cause from ExceptionGroup (Python 3.11+)
4767 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
4768 root_cause = e
4769 if isinstance(e, BaseExceptionGroup):
4770 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
4771 root_cause = root_cause.exceptions[0]
4772 # Log failed MCP call
4773 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
4774 # Sanitize error message to prevent URL secrets from leaking in logs
4775 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted)
4776 structured_logger.log(
4777 level="ERROR",
4778 message=f"MCP tool call failed: {tool_name_original}",
4779 component="tool_service",
4780 correlation_id=correlation_id,
4781 duration_ms=mcp_duration_ms,
4782 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error},
4783 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp"},
4784 )
4785 raise
4787 # REMOVED: Redundant gateway query - gateway already eager-loaded via joinedload
4788 # tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id)...)
4790 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke:
4791 # Use pre-created Pydantic models from Phase 2 (no ORM access)
4792 if tool_metadata:
4793 global_context.metadata[TOOL_METADATA] = tool_metadata
4794 if gateway_metadata:
4795 global_context.metadata[GATEWAY_METADATA] = gateway_metadata
4796 pre_result, context_table = await self._plugin_manager.invoke_hook(
4797 ToolHookType.TOOL_PRE_INVOKE,
4798 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
4799 global_context=global_context,
4800 local_contexts=None,
4801 violations_as_exceptions=True,
4802 )
4803 if pre_result.modified_payload:
4804 payload = pre_result.modified_payload
4805 name = payload.name
4806 arguments = payload.args
4807 if payload.headers is not None:
4808 headers = payload.headers.model_dump()
4810 with create_child_span("tool.gateway_call", {"tool.name": name, "tool.id": tool_id, "tool.integration_type": "MCP"}):
4811 tool_call_result = ToolResult(content=[TextContent(text="", type="text")])
4812 if transport == "sse":
4813 tool_call_result = await connect_to_sse_server(gateway_url, headers=headers)
4814 elif transport == "streamablehttp":
4815 tool_call_result = await connect_to_streamablehttp_server(gateway_url, headers=headers)
4817 # In direct proxy mode, use the tool result as-is without splitting content
4818 if is_direct_proxy:
4819 tool_result = tool_call_result
4820 success = not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False)
4821 logger.debug(f"Direct proxy mode: using tool result as-is: {tool_result}")
4822 else:
4823 dump = tool_call_result.model_dump(by_alias=True, mode="json")
4824 logger.debug(f"Tool call result dump: {dump}")
4825 content = dump.get("content", [])
4826 # Accept both alias and pythonic names for structured content
4827 structured = dump.get("structuredContent") or dump.get("structured_content")
4828 filtered_response = extract_using_jq(content, tool_jsonpath_filter)
4830 is_err = getattr(tool_call_result, "is_error", None)
4831 if is_err is None:
4832 is_err = getattr(tool_call_result, "isError", False)
4833 tool_result = ToolResult(content=filtered_response, structured_content=structured, is_error=is_err, meta=getattr(tool_call_result, "meta", None))
4834 success = not is_err
4835 logger.debug(f"Final tool_result: {tool_result}")
4837 elif tool_integration_type == "A2A" and a2a_agent_endpoint_url:
4838 # A2A tool invocation using pre-extracted agent data (extracted in Phase 2 before db.close())
4839 headers = {"Content-Type": "application/json"}
4841 # Plugin hook: tool pre-invoke for A2A
4842 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke:
4843 if tool_metadata:
4844 global_context.metadata[TOOL_METADATA] = tool_metadata
4845 pre_result, context_table = await self._plugin_manager.invoke_hook(
4846 ToolHookType.TOOL_PRE_INVOKE,
4847 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
4848 global_context=global_context,
4849 local_contexts=context_table,
4850 violations_as_exceptions=True,
4851 )
4852 if pre_result.modified_payload:
4853 payload = pre_result.modified_payload
4854 name = payload.name
4855 arguments = payload.args
4856 if payload.headers is not None:
4857 headers = payload.headers.model_dump()
4859 # Build request data based on agent type
4860 endpoint_url = a2a_agent_endpoint_url
4861 if a2a_agent_type in ["generic", "jsonrpc"] or endpoint_url.endswith("/"):
4862 # JSONRPC agents: Convert flat query to nested message structure
4863 params = None
4864 if isinstance(arguments, dict) and "query" in arguments and isinstance(arguments["query"], str):
4865 message_id = f"admin-test-{int(time.time())}"
4866 # A2A v0.3.x: message.parts use "kind" (not "type").
4867 params = {
4868 "message": {
4869 "kind": "message",
4870 "messageId": message_id,
4871 "role": "user",
4872 "parts": [{"kind": "text", "text": arguments["query"]}],
4873 }
4874 }
4875 method = arguments.get("method", "message/send")
4876 else:
4877 params = arguments.get("params", arguments) if isinstance(arguments, dict) else arguments
4878 method = arguments.get("method", "message/send") if isinstance(arguments, dict) else "message/send"
4879 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1}
4880 else:
4881 # Custom agents: Pass parameters directly
4882 params = arguments if isinstance(arguments, dict) else {}
4883 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": a2a_agent_protocol_version}
4885 # Add authentication
4886 if a2a_agent_auth_type in ("api_key", "basic", "bearer", "authheaders") and a2a_agent_auth_value:
4887 # Decrypt auth_value before using it
4888 if isinstance(a2a_agent_auth_value, str):
4889 try:
4890 auth_headers = decode_auth(a2a_agent_auth_value)
4891 headers.update(auth_headers)
4892 except Exception as e:
4893 logger.error(f"Failed to decrypt authentication for A2A agent '{a2a_agent_name}': {e}")
4894 raise ToolInvocationError(f"Failed to decrypt authentication for A2A agent '{a2a_agent_name}'")
4895 elif isinstance(a2a_agent_auth_value, dict):
4896 auth_headers = {str(k): str(v) for k, v in a2a_agent_auth_value.items()}
4897 headers.update(auth_headers)
4898 elif a2a_agent_auth_type == "query_param" and a2a_agent_auth_query_params:
4899 auth_query_params_decrypted: dict[str, str] = {}
4900 for param_key, encrypted_value in a2a_agent_auth_query_params.items():
4901 if encrypted_value:
4902 try:
4903 decrypted = decode_auth(encrypted_value)
4904 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
4905 except Exception:
4906 logger.debug(f"Failed to decrypt query param for key '{param_key}'")
4907 if auth_query_params_decrypted:
4908 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted)
4910 with create_child_span("tool.gateway_call", {"tool.name": name, "tool.id": tool_id, "tool.integration_type": "A2A"}):
4911 # Make HTTP request with timeout enforcement
4912 logger.info(f"Calling A2A agent '{a2a_agent_name}' at {endpoint_url}")
4913 a2a_start_time = time.time()
4914 try:
4915 http_response = await asyncio.wait_for(self._http_client.post(endpoint_url, json=request_data, headers=headers), timeout=effective_timeout)
4916 except (asyncio.TimeoutError, httpx.TimeoutException):
4917 a2a_elapsed_ms = (time.time() - a2a_start_time) * 1000
4918 structured_logger.log(
4919 level="WARNING",
4920 message=f"A2A tool invocation timed out: {name}",
4921 component="tool_service",
4922 correlation_id=get_correlation_id(),
4923 duration_ms=a2a_elapsed_ms,
4924 metadata={"event": "tool_timeout", "tool_name": name, "a2a_agent": a2a_agent_name, "timeout_seconds": effective_timeout},
4925 )
4927 # Increment timeout counter
4928 try:
4929 # First-Party
4930 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
4932 tool_timeout_counter.labels(tool_name=name).inc()
4933 except Exception as exc:
4934 logger.debug("Failed to increment tool_timeout_counter for %s: %s", name, exc, exc_info=True)
4936 # Trigger circuit breaker on timeout
4937 if self._plugin_manager:
4938 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table)
4940 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
4942 if http_response.status_code == 200:
4943 response_data = http_response.json()
4944 if isinstance(response_data, dict) and "response" in response_data:
4945 val = response_data["response"]
4946 content = [TextContent(type="text", text=val if isinstance(val, str) else orjson.dumps(val).decode())]
4947 else:
4948 content = [TextContent(type="text", text=response_data if isinstance(response_data, str) else orjson.dumps(response_data).decode())]
4949 tool_result = ToolResult(content=content, is_error=False)
4950 success = True
4951 else:
4952 error_message = f"HTTP {http_response.status_code}: {http_response.text}"
4953 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")]
4954 tool_result = ToolResult(content=content, is_error=True)
4955 else:
4956 tool_result = ToolResult(content=[TextContent(type="text", text="Invalid tool type")], is_error=True)
4958 with create_child_span("tool.post_process", {"tool.name": name, "tool.id": tool_id}):
4959 post_result = None
4960 # Plugin hook: tool post-invoke
4961 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
4962 post_result, _ = await self._plugin_manager.invoke_hook(
4963 ToolHookType.TOOL_POST_INVOKE,
4964 payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)),
4965 global_context=global_context,
4966 local_contexts=context_table,
4967 violations_as_exceptions=True,
4968 )
4969 # Use modified payload if provided
4970 if post_result.modified_payload:
4971 # Reconstruct ToolResult from modified result
4972 modified_result = post_result.modified_payload.result
4973 if isinstance(modified_result, dict) and "content" in modified_result:
4974 # Safely obtain structured content using .get() to avoid KeyError when
4975 # plugins provide only the content without structured content fields.
4976 structured = modified_result.get("structuredContent") if "structuredContent" in modified_result else modified_result.get("structured_content")
4978 tool_result = ToolResult(content=modified_result["content"], structured_content=structured)
4979 else:
4980 # If result is not in expected format, convert it to text content
4981 try:
4982 tool_result = ToolResult(content=[TextContent(type="text", text=modified_result if isinstance(modified_result, str) else orjson.dumps(modified_result).decode())])
4983 except Exception:
4984 tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))])
4986 # Retry: if the plugin requested a delayed retry and we haven't hit the gateway ceiling.
4987 # retry_attempt is 0-based (0 = original call). The condition allows retry_attempt
4988 # values 0..max_tool_retries-1, meaning up to max_tool_retries *retry* attempts on
4989 # top of the original call (total attempts = max_tool_retries + 1).
4990 if post_result is not None and post_result.retry_delay_ms > 0 and retry_attempt < settings.max_tool_retries:
4991 return await self._retry_tool_invocation(
4992 post_result.retry_delay_ms,
4993 retry_attempt,
4994 name,
4995 arguments,
4996 request_headers,
4997 app_user_email,
4998 user_email,
4999 token_teams,
5000 server_id,
5001 context_table,
5002 global_context,
5003 meta_data,
5004 skip_pre_invoke,
5005 "success",
5006 )
5008 return tool_result
5009 except (PluginError, PluginViolationError):
5010 raise
5011 except ToolTimeoutError as e:
5012 # ToolTimeoutError is raised by timeout handlers which already called tool_post_invoke.
5013 # Do NOT call post_invoke again — the retry_delay_ms signal is carried on the exception.
5014 error_message = str(e)
5015 if span:
5016 set_span_error(span, error_message)
5018 # Retry if the post-invoke hook (called by the timeout handler) requested it.
5019 if e.retry_delay_ms > 0 and retry_attempt < settings.max_tool_retries:
5020 return await self._retry_tool_invocation(
5021 e.retry_delay_ms,
5022 retry_attempt,
5023 name,
5024 arguments,
5025 request_headers,
5026 app_user_email,
5027 user_email,
5028 token_teams,
5029 server_id,
5030 context_table,
5031 global_context,
5032 meta_data,
5033 skip_pre_invoke,
5034 "timeout",
5035 )
5036 raise
5037 except BaseException as e:
5038 # Extract root cause from ExceptionGroup (Python 3.11+)
5039 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
5040 root_cause = e
5041 if isinstance(e, BaseExceptionGroup):
5042 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
5043 root_cause = root_cause.exceptions[0]
5044 error_message = str(root_cause)
5045 # Set span error status
5046 if span:
5047 set_span_error(span, error_message)
5049 # Notify plugins of the failure so circuit breaker / retry plugin can track it.
5050 # Capture the result so we can honour a retry_delay_ms signal from the retry plugin.
5051 # When the exception carries an HTTP status code (e.g. httpx.HTTPStatusError),
5052 # include it in structuredContent so the retry plugin can honour retry_on_status
5053 # instead of blindly retrying every exception.
5054 exc_post_result = None
5055 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
5056 try:
5057 exc_structured: Optional[Dict[str, Any]] = None
5058 if isinstance(root_cause, httpx.HTTPStatusError):
5059 exc_structured = {"status_code": root_cause.response.status_code}
5060 exception_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation failed: {error_message}")], is_error=True, structured_content=exc_structured)
5061 exc_post_result, _ = await self._plugin_manager.invoke_hook(
5062 ToolHookType.TOOL_POST_INVOKE,
5063 payload=ToolPostInvokePayload(name=name, result=exception_error_result.model_dump(by_alias=True)),
5064 global_context=global_context,
5065 local_contexts=context_table,
5066 violations_as_exceptions=False, # Don't let plugin errors mask the original exception
5067 )
5068 except Exception as plugin_exc:
5069 logger.debug("Failed to invoke post-invoke plugins on exception: %s", plugin_exc)
5071 # Retry if the plugin requested a delayed retry and we haven't hit the ceiling.
5072 # Same counting convention as the success path: retry_attempt is 0-based,
5073 # so this allows up to max_tool_retries retry attempts beyond the original call.
5074 if exc_post_result is not None and exc_post_result.retry_delay_ms > 0 and retry_attempt < settings.max_tool_retries:
5075 return await self._retry_tool_invocation(
5076 exc_post_result.retry_delay_ms,
5077 retry_attempt,
5078 name,
5079 arguments,
5080 request_headers,
5081 app_user_email,
5082 user_email,
5083 token_teams,
5084 server_id,
5085 context_table,
5086 global_context,
5087 meta_data,
5088 skip_pre_invoke,
5089 "exception",
5090 )
5092 raise ToolInvocationError(f"Tool invocation failed: {error_message}")
5093 finally:
5094 # Calculate duration
5095 duration_ms = (time.monotonic() - start_time) * 1000
5097 # End database span for observability_spans table
5098 # Use commit=False since fresh_db_session() handles commits on exit
5099 if db_span_id and observability_service and not db_span_ended:
5100 try:
5101 with fresh_db_session() as span_db:
5102 observability_service.end_span(
5103 db=span_db,
5104 span_id=db_span_id,
5105 status="ok" if success else "error",
5106 status_message=error_message if error_message else None,
5107 attributes={
5108 "success": success,
5109 "duration_ms": duration_ms,
5110 },
5111 commit=False,
5112 )
5113 db_span_ended = True
5114 logger.debug(f"✓ Ended tool.invoke span: {db_span_id}")
5115 except Exception as e:
5116 logger.warning(f"Failed to end observability span for tool invocation: {e}")
5118 # Add final span attributes for OpenTelemetry
5119 if span:
5120 set_span_attribute(span, "success", success)
5121 set_span_attribute(span, "duration.ms", duration_ms)
5122 if success and tool_result and is_output_capture_enabled("tool.invoke"):
5123 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload(tool_result))
5125 # ═══════════════════════════════════════════════════════════════════════════
5126 # PHASE 4: Record metrics via buffered service (batches writes for performance)
5127 # ═══════════════════════════════════════════════════════════════════════════
5128 # Only record metrics if tool_id is valid (skip for direct_proxy mode)
5129 if tool_id:
5130 try:
5131 metrics_buffer.record_tool_metric(
5132 tool_id=tool_id,
5133 start_time=start_time,
5134 success=success,
5135 error_message=error_message,
5136 )
5137 except Exception as metric_error:
5138 logger.warning(f"Failed to record tool metric: {metric_error}")
5140 # Record server metrics ONLY when invoked through a specific virtual server
5141 # When server_id is provided, it means the tool was called via a virtual server endpoint
5142 # Direct tool calls via /rpc should NOT populate server metrics
5143 if tool_id and server_id:
5144 try:
5145 # Record server metric only for the specific virtual server being accessed
5146 metrics_buffer.record_server_metric(
5147 server_id=server_id,
5148 start_time=start_time,
5149 success=success,
5150 error_message=error_message,
5151 )
5152 except Exception as metric_error:
5153 logger.warning(f"Failed to record server metric: {metric_error}")
5155 # Log structured message with performance tracking (using local variables)
5156 if success:
5157 structured_logger.info(
5158 f"Tool '{name}' invoked successfully",
5159 user_id=app_user_email,
5160 resource_type="tool",
5161 resource_id=tool_id,
5162 resource_action="invoke",
5163 duration_ms=duration_ms,
5164 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "arguments_count": len(arguments) if arguments else 0},
5165 )
5166 else:
5167 structured_logger.error(
5168 f"Tool '{name}' invocation failed",
5169 error=Exception(error_message) if error_message else None,
5170 user_id=app_user_email,
5171 resource_type="tool",
5172 resource_id=tool_id,
5173 resource_action="invoke",
5174 duration_ms=duration_ms,
5175 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "error_message": error_message},
5176 )
5178 # Track performance with threshold checking
5179 with perf_tracker.track_operation("tool_invocation", name):
5180 pass # Duration already captured above
5182 @staticmethod
5183 def _check_tool_name_conflict(db: Session, custom_name: str, visibility: str, tool_id: str, team_id: Optional[str] = None, owner_email: Optional[str] = None) -> None:
5184 """Raise ToolNameConflictError if another tool with the same name exists in the target visibility scope.
5186 Args:
5187 db: The SQLAlchemy database session.
5188 custom_name: The custom name to check for conflicts.
5189 visibility: The target visibility scope (``public``, ``team``, or ``private``).
5190 tool_id: The ID of the tool being updated (excluded from the conflict search).
5191 team_id: Required when *visibility* is ``team``; scopes the uniqueness check to this team.
5192 owner_email: Required when *visibility* is ``private``; scopes the uniqueness check to this owner.
5194 Raises:
5195 ToolNameConflictError: If a conflicting tool already exists in the target scope.
5196 """
5197 if visibility == "public":
5198 existing_tool = get_for_update(
5199 db,
5200 DbTool,
5201 where=and_(
5202 DbTool.custom_name == custom_name,
5203 DbTool.visibility == "public",
5204 DbTool.id != tool_id,
5205 ),
5206 )
5207 elif visibility == "team" and team_id:
5208 existing_tool = get_for_update(
5209 db,
5210 DbTool,
5211 where=and_(
5212 DbTool.custom_name == custom_name,
5213 DbTool.visibility == "team",
5214 DbTool.team_id == team_id,
5215 DbTool.id != tool_id,
5216 ),
5217 )
5218 elif visibility == "private" and owner_email:
5219 existing_tool = get_for_update(
5220 db,
5221 DbTool,
5222 where=and_(
5223 DbTool.custom_name == custom_name,
5224 DbTool.visibility == "private",
5225 DbTool.owner_email == owner_email,
5226 DbTool.id != tool_id,
5227 ),
5228 )
5229 else:
5230 logger.warning("Skipping conflict check for tool %s: visibility=%r requires %s but none provided", tool_id, visibility, "team_id" if visibility == "team" else "owner_email")
5231 return
5232 if existing_tool:
5233 raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
5235 async def update_tool(
5236 self,
5237 db: Session,
5238 tool_id: str,
5239 tool_update: ToolUpdate,
5240 modified_by: Optional[str] = None,
5241 modified_from_ip: Optional[str] = None,
5242 modified_via: Optional[str] = None,
5243 modified_user_agent: Optional[str] = None,
5244 user_email: Optional[str] = None,
5245 ) -> ToolRead:
5246 """
5247 Update an existing tool.
5249 Args:
5250 db (Session): The SQLAlchemy database session.
5251 tool_id (str): The unique identifier of the tool.
5252 tool_update (ToolUpdate): Tool update schema with new data.
5253 modified_by (Optional[str]): Username who modified this tool.
5254 modified_from_ip (Optional[str]): IP address of modifier.
5255 modified_via (Optional[str]): Modification method (ui, api).
5256 modified_user_agent (Optional[str]): User agent of modification request.
5257 user_email (Optional[str]): Email of user performing update (for ownership check).
5259 Returns:
5260 The updated ToolRead object.
5262 Raises:
5263 ToolNotFoundError: If the tool is not found.
5264 PermissionError: If user doesn't own the tool.
5265 IntegrityError: If there is a database integrity error.
5266 ToolNameConflictError: If a tool with the same name already exists.
5267 ToolError: For other update errors.
5269 Examples:
5270 >>> from mcpgateway.services.tool_service import ToolService
5271 >>> from unittest.mock import MagicMock, AsyncMock
5272 >>> from mcpgateway.schemas import ToolRead
5273 >>> service = ToolService()
5274 >>> db = MagicMock()
5275 >>> tool = MagicMock()
5276 >>> db.get.return_value = tool
5277 >>> db.commit = MagicMock()
5278 >>> db.refresh = MagicMock()
5279 >>> db.execute.return_value.scalar_one_or_none.return_value = None
5280 >>> service._notify_tool_updated = AsyncMock()
5281 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
5282 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
5283 >>> import asyncio
5284 >>> asyncio.run(service.update_tool(db, 'tool_id', MagicMock()))
5285 'tool_read'
5286 """
5287 try:
5288 tool = get_for_update(db, DbTool, tool_id)
5290 if not tool:
5291 raise ToolNotFoundError(f"Tool not found: {tool_id}")
5293 old_tool_name = tool.name
5294 old_gateway_id = tool.gateway_id
5296 # Check ownership if user_email provided
5297 if user_email:
5298 # First-Party
5299 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
5301 permission_service = PermissionService(db)
5302 if not await permission_service.check_resource_ownership(user_email, tool):
5303 raise PermissionError("Only the owner can update this tool")
5305 # Track whether a name change occurred (before tool.name is mutated)
5306 name_is_changing = bool(tool_update.name and tool_update.name != tool.name)
5308 # Check for name change and ensure uniqueness
5309 if name_is_changing:
5310 # Always derive ownership fields from the DB record — never trust client-provided team_id/owner_email
5311 tool_visibility_ref = tool.visibility if tool_update.visibility is None else tool_update.visibility.lower()
5312 if tool_update.custom_name is not None:
5313 custom_name_ref = tool_update.custom_name
5314 elif tool.name == tool.custom_name:
5315 custom_name_ref = tool_update.name # custom_name will track the rename
5316 else:
5317 custom_name_ref = tool.custom_name # custom_name stays unchanged
5318 self._check_tool_name_conflict(db, custom_name_ref, tool_visibility_ref, tool.id, team_id=tool.team_id, owner_email=tool.owner_email)
5319 if tool_update.custom_name is None and tool.name == tool.custom_name:
5320 tool.custom_name = tool_update.name
5321 tool.name = tool_update.name
5323 # Check for conflicts when visibility changes without a name change
5324 if tool_update.visibility is not None and tool_update.visibility.lower() != tool.visibility and not name_is_changing:
5325 new_visibility = tool_update.visibility.lower()
5326 self._check_tool_name_conflict(db, tool.custom_name, new_visibility, tool.id, team_id=tool.team_id, owner_email=tool.owner_email)
5328 if tool_update.custom_name is not None:
5329 tool.custom_name = tool_update.custom_name
5330 if tool_update.displayName is not None:
5331 tool.display_name = tool_update.displayName
5332 if tool_update.url is not None:
5333 tool.url = str(tool_update.url)
5334 if tool_update.description is not None:
5335 tool.description = tool_update.description
5336 if tool_update.title is not None:
5337 tool.title = tool_update.title
5338 if tool_update.integration_type is not None:
5339 tool.integration_type = tool_update.integration_type
5340 if tool_update.request_type is not None:
5341 tool.request_type = tool_update.request_type
5342 if tool_update.headers is not None:
5343 tool.headers = _protect_tool_headers_for_storage(tool_update.headers, existing_headers=tool.headers)
5344 if tool_update.input_schema is not None:
5345 tool.input_schema = tool_update.input_schema
5346 if tool_update.output_schema is not None:
5347 tool.output_schema = tool_update.output_schema
5348 if tool_update.annotations is not None:
5349 tool.annotations = tool_update.annotations
5350 if tool_update.jsonpath_filter is not None:
5351 tool.jsonpath_filter = tool_update.jsonpath_filter
5352 if tool_update.visibility is not None:
5353 tool.visibility = tool_update.visibility
5355 if tool_update.auth is not None:
5356 if tool_update.auth.auth_type is not None:
5357 tool.auth_type = tool_update.auth.auth_type
5358 if tool_update.auth.auth_value is not None:
5359 tool.auth_value = tool_update.auth.auth_value
5361 # Update tags if provided
5362 if tool_update.tags is not None:
5363 tool.tags = tool_update.tags
5365 # Update modification metadata
5366 if modified_by is not None:
5367 tool.modified_by = modified_by
5368 if modified_from_ip is not None:
5369 tool.modified_from_ip = modified_from_ip
5370 if modified_via is not None:
5371 tool.modified_via = modified_via
5372 if modified_user_agent is not None:
5373 tool.modified_user_agent = modified_user_agent
5375 # Increment version
5376 if hasattr(tool, "version") and tool.version is not None:
5377 tool.version += 1
5378 else:
5379 tool.version = 1
5380 logger.info(f"Update tool: {tool.name} (output_schema: {tool.output_schema})")
5382 tool.updated_at = datetime.now(timezone.utc)
5383 db.commit()
5384 db.refresh(tool)
5385 await self._notify_tool_updated(tool)
5386 logger.info(f"Updated tool: {tool.name}")
5388 # Structured logging: Audit trail for tool update
5389 changes = []
5390 if tool_update.name:
5391 changes.append(f"name: {tool_update.name}")
5392 if tool_update.visibility:
5393 changes.append(f"visibility: {tool_update.visibility}")
5394 if tool_update.description:
5395 changes.append("description updated")
5397 audit_trail.log_action(
5398 user_id=user_email or modified_by or "system",
5399 action="update_tool",
5400 resource_type="tool",
5401 resource_id=tool.id,
5402 resource_name=tool.name,
5403 user_email=user_email,
5404 team_id=tool.team_id,
5405 client_ip=modified_from_ip,
5406 user_agent=modified_user_agent,
5407 new_values={
5408 "name": tool.name,
5409 "display_name": tool.display_name,
5410 "version": tool.version,
5411 },
5412 context={
5413 "modified_via": modified_via,
5414 "changes": ", ".join(changes) if changes else "metadata only",
5415 },
5416 db=db,
5417 )
5419 # Structured logging: Log successful tool update
5420 structured_logger.log(
5421 level="INFO",
5422 message="Tool updated successfully",
5423 event_type="tool_updated",
5424 component="tool_service",
5425 user_id=modified_by,
5426 user_email=user_email,
5427 team_id=tool.team_id,
5428 resource_type="tool",
5429 resource_id=tool.id,
5430 custom_fields={
5431 "tool_name": tool.name,
5432 "version": tool.version,
5433 },
5434 )
5436 # Invalidate cache after successful update
5437 cache = _get_registry_cache()
5438 await cache.invalidate_tools()
5439 tool_lookup_cache = _get_tool_lookup_cache()
5440 await tool_lookup_cache.invalidate(old_tool_name, gateway_id=str(old_gateway_id) if old_gateway_id else None)
5441 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
5442 # Also invalidate tags cache since tool tags may have changed
5443 # First-Party
5444 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
5446 await admin_stats_cache.invalidate_tags()
5448 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None))
5449 except PermissionError as pe:
5450 db.rollback()
5452 # Structured logging: Log permission error
5453 structured_logger.log(
5454 level="WARNING",
5455 message="Tool update failed due to permission error",
5456 event_type="tool_update_permission_denied",
5457 component="tool_service",
5458 user_email=user_email,
5459 resource_type="tool",
5460 resource_id=tool_id,
5461 error=pe,
5462 )
5463 raise
5464 except IntegrityError as ie:
5465 db.rollback()
5466 logger.error(f"IntegrityError during tool update: {ie}")
5468 # Structured logging: Log database integrity error
5469 structured_logger.log(
5470 level="ERROR",
5471 message="Tool update failed due to database integrity error",
5472 event_type="tool_update_failed",
5473 component="tool_service",
5474 user_id=modified_by,
5475 user_email=user_email,
5476 resource_type="tool",
5477 resource_id=tool_id,
5478 error=ie,
5479 )
5480 raise ie
5481 except ToolNotFoundError as tnfe:
5482 db.rollback()
5483 logger.error(f"Tool not found during update: {tnfe}")
5485 # Structured logging: Log not found error
5486 structured_logger.log(
5487 level="ERROR",
5488 message="Tool update failed - tool not found",
5489 event_type="tool_not_found",
5490 component="tool_service",
5491 user_email=user_email,
5492 resource_type="tool",
5493 resource_id=tool_id,
5494 error=tnfe,
5495 )
5496 raise tnfe
5497 except ToolNameConflictError as tnce:
5498 db.rollback()
5499 logger.error(f"Tool name conflict during update: {tnce}")
5501 # Structured logging: Log name conflict error
5502 structured_logger.log(
5503 level="WARNING",
5504 message="Tool update failed due to name conflict",
5505 event_type="tool_name_conflict",
5506 component="tool_service",
5507 user_id=modified_by,
5508 user_email=user_email,
5509 resource_type="tool",
5510 resource_id=tool_id,
5511 error=tnce,
5512 )
5513 raise tnce
5514 except Exception as ex:
5515 db.rollback()
5517 # Structured logging: Log generic tool update failure
5518 structured_logger.log(
5519 level="ERROR",
5520 message="Tool update failed",
5521 event_type="tool_update_failed",
5522 component="tool_service",
5523 user_id=modified_by,
5524 user_email=user_email,
5525 resource_type="tool",
5526 resource_id=tool_id,
5527 error=ex,
5528 )
5529 raise ToolError(f"Failed to update tool: {str(ex)}")
5531 async def _notify_tool_updated(self, tool: DbTool) -> None:
5532 """
5533 Notify subscribers of tool update.
5535 Args:
5536 tool: Tool updated
5537 """
5538 event = {
5539 "type": "tool_updated",
5540 "data": {"id": tool.id, "name": tool.name, "url": tool.url, "description": tool.description, "enabled": tool.enabled},
5541 "timestamp": datetime.now(timezone.utc).isoformat(),
5542 }
5543 await self._publish_event(event)
5545 async def _notify_tool_activated(self, tool: DbTool) -> None:
5546 """
5547 Notify subscribers of tool activation.
5549 Args:
5550 tool: Tool activated
5551 """
5552 event = {
5553 "type": "tool_activated",
5554 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable},
5555 "timestamp": datetime.now(timezone.utc).isoformat(),
5556 }
5557 await self._publish_event(event)
5559 async def _notify_tool_deactivated(self, tool: DbTool) -> None:
5560 """
5561 Notify subscribers of tool deactivation.
5563 Args:
5564 tool: Tool deactivated
5565 """
5566 event = {
5567 "type": "tool_deactivated",
5568 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable},
5569 "timestamp": datetime.now(timezone.utc).isoformat(),
5570 }
5571 await self._publish_event(event)
5573 async def _notify_tool_offline(self, tool: DbTool) -> None:
5574 """
5575 Notify subscribers that tool is offline.
5577 Args:
5578 tool: Tool database object
5579 """
5580 event = {
5581 "type": "tool_offline",
5582 "data": {
5583 "id": tool.id,
5584 "name": tool.name,
5585 "enabled": True,
5586 "reachable": False,
5587 },
5588 "timestamp": datetime.now(timezone.utc).isoformat(),
5589 }
5590 await self._publish_event(event)
5592 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None:
5593 """
5594 Notify subscribers of tool deletion.
5596 Args:
5597 tool_info: Dictionary on tool deleted
5598 """
5599 event = {
5600 "type": "tool_deleted",
5601 "data": tool_info,
5602 "timestamp": datetime.now(timezone.utc).isoformat(),
5603 }
5604 await self._publish_event(event)
5606 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
5607 """Subscribe to tool events via the EventService.
5609 Yields:
5610 Tool event messages.
5611 """
5612 async for event in self._event_service.subscribe_events():
5613 yield event
5615 async def _notify_tool_added(self, tool: DbTool) -> None:
5616 """
5617 Notify subscribers of tool addition.
5619 Args:
5620 tool: Tool added
5621 """
5622 event = {
5623 "type": "tool_added",
5624 "data": {
5625 "id": tool.id,
5626 "name": tool.name,
5627 "url": tool.url,
5628 "description": tool.description,
5629 "enabled": tool.enabled,
5630 },
5631 "timestamp": datetime.now(timezone.utc).isoformat(),
5632 }
5633 await self._publish_event(event)
5635 async def _notify_tool_removed(self, tool: DbTool) -> None:
5636 """
5637 Notify subscribers of tool removal (soft delete/deactivation).
5639 Args:
5640 tool: Tool removed
5641 """
5642 event = {
5643 "type": "tool_removed",
5644 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled},
5645 "timestamp": datetime.now(timezone.utc).isoformat(),
5646 }
5647 await self._publish_event(event)
5649 async def _publish_event(self, event: Dict[str, Any]) -> None:
5650 """
5651 Publish event to all subscribers via the EventService.
5653 Args:
5654 event: Event to publish
5655 """
5656 await self._event_service.publish_event(event)
5658 async def _validate_tool_url(self, url: str) -> None:
5659 """Validate tool URL is accessible.
5661 Args:
5662 url: URL to validate.
5664 Raises:
5665 ToolValidationError: If URL validation fails.
5666 """
5667 try:
5668 response = await self._http_client.get(url)
5669 response.raise_for_status()
5670 except Exception as e:
5671 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}")
5673 async def _check_tool_health(self, tool: DbTool) -> bool:
5674 """Check if tool endpoint is healthy.
5676 Args:
5677 tool: Tool to check.
5679 Returns:
5680 True if tool is healthy.
5681 """
5682 try:
5683 response = await self._http_client.get(tool.url)
5684 return response.is_success
5685 except Exception:
5686 return False
5688 # async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]:
5689 # """Generate tool events for SSE.
5691 # Yields:
5692 # Tool events.
5693 # """
5694 # queue: asyncio.Queue = asyncio.Queue()
5695 # self._event_subscribers.append(queue)
5696 # try:
5697 # while True:
5698 # event = await queue.get()
5699 # yield event
5700 # finally:
5701 # self._event_subscribers.remove(queue)
5703 # --- Metrics ---
5704 async def aggregate_metrics(self, db: Session) -> ToolMetrics:
5705 """
5706 Aggregate metrics for all tool invocations across all tools.
5708 Combines recent raw metrics (within retention period) with historical
5709 hourly rollups for complete historical coverage. Uses in-memory caching
5710 (10s TTL) to reduce database load under high request rates.
5712 Args:
5713 db: Database session
5715 Returns:
5716 ToolMetrics: Aggregated metrics computed from raw ToolMetric + ToolMetricsHourly.
5718 Examples:
5719 >>> from mcpgateway.services.tool_service import ToolService
5720 >>> service = ToolService()
5721 >>> # Method exists and is callable
5722 >>> callable(service.aggregate_metrics)
5723 True
5724 """
5725 # Check cache first (if enabled)
5726 # First-Party
5727 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
5729 if is_cache_enabled():
5730 cached = metrics_cache.get("tools")
5731 if cached is not None:
5732 return ToolMetrics(**cached)
5734 # Use combined raw + rollup query for full historical coverage
5735 # First-Party
5736 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel
5738 result = aggregate_metrics_combined(db, "tool")
5739 metrics = ToolMetrics(
5740 total_executions=result.total_executions,
5741 successful_executions=result.successful_executions,
5742 failed_executions=result.failed_executions,
5743 failure_rate=result.failure_rate,
5744 min_response_time=result.min_response_time,
5745 max_response_time=result.max_response_time,
5746 avg_response_time=result.avg_response_time,
5747 last_execution_time=result.last_execution_time,
5748 )
5750 # Cache the result as dict for serialization compatibility (if enabled)
5751 if is_cache_enabled():
5752 metrics_cache.set("tools", metrics.model_dump())
5754 return metrics
5756 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None:
5757 """
5758 Reset all tool metrics by deleting raw and hourly rollup records.
5760 Args:
5761 db: Database session
5762 tool_id: Optional tool ID to reset metrics for a specific tool
5764 Examples:
5765 >>> from mcpgateway.services.tool_service import ToolService
5766 >>> from unittest.mock import MagicMock
5767 >>> service = ToolService()
5768 >>> db = MagicMock()
5769 >>> db.execute = MagicMock()
5770 >>> db.commit = MagicMock()
5771 >>> import asyncio
5772 >>> asyncio.run(service.reset_metrics(db))
5773 """
5775 if tool_id:
5776 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id))
5777 db.execute(delete(ToolMetricsHourly).where(ToolMetricsHourly.tool_id == tool_id))
5778 else:
5779 db.execute(delete(ToolMetric))
5780 db.execute(delete(ToolMetricsHourly))
5781 db.commit()
5783 # Invalidate metrics cache
5784 # First-Party
5785 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
5787 metrics_cache.invalidate("tools")
5788 metrics_cache.invalidate_prefix("top_tools:")
5790 async def create_tool_from_a2a_agent(
5791 self,
5792 db: Session,
5793 agent: DbA2AAgent,
5794 created_by: Optional[str] = None,
5795 created_from_ip: Optional[str] = None,
5796 created_via: Optional[str] = None,
5797 created_user_agent: Optional[str] = None,
5798 ) -> DbTool:
5799 """Create a tool entry from an A2A agent for virtual server integration.
5801 Args:
5802 db: Database session.
5803 agent: A2A agent to create tool from.
5804 created_by: Username who created this tool.
5805 created_from_ip: IP address of creator.
5806 created_via: Creation method.
5807 created_user_agent: User agent of creation request.
5809 Returns:
5810 The created tool database object.
5812 Raises:
5813 ToolNameConflictError: If a tool with the same name already exists.
5814 """
5815 # Check if tool already exists for this agent
5816 tool_name = f"a2a_{agent.slug}"
5817 existing_query = select(DbTool).where(DbTool.original_name == tool_name)
5818 existing_tool = db.execute(existing_query).scalar_one_or_none()
5820 if existing_tool:
5821 # Tool already exists, return it
5822 return existing_tool
5824 # Create tool entry for the A2A agent
5825 logger.debug(f"agent.tags: {agent.tags} for agent: {agent.name} (ID: {agent.id})")
5827 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..},
5828 # extract the human-friendly label. If tags are already strings, keep them.
5829 normalized_tags: list[str] = []
5830 for t in agent.tags or []:
5831 if isinstance(t, dict):
5832 # Prefer 'label', fall back to 'id' or stringified dict
5833 normalized_tags.append(t.get("label") or t.get("id") or str(t))
5834 elif hasattr(t, "label"):
5835 normalized_tags.append(getattr(t, "label"))
5836 else:
5837 normalized_tags.append(str(t))
5839 # Ensure we include identifying A2A tags
5840 normalized_tags = normalized_tags + ["a2a", "agent"]
5842 tool_data = ToolCreate(
5843 name=tool_name,
5844 displayName=generate_display_name(agent.name),
5845 url=agent.endpoint_url,
5846 description=f"A2A Agent: {agent.description or agent.name}",
5847 integration_type="A2A", # Special integration type for A2A agents
5848 request_type="POST",
5849 input_schema={
5850 "type": "object",
5851 "properties": {
5852 "query": {"type": "string", "description": "User query", "default": "Hello from ContextForge Admin UI test!"},
5853 },
5854 "required": ["query"],
5855 },
5856 allow_auto=True,
5857 annotations={
5858 "title": f"A2A Agent: {agent.name}",
5859 "a2a_agent_id": agent.id,
5860 "a2a_agent_type": agent.agent_type,
5861 },
5862 auth_type=agent.auth_type,
5863 auth_value=agent.auth_value,
5864 tags=normalized_tags,
5865 )
5867 # Default to "public" visibility if agent visibility is not set
5868 # This ensures A2A tools are visible in the Global Tools Tab
5869 tool_visibility = agent.visibility or "public"
5871 tool_read = await self.register_tool(
5872 db,
5873 tool_data,
5874 created_by=created_by,
5875 created_from_ip=created_from_ip,
5876 created_via=created_via or "a2a_integration",
5877 created_user_agent=created_user_agent,
5878 team_id=agent.team_id,
5879 owner_email=agent.owner_email,
5880 visibility=tool_visibility,
5881 )
5883 # Return the DbTool object for relationship assignment
5884 tool_db = db.get(DbTool, tool_read.id)
5885 return tool_db
5887 async def update_tool_from_a2a_agent(
5888 self,
5889 db: Session,
5890 agent: DbA2AAgent,
5891 modified_by: Optional[str] = None,
5892 modified_from_ip: Optional[str] = None,
5893 modified_via: Optional[str] = None,
5894 modified_user_agent: Optional[str] = None,
5895 ) -> Optional[ToolRead]:
5896 """Update the tool associated with an A2A agent when the agent is updated.
5898 Args:
5899 db: Database session.
5900 agent: Updated A2A agent.
5901 modified_by: Username who modified this tool.
5902 modified_from_ip: IP address of modifier.
5903 modified_via: Modification method.
5904 modified_user_agent: User agent of modification request.
5906 Returns:
5907 The updated tool, or None if no associated tool exists.
5908 """
5909 # Use the tool_id from the agent for efficient lookup
5910 if not agent.tool_id:
5911 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool update")
5912 return None
5914 tool = db.get(DbTool, agent.tool_id)
5915 if not tool:
5916 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}, resetting tool_id")
5917 agent.tool_id = None
5918 db.commit()
5919 return None
5921 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..},
5922 # extract the human-friendly label. If tags are already strings, keep them.
5923 normalized_tags: list[str] = []
5924 for t in agent.tags or []:
5925 if isinstance(t, dict):
5926 # Prefer 'label', fall back to 'id' or stringified dict
5927 normalized_tags.append(t.get("label") or t.get("id") or str(t))
5928 elif hasattr(t, "label"):
5929 normalized_tags.append(getattr(t, "label"))
5930 else:
5931 normalized_tags.append(str(t))
5933 # Ensure we include identifying A2A tags
5934 normalized_tags = normalized_tags + ["a2a", "agent"]
5936 # Prepare update data matching the agent's current state
5937 # IMPORTANT: Preserve the existing tool's visibility to avoid unintentionally
5938 # making private/team tools public (ToolUpdate defaults to "public")
5939 # Note: team_id is not a field on ToolUpdate schema, so team assignment is preserved
5940 # implicitly by not changing visibility (team tools stay team-scoped)
5941 new_tool_name = f"a2a_{agent.slug}"
5942 tool_update = ToolUpdate(
5943 name=new_tool_name,
5944 custom_name=new_tool_name, # Also set custom_name to ensure name update works
5945 displayName=generate_display_name(agent.name),
5946 url=agent.endpoint_url,
5947 description=f"A2A Agent: {agent.description or agent.name}",
5948 auth=AuthenticationValues(auth_type=agent.auth_type, auth_value=agent.auth_value) if agent.auth_type else None,
5949 tags=normalized_tags,
5950 visibility=tool.visibility, # Preserve existing visibility
5951 )
5953 # Update the tool
5954 return await self.update_tool(
5955 db=db,
5956 tool_id=tool.id,
5957 tool_update=tool_update,
5958 modified_by=modified_by,
5959 modified_from_ip=modified_from_ip,
5960 modified_via=modified_via or "a2a_sync",
5961 modified_user_agent=modified_user_agent,
5962 )
5964 async def delete_tool_from_a2a_agent(self, db: Session, agent: DbA2AAgent, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
5965 """Delete the tool associated with an A2A agent when the agent is deleted.
5967 Args:
5968 db: Database session.
5969 agent: The A2A agent being deleted.
5970 user_email: Email of user performing delete (for ownership check).
5971 purge_metrics: If True, delete raw + rollup metrics for this tool.
5972 """
5973 # Use the tool_id from the agent for efficient lookup
5974 if not agent.tool_id:
5975 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool deletion")
5976 return
5978 tool = db.get(DbTool, agent.tool_id)
5979 if not tool:
5980 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}")
5981 return
5983 # Delete the tool
5984 await self.delete_tool(db=db, tool_id=tool.id, user_email=user_email, purge_metrics=purge_metrics)
5985 logger.info(f"Deleted tool {tool.id} associated with A2A agent {agent.id}")
5987 async def _invoke_a2a_tool(self, db: Session, tool: DbTool, arguments: Dict[str, Any]) -> ToolResult:
5988 """Invoke an A2A agent through its corresponding tool.
5990 Args:
5991 db: Database session.
5992 tool: The tool record that represents the A2A agent.
5993 arguments: Tool arguments.
5995 Returns:
5996 Tool result from A2A agent invocation.
5998 Raises:
5999 ToolNotFoundError: If the A2A agent is not found.
6000 """
6002 # Extract A2A agent ID from tool annotations
6003 agent_id = tool.annotations.get("a2a_agent_id")
6004 if not agent_id:
6005 raise ToolNotFoundError(f"A2A tool '{tool.name}' missing agent ID in annotations")
6007 # Get the A2A agent
6008 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
6009 agent = db.execute(agent_query).scalar_one_or_none()
6011 if not agent:
6012 raise ToolNotFoundError(f"A2A agent not found for tool '{tool.name}' (agent ID: {agent_id})")
6014 if not agent.enabled:
6015 raise ToolNotFoundError(f"A2A agent '{agent.name}' is disabled")
6017 # Force-load all attributes needed by _call_a2a_agent before detaching
6018 # (accessing them ensures they're loaded into the object's __dict__)
6019 _ = (agent.name, agent.endpoint_url, agent.agent_type, agent.protocol_version, agent.auth_type, agent.auth_value, agent.auth_query_params)
6021 # Detach agent from session so its loaded data remains accessible after close
6022 db.expunge(agent)
6024 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
6025 # This prevents "idle in transaction" connection pool exhaustion under load
6026 db.commit()
6027 db.close()
6029 # Prepare parameters for A2A invocation
6030 try:
6031 # Make the A2A agent call (agent is now detached but data is loaded)
6032 response_data = await self._call_a2a_agent(agent, arguments)
6034 # Convert A2A response to MCP ToolResult format
6035 if isinstance(response_data, dict) and "response" in response_data:
6036 val = response_data["response"]
6037 content = [TextContent(type="text", text=val if isinstance(val, str) else orjson.dumps(val).decode())]
6038 else:
6039 content = [TextContent(type="text", text=response_data if isinstance(response_data, str) else orjson.dumps(response_data).decode())]
6041 result = ToolResult(content=content, is_error=False)
6043 except Exception as e:
6044 error_message = str(e)
6045 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")]
6046 result = ToolResult(content=content, is_error=True)
6048 # Note: Metrics are recorded by the calling invoke_tool method, not here
6049 return result
6051 async def _call_a2a_agent(self, agent: DbA2AAgent, parameters: Dict[str, Any]):
6052 """Call an A2A agent directly.
6054 Args:
6055 agent: The A2A agent to call.
6056 parameters: Parameters for the interaction.
6058 Returns:
6059 Response from the A2A agent.
6061 Raises:
6062 ToolInvocationError: If authentication decryption fails.
6063 Exception: If the call fails.
6064 """
6065 logger.info(f"Calling A2A agent '{agent.name}' at {agent.endpoint_url} with arguments: {parameters}")
6067 # Build request data based on agent type
6068 if agent.agent_type in ["generic", "jsonrpc"] or agent.endpoint_url.endswith("/"):
6069 # JSONRPC agents: Convert flat query to nested message structure
6070 params = None
6071 if isinstance(parameters, dict) and "query" in parameters and isinstance(parameters["query"], str):
6072 # Build the nested message object for JSONRPC protocol
6073 message_id = f"admin-test-{int(time.time())}"
6074 # A2A v0.3.x: message.parts use "kind" (not "type").
6075 params = {
6076 "message": {
6077 "kind": "message",
6078 "messageId": message_id,
6079 "role": "user",
6080 "parts": [{"kind": "text", "text": parameters["query"]}],
6081 }
6082 }
6083 method = parameters.get("method", "message/send")
6084 else:
6085 # Already in correct format or unknown, pass through
6086 params = parameters.get("params", parameters)
6087 method = parameters.get("method", "message/send")
6089 try:
6090 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1}
6091 logger.info(f"invoke tool JSONRPC request_data prepared: {request_data}")
6092 except Exception as e:
6093 logger.error(f"Error preparing JSONRPC request data: {e}")
6094 raise
6095 else:
6096 # Custom agents: Pass parameters directly without JSONRPC message conversion
6097 # Custom agents expect flat fields like {"query": "...", "message": "..."}
6098 params = parameters if isinstance(parameters, dict) else {}
6099 logger.info(f"invoke tool Using custom A2A format for A2A agent '{params}'")
6100 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": agent.protocol_version}
6101 logger.info(f"invoke tool request_data prepared: {request_data}")
6102 # Make HTTP request to the agent endpoint using shared HTTP client
6103 # First-Party
6104 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
6106 client = await get_http_client()
6107 headers = {"Content-Type": "application/json"}
6109 # Determine the endpoint URL (may be modified for query_param auth)
6110 endpoint_url = agent.endpoint_url
6112 # Add authentication if configured
6113 if agent.auth_type in ("api_key", "basic", "bearer", "authheaders") and agent.auth_value:
6114 # Decrypt auth_value and extract headers (matches a2a_service.py pattern)
6115 if isinstance(agent.auth_value, str):
6116 try:
6117 auth_headers = decode_auth(agent.auth_value)
6118 headers.update(auth_headers)
6119 except Exception as e:
6120 logger.error(f"Failed to decrypt authentication for A2A agent '{agent.name}': {e}")
6121 raise ToolInvocationError(f"Failed to decrypt authentication for A2A agent '{agent.name}'")
6122 elif isinstance(agent.auth_value, dict):
6123 auth_headers = {str(k): str(v) for k, v in agent.auth_value.items()}
6124 headers.update(auth_headers)
6125 elif agent.auth_type == "query_param" and agent.auth_query_params:
6126 # Handle query parameter authentication (imports at top: decode_auth, apply_query_param_auth, sanitize_url_for_logging)
6127 auth_query_params_decrypted: dict[str, str] = {}
6128 for param_key, encrypted_value in agent.auth_query_params.items():
6129 if encrypted_value:
6130 try:
6131 decrypted = decode_auth(encrypted_value)
6132 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
6133 except Exception:
6134 logger.debug(f"Failed to decrypt query param for key '{param_key}'")
6135 if auth_query_params_decrypted:
6136 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted)
6137 # Log sanitized URL to avoid credential leakage
6138 sanitized_url = sanitize_url_for_logging(endpoint_url, auth_query_params_decrypted)
6139 logger.debug(f"Applied query param auth to A2A agent endpoint: {sanitized_url}")
6141 http_response = await client.post(endpoint_url, json=request_data, headers=headers)
6143 if http_response.status_code == 200:
6144 return http_response.json()
6146 raise Exception(f"HTTP {http_response.status_code}: {http_response.text}")
6149# Lazy singleton - created on first access, not at module import time.
6150# This avoids instantiation when only exception classes are imported.
6151_tool_service_instance = None # pylint: disable=invalid-name
6154def __getattr__(name: str):
6155 """Module-level __getattr__ for lazy singleton creation.
6157 Args:
6158 name: The attribute name being accessed.
6160 Returns:
6161 The tool_service singleton instance if name is "tool_service".
6163 Raises:
6164 AttributeError: If the attribute name is not "tool_service".
6165 """
6166 global _tool_service_instance # pylint: disable=global-statement
6167 if name == "tool_service":
6168 if _tool_service_instance is None:
6169 _tool_service_instance = ToolService()
6170 return _tool_service_instance
6171 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")