Coverage for mcpgateway / services / tool_service.py: 99%
1809 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/tool_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Tool Service Implementation.
8This module implements tool management and invocation according to the MCP specification.
9It handles:
10- Tool registration and validation
11- Tool invocation with schema validation
12- Tool federation across gateways
13- Event notifications for tool changes
14- Active/inactive tool management
15"""
17# Standard
18import asyncio
19import base64
20import binascii
21from datetime import datetime, timezone
22from functools import lru_cache
23import os
24import re
25import ssl
26import time
27from types import SimpleNamespace
28from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
29from urllib.parse import parse_qs, urlparse
30import uuid
32# Third-Party
33import httpx
34import jq
35import jsonschema
36from jsonschema import Draft4Validator, Draft6Validator, Draft7Validator, validators
37from mcp import ClientSession, types
38from mcp.client.sse import sse_client
39from mcp.client.streamable_http import streamablehttp_client
40import orjson
41from pydantic import ValidationError
42from sqlalchemy import and_, delete, desc, or_, select
43from sqlalchemy.exc import IntegrityError, OperationalError
44from sqlalchemy.orm import joinedload, selectinload, Session
46# First-Party
47from mcpgateway.cache.global_config_cache import global_config_cache
48from mcpgateway.common.models import Gateway as PydanticGateway
49from mcpgateway.common.models import TextContent
50from mcpgateway.common.models import Tool as PydanticTool
51from mcpgateway.common.models import ToolResult
52from mcpgateway.config import settings
53from mcpgateway.db import A2AAgent as DbA2AAgent
54from mcpgateway.db import fresh_db_session
55from mcpgateway.db import Gateway as DbGateway
56from mcpgateway.db import get_for_update, server_tool_association
57from mcpgateway.db import Tool as DbTool
58from mcpgateway.db import ToolMetric, ToolMetricsHourly
59from mcpgateway.observability import create_span
60from mcpgateway.plugins.framework import (
61 get_plugin_manager,
62 GlobalContext,
63 HttpHeaderPayload,
64 PluginContextTable,
65 PluginError,
66 PluginManager,
67 PluginViolationError,
68 ToolHookType,
69 ToolPostInvokePayload,
70 ToolPreInvokePayload,
71)
72from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA
73from mcpgateway.schemas import AuthenticationValues, ToolCreate, ToolMetrics, ToolRead, ToolUpdate, TopPerformer
74from mcpgateway.services.audit_trail_service import get_audit_trail_service
75from mcpgateway.services.base_service import BaseService
76from mcpgateway.services.event_service import EventService
77from mcpgateway.services.logging_service import LoggingService
78from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType
79from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service
80from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge
81from mcpgateway.services.metrics_query_service import get_top_performers_combined
82from mcpgateway.services.oauth_manager import OAuthManager
83from mcpgateway.services.observability_service import current_trace_id, ObservabilityService
84from mcpgateway.services.performance_tracker import get_performance_tracker
85from mcpgateway.services.structured_logger import get_structured_logger
86from mcpgateway.services.team_management_service import TeamManagementService
87from mcpgateway.utils.correlation_id import get_correlation_id
88from mcpgateway.utils.create_slug import slugify
89from mcpgateway.utils.display_name import generate_display_name
90from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers
91from mcpgateway.utils.metrics_common import build_top_performers
92from mcpgateway.utils.pagination import decode_cursor, encode_cursor, unified_paginate
93from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
94from mcpgateway.utils.retry_manager import ResilientHttpClient
95from mcpgateway.utils.services_auth import decode_auth, encode_auth
96from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
97from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context
98from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging
99from mcpgateway.utils.validate_signature import validate_signature
101# Cache import (lazy to avoid circular dependencies)
102_REGISTRY_CACHE = None
103_TOOL_LOOKUP_CACHE = None
106def _get_registry_cache():
107 """Get registry cache singleton lazily.
109 Returns:
110 RegistryCache instance.
111 """
112 global _REGISTRY_CACHE # pylint: disable=global-statement
113 if _REGISTRY_CACHE is None:
114 # First-Party
115 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
117 _REGISTRY_CACHE = registry_cache
118 return _REGISTRY_CACHE
121def _get_tool_lookup_cache():
122 """Get tool lookup cache singleton lazily.
124 Returns:
125 ToolLookupCache instance.
126 """
127 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement
128 if _TOOL_LOOKUP_CACHE is None:
129 # First-Party
130 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
132 _TOOL_LOOKUP_CACHE = tool_lookup_cache
133 return _TOOL_LOOKUP_CACHE
136# Initialize logging service first
137logging_service = LoggingService()
138logger = logging_service.get_logger(__name__)
140# Initialize performance tracker, structured logger, audit trail, and metrics buffer for tool operations
141perf_tracker = get_performance_tracker()
142structured_logger = get_structured_logger("tool_service")
143audit_trail = get_audit_trail_service()
144metrics_buffer = get_metrics_buffer_service()
146_ENCRYPTED_TOOL_HEADER_VALUE_KEY = "_mcpgateway_encrypted_header_value_v1"
147_TOOL_HEADER_DATA_KEY = "data"
148_TOOL_HEADER_LEGACY_VALUE_KEY = "value"
149_SENSITIVE_TOOL_HEADER_PATTERNS = (
150 re.compile(r"^authorization$", re.IGNORECASE),
151 re.compile(r"^proxy-authorization$", re.IGNORECASE),
152 re.compile(r"^x-api-key$", re.IGNORECASE),
153 re.compile(r"^api-key$", re.IGNORECASE),
154 re.compile(r"^apikey$", re.IGNORECASE),
155 # Keep broad-enough auth matching while avoiding operational noise from
156 # non-secret tracing/idempotency headers (e.g. X-Correlation-Token).
157 re.compile(r"^x-(?:auth|api|access|refresh|client|bearer|session|security)[-_]?(?:token|secret|key)$", re.IGNORECASE),
158 re.compile(r"^(?:auth|api|access|refresh|client|bearer|session|security)[-_]?(?:token|secret|key)$", re.IGNORECASE),
159)
162def _is_sensitive_tool_header_name(name: str) -> bool:
163 """Return whether a tool header name should be treated as sensitive.
165 Args:
166 name: Header name to evaluate.
168 Returns:
169 ``True`` when header value should be protected.
170 """
171 normalized_name = str(name).strip().lower()
172 return any(pattern.match(normalized_name) for pattern in _SENSITIVE_TOOL_HEADER_PATTERNS)
175def _is_encrypted_tool_header_value(value: Any) -> bool:
176 """Return whether a header value uses encrypted envelope format.
178 Args:
179 value: Header value candidate.
181 Returns:
182 ``True`` when value is an encrypted envelope mapping.
183 """
184 return isinstance(value, dict) and isinstance(value.get(_ENCRYPTED_TOOL_HEADER_VALUE_KEY), str)
187def _encrypt_tool_header_value(value: Any, existing_value: Any = None) -> Any:
188 """Encrypt a single sensitive tool header value.
190 Args:
191 value: Incoming header value from payload.
192 existing_value: Existing stored value used for masked-value merges.
194 Returns:
195 Encrypted envelope, preserved existing value, or ``None`` when cleared.
196 """
197 if value is None or value == "":
198 return value
200 if value == settings.masked_auth_value:
201 if _is_encrypted_tool_header_value(existing_value):
202 return existing_value
203 if existing_value in (None, ""):
204 return None
205 return _encrypt_tool_header_value(existing_value, None)
207 if _is_encrypted_tool_header_value(value):
208 return value
210 encrypted = encode_auth({_TOOL_HEADER_DATA_KEY: str(value)})
211 return {_ENCRYPTED_TOOL_HEADER_VALUE_KEY: encrypted}
214def _protect_tool_headers_for_storage(headers: Optional[Dict[str, Any]], existing_headers: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
215 """Encrypt sensitive tool header values before persistence.
217 Args:
218 headers: Incoming tool headers payload.
219 existing_headers: Existing stored headers used for masked-value merges.
221 Returns:
222 Header mapping with sensitive values protected for storage, or ``None``.
223 """
224 if headers is None:
225 return None
226 if not isinstance(headers, dict):
227 return None
229 existing_by_lower: Dict[str, Any] = {}
230 if isinstance(existing_headers, dict):
231 for key, existing_value in existing_headers.items():
232 existing_by_lower[str(key).strip().lower()] = existing_value
234 protected: Dict[str, Any] = {}
235 for key, value in headers.items():
236 if _is_sensitive_tool_header_name(key):
237 existing_value = existing_by_lower.get(str(key).strip().lower())
238 protected[key] = _encrypt_tool_header_value(value, existing_value)
239 else:
240 protected[key] = value
241 return protected
244def _decrypt_tool_header_value(value: Any) -> Any:
245 """Decrypt a single tool header envelope when possible.
247 Args:
248 value: Stored header value, possibly encrypted.
250 Returns:
251 Decrypted plain value when envelope is valid, else original value.
252 """
253 if not _is_encrypted_tool_header_value(value):
254 return value
256 encrypted_payload = value.get(_ENCRYPTED_TOOL_HEADER_VALUE_KEY)
257 if not encrypted_payload:
258 return value
260 try:
261 decoded = decode_auth(encrypted_payload)
262 if isinstance(decoded, dict):
263 if _TOOL_HEADER_DATA_KEY in decoded:
264 return decoded[_TOOL_HEADER_DATA_KEY]
265 if _TOOL_HEADER_LEGACY_VALUE_KEY in decoded:
266 return decoded[_TOOL_HEADER_LEGACY_VALUE_KEY]
267 except Exception as exc:
268 logger.warning("Failed to decrypt tool header value: %s", exc)
269 return value
272def _decrypt_tool_headers_for_runtime(headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
273 """Decrypt tool header map for runtime outbound requests.
275 Args:
276 headers: Stored header mapping.
278 Returns:
279 Header mapping with encrypted values decrypted where possible.
280 """
281 if not isinstance(headers, dict):
282 return {}
283 return {key: _decrypt_tool_header_value(value) for key, value in headers.items()}
286@lru_cache(maxsize=256)
287def _compile_jq_filter(jq_filter: str):
288 """Cache compiled jq filter program.
290 Args:
291 jq_filter: The jq filter string to compile.
293 Returns:
294 Compiled jq program object.
296 Raises:
297 ValueError: If the jq filter is invalid.
298 """
299 # pylint: disable=c-extension-no-member
300 return jq.compile(jq_filter)
303@lru_cache(maxsize=128)
304def _get_validator_class_and_check(schema_json: str) -> Tuple[type, dict]:
305 """Cache schema validation and validator class selection.
307 This caches the expensive operations:
308 1. Deserializing the schema
309 2. Selecting the appropriate validator class based on $schema
310 3. Checking the schema is valid
312 Supports multiple JSON Schema drafts by using fallback validators when the
313 auto-detected validator fails. This handles schemas using older draft features
314 (e.g., Draft 4 style exclusiveMinimum: true) that are invalid in newer drafts.
316 Args:
317 schema_json: Canonical JSON string of the schema (used as cache key).
319 Returns:
320 Tuple of (validator_class, schema_dict) ready for instantiation.
321 """
322 schema = orjson.loads(schema_json)
324 # First try auto-detection based on $schema
325 validator_cls = validators.validator_for(schema)
326 try:
327 validator_cls.check_schema(schema)
328 return validator_cls, schema
329 except jsonschema.exceptions.SchemaError:
330 pass
332 # Fallback: try older drafts that may accept schemas with legacy features
333 # (e.g., Draft 4/6 style boolean exclusiveMinimum/exclusiveMaximum)
334 for fallback_cls in [Draft7Validator, Draft6Validator, Draft4Validator]:
335 try:
336 fallback_cls.check_schema(schema)
337 return fallback_cls, schema
338 except jsonschema.exceptions.SchemaError:
339 continue
341 # If no validator accepts the schema, use the original and let it fail
342 # with a clear error message during validation
343 validator_cls.check_schema(schema)
344 return validator_cls, schema
347def _canonicalize_schema(schema: dict) -> str:
348 """Create a canonical JSON string of a schema for use as a cache key.
350 Args:
351 schema: The JSON Schema dictionary.
353 Returns:
354 Canonical JSON string with sorted keys.
355 """
356 return orjson.dumps(schema, option=orjson.OPT_SORT_KEYS).decode()
359def _validate_with_cached_schema(instance: Any, schema: dict) -> None:
360 """Validate instance against schema using cached validator class.
362 Creates a fresh validator instance for thread safety, but reuses
363 the cached validator class and schema check. Uses best_match to
364 preserve jsonschema.validate() error selection semantics.
366 Args:
367 instance: The data to validate.
368 schema: The JSON Schema to validate against.
370 Raises:
371 error: The best matching ValidationError from jsonschema validation.
372 jsonschema.exceptions.ValidationError: If validation fails.
373 jsonschema.exceptions.SchemaError: If the schema itself is invalid.
374 """
375 schema_json = _canonicalize_schema(schema)
376 validator_cls, checked_schema = _get_validator_class_and_check(schema_json)
377 # Create fresh validator instance for thread safety
378 validator = validator_cls(checked_schema)
379 # Use best_match to match jsonschema.validate() error selection behavior
380 error = jsonschema.exceptions.best_match(validator.iter_errors(instance))
381 if error is not None:
382 raise error
385def extract_using_jq(data, jq_filter=""):
386 """
387 Extracts data from a given input (string, dict, or list) using a jq filter string.
389 Uses cached compiled jq programs for performance.
391 Args:
392 data (str, dict, list): The input JSON data. Can be a string, dict, or list.
393 jq_filter (str): The jq filter string to extract the desired data.
395 Returns:
396 The result of applying the jq filter to the input data.
398 Examples:
399 >>> extract_using_jq('{"a": 1, "b": 2}', '.a')
400 [1]
401 >>> extract_using_jq({'a': 1, 'b': 2}, '.b')
402 [2]
403 >>> extract_using_jq('[{"a": 1}, {"a": 2}]', '.[].a')
404 [1, 2]
405 >>> extract_using_jq('not a json', '.a')
406 ['Invalid JSON string provided.']
407 >>> extract_using_jq({'a': 1}, '')
408 {'a': 1}
409 """
410 if not jq_filter or jq_filter == "":
411 return data
413 # Track if input was originally a string (for error handling)
414 was_string = isinstance(data, str)
416 if was_string:
417 # If the input is a string, parse it as JSON
418 try:
419 data = orjson.loads(data)
420 except orjson.JSONDecodeError:
421 return ["Invalid JSON string provided."]
422 elif not isinstance(data, (dict, list)):
423 # If the input is not a string, dict, or list, raise an error
424 return ["Input data must be a JSON string, dictionary, or list."]
426 # Apply the jq filter to the data using cached compiled program
427 try:
428 program = _compile_jq_filter(jq_filter)
429 result = program.input(data).all()
430 if result == [None]:
431 return [TextContent(type="text", text="Error applying jsonpath filter")]
432 except Exception as e:
433 message = "Error applying jsonpath filter: " + str(e)
434 return [TextContent(type="text", text=message)]
436 return result
439class ToolError(Exception):
440 """Base class for tool-related errors.
442 Examples:
443 >>> from mcpgateway.services.tool_service import ToolError
444 >>> err = ToolError("Something went wrong")
445 >>> str(err)
446 'Something went wrong'
447 """
450class ToolNotFoundError(ToolError):
451 """Raised when a requested tool is not found.
453 Examples:
454 >>> from mcpgateway.services.tool_service import ToolNotFoundError
455 >>> err = ToolNotFoundError("Tool xyz not found")
456 >>> str(err)
457 'Tool xyz not found'
458 >>> isinstance(err, ToolError)
459 True
460 """
463class ToolNameConflictError(ToolError):
464 """Raised when a tool name conflicts with existing (active or inactive) tool."""
466 def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = None, visibility: str = "public"):
467 """Initialize the error with tool information.
469 Args:
470 name: The conflicting tool name.
471 enabled: Whether the existing tool is enabled or not.
472 tool_id: ID of the existing tool if available.
473 visibility: The visibility of the tool ("public" or "team").
475 Examples:
476 >>> from mcpgateway.services.tool_service import ToolNameConflictError
477 >>> err = ToolNameConflictError('test_tool', enabled=False, tool_id=123)
478 >>> str(err)
479 'Public Tool already exists with name: test_tool (currently inactive, ID: 123)'
480 >>> err.name
481 'test_tool'
482 >>> err.enabled
483 False
484 >>> err.tool_id
485 123
486 """
487 self.name = name
488 self.enabled = enabled
489 self.tool_id = tool_id
490 if visibility == "team":
491 vis_label = "Team-level"
492 elif visibility == "private":
493 vis_label = "Private"
494 else:
495 vis_label = "Public"
496 message = f"{vis_label} Tool already exists with name: {name}"
497 if not enabled:
498 message += f" (currently inactive, ID: {tool_id})"
499 super().__init__(message)
502class ToolLockConflictError(ToolError):
503 """Raised when a tool row is locked by another transaction."""
506class ToolValidationError(ToolError):
507 """Raised when tool validation fails.
509 Examples:
510 >>> from mcpgateway.services.tool_service import ToolValidationError
511 >>> err = ToolValidationError("Invalid tool configuration")
512 >>> str(err)
513 'Invalid tool configuration'
514 >>> isinstance(err, ToolError)
515 True
516 """
519class ToolInvocationError(ToolError):
520 """Raised when tool invocation fails.
522 Examples:
523 >>> from mcpgateway.services.tool_service import ToolInvocationError
524 >>> err = ToolInvocationError("Tool execution failed")
525 >>> str(err)
526 'Tool execution failed'
527 >>> isinstance(err, ToolError)
528 True
529 >>> # Test with detailed error
530 >>> detailed_err = ToolInvocationError("Network timeout after 30 seconds")
531 >>> "timeout" in str(detailed_err)
532 True
533 >>> isinstance(err, ToolError)
534 True
535 """
538class ToolTimeoutError(ToolInvocationError):
539 """Raised when tool invocation times out.
541 This subclass is used to distinguish timeout errors from other invocation errors.
542 Timeout handlers call tool_post_invoke before raising this, so the generic exception
543 handler should skip calling post_invoke again to avoid double-counting failures.
544 """
547class ToolService(BaseService):
548 """Service for managing and invoking tools.
550 Handles:
551 - Tool registration and deregistration.
552 - Tool invocation and validation.
553 - Tool federation.
554 - Event notifications.
555 - Active/inactive tool management.
556 """
558 _visibility_model_cls = DbTool
560 def __init__(self) -> None:
561 """Initialize the tool service.
563 Examples:
564 >>> from mcpgateway.services.tool_service import ToolService
565 >>> service = ToolService()
566 >>> isinstance(service._event_service, EventService)
567 True
568 >>> hasattr(service, '_http_client')
569 True
570 """
571 self._event_service = EventService(channel_name="mcpgateway:tool_events")
572 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
573 self._plugin_manager: PluginManager | None = get_plugin_manager()
574 self.oauth_manager = OAuthManager(
575 request_timeout=int(settings.oauth_request_timeout if hasattr(settings, "oauth_request_timeout") else 30),
576 max_retries=int(settings.oauth_max_retries if hasattr(settings, "oauth_max_retries") else 3),
577 )
579 async def initialize(self) -> None:
580 """Initialize the service.
582 Examples:
583 >>> from mcpgateway.services.tool_service import ToolService
584 >>> service = ToolService()
585 >>> import asyncio
586 >>> asyncio.run(service.initialize()) # Should log "Initializing tool service"
587 """
588 logger.info("Initializing tool service")
589 await self._event_service.initialize()
591 async def shutdown(self) -> None:
592 """Shutdown the service.
594 Examples:
595 >>> from mcpgateway.services.tool_service import ToolService
596 >>> service = ToolService()
597 >>> import asyncio
598 >>> asyncio.run(service.shutdown()) # Should log "Tool service shutdown complete"
599 """
600 await self._http_client.aclose()
601 await self._event_service.shutdown()
602 logger.info("Tool service shutdown complete")
604 async def get_top_tools(self, db: Session, limit: Optional[int] = 5, include_deleted: bool = False) -> List[TopPerformer]:
605 """Retrieve the top-performing tools based on execution count.
607 Queries the database to get tools with their metrics, ordered by the number of executions
608 in descending order. Returns a list of TopPerformer objects containing tool details and
609 performance metrics. Results are cached for performance.
611 Args:
612 db (Session): Database session for querying tool metrics.
613 limit (Optional[int]): Maximum number of tools to return. Defaults to 5.
614 include_deleted (bool): Whether to include deleted tools from rollups.
616 Returns:
617 List[TopPerformer]: A list of TopPerformer objects, each containing:
618 - id: Tool ID.
619 - name: Tool name.
620 - execution_count: Total number of executions.
621 - avg_response_time: Average response time in seconds, or None if no metrics.
622 - success_rate: Success rate percentage, or None if no metrics.
623 - last_execution: Timestamp of the last execution, or None if no metrics.
624 """
625 # Check cache first (if enabled)
626 # First-Party
627 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
629 effective_limit = limit or 5
630 cache_key = f"top_tools:{effective_limit}:include_deleted={include_deleted}"
632 if is_cache_enabled():
633 cached = metrics_cache.get(cache_key)
634 if cached is not None:
635 return cached
637 # Use combined query that includes both raw metrics and rollup data
638 results = get_top_performers_combined(
639 db=db,
640 metric_type="tool",
641 entity_model=DbTool,
642 limit=effective_limit,
643 include_deleted=include_deleted,
644 )
645 top_performers = build_top_performers(results)
647 # Cache the result (if enabled)
648 if is_cache_enabled():
649 metrics_cache.set(cache_key, top_performers)
651 return top_performers
653 def _build_tool_cache_payload(self, tool: DbTool, gateway: Optional[DbGateway]) -> Dict[str, Any]:
654 """Build cache payload for tool lookup by name.
656 Args:
657 tool: Tool ORM instance.
658 gateway: Optional gateway ORM instance.
660 Returns:
661 Cache payload dict for tool lookup.
662 """
663 tool_payload = {
664 "id": str(tool.id),
665 "name": tool.name,
666 "original_name": tool.original_name,
667 "url": tool.url,
668 "description": tool.description,
669 "original_description": tool.original_description,
670 "integration_type": tool.integration_type,
671 "request_type": tool.request_type,
672 "headers": tool.headers or {},
673 "input_schema": tool.input_schema or {"type": "object", "properties": {}},
674 "output_schema": tool.output_schema,
675 "annotations": tool.annotations or {},
676 "auth_type": tool.auth_type,
677 "jsonpath_filter": tool.jsonpath_filter,
678 "custom_name": tool.custom_name,
679 "custom_name_slug": tool.custom_name_slug,
680 "display_name": tool.display_name,
681 "gateway_id": str(tool.gateway_id) if tool.gateway_id else None,
682 "enabled": bool(tool.enabled),
683 "reachable": bool(tool.reachable),
684 "tags": tool.tags or [],
685 "team_id": tool.team_id,
686 "owner_email": tool.owner_email,
687 "visibility": tool.visibility,
688 }
690 gateway_payload = None
691 if gateway:
692 gateway_payload = {
693 "id": str(gateway.id),
694 "name": gateway.name,
695 "url": gateway.url,
696 "description": gateway.description,
697 "slug": gateway.slug,
698 "transport": gateway.transport,
699 "capabilities": gateway.capabilities or {},
700 "passthrough_headers": gateway.passthrough_headers or [],
701 "auth_type": gateway.auth_type,
702 "ca_certificate": getattr(gateway, "ca_certificate", None),
703 "ca_certificate_sig": getattr(gateway, "ca_certificate_sig", None),
704 "enabled": bool(gateway.enabled),
705 "reachable": bool(gateway.reachable),
706 "team_id": gateway.team_id,
707 "owner_email": gateway.owner_email,
708 "visibility": gateway.visibility,
709 "tags": gateway.tags or [],
710 "gateway_mode": getattr(gateway, "gateway_mode", "cache"), # Gateway mode for direct proxy support
711 }
713 return {"status": "active", "tool": tool_payload, "gateway": gateway_payload}
715 def _pydantic_tool_from_payload(self, tool_payload: Dict[str, Any]) -> Optional[PydanticTool]:
716 """Build Pydantic tool metadata from cache payload.
718 Args:
719 tool_payload: Cached tool payload dict.
721 Returns:
722 Pydantic tool metadata or None if validation fails.
723 """
724 try:
725 return PydanticTool.model_validate(tool_payload)
726 except Exception as exc:
727 logger.debug("Failed to build PydanticTool from cache payload: %s", exc)
728 return None
730 def _pydantic_gateway_from_payload(self, gateway_payload: Dict[str, Any]) -> Optional[PydanticGateway]:
731 """Build Pydantic gateway metadata from cache payload.
733 Args:
734 gateway_payload: Cached gateway payload dict.
736 Returns:
737 Pydantic gateway metadata or None if validation fails.
738 """
739 try:
740 return PydanticGateway.model_validate(gateway_payload)
741 except Exception as exc:
742 logger.debug("Failed to build PydanticGateway from cache payload: %s", exc)
743 return None
745 async def _check_tool_access(
746 self,
747 db: Session,
748 tool_payload: Dict[str, Any],
749 user_email: Optional[str],
750 token_teams: Optional[List[str]],
751 ) -> bool:
752 """Check if user has access to a tool based on visibility rules.
754 Implements the same access control logic as list_tools() for consistency.
756 Access Rules:
757 - Public tools: Accessible by all authenticated users
758 - Team tools: Accessible by team members (team_id in user's teams)
759 - Private tools: Accessible only by owner (owner_email matches)
761 Args:
762 db: Database session for team membership lookup if needed.
763 tool_payload: Tool data dict with visibility, team_id, owner_email.
764 user_email: Email of the requesting user (None = unauthenticated).
765 token_teams: List of team IDs from token.
766 - None = unrestricted admin access
767 - [] = public-only token
768 - [...] = team-scoped token
770 Returns:
771 True if access is allowed, False otherwise.
772 """
773 visibility = tool_payload.get("visibility", "public")
774 tool_team_id = tool_payload.get("team_id")
775 tool_owner_email = tool_payload.get("owner_email")
777 # Public tools are accessible by everyone
778 if visibility == "public":
779 return True
781 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin
782 # This happens when is_admin=True and no team scoping in token
783 if token_teams is None and user_email is None:
784 return True
786 # No user context (but not admin) = deny access to non-public tools
787 if not user_email:
788 return False
790 # Public-only tokens (empty teams array) can ONLY access public tools
791 is_public_only_token = token_teams is not None and len(token_teams) == 0
792 if is_public_only_token:
793 return False # Already checked public above
795 # Owner can access their own private tools
796 if visibility == "private" and tool_owner_email and tool_owner_email == user_email:
797 return True
799 # Team tools: check team membership (matches list_tools behavior)
800 if tool_team_id:
801 # Use token_teams if provided, otherwise look up from DB
802 if token_teams is not None:
803 team_ids = token_teams
804 else:
805 team_service = TeamManagementService(db)
806 user_teams = await team_service.get_user_teams(user_email)
807 team_ids = [team.id for team in user_teams]
809 # Team/public visibility allows access if user is in the team
810 if visibility in ["team", "public"] and tool_team_id in team_ids:
811 return True
813 return False
815 def convert_tool_to_read(
816 self,
817 tool: DbTool,
818 include_metrics: bool = False,
819 include_auth: bool = True,
820 requesting_user_email: Optional[str] = None,
821 requesting_user_is_admin: bool = False,
822 requesting_user_team_roles: Optional[Dict[str, str]] = None,
823 ) -> ToolRead:
824 """Converts a DbTool instance into a ToolRead model, including aggregated metrics and
825 new API gateway fields: request_type and authentication credentials (masked).
827 Args:
828 tool (DbTool): The ORM instance of the tool.
829 include_metrics (bool): Whether to include metrics in the result. Defaults to False.
830 include_auth (bool): Whether to decode and include auth details. Defaults to True.
831 When False, skips expensive AES-GCM decryption and returns minimal auth info.
832 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
833 requesting_user_is_admin (bool): Whether the requester is an admin.
834 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
836 Returns:
837 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields.
838 """
839 # NOTE: This serves two purposes:
840 # 1. It determines whether to decode auth (used later)
841 # 2. It forces the tool object to lazily evaluate (required before copy)
842 has_encrypted_auth = tool.auth_type and tool.auth_value
844 # Copy the dict from the tool
845 tool_dict = tool.__dict__.copy()
846 tool_dict.pop("_sa_instance_state", None)
848 # Compute metrics in a single pass (matches server/resource/prompt service pattern)
849 if include_metrics:
850 metrics = tool.metrics_summary # Single-pass computation
851 tool_dict["metrics"] = metrics
852 tool_dict["execution_count"] = metrics["total_executions"]
853 else:
854 tool_dict["metrics"] = None
855 tool_dict["execution_count"] = None
857 tool_dict["request_type"] = tool.request_type
858 tool_dict["annotations"] = tool.annotations or {}
860 # Only decode auth if include_auth=True AND we have encrypted credentials
861 if include_auth and has_encrypted_auth:
862 decoded_auth_value = decode_auth(tool.auth_value)
863 if tool.auth_type == "basic":
864 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1])
865 username, password = decoded_bytes.decode("utf-8").split(":")
866 tool_dict["auth"] = {
867 "auth_type": "basic",
868 "username": username,
869 "password": settings.masked_auth_value if password else None,
870 }
871 elif tool.auth_type == "bearer":
872 tool_dict["auth"] = {
873 "auth_type": "bearer",
874 "token": settings.masked_auth_value if decoded_auth_value["Authorization"] else None,
875 }
876 elif tool.auth_type == "authheaders":
877 # Support multi-header format (list of {key, value} dicts)
878 if decoded_auth_value:
879 # Convert decoded dict to list format for frontend
880 auth_headers = [
881 {
882 "key": key,
883 "value": settings.masked_auth_value if value else None,
884 }
885 for key, value in decoded_auth_value.items()
886 ]
887 # Also include legacy single-header fields for backward compatibility
888 first_key = next(iter(decoded_auth_value))
889 tool_dict["auth"] = {
890 "auth_type": "authheaders",
891 "authHeaders": auth_headers, # Multi-header format (masked)
892 "auth_header_key": first_key, # Legacy format
893 "auth_header_value": settings.masked_auth_value if decoded_auth_value[first_key] else None, # Legacy format
894 }
895 else:
896 tool_dict["auth"] = None
897 else:
898 tool_dict["auth"] = None
899 elif not include_auth and has_encrypted_auth:
900 # LIST VIEW: Minimal auth info without decryption
901 # Only show auth_type for tools that have encrypted credentials
902 tool_dict["auth"] = {"auth_type": tool.auth_type}
903 else:
904 # No encrypted auth (includes OAuth tools where auth_value=None)
905 # Behavior unchanged from current implementation
906 tool_dict["auth"] = None
908 tool_dict["name"] = tool.name
909 # Handle displayName with fallback and None checks
910 display_name = getattr(tool, "display_name", None)
911 custom_name = getattr(tool, "custom_name", tool.original_name)
912 tool_dict["displayName"] = display_name or custom_name
913 tool_dict["custom_name"] = custom_name
914 tool_dict["gateway_slug"] = getattr(tool, "gateway_slug", "") or ""
915 tool_dict["custom_name_slug"] = getattr(tool, "custom_name_slug", "") or ""
916 tool_dict["tags"] = getattr(tool, "tags", []) or []
917 tool_dict["team"] = getattr(tool, "team", None)
919 # Mask custom headers unless the requester is allowed to modify this tool.
920 # Safe default: if no requester context is provided, mask everything.
921 headers = tool_dict.get("headers")
922 if headers:
923 tool_dict["headers"] = _decrypt_tool_headers_for_runtime(headers)
924 headers = tool_dict["headers"]
925 can_view = requesting_user_is_admin
926 if not can_view and getattr(tool, "owner_email", None) == requesting_user_email:
927 can_view = True
928 if (
929 not can_view
930 and getattr(tool, "visibility", None) == "team"
931 and getattr(tool, "team_id", None) is not None
932 and requesting_user_team_roles
933 and requesting_user_team_roles.get(str(tool.team_id)) == "owner"
934 ):
935 can_view = True
936 if not can_view:
937 tool_dict["headers"] = {k: settings.masked_auth_value for k in headers}
939 return ToolRead.model_validate(tool_dict)
941 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None:
942 """
943 Records a metric for a tool invocation.
945 This function calculates the response time using the provided start time and records
946 the metric details (including whether the invocation was successful and any error message)
947 into the database. The metric is then committed to the database.
949 Args:
950 db (Session): The SQLAlchemy database session.
951 tool (DbTool): The tool that was invoked.
952 start_time (float): The monotonic start time of the invocation.
953 success (bool): True if the invocation succeeded; otherwise, False.
954 error_message (Optional[str]): The error message if the invocation failed, otherwise None.
955 """
956 end_time = time.monotonic()
957 response_time = end_time - start_time
958 metric = ToolMetric(
959 tool_id=tool.id,
960 response_time=response_time,
961 is_success=success,
962 error_message=error_message,
963 )
964 db.add(metric)
965 db.commit()
967 def _record_tool_metric_by_id(
968 self,
969 db: Session,
970 tool_id: str,
971 start_time: float,
972 success: bool,
973 error_message: Optional[str],
974 ) -> None:
975 """Record tool metric using tool ID instead of ORM object.
977 This method is designed to be used with a fresh database session after the main
978 request session has been released. It avoids requiring the ORM tool object,
979 which may have been detached from the session.
981 Args:
982 db: A fresh database session (not the request session).
983 tool_id: The UUID string of the tool.
984 start_time: The monotonic start time of the invocation.
985 success: True if the invocation succeeded; otherwise, False.
986 error_message: The error message if the invocation failed, otherwise None.
987 """
988 end_time = time.monotonic()
989 response_time = end_time - start_time
990 metric = ToolMetric(
991 tool_id=tool_id,
992 response_time=response_time,
993 is_success=success,
994 error_message=error_message,
995 )
996 db.add(metric)
997 db.commit()
999 def _record_tool_metric_sync(
1000 self,
1001 tool_id: str,
1002 start_time: float,
1003 success: bool,
1004 error_message: Optional[str],
1005 ) -> None:
1006 """Synchronous helper to record tool metrics with its own session.
1008 This method creates a fresh database session, records the metric, and closes
1009 the session. Designed to be called via asyncio.to_thread() to avoid blocking
1010 the event loop.
1012 Args:
1013 tool_id: The UUID string of the tool.
1014 start_time: The monotonic start time of the invocation.
1015 success: True if the invocation succeeded; otherwise, False.
1016 error_message: The error message if the invocation failed, otherwise None.
1017 """
1018 with fresh_db_session() as db_metrics:
1019 self._record_tool_metric_by_id(
1020 db_metrics,
1021 tool_id=tool_id,
1022 start_time=start_time,
1023 success=success,
1024 error_message=error_message,
1025 )
1027 def _extract_and_validate_structured_content(self, tool: DbTool, tool_result: "ToolResult", candidate: Optional[Any] = None) -> bool:
1028 """
1029 Extract structured content (if any) and validate it against ``tool.output_schema``.
1031 Args:
1032 tool: The tool with an optional output schema to validate against.
1033 tool_result: The tool result containing content to validate.
1034 candidate: Optional structured payload to validate. If not provided, will attempt
1035 to parse the first TextContent item as JSON.
1037 Behavior:
1038 - If ``candidate`` is provided it is used as the structured payload to validate.
1039 - Otherwise the method will try to parse the first ``TextContent`` item in
1040 ``tool_result.content`` as JSON and use that as the candidate.
1041 - If no output schema is declared on the tool the method returns True (nothing to validate).
1042 - On successful validation the parsed value is attached to ``tool_result.structured_content``.
1043 When structured content is present and valid callers may drop textual ``content`` in favour
1044 of the structured payload.
1045 - On validation failure the method sets ``tool_result.content`` to a single ``TextContent``
1046 containing a compact JSON object describing the validation error, sets
1047 ``tool_result.is_error = True`` and returns False.
1049 Returns:
1050 True when the structured content is valid or when no schema is declared.
1051 False when validation fails.
1053 Examples:
1054 >>> from mcpgateway.services.tool_service import ToolService
1055 >>> from mcpgateway.common.models import TextContent, ToolResult
1056 >>> import json
1057 >>> service = ToolService()
1058 >>> # No schema declared -> nothing to validate
1059 >>> tool = type("T", (object,), {"output_schema": None})()
1060 >>> r = ToolResult(content=[TextContent(type="text", text='{"a":1}')])
1061 >>> service._extract_and_validate_structured_content(tool, r)
1062 True
1064 >>> # Valid candidate provided -> attaches structured_content and returns True
1065 >>> tool = type(
1066 ... "T",
1067 ... (object,),
1068 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}},
1069 ... )()
1070 >>> r = ToolResult(content=[])
1071 >>> service._extract_and_validate_structured_content(tool, r, candidate={"foo": "bar"})
1072 True
1073 >>> r.structured_content == {"foo": "bar"}
1074 True
1076 >>> # Invalid candidate -> returns False, marks result as error and emits details
1077 >>> tool = type(
1078 ... "T",
1079 ... (object,),
1080 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}},
1081 ... )()
1082 >>> r = ToolResult(content=[])
1083 >>> ok = service._extract_and_validate_structured_content(tool, r, candidate={"foo": 123})
1084 >>> ok
1085 False
1086 >>> r.is_error
1087 True
1088 >>> details = orjson.loads(r.content[0].text)
1089 >>> "received" in details
1090 True
1091 """
1092 try:
1093 output_schema = getattr(tool, "output_schema", None)
1094 # Nothing to do if the tool doesn't declare a schema
1095 if not output_schema:
1096 return True
1098 structured: Optional[Any] = None
1099 # Prefer explicit candidate
1100 if candidate is not None:
1101 structured = candidate
1102 else:
1103 # Try to parse first TextContent text payload as JSON
1104 for c in getattr(tool_result, "content", []) or []:
1105 try:
1106 if isinstance(c, dict) and "type" in c and c.get("type") == "text" and "text" in c:
1107 structured = orjson.loads(c.get("text") or "null")
1108 break
1109 except (orjson.JSONDecodeError, TypeError, ValueError):
1110 # ignore JSON parse errors and continue
1111 continue
1113 # If no structured data found, treat as valid (nothing to validate)
1114 if structured is None:
1115 return True
1117 # Try to normalize common wrapper shapes to match schema expectations
1118 schema_type = None
1119 try:
1120 if isinstance(output_schema, dict):
1121 schema_type = output_schema.get("type")
1122 except Exception:
1123 schema_type = None
1125 # Unwrap single-element list wrappers when schema expects object
1126 if isinstance(structured, list) and len(structured) == 1 and schema_type == "object":
1127 inner = structured[0]
1128 # If inner is a TextContent-like dict with 'text' JSON string, parse it
1129 if isinstance(inner, dict) and "text" in inner and "type" in inner and inner.get("type") == "text":
1130 try:
1131 structured = orjson.loads(inner.get("text") or "null")
1132 except Exception:
1133 # leave as-is if parsing fails
1134 structured = inner
1135 else:
1136 structured = inner
1138 # Attach structured content
1139 try:
1140 setattr(tool_result, "structured_content", structured)
1141 except Exception:
1142 logger.debug("Failed to set structured_content on ToolResult")
1144 # Validate using cached schema validator
1145 try:
1146 _validate_with_cached_schema(structured, output_schema)
1147 return True
1148 except jsonschema.exceptions.ValidationError as e:
1149 details = {
1150 "code": getattr(e, "validator", "validation_error"),
1151 "expected": e.schema.get("type") if isinstance(e.schema, dict) and "type" in e.schema else None,
1152 "received": type(e.instance).__name__.lower() if e.instance is not None else None,
1153 "path": list(e.absolute_path) if hasattr(e, "absolute_path") else list(e.path or []),
1154 "message": e.message,
1155 }
1156 try:
1157 tool_result.content = [TextContent(type="text", text=orjson.dumps(details).decode())]
1158 except Exception:
1159 tool_result.content = [TextContent(type="text", text=str(details))]
1160 tool_result.is_error = True
1161 logger.debug(f"structured_content validation failed for tool {getattr(tool, 'name', '<unknown>')}: {details}")
1162 return False
1163 except Exception as exc: # pragma: no cover - defensive
1164 logger.error(f"Error extracting/validating structured_content: {exc}")
1165 return False
1167 async def register_tool(
1168 self,
1169 db: Session,
1170 tool: ToolCreate,
1171 created_by: Optional[str] = None,
1172 created_from_ip: Optional[str] = None,
1173 created_via: Optional[str] = None,
1174 created_user_agent: Optional[str] = None,
1175 import_batch_id: Optional[str] = None,
1176 federation_source: Optional[str] = None,
1177 team_id: Optional[str] = None,
1178 owner_email: Optional[str] = None,
1179 visibility: str = None,
1180 ) -> ToolRead:
1181 """Register a new tool with team support.
1183 Args:
1184 db: Database session.
1185 tool: Tool creation schema.
1186 created_by: Username who created this tool.
1187 created_from_ip: IP address of creator.
1188 created_via: Creation method (ui, api, import, federation).
1189 created_user_agent: User agent of creation request.
1190 import_batch_id: UUID for bulk import operations.
1191 federation_source: Source gateway for federated tools.
1192 team_id: Optional team ID to assign tool to.
1193 owner_email: Optional owner email for tool ownership.
1194 visibility: Tool visibility (private, team, public).
1196 Returns:
1197 Created tool information.
1199 Raises:
1200 IntegrityError: If there is a database integrity error.
1201 ToolNameConflictError: If a tool with the same name and visibility public exists.
1202 ToolError: For other tool registration errors.
1204 Examples:
1205 >>> from mcpgateway.services.tool_service import ToolService
1206 >>> from unittest.mock import MagicMock, AsyncMock
1207 >>> from mcpgateway.schemas import ToolRead
1208 >>> service = ToolService()
1209 >>> db = MagicMock()
1210 >>> tool = MagicMock()
1211 >>> tool.name = 'test'
1212 >>> db.execute.return_value.scalar_one_or_none.return_value = None
1213 >>> mock_gateway = MagicMock()
1214 >>> mock_gateway.name = 'test_gateway'
1215 >>> db.add = MagicMock()
1216 >>> db.commit = MagicMock()
1217 >>> def mock_refresh(obj):
1218 ... obj.gateway = mock_gateway
1219 >>> db.refresh = MagicMock(side_effect=mock_refresh)
1220 >>> service._notify_tool_added = AsyncMock()
1221 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
1222 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
1223 >>> import asyncio
1224 >>> asyncio.run(service.register_tool(db, tool))
1225 'tool_read'
1226 """
1227 try:
1228 if tool.auth is None:
1229 auth_type = None
1230 auth_value = None
1231 else:
1232 auth_type = tool.auth.auth_type
1233 auth_value = tool.auth.auth_value
1235 if team_id is None:
1236 team_id = tool.team_id
1238 if owner_email is None:
1239 owner_email = tool.owner_email
1241 if visibility is None:
1242 visibility = tool.visibility or "public"
1243 # Check for existing tool with the same name and visibility
1244 if visibility.lower() == "public":
1245 # Check for existing public tool with the same name
1246 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "public")).scalar_one_or_none() # pylint: disable=comparison-with-callable
1247 if existing_tool:
1248 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
1249 elif visibility.lower() == "team" and team_id:
1250 # Check for existing team tool with the same name, team_id
1251 existing_tool = db.execute(
1252 select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "team", DbTool.team_id == team_id) # pylint: disable=comparison-with-callable
1253 ).scalar_one_or_none()
1254 if existing_tool:
1255 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
1257 db_tool = DbTool(
1258 original_name=tool.name,
1259 custom_name=tool.name,
1260 custom_name_slug=slugify(tool.name),
1261 display_name=tool.displayName or tool.name,
1262 url=str(tool.url),
1263 description=tool.description,
1264 original_description=tool.description,
1265 integration_type=tool.integration_type,
1266 request_type=tool.request_type,
1267 headers=_protect_tool_headers_for_storage(tool.headers),
1268 input_schema=tool.input_schema,
1269 output_schema=tool.output_schema,
1270 annotations=tool.annotations,
1271 jsonpath_filter=tool.jsonpath_filter,
1272 auth_type=auth_type,
1273 auth_value=auth_value,
1274 gateway_id=tool.gateway_id,
1275 tags=tool.tags or [],
1276 # Metadata fields
1277 created_by=created_by,
1278 created_from_ip=created_from_ip,
1279 created_via=created_via,
1280 created_user_agent=created_user_agent,
1281 import_batch_id=import_batch_id,
1282 federation_source=federation_source,
1283 version=1,
1284 # Team scoping fields
1285 team_id=team_id,
1286 owner_email=owner_email or created_by,
1287 visibility=visibility,
1288 # passthrough REST tools fields
1289 base_url=tool.base_url if tool.integration_type == "REST" else None,
1290 path_template=tool.path_template if tool.integration_type == "REST" else None,
1291 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None,
1292 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None,
1293 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None,
1294 expose_passthrough=(tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None,
1295 allowlist=tool.allowlist if tool.integration_type == "REST" else None,
1296 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None,
1297 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None,
1298 )
1299 db.add(db_tool)
1300 db.commit()
1301 db.refresh(db_tool)
1302 await self._notify_tool_added(db_tool)
1304 # Structured logging: Audit trail for tool creation
1305 audit_trail.log_action(
1306 user_id=created_by or "system",
1307 action="create_tool",
1308 resource_type="tool",
1309 resource_id=db_tool.id,
1310 resource_name=db_tool.name,
1311 user_email=owner_email,
1312 team_id=team_id,
1313 client_ip=created_from_ip,
1314 user_agent=created_user_agent,
1315 new_values={
1316 "name": db_tool.name,
1317 "display_name": db_tool.display_name,
1318 "visibility": visibility,
1319 "integration_type": db_tool.integration_type,
1320 },
1321 context={
1322 "created_via": created_via,
1323 "import_batch_id": import_batch_id,
1324 "federation_source": federation_source,
1325 },
1326 db=db,
1327 )
1329 # Structured logging: Log successful tool creation
1330 structured_logger.log(
1331 level="INFO",
1332 message="Tool created successfully",
1333 event_type="tool_created",
1334 component="tool_service",
1335 user_id=created_by,
1336 user_email=owner_email,
1337 team_id=team_id,
1338 resource_type="tool",
1339 resource_id=db_tool.id,
1340 custom_fields={
1341 "tool_name": db_tool.name,
1342 "visibility": visibility,
1343 "integration_type": db_tool.integration_type,
1344 },
1345 )
1347 # Refresh db_tool after logging commits (they expire the session objects)
1348 db.refresh(db_tool)
1350 # Invalidate cache after successful creation
1351 cache = _get_registry_cache()
1352 await cache.invalidate_tools()
1353 tool_lookup_cache = _get_tool_lookup_cache()
1354 await tool_lookup_cache.invalidate(db_tool.name, gateway_id=str(db_tool.gateway_id) if db_tool.gateway_id else None)
1355 # Also invalidate tags cache since tool tags may have changed
1356 # First-Party
1357 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1359 await admin_stats_cache.invalidate_tags()
1361 return self.convert_tool_to_read(db_tool, requesting_user_email=getattr(db_tool, "owner_email", None))
1362 except IntegrityError as ie:
1363 db.rollback()
1364 logger.error(f"IntegrityError during tool registration: {ie}")
1366 # Structured logging: Log database integrity error
1367 structured_logger.log(
1368 level="ERROR",
1369 message="Tool creation failed due to database integrity error",
1370 event_type="tool_creation_failed",
1371 component="tool_service",
1372 user_id=created_by,
1373 user_email=owner_email,
1374 error=ie,
1375 custom_fields={
1376 "tool_name": tool.name,
1377 },
1378 )
1379 raise ie
1380 except ToolNameConflictError as tnce:
1381 db.rollback()
1382 logger.error(f"ToolNameConflictError during tool registration: {tnce}")
1384 # Structured logging: Log name conflict error
1385 structured_logger.log(
1386 level="WARNING",
1387 message="Tool creation failed due to name conflict",
1388 event_type="tool_name_conflict",
1389 component="tool_service",
1390 user_id=created_by,
1391 user_email=owner_email,
1392 custom_fields={
1393 "tool_name": tool.name,
1394 "visibility": visibility,
1395 },
1396 )
1397 raise tnce
1398 except Exception as e:
1399 db.rollback()
1401 # Structured logging: Log generic tool creation failure
1402 structured_logger.log(
1403 level="ERROR",
1404 message="Tool creation failed",
1405 event_type="tool_creation_failed",
1406 component="tool_service",
1407 user_id=created_by,
1408 user_email=owner_email,
1409 error=e,
1410 custom_fields={
1411 "tool_name": tool.name,
1412 },
1413 )
1414 raise ToolError(f"Failed to register tool: {str(e)}")
1416 async def register_tools_bulk(
1417 self,
1418 db: Session,
1419 tools: List[ToolCreate],
1420 created_by: Optional[str] = None,
1421 created_from_ip: Optional[str] = None,
1422 created_via: Optional[str] = None,
1423 created_user_agent: Optional[str] = None,
1424 import_batch_id: Optional[str] = None,
1425 federation_source: Optional[str] = None,
1426 team_id: Optional[str] = None,
1427 owner_email: Optional[str] = None,
1428 visibility: Optional[str] = "public",
1429 conflict_strategy: str = "skip",
1430 ) -> Dict[str, Any]:
1431 """Register multiple tools in bulk with a single commit.
1433 This method provides significant performance improvements over individual
1434 tool registration by:
1435 - Using db.add_all() instead of individual db.add() calls
1436 - Performing a single commit for all tools
1437 - Batch conflict detection
1438 - Chunking for very large imports (>500 items)
1440 Args:
1441 db: Database session
1442 tools: List of tool creation schemas
1443 created_by: Username who created these tools
1444 created_from_ip: IP address of creator
1445 created_via: Creation method (ui, api, import, federation)
1446 created_user_agent: User agent of creation request
1447 import_batch_id: UUID for bulk import operations
1448 federation_source: Source gateway for federated tools
1449 team_id: Team ID to assign the tools to
1450 owner_email: Email of the user who owns these tools
1451 visibility: Tool visibility level (private, team, public)
1452 conflict_strategy: How to handle conflicts (skip, update, rename, fail)
1454 Returns:
1455 Dict with statistics:
1456 - created: Number of tools created
1457 - updated: Number of tools updated
1458 - skipped: Number of tools skipped
1459 - failed: Number of tools that failed
1460 - errors: List of error messages
1462 Raises:
1463 ToolError: If bulk registration fails critically
1465 Examples:
1466 >>> from mcpgateway.services.tool_service import ToolService
1467 >>> from unittest.mock import MagicMock
1468 >>> service = ToolService()
1469 >>> db = MagicMock()
1470 >>> tools = [MagicMock(), MagicMock()]
1471 >>> import asyncio
1472 >>> try:
1473 ... result = asyncio.run(service.register_tools_bulk(db, tools))
1474 ... except Exception:
1475 ... pass
1476 """
1477 if not tools:
1478 return {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1480 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1482 # Process in chunks to avoid memory issues and SQLite parameter limits
1483 chunk_size = 500
1485 for chunk_start in range(0, len(tools), chunk_size):
1486 chunk = tools[chunk_start : chunk_start + chunk_size]
1487 chunk_stats = self._process_tool_chunk(
1488 db=db,
1489 chunk=chunk,
1490 conflict_strategy=conflict_strategy,
1491 visibility=visibility,
1492 team_id=team_id,
1493 owner_email=owner_email,
1494 created_by=created_by,
1495 created_from_ip=created_from_ip,
1496 created_via=created_via,
1497 created_user_agent=created_user_agent,
1498 import_batch_id=import_batch_id,
1499 federation_source=federation_source,
1500 )
1502 # Aggregate stats
1503 for key, value in chunk_stats.items():
1504 if key == "errors":
1505 stats[key].extend(value)
1506 else:
1507 stats[key] += value
1509 if chunk_stats["created"] or chunk_stats["updated"]:
1510 cache = _get_registry_cache()
1511 await cache.invalidate_tools()
1512 tool_lookup_cache = _get_tool_lookup_cache()
1513 tool_name_map: Dict[str, Optional[str]] = {}
1514 for tool in chunk:
1515 name = getattr(tool, "name", None)
1516 if not name:
1517 continue
1518 gateway_id = getattr(tool, "gateway_id", None)
1519 tool_name_map[name] = str(gateway_id) if gateway_id else tool_name_map.get(name)
1520 for tool_name, gateway_id in tool_name_map.items():
1521 await tool_lookup_cache.invalidate(tool_name, gateway_id=gateway_id)
1522 # Also invalidate tags cache since tool tags may have changed
1523 # First-Party
1524 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1526 await admin_stats_cache.invalidate_tags()
1528 return stats
1530 def _process_tool_chunk(
1531 self,
1532 db: Session,
1533 chunk: List[ToolCreate],
1534 conflict_strategy: str,
1535 visibility: str,
1536 team_id: Optional[int],
1537 owner_email: Optional[str],
1538 created_by: str,
1539 created_from_ip: Optional[str],
1540 created_via: Optional[str],
1541 created_user_agent: Optional[str],
1542 import_batch_id: Optional[str],
1543 federation_source: Optional[str],
1544 ) -> dict:
1545 """Process a chunk of tools for bulk import.
1547 Args:
1548 db: The SQLAlchemy database session.
1549 chunk: List of ToolCreate objects to process.
1550 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail").
1551 visibility: Tool visibility level ("public", "team", or "private").
1552 team_id: Team ID for team-scoped tools.
1553 owner_email: Email of the tool owner.
1554 created_by: Email of the user creating the tools.
1555 created_from_ip: IP address of the request origin.
1556 created_via: Source of the creation (e.g., "api", "ui").
1557 created_user_agent: User agent string from the request.
1558 import_batch_id: Batch identifier for bulk imports.
1559 federation_source: Source identifier for federated tools.
1561 Returns:
1562 dict: Statistics dictionary with keys "created", "updated", "skipped", "failed", and "errors".
1563 """
1564 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1566 try:
1567 # Batch check for existing tools to detect conflicts
1568 tool_names = [tool.name for tool in chunk]
1570 if visibility.lower() == "public":
1571 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "public")
1572 elif visibility.lower() == "team" and team_id:
1573 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "team", DbTool.team_id == team_id)
1574 else:
1575 # Private tools - check by owner
1576 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "private", DbTool.owner_email == (owner_email or created_by))
1578 existing_tools = db.execute(existing_tools_query).scalars().all()
1579 existing_tools_map = {tool.name: tool for tool in existing_tools}
1581 tools_to_add = []
1582 tools_to_update = []
1584 for tool in chunk:
1585 result = self._process_single_tool_for_bulk(
1586 tool=tool,
1587 existing_tools_map=existing_tools_map,
1588 conflict_strategy=conflict_strategy,
1589 visibility=visibility,
1590 team_id=team_id,
1591 owner_email=owner_email,
1592 created_by=created_by,
1593 created_from_ip=created_from_ip,
1594 created_via=created_via,
1595 created_user_agent=created_user_agent,
1596 import_batch_id=import_batch_id,
1597 federation_source=federation_source,
1598 )
1600 if result["status"] == "add":
1601 tools_to_add.append(result["tool"])
1602 stats["created"] += 1
1603 elif result["status"] == "update":
1604 tools_to_update.append(result["tool"])
1605 stats["updated"] += 1
1606 elif result["status"] == "skip":
1607 stats["skipped"] += 1
1608 elif result["status"] == "fail":
1609 stats["failed"] += 1
1610 stats["errors"].append(result["error"])
1612 # Bulk add new tools
1613 if tools_to_add:
1614 db.add_all(tools_to_add)
1616 # Commit the chunk
1617 db.commit()
1619 # Refresh tools for notifications and audit trail
1620 for db_tool in tools_to_add:
1621 db.refresh(db_tool)
1622 # Notify subscribers (sync call in async context handled by caller)
1624 # Log bulk audit trail entry
1625 if tools_to_add or tools_to_update:
1626 audit_trail.log_action(
1627 user_id=created_by or "system",
1628 action="bulk_create_tools" if tools_to_add else "bulk_update_tools",
1629 resource_type="tool",
1630 resource_id=None,
1631 details={"count": len(tools_to_add) + len(tools_to_update), "import_batch_id": import_batch_id},
1632 db=db,
1633 )
1635 except Exception as e:
1636 db.rollback()
1637 logger.error(f"Failed to process tool chunk: {str(e)}")
1638 stats["failed"] += len(chunk)
1639 stats["errors"].append(f"Chunk processing failed: {str(e)}")
1641 return stats
1643 def _process_single_tool_for_bulk(
1644 self,
1645 tool: ToolCreate,
1646 existing_tools_map: dict,
1647 conflict_strategy: str,
1648 visibility: str,
1649 team_id: Optional[int],
1650 owner_email: Optional[str],
1651 created_by: str,
1652 created_from_ip: Optional[str],
1653 created_via: Optional[str],
1654 created_user_agent: Optional[str],
1655 import_batch_id: Optional[str],
1656 federation_source: Optional[str],
1657 ) -> dict:
1658 """Process a single tool for bulk import.
1660 Args:
1661 tool: ToolCreate object to process.
1662 existing_tools_map: Dictionary mapping tool names to existing DbTool objects.
1663 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail").
1664 visibility: Tool visibility level ("public", "team", or "private").
1665 team_id: Team ID for team-scoped tools.
1666 owner_email: Email of the tool owner.
1667 created_by: Email of the user creating the tool.
1668 created_from_ip: IP address of the request origin.
1669 created_via: Source of the creation (e.g., "api", "ui").
1670 created_user_agent: User agent string from the request.
1671 import_batch_id: Batch identifier for bulk imports.
1672 federation_source: Source identifier for federated tools.
1674 Returns:
1675 dict: Result dictionary with "status" key ("add", "update", "skip", or "fail")
1676 and either "tool" (DbTool object) or "error" (error message).
1677 """
1678 try:
1679 # Extract auth information
1680 if tool.auth is None:
1681 auth_type = None
1682 auth_value = None
1683 else:
1684 auth_type = tool.auth.auth_type
1685 auth_value = tool.auth.auth_value
1687 # Use provided parameters or schema values
1688 tool_team_id = team_id if team_id is not None else getattr(tool, "team_id", None)
1689 tool_owner_email = owner_email or getattr(tool, "owner_email", None) or created_by
1690 tool_visibility = visibility if visibility is not None else getattr(tool, "visibility", "public")
1692 existing_tool = existing_tools_map.get(tool.name)
1694 if existing_tool:
1695 # Handle conflict based on strategy
1696 if conflict_strategy == "skip":
1697 return {"status": "skip"}
1698 if conflict_strategy == "update":
1699 # Update existing tool
1700 existing_tool.display_name = tool.displayName or tool.name
1701 existing_tool.url = str(tool.url)
1702 existing_tool.description = tool.description
1703 if getattr(existing_tool, "original_description", None) is None:
1704 existing_tool.original_description = tool.description
1705 existing_tool.integration_type = tool.integration_type
1706 existing_tool.request_type = tool.request_type
1707 existing_tool.headers = _protect_tool_headers_for_storage(tool.headers, existing_headers=existing_tool.headers)
1708 existing_tool.input_schema = tool.input_schema
1709 existing_tool.output_schema = tool.output_schema
1710 existing_tool.annotations = tool.annotations
1711 existing_tool.jsonpath_filter = tool.jsonpath_filter
1712 existing_tool.auth_type = auth_type
1713 existing_tool.auth_value = auth_value
1714 existing_tool.tags = tool.tags or []
1715 existing_tool.modified_by = created_by
1716 existing_tool.modified_from_ip = created_from_ip
1717 existing_tool.modified_via = created_via
1718 existing_tool.modified_user_agent = created_user_agent
1719 existing_tool.updated_at = datetime.now(timezone.utc)
1720 existing_tool.version = (existing_tool.version or 1) + 1
1722 # Update REST-specific fields if applicable
1723 if tool.integration_type == "REST":
1724 existing_tool.base_url = tool.base_url
1725 existing_tool.path_template = tool.path_template
1726 existing_tool.query_mapping = tool.query_mapping
1727 existing_tool.header_mapping = tool.header_mapping
1728 existing_tool.timeout_ms = tool.timeout_ms
1729 existing_tool.expose_passthrough = tool.expose_passthrough if tool.expose_passthrough is not None else True
1730 existing_tool.allowlist = tool.allowlist
1731 existing_tool.plugin_chain_pre = tool.plugin_chain_pre
1732 existing_tool.plugin_chain_post = tool.plugin_chain_post
1734 return {"status": "update", "tool": existing_tool}
1736 if conflict_strategy == "rename":
1737 # Create with renamed tool
1738 new_name = f"{tool.name}_imported_{int(datetime.now().timestamp())}"
1739 db_tool = self._create_tool_object(
1740 tool,
1741 new_name,
1742 auth_type,
1743 auth_value,
1744 tool_team_id,
1745 tool_owner_email,
1746 tool_visibility,
1747 created_by,
1748 created_from_ip,
1749 created_via,
1750 created_user_agent,
1751 import_batch_id,
1752 federation_source,
1753 )
1754 return {"status": "add", "tool": db_tool}
1756 if conflict_strategy == "fail":
1757 return {"status": "fail", "error": f"Tool name conflict: {tool.name}"}
1759 # Create new tool
1760 db_tool = self._create_tool_object(
1761 tool,
1762 tool.name,
1763 auth_type,
1764 auth_value,
1765 tool_team_id,
1766 tool_owner_email,
1767 tool_visibility,
1768 created_by,
1769 created_from_ip,
1770 created_via,
1771 created_user_agent,
1772 import_batch_id,
1773 federation_source,
1774 )
1775 return {"status": "add", "tool": db_tool}
1777 except Exception as e:
1778 logger.warning(f"Failed to process tool {tool.name} in bulk operation: {str(e)}")
1779 return {"status": "fail", "error": f"Failed to process tool {tool.name}: {str(e)}"}
1781 def _create_tool_object(
1782 self,
1783 tool: ToolCreate,
1784 name: str,
1785 auth_type: Optional[str],
1786 auth_value: Optional[str],
1787 tool_team_id: Optional[int],
1788 tool_owner_email: Optional[str],
1789 tool_visibility: str,
1790 created_by: str,
1791 created_from_ip: Optional[str],
1792 created_via: Optional[str],
1793 created_user_agent: Optional[str],
1794 import_batch_id: Optional[str],
1795 federation_source: Optional[str],
1796 ) -> DbTool:
1797 """Create a DbTool object from ToolCreate schema.
1799 Args:
1800 tool: ToolCreate schema object containing tool data.
1801 name: Name of the tool.
1802 auth_type: Authentication type for the tool.
1803 auth_value: Authentication value/credentials for the tool.
1804 tool_team_id: Team ID for team-scoped tools.
1805 tool_owner_email: Email of the tool owner.
1806 tool_visibility: Tool visibility level ("public", "team", or "private").
1807 created_by: Email of the user creating the tool.
1808 created_from_ip: IP address of the request origin.
1809 created_via: Source of the creation (e.g., "api", "ui").
1810 created_user_agent: User agent string from the request.
1811 import_batch_id: Batch identifier for bulk imports.
1812 federation_source: Source identifier for federated tools.
1814 Returns:
1815 DbTool: Database model instance ready to be added to the session.
1816 """
1817 return DbTool(
1818 original_name=name,
1819 custom_name=name,
1820 custom_name_slug=slugify(name),
1821 display_name=tool.displayName or name,
1822 url=str(tool.url),
1823 description=tool.description,
1824 original_description=tool.description,
1825 integration_type=tool.integration_type,
1826 request_type=tool.request_type,
1827 headers=_protect_tool_headers_for_storage(tool.headers),
1828 input_schema=tool.input_schema,
1829 output_schema=tool.output_schema,
1830 annotations=tool.annotations,
1831 jsonpath_filter=tool.jsonpath_filter,
1832 auth_type=auth_type,
1833 auth_value=auth_value,
1834 gateway_id=tool.gateway_id,
1835 tags=tool.tags or [],
1836 created_by=created_by,
1837 created_from_ip=created_from_ip,
1838 created_via=created_via,
1839 created_user_agent=created_user_agent,
1840 import_batch_id=import_batch_id,
1841 federation_source=federation_source,
1842 version=1,
1843 team_id=tool_team_id,
1844 owner_email=tool_owner_email,
1845 visibility=tool_visibility,
1846 base_url=tool.base_url if tool.integration_type == "REST" else None,
1847 path_template=tool.path_template if tool.integration_type == "REST" else None,
1848 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None,
1849 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None,
1850 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None,
1851 expose_passthrough=((tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None),
1852 allowlist=tool.allowlist if tool.integration_type == "REST" else None,
1853 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None,
1854 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None,
1855 )
1857 async def list_tools(
1858 self,
1859 db: Session,
1860 include_inactive: bool = False,
1861 cursor: Optional[str] = None,
1862 tags: Optional[List[str]] = None,
1863 gateway_id: Optional[str] = None,
1864 limit: Optional[int] = None,
1865 page: Optional[int] = None,
1866 per_page: Optional[int] = None,
1867 user_email: Optional[str] = None,
1868 team_id: Optional[str] = None,
1869 visibility: Optional[str] = None,
1870 token_teams: Optional[List[str]] = None,
1871 _request_headers: Optional[Dict[str, str]] = None,
1872 requesting_user_email: Optional[str] = None,
1873 requesting_user_is_admin: bool = False,
1874 requesting_user_team_roles: Optional[Dict[str, str]] = None,
1875 ) -> Union[tuple[List[ToolRead], Optional[str]], Dict[str, Any]]:
1876 """
1877 Retrieve a list of registered tools from the database with pagination support.
1879 Args:
1880 db (Session): The SQLAlchemy database session.
1881 include_inactive (bool): If True, include inactive tools in the result.
1882 Defaults to False.
1883 cursor (Optional[str], optional): An opaque cursor token for pagination.
1884 Opaque base64-encoded string containing last item's ID.
1885 tags (Optional[List[str]]): Filter tools by tags. If provided, only tools with at least one matching tag will be returned.
1886 gateway_id (Optional[str]): Filter tools by gateway ID. Accepts the literal value 'null' to match NULL gateway_id.
1887 limit (Optional[int]): Maximum number of tools to return. Use 0 for all tools (no limit).
1888 If not specified, uses pagination_default_page_size.
1889 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1890 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1891 user_email (Optional[str]): User email for team-based access control. If None, no access control is applied.
1892 team_id (Optional[str]): Filter by specific team ID. Requires user_email for access validation.
1893 visibility (Optional[str]): Filter by visibility (private, team, public).
1894 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API token access
1895 where the token scope should be respected instead of the user's full team memberships.
1896 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
1897 Currently unused but kept for API consistency. Defaults to None.
1898 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
1899 requesting_user_is_admin (bool): Whether the requester is an admin.
1900 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
1902 Returns:
1903 tuple[List[ToolRead], Optional[str]]: Tuple containing:
1904 - List of tools for current page
1905 - Next cursor token if more results exist, None otherwise
1907 Examples:
1908 >>> from mcpgateway.services.tool_service import ToolService
1909 >>> from unittest.mock import MagicMock
1910 >>> service = ToolService()
1911 >>> db = MagicMock()
1912 >>> tool_read = MagicMock()
1913 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read)
1914 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
1915 >>> import asyncio
1916 >>> tools, next_cursor = asyncio.run(service.list_tools(db))
1917 >>> isinstance(tools, list)
1918 True
1919 """
1920 # Check cache for first page only (cursor=None)
1921 # Skip caching when:
1922 # - user_email is provided (team-filtered results are user-specific)
1923 # - token_teams is set (scoped access, e.g., public-only or team-scoped tokens)
1924 # - page-based pagination is used
1925 # This prevents cache poisoning where admin results could leak to public-only requests
1926 cache = _get_registry_cache()
1927 filters_hash = None
1928 # Only use the cache when using the real converter. In unit tests we often patch
1929 # convert_tool_to_read() to exercise error handling, and a warm cache would bypass it.
1930 try:
1931 converter_is_default = self.convert_tool_to_read.__func__ is ToolService.convert_tool_to_read # type: ignore[attr-defined]
1932 except Exception:
1933 converter_is_default = False
1935 if cursor is None and user_email is None and token_teams is None and page is None and converter_is_default:
1936 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None, gateway_id=gateway_id, limit=limit)
1937 cached = await cache.get("tools", filters_hash)
1938 if cached is not None:
1939 # Reconstruct ToolRead objects from cached dicts
1940 cached_tools = [ToolRead.model_validate(t) for t in cached["tools"]]
1941 return (cached_tools, cached.get("next_cursor"))
1943 # Build base query with ordering and eager load gateway + email_team to avoid N+1
1944 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)).order_by(desc(DbTool.created_at), desc(DbTool.id))
1946 # Apply active/inactive filter
1947 if not include_inactive:
1948 query = query.where(DbTool.enabled)
1949 query = await self._apply_access_control(query, db, user_email, token_teams, team_id)
1951 if visibility:
1952 query = query.where(DbTool.visibility == visibility)
1954 # Add gateway_id filtering if provided
1955 if gateway_id:
1956 if gateway_id.lower() == "null":
1957 query = query.where(DbTool.gateway_id.is_(None))
1958 else:
1959 query = query.where(DbTool.gateway_id == gateway_id)
1961 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
1962 if tags:
1963 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True))
1965 # Use unified pagination helper - handles both page and cursor pagination
1966 pag_result = await unified_paginate(
1967 db=db,
1968 query=query,
1969 page=page,
1970 per_page=per_page,
1971 cursor=cursor,
1972 limit=limit,
1973 base_url="/admin/tools", # Used for page-based links
1974 query_params={"include_inactive": include_inactive} if include_inactive else {},
1975 )
1977 next_cursor = None
1978 # Extract servers based on pagination type
1979 if page is not None:
1980 # Page-based: pag_result is a dict
1981 tools_db = pag_result["data"]
1982 else:
1983 # Cursor-based: pag_result is a tuple
1984 tools_db, next_cursor = pag_result
1986 db.commit() # Release transaction to avoid idle-in-transaction
1988 # Convert to ToolRead (common for both pagination types)
1989 # Team names are loaded via joinedload(DbTool.email_team)
1990 result = []
1991 for s in tools_db:
1992 try:
1993 result.append(
1994 self.convert_tool_to_read(
1995 s,
1996 include_metrics=False,
1997 include_auth=False,
1998 requesting_user_email=requesting_user_email,
1999 requesting_user_is_admin=requesting_user_is_admin,
2000 requesting_user_team_roles=requesting_user_team_roles,
2001 )
2002 )
2003 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2004 logger.exception(f"Failed to convert tool {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
2005 # Continue with remaining tools instead of failing completely
2007 # Return appropriate format based on pagination type
2008 if page is not None:
2009 # Page-based format
2010 return {
2011 "data": result,
2012 "pagination": pag_result["pagination"],
2013 "links": pag_result["links"],
2014 }
2016 # Cursor-based format
2018 # Cache first page results - only for non-user-specific/non-scoped queries
2019 # Must match the same conditions as cache lookup to prevent cache poisoning
2020 if filters_hash is not None and cursor is None and user_email is None and token_teams is None and page is None and converter_is_default:
2021 try:
2022 cache_data = {"tools": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
2023 await cache.set("tools", cache_data, filters_hash)
2024 except AttributeError:
2025 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
2027 return (result, next_cursor)
2029 async def list_server_tools(
2030 self,
2031 db: Session,
2032 server_id: str,
2033 include_inactive: bool = False,
2034 include_metrics: bool = False,
2035 cursor: Optional[str] = None,
2036 user_email: Optional[str] = None,
2037 token_teams: Optional[List[str]] = None,
2038 _request_headers: Optional[Dict[str, str]] = None,
2039 requesting_user_email: Optional[str] = None,
2040 requesting_user_is_admin: bool = False,
2041 requesting_user_team_roles: Optional[Dict[str, str]] = None,
2042 ) -> List[ToolRead]:
2043 """
2044 Retrieve a list of registered tools from the database.
2046 Args:
2047 db (Session): The SQLAlchemy database session.
2048 server_id (str): Server ID
2049 include_inactive (bool): If True, include inactive tools in the result.
2050 Defaults to False.
2051 include_metrics (bool): If True, all tool metrics included in result otherwise null.
2052 Defaults to False.
2053 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
2054 this parameter is ignored. Defaults to None.
2055 user_email (Optional[str]): User email for visibility filtering. If None, no filtering applied.
2056 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API
2057 token access where the token scope should be respected.
2058 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
2059 Currently unused but kept for API consistency. Defaults to None.
2060 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
2061 requesting_user_is_admin (bool): Whether the requester is an admin.
2062 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
2064 Returns:
2065 List[ToolRead]: A list of registered tools represented as ToolRead objects.
2067 Examples:
2068 >>> from mcpgateway.services.tool_service import ToolService
2069 >>> from unittest.mock import MagicMock
2070 >>> service = ToolService()
2071 >>> db = MagicMock()
2072 >>> tool_read = MagicMock()
2073 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read)
2074 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
2075 >>> import asyncio
2076 >>> result = asyncio.run(service.list_server_tools(db, 'server1'))
2077 >>> isinstance(result, list)
2078 True
2079 """
2081 if include_metrics:
2082 query = (
2083 select(DbTool)
2084 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2085 .options(selectinload(DbTool.metrics))
2086 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
2087 .where(server_tool_association.c.server_id == server_id)
2088 )
2089 else:
2090 query = (
2091 select(DbTool)
2092 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2093 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
2094 .where(server_tool_association.c.server_id == server_id)
2095 )
2097 cursor = None # Placeholder for pagination; ignore for now
2098 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}")
2100 if not include_inactive:
2101 query = query.where(DbTool.enabled)
2103 # Add visibility filtering if user context OR token_teams provided
2104 # This ensures unauthenticated requests with token_teams=[] only see public tools
2105 if user_email is not None or token_teams is not None: # empty-string user_email -> public-only filtering (secure default)
2106 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB
2107 if token_teams is not None:
2108 team_ids = token_teams
2109 elif user_email:
2110 team_service = TeamManagementService(db)
2111 user_teams = await team_service.get_user_teams(user_email)
2112 team_ids = [team.id for team in user_teams]
2113 else:
2114 team_ids = []
2116 # Check if this is a public-only token (empty teams array)
2117 # Public-only tokens can ONLY see public resources - no owner access
2118 is_public_only_token = token_teams is not None and len(token_teams) == 0
2120 access_conditions = [
2121 DbTool.visibility == "public",
2122 ]
2123 # Only include owner access for non-public-only tokens with user_email
2124 if not is_public_only_token and user_email:
2125 access_conditions.append(DbTool.owner_email == user_email)
2126 if team_ids:
2127 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2128 query = query.where(or_(*access_conditions))
2130 # Execute the query - team names are loaded via joinedload(DbTool.email_team)
2131 tools = db.execute(query).scalars().all()
2133 db.commit() # Release transaction to avoid idle-in-transaction
2135 result = []
2136 for tool in tools:
2137 try:
2138 result.append(
2139 self.convert_tool_to_read(
2140 tool,
2141 include_metrics=include_metrics,
2142 include_auth=False,
2143 requesting_user_email=requesting_user_email,
2144 requesting_user_is_admin=requesting_user_is_admin,
2145 requesting_user_team_roles=requesting_user_team_roles,
2146 )
2147 )
2148 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2149 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}")
2150 # Continue with remaining tools instead of failing completely
2152 return result
2154 async def list_tools_for_user(
2155 self,
2156 db: Session,
2157 user_email: str,
2158 team_id: Optional[str] = None,
2159 visibility: Optional[str] = None,
2160 include_inactive: bool = False,
2161 _skip: int = 0,
2162 _limit: int = 100,
2163 *,
2164 cursor: Optional[str] = None,
2165 gateway_id: Optional[str] = None,
2166 tags: Optional[List[str]] = None,
2167 limit: Optional[int] = None,
2168 ) -> tuple[List[ToolRead], Optional[str]]:
2169 """
2170 DEPRECATED: Use list_tools() with user_email parameter instead.
2172 List tools user has access to with team filtering and cursor pagination.
2174 This method is maintained for backward compatibility but is no longer used.
2175 New code should call list_tools() with user_email, team_id, and visibility parameters.
2177 Args:
2178 db: Database session
2179 user_email: Email of the user requesting tools
2180 team_id: Optional team ID to filter by specific team
2181 visibility: Optional visibility filter (private, team, public)
2182 include_inactive: Whether to include inactive tools
2183 _skip: Number of tools to skip for pagination (deprecated)
2184 _limit: Maximum number of tools to return (deprecated)
2185 cursor: Opaque cursor token for pagination
2186 gateway_id: Filter tools by gateway ID. Accepts literal 'null' for NULL gateway_id.
2187 tags: Filter tools by tags (match any)
2188 limit: Maximum number of tools to return. Use 0 for all tools (no limit).
2189 If not specified, uses pagination_default_page_size.
2191 Returns:
2192 tuple[List[ToolRead], Optional[str]]: Tools the user has access to and optional next_cursor
2193 """
2194 # Determine page size based on limit parameter
2195 # limit=None: use default, limit=0: no limit (all), limit>0: use specified (capped)
2196 if limit is None:
2197 page_size = settings.pagination_default_page_size
2198 elif limit == 0:
2199 page_size = None # No limit - fetch all
2200 else:
2201 page_size = min(limit, settings.pagination_max_page_size)
2203 # Decode cursor to get last_id if provided
2204 last_id = None
2205 if cursor:
2206 try:
2207 cursor_data = decode_cursor(cursor)
2208 last_id = cursor_data.get("id")
2209 logger.debug(f"Decoded cursor: last_id={last_id}")
2210 except ValueError as e:
2211 logger.warning(f"Invalid cursor, ignoring: {e}")
2213 # Build query following existing patterns from list_tools()
2214 team_service = TeamManagementService(db)
2215 user_teams = await team_service.get_user_teams(user_email)
2216 team_ids = [team.id for team in user_teams]
2218 # Eager load gateway and email_team to avoid N+1 when accessing gateway_slug and team name
2219 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2221 # Apply active/inactive filter
2222 if not include_inactive:
2223 query = query.where(DbTool.enabled.is_(True))
2225 if team_id:
2226 if team_id not in team_ids:
2227 return ([], None) # No access to team
2229 access_conditions = [
2230 and_(DbTool.team_id == team_id, DbTool.visibility.in_(["team", "public"])),
2231 and_(DbTool.team_id == team_id, DbTool.owner_email == user_email),
2232 ]
2233 query = query.where(or_(*access_conditions))
2234 else:
2235 access_conditions = [
2236 DbTool.owner_email == user_email,
2237 DbTool.visibility == "public",
2238 ]
2239 if team_ids:
2240 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2242 query = query.where(or_(*access_conditions))
2244 # Apply visibility filter if specified
2245 if visibility:
2246 query = query.where(DbTool.visibility == visibility)
2248 if gateway_id:
2249 if gateway_id.lower() == "null":
2250 query = query.where(DbTool.gateway_id.is_(None))
2251 else:
2252 query = query.where(DbTool.gateway_id == gateway_id)
2254 if tags:
2255 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True))
2257 # Apply cursor filter (WHERE id > last_id)
2258 if last_id:
2259 query = query.where(DbTool.id > last_id)
2261 # Execute query - team names are loaded via joinedload(DbTool.email_team)
2262 if page_size is not None:
2263 tools = db.execute(query.limit(page_size + 1)).scalars().all()
2264 else:
2265 tools = db.execute(query).scalars().all()
2267 db.commit() # Release transaction to avoid idle-in-transaction
2269 # Check if there are more results (only when paginating)
2270 has_more = page_size is not None and len(tools) > page_size
2271 if has_more:
2272 tools = tools[:page_size]
2274 # Convert to ToolRead objects
2275 result = []
2276 for tool in tools:
2277 try:
2278 result.append(self.convert_tool_to_read(tool, include_metrics=False, include_auth=False, requesting_user_email=user_email, requesting_user_is_admin=False))
2279 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2280 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}")
2281 # Continue with remaining tools instead of failing completely
2283 next_cursor = None
2284 # Generate cursor if there are more results (cursor-based pagination)
2285 if has_more and tools:
2286 last_tool = tools[-1]
2287 next_cursor = encode_cursor({"created_at": last_tool.created_at.isoformat(), "id": last_tool.id})
2289 return (result, next_cursor)
2291 async def get_tool(
2292 self,
2293 db: Session,
2294 tool_id: str,
2295 requesting_user_email: Optional[str] = None,
2296 requesting_user_is_admin: bool = False,
2297 requesting_user_team_roles: Optional[Dict[str, str]] = None,
2298 ) -> ToolRead:
2299 """
2300 Retrieve a tool by its ID.
2302 Args:
2303 db (Session): The SQLAlchemy database session.
2304 tool_id (str): The unique identifier of the tool.
2305 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
2306 requesting_user_is_admin (bool): Whether the requester is an admin.
2307 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
2309 Returns:
2310 ToolRead: The tool object.
2312 Raises:
2313 ToolNotFoundError: If the tool is not found.
2315 Examples:
2316 >>> from mcpgateway.services.tool_service import ToolService
2317 >>> from unittest.mock import MagicMock
2318 >>> service = ToolService()
2319 >>> db = MagicMock()
2320 >>> tool = MagicMock()
2321 >>> db.get.return_value = tool
2322 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
2323 >>> import asyncio
2324 >>> asyncio.run(service.get_tool(db, 'tool_id'))
2325 'tool_read'
2326 """
2327 tool = db.get(DbTool, tool_id)
2328 if not tool:
2329 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2331 tool_read = self.convert_tool_to_read(
2332 tool,
2333 requesting_user_email=requesting_user_email,
2334 requesting_user_is_admin=requesting_user_is_admin,
2335 requesting_user_team_roles=requesting_user_team_roles,
2336 )
2338 structured_logger.log(
2339 level="INFO",
2340 message="Tool retrieved successfully",
2341 event_type="tool_viewed",
2342 component="tool_service",
2343 team_id=getattr(tool, "team_id", None),
2344 resource_type="tool",
2345 resource_id=str(tool.id),
2346 custom_fields={
2347 "tool_name": tool.name,
2348 "include_metrics": bool(getattr(tool_read, "metrics", {})),
2349 },
2350 )
2352 return tool_read
2354 async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
2355 """
2356 Delete a tool by its ID.
2358 Args:
2359 db (Session): The SQLAlchemy database session.
2360 tool_id (str): The unique identifier of the tool.
2361 user_email (Optional[str]): Email of user performing delete (for ownership check).
2362 purge_metrics (bool): If True, delete raw + rollup metrics for this tool.
2364 Raises:
2365 ToolNotFoundError: If the tool is not found.
2366 PermissionError: If user doesn't own the tool.
2367 ToolError: For other deletion errors.
2369 Examples:
2370 >>> from mcpgateway.services.tool_service import ToolService
2371 >>> from unittest.mock import MagicMock, AsyncMock
2372 >>> service = ToolService()
2373 >>> db = MagicMock()
2374 >>> tool = MagicMock()
2375 >>> db.get.return_value = tool
2376 >>> db.delete = MagicMock()
2377 >>> db.commit = MagicMock()
2378 >>> service._notify_tool_deleted = AsyncMock()
2379 >>> import asyncio
2380 >>> asyncio.run(service.delete_tool(db, 'tool_id'))
2381 """
2382 try:
2383 tool = db.get(DbTool, tool_id)
2384 if not tool:
2385 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2387 # Check ownership if user_email provided
2388 if user_email:
2389 # First-Party
2390 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2392 permission_service = PermissionService(db)
2393 if not await permission_service.check_resource_ownership(user_email, tool):
2394 raise PermissionError("Only the owner can delete this tool")
2396 tool_info = {"id": tool.id, "name": tool.name}
2397 tool_name = tool.name
2398 tool_team_id = tool.team_id
2400 if purge_metrics:
2401 with pause_rollup_during_purge(reason=f"purge_tool:{tool_id}"):
2402 delete_metrics_in_batches(db, ToolMetric, ToolMetric.tool_id, tool_id)
2403 delete_metrics_in_batches(db, ToolMetricsHourly, ToolMetricsHourly.tool_id, tool_id)
2405 # Use DELETE with rowcount check for database-agnostic atomic delete
2406 # (RETURNING is not supported on MySQL/MariaDB)
2407 stmt = delete(DbTool).where(DbTool.id == tool_id)
2408 result = db.execute(stmt)
2409 if result.rowcount == 0:
2410 # Tool was already deleted by another concurrent request
2411 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2413 db.commit()
2414 await self._notify_tool_deleted(tool_info)
2415 logger.info(f"Permanently deleted tool: {tool_info['name']}")
2417 # Structured logging: Audit trail for tool deletion
2418 audit_trail.log_action(
2419 user_id=user_email or "system",
2420 action="delete_tool",
2421 resource_type="tool",
2422 resource_id=tool_info["id"],
2423 resource_name=tool_name,
2424 user_email=user_email,
2425 team_id=tool_team_id,
2426 old_values={
2427 "name": tool_name,
2428 },
2429 db=db,
2430 )
2432 # Structured logging: Log successful tool deletion
2433 structured_logger.log(
2434 level="INFO",
2435 message="Tool deleted successfully",
2436 event_type="tool_deleted",
2437 component="tool_service",
2438 user_email=user_email,
2439 team_id=tool_team_id,
2440 resource_type="tool",
2441 resource_id=tool_info["id"],
2442 custom_fields={
2443 "tool_name": tool_name,
2444 "purge_metrics": purge_metrics,
2445 },
2446 )
2448 # Invalidate cache after successful deletion
2449 cache = _get_registry_cache()
2450 await cache.invalidate_tools()
2451 tool_lookup_cache = _get_tool_lookup_cache()
2452 await tool_lookup_cache.invalidate(tool_name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
2453 # Also invalidate tags cache since tool tags may have changed
2454 # First-Party
2455 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2457 await admin_stats_cache.invalidate_tags()
2458 # Invalidate top performers cache
2459 # First-Party
2460 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
2462 metrics_cache.invalidate_prefix("top_tools:")
2463 metrics_cache.invalidate("tools")
2464 except PermissionError as pe:
2465 db.rollback()
2467 # Structured logging: Log permission error
2468 structured_logger.log(
2469 level="WARNING",
2470 message="Tool deletion failed due to permission error",
2471 event_type="tool_delete_permission_denied",
2472 component="tool_service",
2473 user_email=user_email,
2474 resource_type="tool",
2475 resource_id=tool_id,
2476 error=pe,
2477 )
2478 raise
2479 except Exception as e:
2480 db.rollback()
2482 # Structured logging: Log generic tool deletion failure
2483 structured_logger.log(
2484 level="ERROR",
2485 message="Tool deletion failed",
2486 event_type="tool_deletion_failed",
2487 component="tool_service",
2488 user_email=user_email,
2489 resource_type="tool",
2490 resource_id=tool_id,
2491 error=e,
2492 )
2493 raise ToolError(f"Failed to delete tool: {str(e)}")
2495 async def set_tool_state(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead:
2496 """
2497 Set the activation status of a tool.
2499 Args:
2500 db (Session): The SQLAlchemy database session.
2501 tool_id (str): The unique identifier of the tool.
2502 activate (bool): True to activate, False to deactivate.
2503 reachable (bool): True if the tool is reachable.
2504 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2505 skip_cache_invalidation: If True, skip cache invalidation (used for batch operations).
2507 Returns:
2508 ToolRead: The updated tool object.
2510 Raises:
2511 ToolNotFoundError: If the tool is not found.
2512 ToolLockConflictError: If the tool row is locked by another transaction.
2513 ToolError: For other errors.
2514 PermissionError: If user doesn't own the agent.
2516 Examples:
2517 >>> from mcpgateway.services.tool_service import ToolService
2518 >>> from unittest.mock import MagicMock, AsyncMock
2519 >>> from mcpgateway.schemas import ToolRead
2520 >>> service = ToolService()
2521 >>> db = MagicMock()
2522 >>> tool = MagicMock()
2523 >>> db.get.return_value = tool
2524 >>> db.commit = MagicMock()
2525 >>> db.refresh = MagicMock()
2526 >>> service._notify_tool_activated = AsyncMock()
2527 >>> service._notify_tool_deactivated = AsyncMock()
2528 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
2529 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
2530 >>> import asyncio
2531 >>> asyncio.run(service.set_tool_state(db, 'tool_id', True, True))
2532 'tool_read'
2533 """
2534 try:
2535 # Use nowait=True to fail fast if row is locked, preventing lock contention under high load
2536 try:
2537 tool = get_for_update(db, DbTool, tool_id, nowait=True)
2538 except OperationalError as lock_err:
2539 # Row is locked by another transaction - fail fast with 409
2540 db.rollback()
2541 raise ToolLockConflictError(f"Tool {tool_id} is currently being modified by another request") from lock_err
2542 if not tool:
2543 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2545 if user_email:
2546 # First-Party
2547 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2549 permission_service = PermissionService(db)
2550 if not await permission_service.check_resource_ownership(user_email, tool):
2551 raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool")
2553 is_activated = is_reachable = False
2554 if tool.enabled != activate:
2555 tool.enabled = activate
2556 is_activated = True
2558 if tool.reachable != reachable:
2559 tool.reachable = reachable
2560 is_reachable = True
2562 if is_activated or is_reachable:
2563 tool.updated_at = datetime.now(timezone.utc)
2565 db.commit()
2566 db.refresh(tool)
2568 # Invalidate cache after status change (skip for batch operations)
2569 if not skip_cache_invalidation:
2570 cache = _get_registry_cache()
2571 await cache.invalidate_tools()
2572 tool_lookup_cache = _get_tool_lookup_cache()
2573 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
2575 if not tool.enabled:
2576 # Inactive
2577 await self._notify_tool_deactivated(tool)
2578 elif tool.enabled and not tool.reachable:
2579 # Offline
2580 await self._notify_tool_offline(tool)
2581 else:
2582 # Active
2583 await self._notify_tool_activated(tool)
2585 logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}")
2587 # Structured logging: Audit trail for tool state change
2588 audit_trail.log_action(
2589 user_id=user_email or "system",
2590 action="set_tool_state",
2591 resource_type="tool",
2592 resource_id=tool.id,
2593 resource_name=tool.name,
2594 user_email=user_email,
2595 team_id=tool.team_id,
2596 new_values={
2597 "enabled": tool.enabled,
2598 "reachable": tool.reachable,
2599 },
2600 context={
2601 "action": "activate" if activate else "deactivate",
2602 },
2603 db=db,
2604 )
2606 # Structured logging: Log successful tool state change
2607 structured_logger.log(
2608 level="INFO",
2609 message=f"Tool {'activated' if activate else 'deactivated'} successfully",
2610 event_type="tool_state_changed",
2611 component="tool_service",
2612 user_email=user_email,
2613 team_id=tool.team_id,
2614 resource_type="tool",
2615 resource_id=tool.id,
2616 custom_fields={
2617 "tool_name": tool.name,
2618 "enabled": tool.enabled,
2619 "reachable": tool.reachable,
2620 },
2621 )
2623 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None))
2624 except PermissionError as e:
2625 # Structured logging: Log permission error
2626 structured_logger.log(
2627 level="WARNING",
2628 message="Tool state change failed due to permission error",
2629 event_type="tool_state_change_permission_denied",
2630 component="tool_service",
2631 user_email=user_email,
2632 resource_type="tool",
2633 resource_id=tool_id,
2634 error=e,
2635 )
2636 raise e
2637 except ToolLockConflictError:
2638 # Re-raise lock conflicts without wrapping - allows 409 response
2639 raise
2640 except ToolNotFoundError:
2641 # Re-raise not found without wrapping - allows 404 response
2642 raise
2643 except Exception as e:
2644 db.rollback()
2646 # Structured logging: Log generic tool state change failure
2647 structured_logger.log(
2648 level="ERROR",
2649 message="Tool state change failed",
2650 event_type="tool_state_change_failed",
2651 component="tool_service",
2652 user_email=user_email,
2653 resource_type="tool",
2654 resource_id=tool_id,
2655 error=e,
2656 )
2657 raise ToolError(f"Failed to set tool state: {str(e)}")
2659 async def invoke_tool_direct(
2660 self,
2661 gateway_id: str,
2662 name: str,
2663 arguments: Dict[str, Any],
2664 request_headers: Optional[Dict[str, str]] = None,
2665 meta_data: Optional[Dict[str, Any]] = None,
2666 user_email: Optional[str] = None,
2667 token_teams: Optional[List[str]] = None,
2668 ) -> types.CallToolResult:
2669 """
2670 Invoke a tool directly on a remote MCP gateway in direct_proxy mode.
2672 This bypasses all gateway processing (caching, plugins, validation) and forwards
2673 the tool call directly to the remote MCP server, returning the raw result.
2675 Args:
2676 gateway_id: Gateway ID to invoke the tool on.
2677 name: Name of tool to invoke.
2678 arguments: Tool arguments.
2679 request_headers: Headers from the request to pass through.
2680 meta_data: Optional metadata dictionary for additional context (e.g., request ID).
2681 user_email: Email of the requesting user for access control.
2682 token_teams: Team IDs from the user's token for access control.
2684 Returns:
2685 CallToolResult from the remote MCP server (as-is, no normalization).
2687 Raises:
2688 ToolNotFoundError: If gateway not found or access denied.
2689 ToolInvocationError: If invocation fails.
2690 """
2691 logger.info(f"Direct proxy tool invocation: {name} via gateway {gateway_id}")
2692 # Look up gateway
2693 # Use a fresh session for this lookup
2694 with fresh_db_session() as db:
2695 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
2696 if not gateway:
2697 raise ToolNotFoundError(f"Gateway {gateway_id} not found")
2699 if getattr(gateway, "gateway_mode", "cache") != "direct_proxy" or not settings.mcpgateway_direct_proxy_enabled:
2700 raise ToolInvocationError(f"Gateway {gateway_id} is not in direct_proxy mode")
2702 # SECURITY: Defensive access check — callers should also check,
2703 # but enforce here to prevent RBAC bypass if called from a new context.
2704 if not await check_gateway_access(db, gateway, user_email, token_teams):
2705 raise ToolNotFoundError(f"Tool not found: {name}")
2707 # Prepare headers with gateway auth
2708 headers = build_gateway_auth_headers(gateway)
2710 # Forward passthrough headers if configured
2711 if gateway.passthrough_headers and request_headers:
2712 for header_name in gateway.passthrough_headers:
2713 header_value = request_headers.get(header_name.lower()) or request_headers.get(header_name)
2714 if header_value:
2715 headers[header_name] = header_value
2717 gateway_url = gateway.url
2719 # Resolve the original (unprefixed) tool name for the remote server.
2720 # Tools registered via gateways are stored as "{gateway_slug}{separator}{slugified_name}",
2721 # but the remote server only knows the original name (e.g. "get_system_time" not "get-system-time").
2722 # Look up the tool's original_name from the DB; fall back to the prefixed name if not found
2723 # (e.g. when calling a tool that exists on the remote but hasn't been cached locally).
2724 remote_name = name
2725 tool_row = db.execute(select(DbTool).where(DbTool.name == name, DbTool.gateway_id == gateway_id)).scalar_one_or_none()
2726 if tool_row and tool_row.original_name:
2727 remote_name = tool_row.original_name
2728 else:
2729 # Fallback: strip the slug prefix (best-effort for tools not yet in DB)
2730 gateway_slug = getattr(gateway, "slug", None) or ""
2731 if gateway_slug:
2732 prefix = f"{gateway_slug}{settings.gateway_tool_name_separator}"
2733 if name.startswith(prefix):
2734 remote_name = name[len(prefix) :]
2736 # Use MCP SDK to connect and call tool
2737 try:
2738 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
2739 async with ClientSession(read_stream, write_stream) as session:
2740 await session.initialize()
2742 # Call tool with meta if provided
2743 if meta_data:
2744 logger.debug(f"Forwarding _meta to remote gateway: {meta_data}")
2745 tool_result = await session.call_tool(name=remote_name, arguments=arguments, meta=meta_data)
2746 else:
2747 tool_result = await session.call_tool(name=remote_name, arguments=arguments)
2749 logger.info(f"[INVOKE TOOL] Using direct_proxy mode for gateway {gateway.id} (from X-Context-Forge-Gateway-Id header). Meta Attached: {meta_data is not None}")
2750 return tool_result
2751 except Exception as e:
2752 logger.exception(f"Direct proxy tool invocation failed for {name}: {e}")
2753 raise ToolInvocationError(f"Direct proxy tool invocation failed: {str(e)}")
2755 async def invoke_tool(
2756 self,
2757 db: Session,
2758 name: str,
2759 arguments: Dict[str, Any],
2760 request_headers: Optional[Dict[str, str]] = None,
2761 app_user_email: Optional[str] = None,
2762 user_email: Optional[str] = None,
2763 token_teams: Optional[List[str]] = None,
2764 server_id: Optional[str] = None,
2765 plugin_context_table: Optional[PluginContextTable] = None,
2766 plugin_global_context: Optional[GlobalContext] = None,
2767 meta_data: Optional[Dict[str, Any]] = None,
2768 ) -> ToolResult:
2769 """
2770 Invoke a registered tool and record execution metrics.
2772 Args:
2773 db: Database session.
2774 name: Name of tool to invoke.
2775 arguments: Tool arguments.
2776 request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
2777 Defaults to None.
2778 app_user_email (Optional[str], optional): ContextForge user email for OAuth token retrieval.
2779 Required for OAuth-protected gateways.
2780 user_email (Optional[str], optional): User email for authorization checks.
2781 None = unauthenticated request.
2782 token_teams (Optional[List[str]], optional): Team IDs from JWT token for authorization.
2783 None = unrestricted admin, [] = public-only, [...] = team-scoped.
2784 server_id (Optional[str], optional): Virtual server ID for server scoping enforcement.
2785 If provided, tool must be attached to this server.
2786 plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing.
2787 plugin_global_context: Optional global context from middleware for consistency across hooks.
2788 meta_data: Optional metadata dictionary for additional context (e.g., request ID).
2790 Returns:
2791 Tool invocation result.
2793 Raises:
2794 ToolNotFoundError: If tool not found or access denied.
2795 ToolInvocationError: If invocation fails.
2796 ToolTimeoutError: If tool invocation times out.
2797 PluginViolationError: If plugin blocks tool invocation.
2798 PluginError: If encounters issue with plugin
2800 Examples:
2801 >>> # Note: This method requires extensive mocking of SQLAlchemy models,
2802 >>> # database relationships, and caching infrastructure, which is not
2803 >>> # suitable for doctests. See tests/unit/mcpgateway/services/test_tool_service.py
2804 >>> pass # doctest: +SKIP
2805 """
2806 # pylint: disable=comparison-with-callable
2807 logger.info(f"Invoking tool: {name} with arguments: {arguments.keys() if arguments else None} and headers: {request_headers.keys() if request_headers else None}, server_id={server_id}")
2808 # ═══════════════════════════════════════════════════════════════════════════
2809 # PHASE 1: Check for X-Context-Forge-Gateway-Id header for direct_proxy mode (no DB lookup)
2810 # ═══════════════════════════════════════════════════════════════════════════
2811 gateway_id_from_header = extract_gateway_id_from_headers(request_headers)
2813 # If X-Context-Forge-Gateway-Id header is present, check if gateway is in direct_proxy mode
2814 is_direct_proxy = False
2815 tool = None
2816 gateway = None
2817 tool_payload: Dict[str, Any] = {}
2818 gateway_payload: Optional[Dict[str, Any]] = None
2820 if gateway_id_from_header:
2821 # Look up gateway to check if it's in direct_proxy mode
2822 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none()
2823 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
2824 # SECURITY: Check gateway access before allowing direct proxy
2825 # This prevents RBAC bypass where any authenticated user could invoke tools
2826 # on any gateway just by knowing the gateway ID
2827 if not await check_gateway_access(db, gateway, user_email, token_teams):
2828 logger.warning(f"Access denied to gateway {gateway_id_from_header} in direct_proxy mode for user {user_email}")
2829 raise ToolNotFoundError(f"Tool not found: {name}")
2831 is_direct_proxy = True
2832 # Build minimal gateway payload for direct proxy (no tool lookup needed)
2833 gateway_payload = {
2834 "id": str(gateway.id),
2835 "name": gateway.name,
2836 "url": gateway.url,
2837 "auth_type": gateway.auth_type,
2838 # DbGateway.auth_value is JSON (dict); downstream code expects an encoded str.
2839 "auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
2840 "auth_query_params": gateway.auth_query_params,
2841 "oauth_config": gateway.oauth_config,
2842 "ca_certificate": gateway.ca_certificate,
2843 "ca_certificate_sig": gateway.ca_certificate_sig,
2844 "passthrough_headers": gateway.passthrough_headers,
2845 "gateway_mode": gateway.gateway_mode,
2846 }
2847 # Create minimal tool payload for direct proxy (no DB tool needed)
2848 tool_payload = {
2849 "id": None, # No tool ID in direct proxy mode
2850 "name": name,
2851 "original_name": name,
2852 "enabled": True,
2853 "reachable": True,
2854 "integration_type": "MCP",
2855 "request_type": "streamablehttp", # Default to streamablehttp
2856 "gateway_id": str(gateway.id),
2857 }
2858 logger.info(f"Direct proxy mode via X-Context-Forge-Gateway-Id header: passing tool '{name}' directly to remote MCP server at {gateway.url}")
2859 elif gateway:
2860 logger.debug(f"Gateway {gateway_id_from_header} found but not in direct_proxy mode (mode: {gateway.gateway_mode}), using normal lookup")
2861 else:
2862 logger.warning(f"Gateway {gateway_id_from_header} specified in X-Context-Forge-Gateway-Id header not found")
2864 # Normal mode: look up tool in database/cache
2865 if not is_direct_proxy:
2866 tool_lookup_cache = _get_tool_lookup_cache()
2867 cached_payload = await tool_lookup_cache.get(name) if tool_lookup_cache.enabled else None
2869 if cached_payload:
2870 status = cached_payload.get("status", "active")
2871 if status == "missing":
2872 raise ToolNotFoundError(f"Tool not found: {name}")
2873 if status == "inactive":
2874 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
2875 if status == "offline":
2876 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
2877 tool_payload = cached_payload.get("tool") or {}
2878 gateway_payload = cached_payload.get("gateway")
2880 if not tool_payload:
2881 # Eager load tool WITH gateway in single query to prevent lazy load N+1
2882 # Use a single query to avoid a race between separate enabled/inactive lookups.
2883 # Use scalars().all() instead of scalar_one_or_none() to handle duplicate
2884 # tool names across teams without crashing on MultipleResultsFound.
2885 tools = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.name == name)).scalars().all()
2887 if not tools:
2888 raise ToolNotFoundError(f"Tool not found: {name}")
2890 multiple_found = len(tools) > 1
2891 if not multiple_found:
2892 tool = tools[0]
2893 else:
2894 # Multiple tools found with same name — filter by access using
2895 # _check_tool_access (same rules as list_tools) and prioritize.
2896 # Priority (lower is better): team (0) > private (1) > public (2)
2897 visibility_priority = {"team": 0, "private": 1, "public": 2}
2898 accessible_tools: list[tuple[int, Any]] = []
2899 for t in tools:
2900 tool_dict = {"visibility": t.visibility, "team_id": t.team_id, "owner_email": t.owner_email}
2901 if await self._check_tool_access(db, tool_dict, user_email, token_teams):
2902 priority = visibility_priority.get(t.visibility, 99)
2903 accessible_tools.append((priority, t))
2905 if not accessible_tools:
2906 raise ToolNotFoundError(f"Tool not found: {name}")
2908 accessible_tools.sort(key=lambda x: x[0])
2910 # Check for ambiguity at the highest priority level
2911 best_priority = accessible_tools[0][0]
2912 best_tools = [t for p, t in accessible_tools if p == best_priority]
2914 if len(best_tools) > 1:
2915 raise ToolInvocationError(f"Multiple tools found with name '{name}' at same priority level. Tool name is ambiguous.")
2917 tool = best_tools[0]
2919 if not tool.enabled:
2920 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
2922 if not tool.reachable:
2923 await tool_lookup_cache.set_negative(name, "offline")
2924 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
2926 gateway = tool.gateway
2927 cache_payload = self._build_tool_cache_payload(tool, gateway)
2928 tool_payload = cache_payload.get("tool") or {}
2929 gateway_payload = cache_payload.get("gateway")
2930 # Skip caching when multiple tools share a name — resolution is
2931 # user-dependent, so a cached result could be wrong for other users.
2932 if not multiple_found:
2933 await tool_lookup_cache.set(name, cache_payload, gateway_id=tool_payload.get("gateway_id"))
2935 if tool_payload.get("enabled") is False:
2936 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
2937 if tool_payload.get("reachable") is False:
2938 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
2940 # ═══════════════════════════════════════════════════════════════════════════
2941 # SECURITY: Check tool access based on visibility and team membership
2942 # Skip these checks for direct_proxy mode (no tool in database)
2943 # ═══════════════════════════════════════════════════════════════════════════
2944 if not is_direct_proxy:
2945 if not await self._check_tool_access(db, tool_payload, user_email, token_teams):
2946 # Don't reveal tool existence - return generic "not found"
2947 raise ToolNotFoundError(f"Tool not found: {name}")
2949 # ═══════════════════════════════════════════════════════════════════════════
2950 # SECURITY: Enforce server scoping if server_id is provided
2951 # Tool must be attached to the specified virtual server
2952 # ═══════════════════════════════════════════════════════════════════════════
2953 if server_id:
2954 tool_id_for_check = tool_payload.get("id")
2955 if not tool_id_for_check:
2956 # Cannot verify server membership without tool ID - deny access
2957 # This should not happen with properly cached tools, but fail safe
2958 logger.warning(f"Tool '{name}' has no ID in payload, cannot verify server membership")
2959 raise ToolNotFoundError(f"Tool not found: {name}")
2961 server_match = db.execute(
2962 select(server_tool_association.c.tool_id).where(
2963 server_tool_association.c.server_id == server_id,
2964 server_tool_association.c.tool_id == tool_id_for_check,
2965 )
2966 ).first()
2967 if not server_match:
2968 raise ToolNotFoundError(f"Tool not found: {name}")
2970 # Extract A2A-related data from annotations (will be used after db.close() if A2A tool)
2971 tool_annotations = tool_payload.get("annotations") or {}
2972 tool_integration_type = tool_payload.get("integration_type")
2974 # Get passthrough headers from in-memory cache (Issue #1715)
2975 # This eliminates 42,000+ redundant DB queries under load
2976 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
2978 # Access gateway now (already eager-loaded) to prevent later lazy load
2979 if tool is not None:
2980 gateway = tool.gateway
2982 # ═══════════════════════════════════════════════════════════════════════════
2983 # PHASE 2: Extract all needed data to local variables before network I/O
2984 # This allows us to release the DB session before making HTTP calls
2985 # ═══════════════════════════════════════════════════════════════════════════
2986 tool_id = tool_payload.get("id") or (str(tool.id) if tool else "")
2987 tool_name_original = tool_payload.get("original_name") or tool_payload.get("name") or name
2988 tool_name_computed = tool_payload.get("name") or name
2989 tool_url = tool_payload.get("url")
2990 tool_integration_type = tool_payload.get("integration_type")
2991 tool_request_type = tool_payload.get("request_type")
2992 tool_headers = _decrypt_tool_headers_for_runtime(tool_payload.get("headers") or {})
2993 tool_auth_type = tool_payload.get("auth_type")
2994 tool_auth_value = tool_payload.get("auth_value")
2995 if tool is not None:
2996 runtime_tool_auth_value = getattr(tool, "auth_value", None)
2997 if isinstance(runtime_tool_auth_value, str):
2998 tool_auth_value = runtime_tool_auth_value
2999 if not isinstance(tool_auth_value, str):
3000 tool_auth_value = None
3001 tool_jsonpath_filter = tool_payload.get("jsonpath_filter")
3002 tool_output_schema = tool_payload.get("output_schema")
3003 tool_oauth_config = tool_payload.get("oauth_config") if isinstance(tool_payload.get("oauth_config"), dict) else None
3004 if tool is not None:
3005 runtime_tool_oauth_config = getattr(tool, "oauth_config", None)
3006 if isinstance(runtime_tool_oauth_config, dict):
3007 tool_oauth_config = runtime_tool_oauth_config
3008 tool_gateway_id = tool_payload.get("gateway_id")
3010 # Get effective timeout: per-tool timeout_ms (in seconds) or global fallback
3011 # timeout_ms is stored in milliseconds, convert to seconds
3012 tool_timeout_ms = tool_payload.get("timeout_ms")
3013 effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout
3015 # Save gateway existence as local boolean BEFORE db.close()
3016 # to avoid checking ORM object truthiness after session is closed
3017 has_gateway = gateway_payload is not None
3018 gateway_url = gateway_payload.get("url") if has_gateway else None
3019 gateway_name = gateway_payload.get("name") if has_gateway else None
3020 gateway_auth_type = gateway_payload.get("auth_type") if has_gateway else None
3021 gateway_auth_value = gateway_payload.get("auth_value") if has_gateway and isinstance(gateway_payload.get("auth_value"), str) else None
3022 gateway_auth_query_params = gateway_payload.get("auth_query_params") if has_gateway and isinstance(gateway_payload.get("auth_query_params"), dict) else None
3023 gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None
3024 if has_gateway and gateway is not None:
3025 runtime_gateway_auth_value = getattr(gateway, "auth_value", None)
3026 if isinstance(runtime_gateway_auth_value, dict):
3027 gateway_auth_value = encode_auth(runtime_gateway_auth_value)
3028 elif isinstance(runtime_gateway_auth_value, str):
3029 gateway_auth_value = runtime_gateway_auth_value
3030 runtime_gateway_query_params = getattr(gateway, "auth_query_params", None)
3031 if isinstance(runtime_gateway_query_params, dict):
3032 gateway_auth_query_params = runtime_gateway_query_params
3033 runtime_gateway_oauth_config = getattr(gateway, "oauth_config", None)
3034 if isinstance(runtime_gateway_oauth_config, dict):
3035 gateway_oauth_config = runtime_gateway_oauth_config
3036 gateway_ca_cert = gateway_payload.get("ca_certificate") if has_gateway else None
3037 gateway_ca_cert_sig = gateway_payload.get("ca_certificate_sig") if has_gateway else None
3038 gateway_passthrough = gateway_payload.get("passthrough_headers") if has_gateway else None
3039 gateway_id_str = gateway_payload.get("id") if has_gateway else None
3041 # Cache payload intentionally excludes sensitive auth material. For cache hits
3042 # (tool is None), hydrate auth-related fields from DB only when needed.
3043 if tool is None:
3044 requires_tool_auth_hydration = tool_auth_type in {"basic", "bearer", "authheaders", "oauth"}
3045 requires_gateway_auth_hydration = has_gateway and gateway_auth_type in {"basic", "bearer", "authheaders", "oauth", "query_param"}
3046 if requires_tool_auth_hydration or requires_gateway_auth_hydration:
3047 tool_id_for_hydration = tool_payload.get("id")
3048 if tool_id_for_hydration:
3049 tool_auth_row = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.id == tool_id_for_hydration)).scalar_one_or_none()
3050 if tool_auth_row:
3051 hydrated_tool_auth_value = getattr(tool_auth_row, "auth_value", None)
3052 if isinstance(hydrated_tool_auth_value, str):
3053 tool_auth_value = hydrated_tool_auth_value
3054 hydrated_tool_oauth_config = getattr(tool_auth_row, "oauth_config", None)
3055 if isinstance(hydrated_tool_oauth_config, dict):
3056 tool_oauth_config = hydrated_tool_oauth_config
3057 if has_gateway and tool_auth_row.gateway:
3058 hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None)
3059 if isinstance(hydrated_gateway_auth_value, dict):
3060 gateway_auth_value = encode_auth(hydrated_gateway_auth_value)
3061 elif isinstance(hydrated_gateway_auth_value, str):
3062 gateway_auth_value = hydrated_gateway_auth_value
3063 hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None)
3064 if isinstance(hydrated_gateway_query_params, dict):
3065 gateway_auth_query_params = hydrated_gateway_query_params
3066 hydrated_gateway_oauth_config = getattr(tool_auth_row.gateway, "oauth_config", None)
3067 if isinstance(hydrated_gateway_oauth_config, dict):
3068 gateway_oauth_config = hydrated_gateway_oauth_config
3070 # Decrypt and apply query param auth to URL if applicable
3071 gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None
3072 if gateway_auth_type == "query_param" and gateway_auth_query_params:
3073 # Decrypt the query param values
3074 gateway_auth_query_params_decrypted = {}
3075 for param_key, encrypted_value in gateway_auth_query_params.items():
3076 if encrypted_value:
3077 try:
3078 decrypted = decode_auth(encrypted_value)
3079 gateway_auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3080 except Exception: # noqa: S110 - intentionally skip failed decryptions
3081 # Silently skip params that fail decryption (may be corrupted or use old key)
3082 logger.debug(f"Failed to decrypt query param '{param_key}' for tool invocation")
3083 # Apply query params to gateway URL
3084 if gateway_auth_query_params_decrypted and gateway_url:
3085 gateway_url = apply_query_param_auth(gateway_url, gateway_auth_query_params_decrypted)
3087 # Create Pydantic models for plugins BEFORE HTTP calls (use ORM objects while still valid)
3088 # This prevents lazy loading during HTTP calls
3089 tool_metadata: Optional[PydanticTool] = None
3090 gateway_metadata: Optional[PydanticGateway] = None
3091 if self._plugin_manager:
3092 if tool is not None:
3093 tool_metadata = PydanticTool.model_validate(tool)
3094 if has_gateway and gateway is not None:
3095 gateway_metadata = PydanticGateway.model_validate(gateway)
3096 else:
3097 tool_metadata = self._pydantic_tool_from_payload(tool_payload)
3098 if has_gateway and gateway_payload:
3099 gateway_metadata = self._pydantic_gateway_from_payload(gateway_payload)
3101 tool_for_validation = tool if tool is not None else SimpleNamespace(output_schema=tool_output_schema, name=tool_name_computed)
3103 # ═══════════════════════════════════════════════════════════════════════════
3104 # A2A Agent Data Extraction (must happen before db.close())
3105 # Extract all A2A agent data to local variables so HTTP call can happen after db.close()
3106 # ═══════════════════════════════════════════════════════════════════════════
3107 a2a_agent_name: Optional[str] = None
3108 a2a_agent_endpoint_url: Optional[str] = None
3109 a2a_agent_type: Optional[str] = None
3110 a2a_agent_protocol_version: Optional[str] = None
3111 a2a_agent_auth_type: Optional[str] = None
3112 a2a_agent_auth_value: Optional[str] = None
3113 a2a_agent_auth_query_params: Optional[Dict[str, str]] = None
3115 if tool_integration_type == "A2A" and "a2a_agent_id" in tool_annotations:
3116 a2a_agent_id = tool_annotations.get("a2a_agent_id")
3117 if not a2a_agent_id:
3118 raise ToolNotFoundError(f"A2A tool '{name}' missing agent ID in annotations")
3120 # Query for the A2A agent
3121 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == a2a_agent_id)
3122 a2a_agent = db.execute(agent_query).scalar_one_or_none()
3124 if not a2a_agent:
3125 raise ToolNotFoundError(f"A2A agent not found for tool '{name}' (agent ID: {a2a_agent_id})")
3127 if not a2a_agent.enabled:
3128 raise ToolNotFoundError(f"A2A agent '{a2a_agent.name}' is disabled")
3130 # Extract all needed data to local variables before db.close()
3131 a2a_agent_name = a2a_agent.name
3132 a2a_agent_endpoint_url = a2a_agent.endpoint_url
3133 a2a_agent_type = a2a_agent.agent_type
3134 a2a_agent_protocol_version = a2a_agent.protocol_version
3135 a2a_agent_auth_type = a2a_agent.auth_type
3136 a2a_agent_auth_value = a2a_agent.auth_value
3137 a2a_agent_auth_query_params = a2a_agent.auth_query_params
3139 # ═══════════════════════════════════════════════════════════════════════════
3140 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
3141 # This prevents connection pool exhaustion during slow upstream requests.
3142 # All needed data has been extracted to local variables above.
3143 # The session will be closed again by FastAPI's get_db() finally block (safe no-op).
3144 # ═══════════════════════════════════════════════════════════════════════════
3145 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats)
3146 db.close()
3148 # Plugin hook: tool pre-invoke
3149 # Use existing context_table from previous hooks if available
3150 context_table = plugin_context_table
3152 # Reuse existing global_context from middleware or create new one
3153 # IMPORTANT: Use local variables (tool_gateway_id) instead of ORM object access
3154 if plugin_global_context:
3155 global_context = plugin_global_context
3156 # Update server_id using local variable (not ORM access)
3157 if tool_gateway_id and isinstance(tool_gateway_id, str):
3158 global_context.server_id = tool_gateway_id
3159 # Propagate user email to global context for plugin access
3160 if not plugin_global_context.user and app_user_email and isinstance(app_user_email, str):
3161 global_context.user = app_user_email
3162 else:
3163 # Create new context (fallback when middleware didn't run)
3164 # Use correlation ID from context if available, otherwise generate new one
3165 request_id = get_correlation_id() or uuid.uuid4().hex
3166 context_server_id = tool_gateway_id if tool_gateway_id and isinstance(tool_gateway_id, str) else "unknown"
3167 global_context = GlobalContext(request_id=request_id, server_id=context_server_id, tenant_id=None, user=app_user_email)
3169 start_time = time.monotonic()
3170 success = False
3171 error_message = None
3173 # Get trace_id from context for database span creation
3174 trace_id = current_trace_id.get()
3175 db_span_id = None
3176 db_span_ended = False
3177 observability_service = ObservabilityService() if trace_id else None
3179 # Create database span for observability_spans table
3180 if trace_id and observability_service:
3181 try:
3182 # Re-open database session for span creation (original was closed at line 2285)
3183 # Use commit=False since fresh_db_session() handles commits on exit
3184 with fresh_db_session() as span_db:
3185 db_span_id = observability_service.start_span(
3186 db=span_db,
3187 trace_id=trace_id,
3188 name="tool.invoke",
3189 kind="client",
3190 resource_type="tool",
3191 resource_name=name,
3192 resource_id=tool_id,
3193 attributes={
3194 "tool.name": name,
3195 "tool.id": tool_id,
3196 "tool.integration_type": tool_integration_type,
3197 "tool.gateway_id": tool_gateway_id,
3198 "arguments_count": len(arguments) if arguments else 0,
3199 "has_headers": bool(request_headers),
3200 },
3201 commit=False,
3202 )
3203 logger.debug(f"✓ Created tool.invoke span: {db_span_id} for tool: {name}")
3204 except Exception as e:
3205 logger.warning(f"Failed to start observability span for tool invocation: {e}")
3206 db_span_id = None
3208 # Create a trace span for OpenTelemetry export (Jaeger, Zipkin, etc.)
3209 with create_span(
3210 "tool.invoke",
3211 {
3212 "tool.name": name,
3213 "tool.id": tool_id,
3214 "tool.integration_type": tool_integration_type,
3215 "tool.gateway_id": tool_gateway_id,
3216 "arguments_count": len(arguments) if arguments else 0,
3217 "has_headers": bool(request_headers),
3218 },
3219 ) as span:
3220 try:
3221 # Get combined headers for the tool including base headers, auth, and passthrough headers
3222 headers = tool_headers.copy()
3223 if tool_integration_type == "REST":
3224 # Handle OAuth authentication for REST tools
3225 if tool_auth_type == "oauth" and isinstance(tool_oauth_config, dict) and tool_oauth_config:
3226 try:
3227 access_token = await self.oauth_manager.get_access_token(tool_oauth_config)
3228 headers["Authorization"] = f"Bearer {access_token}"
3229 except Exception as e:
3230 logger.error(f"Failed to obtain OAuth access token for tool {tool_name_computed}: {e}")
3231 raise ToolInvocationError(f"OAuth authentication failed: {str(e)}")
3232 else:
3233 credentials = decode_auth(tool_auth_value) if tool_auth_value else {}
3234 # Filter out empty header names/values to avoid "Illegal header name" errors
3235 filtered_credentials = {k: v for k, v in credentials.items() if k and v}
3236 headers.update(filtered_credentials)
3238 # Use cached passthrough headers (no DB query needed)
3239 if request_headers:
3240 headers = compute_passthrough_headers_cached(
3241 request_headers,
3242 headers,
3243 passthrough_allowed,
3244 gateway_auth_type=None,
3245 gateway_passthrough_headers=None, # REST tools don't use gateway auth here
3246 )
3247 # Read MCP-Session-Id from downstream client (MCP protocol header)
3248 # and normalize to x-mcp-session-id for our internal session affinity logic
3249 # The pool will strip this before sending to upstream
3250 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests)
3251 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
3252 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id")
3253 if mcp_session_id:
3254 headers["x-mcp-session-id"] = mcp_session_id
3256 worker_id = str(os.getpid())
3257 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
3258 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity")
3260 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
3261 # Use pre-created Pydantic model from Phase 2 (no ORM access)
3262 if tool_metadata:
3263 global_context.metadata[TOOL_METADATA] = tool_metadata
3264 pre_result, context_table = await self._plugin_manager.invoke_hook(
3265 ToolHookType.TOOL_PRE_INVOKE,
3266 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
3267 global_context=global_context,
3268 local_contexts=context_table, # Pass context from previous hooks
3269 violations_as_exceptions=True,
3270 )
3271 if pre_result.modified_payload:
3272 payload = pre_result.modified_payload
3273 name = payload.name
3274 arguments = payload.args
3275 if payload.headers is not None:
3276 headers = payload.headers.model_dump()
3278 # Build the payload based on integration type
3279 payload = arguments.copy()
3281 # Handle URL path parameter substitution (using local variable)
3282 final_url = tool_url
3283 if "{" in tool_url and "}" in tool_url:
3284 # Extract path parameters from URL template and arguments
3285 url_params = re.findall(r"\{(\w+)\}", tool_url)
3286 url_substitutions = {}
3288 for param in url_params:
3289 if param in payload:
3290 url_substitutions[param] = payload.pop(param) # Remove from payload
3291 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param]))
3292 else:
3293 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments")
3295 # --- Extract query params from URL ---
3296 parsed = urlparse(final_url)
3297 final_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
3299 query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()}
3301 # Merge leftover payload + query params
3302 payload.update(query_params)
3304 # Use the tool's request_type rather than defaulting to POST (using local variable)
3305 method = tool_request_type.upper() if tool_request_type else "POST"
3306 rest_start_time = time.time()
3307 try:
3308 if method == "GET":
3309 response = await asyncio.wait_for(self._http_client.get(final_url, params=payload, headers=headers), timeout=effective_timeout)
3310 else:
3311 response = await asyncio.wait_for(self._http_client.request(method, final_url, json=payload, headers=headers), timeout=effective_timeout)
3312 except (asyncio.TimeoutError, httpx.TimeoutException):
3313 rest_elapsed_ms = (time.time() - rest_start_time) * 1000
3314 structured_logger.log(
3315 level="WARNING",
3316 message=f"REST tool invocation timed out: {tool_name_computed}",
3317 component="tool_service",
3318 correlation_id=get_correlation_id(),
3319 duration_ms=rest_elapsed_ms,
3320 metadata={"event": "tool_timeout", "tool_name": tool_name_computed, "timeout_seconds": effective_timeout},
3321 )
3323 # Manually trigger circuit breaker (or other plugins) on timeout
3324 try:
3325 # First-Party
3326 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3328 tool_timeout_counter.labels(tool_name=name).inc()
3329 except Exception as exc:
3330 logger.debug(
3331 "Failed to increment tool_timeout_counter for %s: %s",
3332 name,
3333 exc,
3334 exc_info=True,
3335 )
3337 if self._plugin_manager:
3338 if context_table:
3339 for ctx in context_table.values():
3340 ctx.set_state("cb_timeout_failure", True)
3342 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3343 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3344 await self._plugin_manager.invoke_hook(
3345 ToolHookType.TOOL_POST_INVOKE,
3346 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3347 global_context=global_context,
3348 local_contexts=context_table,
3349 violations_as_exceptions=False,
3350 )
3352 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3353 response.raise_for_status()
3355 # Handle 204 No Content responses that have no body
3356 if response.status_code == 204:
3357 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")])
3358 success = True
3359 elif response.status_code not in [200, 201, 202, 206]:
3360 try:
3361 result = response.json()
3362 except orjson.JSONDecodeError:
3363 result = {"response_text": response.text} if response.text else {}
3364 error_val = result["error"] if "error" in result else "Tool error encountered"
3365 tool_result = ToolResult(
3366 content=[TextContent(type="text", text=error_val if isinstance(error_val, str) else orjson.dumps(error_val).decode())],
3367 is_error=True,
3368 )
3369 # Don't mark as successful for error responses - success remains False
3370 else:
3371 try:
3372 result = response.json()
3373 except orjson.JSONDecodeError:
3374 result = {"response_text": response.text} if response.text else {}
3375 logger.debug(f"REST API tool response: {result}")
3376 filtered_response = extract_using_jq(result, tool_jsonpath_filter)
3377 # Check if extract_using_jq returned an error (list of TextContent objects)
3378 if isinstance(filtered_response, list) and len(filtered_response) > 0 and isinstance(filtered_response[0], TextContent):
3379 # Error case - use the TextContent directly
3380 tool_result = ToolResult(content=filtered_response, is_error=True)
3381 success = False
3382 else:
3383 # Success case - serialize the filtered response
3384 serialized = orjson.dumps(filtered_response, option=orjson.OPT_INDENT_2)
3385 tool_result = ToolResult(content=[TextContent(type="text", text=serialized.decode())])
3386 success = True
3387 # If output schema is present, validate and attach structured content
3388 if tool_output_schema:
3389 valid = self._extract_and_validate_structured_content(tool_for_validation, tool_result, candidate=filtered_response)
3390 success = bool(valid)
3391 elif tool_integration_type == "MCP":
3392 transport = tool_request_type.lower() if tool_request_type else "sse"
3394 # Handle OAuth authentication for the gateway (using local variables)
3395 # NOTE: Use has_gateway instead of gateway to avoid accessing detached ORM object
3396 if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config:
3397 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
3399 if grant_type == "authorization_code":
3400 # For Authorization Code flow, try to get stored tokens
3401 # NOTE: Use fresh_db_session() since the original db was closed
3402 try:
3403 # First-Party
3404 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3406 with fresh_db_session() as token_db:
3407 token_storage = TokenStorageService(token_db)
3409 # Get user-specific OAuth token
3410 if not app_user_email:
3411 raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.")
3413 access_token = await token_storage.get_user_token(gateway_id_str, app_user_email)
3415 if access_token:
3416 headers = {"Authorization": f"Bearer {access_token}"}
3417 else:
3418 # User hasn't authorized this gateway yet
3419 raise ToolInvocationError(f"Please authorize {gateway_name} first. Visit /oauth/authorize/{gateway_id_str} to complete OAuth flow.")
3420 except Exception as e:
3421 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
3422 raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}")
3423 else:
3424 # For Client Credentials flow, get token directly (no DB needed)
3425 try:
3426 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
3427 headers = {"Authorization": f"Bearer {access_token}"}
3428 except Exception as e:
3429 logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}")
3430 raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}")
3431 else:
3432 headers = decode_auth(gateway_auth_value) if gateway_auth_value else {}
3434 # Use cached passthrough headers (no DB query needed)
3435 if request_headers:
3436 headers = compute_passthrough_headers_cached(
3437 request_headers, headers, passthrough_allowed, gateway_auth_type=gateway_auth_type, gateway_passthrough_headers=gateway_passthrough
3438 )
3439 # Read MCP-Session-Id from downstream client (MCP protocol header)
3440 # and normalize to x-mcp-session-id for our internal session affinity logic
3441 # The pool will strip this before sending to upstream
3442 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests)
3443 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
3444 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id")
3445 if mcp_session_id:
3446 headers["x-mcp-session-id"] = mcp_session_id
3448 worker_id = str(os.getpid())
3449 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
3450 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity (MCP transport)")
3452 def create_ssl_context(ca_certificate: str) -> ssl.SSLContext:
3453 """Create an SSL context with the provided CA certificate.
3455 Uses caching to avoid repeated SSL context creation for the same certificate.
3457 Args:
3458 ca_certificate: CA certificate in PEM format
3460 Returns:
3461 ssl.SSLContext: Configured SSL context
3462 """
3463 return get_cached_ssl_context(ca_certificate)
3465 def get_httpx_client_factory(
3466 headers: dict[str, str] | None = None,
3467 timeout: httpx.Timeout | None = None,
3468 auth: httpx.Auth | None = None,
3469 ) -> httpx.AsyncClient:
3470 """Factory function to create httpx.AsyncClient with optional CA certificate.
3472 Args:
3473 headers: Optional headers for the client
3474 timeout: Optional timeout for the client
3475 auth: Optional auth for the client
3477 Returns:
3478 httpx.AsyncClient: Configured HTTPX async client
3480 Raises:
3481 Exception: If CA certificate signature is invalid
3482 """
3483 # Use local variables instead of ORM objects (captured from outer scope)
3484 valid = False
3485 if gateway_ca_cert:
3486 if settings.enable_ed25519_signing:
3487 public_key_pem = settings.ed25519_public_key
3488 valid = validate_signature(gateway_ca_cert.encode(), gateway_ca_cert_sig, public_key_pem)
3489 else:
3490 valid = True
3491 # First-Party
3492 from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout # pylint: disable=import-outside-toplevel
3494 if valid:
3495 ctx = create_ssl_context(gateway_ca_cert)
3496 else:
3497 ctx = None
3499 # Use effective_timeout for read operations if not explicitly overridden by caller
3500 # This ensures the underlying client waits at least as long as the tool configuration requires
3501 factory_timeout = timeout if timeout else get_http_timeout(read_timeout=effective_timeout)
3503 return httpx.AsyncClient(
3504 verify=ctx if ctx else get_default_verify(),
3505 follow_redirects=True,
3506 headers=headers,
3507 timeout=factory_timeout,
3508 auth=auth,
3509 limits=httpx.Limits(
3510 max_connections=settings.httpx_max_connections,
3511 max_keepalive_connections=settings.httpx_max_keepalive_connections,
3512 keepalive_expiry=settings.httpx_keepalive_expiry,
3513 ),
3514 )
3516 async def connect_to_sse_server(server_url: str, headers: dict = headers):
3517 """Connect to an MCP server running with SSE transport.
3519 Args:
3520 server_url: MCP Server SSE URL
3521 headers: HTTP headers to include in the request
3523 Returns:
3524 ToolResult: Result of tool call
3526 Raises:
3527 ToolInvocationError: If the tool invocation fails during execution.
3528 ToolTimeoutError: If the tool invocation times out.
3529 BaseException: On connection or communication errors
3531 """
3532 # Get correlation ID for distributed tracing
3533 correlation_id = get_correlation_id()
3535 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions.
3536 # MCP SDK pins headers at transport creation, so adding per-request headers
3537 # would cause the first request's correlation ID to be reused for all
3538 # subsequent requests on the same pooled session. Correlation IDs are
3539 # still logged locally for tracing within the gateway.
3541 # Log MCP call start (using local variables)
3542 # Sanitize server_url to redact sensitive query params from logs
3543 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted)
3544 mcp_start_time = time.time()
3545 structured_logger.log(
3546 level="INFO",
3547 message=f"MCP tool call started: {tool_name_original}",
3548 component="tool_service",
3549 correlation_id=correlation_id,
3550 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "sse"},
3551 )
3553 try:
3554 # Use session pool if enabled for 10-20x latency improvement
3555 use_pool = False
3556 pool = None
3557 if settings.mcp_session_pool_enabled:
3558 try:
3559 pool = get_mcp_session_pool()
3560 use_pool = True
3561 except RuntimeError:
3562 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3563 pass
3565 if use_pool and pool is not None:
3566 # Pooled path: do NOT add per-request headers (they would be pinned)
3567 async with pool.session(
3568 url=server_url,
3569 headers=headers,
3570 transport_type=TransportType.SSE,
3571 httpx_client_factory=get_httpx_client_factory,
3572 user_identity=app_user_email,
3573 gateway_id=gateway_id_str,
3574 ) as pooled:
3575 tool_call_result = await asyncio.wait_for(pooled.session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3576 else:
3577 # Non-pooled path: safe to add per-request headers
3578 if correlation_id and headers:
3579 headers["X-Correlation-ID"] = correlation_id
3580 # Fallback to per-call sessions when pool disabled or not initialized
3581 async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams:
3582 async with ClientSession(*streams) as session:
3583 await session.initialize()
3584 tool_call_result = await asyncio.wait_for(session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3586 # Log successful MCP call
3587 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3588 structured_logger.log(
3589 level="INFO",
3590 message=f"MCP tool call completed: {tool_name_original}",
3591 component="tool_service",
3592 correlation_id=correlation_id,
3593 duration_ms=mcp_duration_ms,
3594 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "success": True},
3595 )
3597 return tool_call_result
3598 except (asyncio.TimeoutError, httpx.TimeoutException):
3599 # Handle timeout specifically - log and raise ToolInvocationError
3600 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3601 structured_logger.log(
3602 level="WARNING",
3603 message=f"MCP SSE tool invocation timed out: {tool_name_original}",
3604 component="tool_service",
3605 correlation_id=correlation_id,
3606 duration_ms=mcp_duration_ms,
3607 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "timeout_seconds": effective_timeout},
3608 )
3610 # Manually trigger circuit breaker (or other plugins) on timeout
3611 try:
3612 # First-Party
3613 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3615 tool_timeout_counter.labels(tool_name=name).inc()
3616 except Exception as exc:
3617 logger.debug(
3618 "Failed to increment tool_timeout_counter for %s: %s",
3619 name,
3620 exc,
3621 exc_info=True,
3622 )
3624 if self._plugin_manager:
3625 if context_table:
3626 for ctx in context_table.values():
3627 ctx.set_state("cb_timeout_failure", True)
3629 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3630 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3631 await self._plugin_manager.invoke_hook(
3632 ToolHookType.TOOL_POST_INVOKE,
3633 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3634 global_context=global_context,
3635 local_contexts=context_table,
3636 violations_as_exceptions=False,
3637 )
3639 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3640 except BaseException as e:
3641 # Extract root cause from ExceptionGroup (Python 3.11+)
3642 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3643 root_cause = e
3644 if isinstance(e, BaseExceptionGroup):
3645 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3646 root_cause = root_cause.exceptions[0]
3647 # Log failed MCP call (using local variables)
3648 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3649 # Sanitize error message to prevent URL secrets from leaking in logs
3650 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted)
3651 structured_logger.log(
3652 level="ERROR",
3653 message=f"MCP tool call failed: {tool_name_original}",
3654 component="tool_service",
3655 correlation_id=correlation_id,
3656 duration_ms=mcp_duration_ms,
3657 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error},
3658 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse"},
3659 )
3660 raise
3662 async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers):
3663 """Connect to an MCP server running with Streamable HTTP transport.
3665 Args:
3666 server_url: MCP Server URL
3667 headers: HTTP headers to include in the request
3669 Returns:
3670 ToolResult: Result of tool call
3672 Raises:
3673 ToolInvocationError: If the tool invocation fails during execution.
3674 ToolTimeoutError: If the tool invocation times out.
3675 BaseException: On connection or communication errors
3676 """
3677 # Get correlation ID for distributed tracing
3678 correlation_id = get_correlation_id()
3680 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions.
3681 # MCP SDK pins headers at transport creation, so adding per-request headers
3682 # would cause the first request's correlation ID to be reused for all
3683 # subsequent requests on the same pooled session. Correlation IDs are
3684 # still logged locally for tracing within the gateway.
3686 # Log MCP call start (using local variables)
3687 # Sanitize server_url to redact sensitive query params from logs
3688 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted)
3689 mcp_start_time = time.time()
3690 structured_logger.log(
3691 level="INFO",
3692 message=f"MCP tool call started: {tool_name_original}",
3693 component="tool_service",
3694 correlation_id=correlation_id,
3695 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "streamablehttp"},
3696 )
3698 try:
3699 # Use session pool if enabled for 10-20x latency improvement
3700 use_pool = False
3701 pool = None
3702 if settings.mcp_session_pool_enabled:
3703 try:
3704 pool = get_mcp_session_pool()
3705 use_pool = True
3706 except RuntimeError:
3707 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3708 pass
3710 if use_pool and pool is not None:
3711 # Pooled path: do NOT add per-request headers (they would be pinned)
3712 # Determine transport type based on current transport setting
3713 pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP
3714 async with pool.session(
3715 url=server_url,
3716 headers=headers,
3717 transport_type=pool_transport_type,
3718 httpx_client_factory=get_httpx_client_factory,
3719 user_identity=app_user_email,
3720 gateway_id=gateway_id_str,
3721 ) as pooled:
3722 tool_call_result = await asyncio.wait_for(pooled.session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3723 else:
3724 # Non-pooled path: safe to add per-request headers
3725 if correlation_id and headers:
3726 headers["X-Correlation-ID"] = correlation_id
3728 # Fallback to per-call sessions when pool disabled or not initialized
3729 async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id):
3730 async with ClientSession(read_stream, write_stream) as session:
3731 await session.initialize()
3732 tool_call_result = await asyncio.wait_for(session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3734 # Log successful MCP call
3735 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3736 structured_logger.log(
3737 level="INFO",
3738 message=f"MCP tool call completed: {tool_name_original}",
3739 component="tool_service",
3740 correlation_id=correlation_id,
3741 duration_ms=mcp_duration_ms,
3742 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "success": True},
3743 )
3745 return tool_call_result
3746 except (asyncio.TimeoutError, httpx.TimeoutException):
3747 # Handle timeout specifically - log and raise ToolInvocationError
3748 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3749 structured_logger.log(
3750 level="WARNING",
3751 message=f"MCP StreamableHTTP tool invocation timed out: {tool_name_original}",
3752 component="tool_service",
3753 correlation_id=correlation_id,
3754 duration_ms=mcp_duration_ms,
3755 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "timeout_seconds": effective_timeout},
3756 )
3758 # Manually trigger circuit breaker (or other plugins) on timeout
3759 try:
3760 # First-Party
3761 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3763 tool_timeout_counter.labels(tool_name=name).inc()
3764 except Exception as exc:
3765 logger.debug(
3766 "Failed to increment tool_timeout_counter for %s: %s",
3767 name,
3768 exc,
3769 exc_info=True,
3770 )
3772 if self._plugin_manager:
3773 if context_table:
3774 for ctx in context_table.values():
3775 ctx.set_state("cb_timeout_failure", True)
3777 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3778 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3779 await self._plugin_manager.invoke_hook(
3780 ToolHookType.TOOL_POST_INVOKE,
3781 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3782 global_context=global_context,
3783 local_contexts=context_table,
3784 violations_as_exceptions=False,
3785 )
3787 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3788 except BaseException as e:
3789 # Extract root cause from ExceptionGroup (Python 3.11+)
3790 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3791 root_cause = e
3792 if isinstance(e, BaseExceptionGroup):
3793 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3794 root_cause = root_cause.exceptions[0]
3795 # Log failed MCP call
3796 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3797 # Sanitize error message to prevent URL secrets from leaking in logs
3798 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted)
3799 structured_logger.log(
3800 level="ERROR",
3801 message=f"MCP tool call failed: {tool_name_original}",
3802 component="tool_service",
3803 correlation_id=correlation_id,
3804 duration_ms=mcp_duration_ms,
3805 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error},
3806 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp"},
3807 )
3808 raise
3810 # REMOVED: Redundant gateway query - gateway already eager-loaded via joinedload
3811 # tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id)...)
3813 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
3814 # Use pre-created Pydantic models from Phase 2 (no ORM access)
3815 if tool_metadata:
3816 global_context.metadata[TOOL_METADATA] = tool_metadata
3817 if gateway_metadata:
3818 global_context.metadata[GATEWAY_METADATA] = gateway_metadata
3819 pre_result, context_table = await self._plugin_manager.invoke_hook(
3820 ToolHookType.TOOL_PRE_INVOKE,
3821 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
3822 global_context=global_context,
3823 local_contexts=None,
3824 violations_as_exceptions=True,
3825 )
3826 if pre_result.modified_payload:
3827 payload = pre_result.modified_payload
3828 name = payload.name
3829 arguments = payload.args
3830 if payload.headers is not None:
3831 headers = payload.headers.model_dump()
3833 tool_call_result = ToolResult(content=[TextContent(text="", type="text")])
3834 if transport == "sse":
3835 tool_call_result = await connect_to_sse_server(gateway_url, headers=headers)
3836 elif transport == "streamablehttp":
3837 tool_call_result = await connect_to_streamablehttp_server(gateway_url, headers=headers)
3839 # In direct proxy mode, use the tool result as-is without splitting content
3840 if is_direct_proxy:
3841 tool_result = tool_call_result
3842 success = not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False)
3843 logger.debug(f"Direct proxy mode: using tool result as-is: {tool_result}")
3844 else:
3845 dump = tool_call_result.model_dump(by_alias=True, mode="json")
3846 logger.debug(f"Tool call result dump: {dump}")
3847 content = dump.get("content", [])
3848 # Accept both alias and pythonic names for structured content
3849 structured = dump.get("structuredContent") or dump.get("structured_content")
3850 filtered_response = extract_using_jq(content, tool_jsonpath_filter)
3852 is_err = getattr(tool_call_result, "is_error", None)
3853 if is_err is None:
3854 is_err = getattr(tool_call_result, "isError", False)
3855 tool_result = ToolResult(content=filtered_response, structured_content=structured, is_error=is_err, meta=getattr(tool_call_result, "meta", None))
3856 success = not is_err
3857 logger.debug(f"Final tool_result: {tool_result}")
3859 elif tool_integration_type == "A2A" and a2a_agent_endpoint_url:
3860 # A2A tool invocation using pre-extracted agent data (extracted in Phase 2 before db.close())
3861 headers = {"Content-Type": "application/json"}
3863 # Plugin hook: tool pre-invoke for A2A
3864 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
3865 if tool_metadata:
3866 global_context.metadata[TOOL_METADATA] = tool_metadata
3867 pre_result, context_table = await self._plugin_manager.invoke_hook(
3868 ToolHookType.TOOL_PRE_INVOKE,
3869 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
3870 global_context=global_context,
3871 local_contexts=context_table,
3872 violations_as_exceptions=True,
3873 )
3874 if pre_result.modified_payload:
3875 payload = pre_result.modified_payload
3876 name = payload.name
3877 arguments = payload.args
3878 if payload.headers is not None:
3879 headers = payload.headers.model_dump()
3881 # Build request data based on agent type
3882 endpoint_url = a2a_agent_endpoint_url
3883 if a2a_agent_type in ["generic", "jsonrpc"] or endpoint_url.endswith("/"):
3884 # JSONRPC agents: Convert flat query to nested message structure
3885 params = None
3886 if isinstance(arguments, dict) and "query" in arguments and isinstance(arguments["query"], str):
3887 message_id = f"admin-test-{int(time.time())}"
3888 # A2A v0.3.x: message.parts use "kind" (not "type").
3889 params = {
3890 "message": {
3891 "kind": "message",
3892 "messageId": message_id,
3893 "role": "user",
3894 "parts": [{"kind": "text", "text": arguments["query"]}],
3895 }
3896 }
3897 method = arguments.get("method", "message/send")
3898 else:
3899 params = arguments.get("params", arguments) if isinstance(arguments, dict) else arguments
3900 method = arguments.get("method", "message/send") if isinstance(arguments, dict) else "message/send"
3901 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1}
3902 else:
3903 # Custom agents: Pass parameters directly
3904 params = arguments if isinstance(arguments, dict) else {}
3905 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": a2a_agent_protocol_version}
3907 # Add authentication
3908 if a2a_agent_auth_type == "api_key" and a2a_agent_auth_value:
3909 headers["Authorization"] = f"Bearer {a2a_agent_auth_value}"
3910 elif a2a_agent_auth_type == "bearer" and a2a_agent_auth_value:
3911 headers["Authorization"] = f"Bearer {a2a_agent_auth_value}"
3912 elif a2a_agent_auth_type == "query_param" and a2a_agent_auth_query_params:
3913 auth_query_params_decrypted: dict[str, str] = {}
3914 for param_key, encrypted_value in a2a_agent_auth_query_params.items():
3915 if encrypted_value:
3916 try:
3917 decrypted = decode_auth(encrypted_value)
3918 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3919 except Exception:
3920 logger.debug(f"Failed to decrypt query param for key '{param_key}'")
3921 if auth_query_params_decrypted:
3922 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted)
3924 # Make HTTP request with timeout enforcement
3925 logger.info(f"Calling A2A agent '{a2a_agent_name}' at {endpoint_url}")
3926 a2a_start_time = time.time()
3927 try:
3928 http_response = await asyncio.wait_for(self._http_client.post(endpoint_url, json=request_data, headers=headers), timeout=effective_timeout)
3929 except (asyncio.TimeoutError, httpx.TimeoutException):
3930 a2a_elapsed_ms = (time.time() - a2a_start_time) * 1000
3931 structured_logger.log(
3932 level="WARNING",
3933 message=f"A2A tool invocation timed out: {name}",
3934 component="tool_service",
3935 correlation_id=get_correlation_id(),
3936 duration_ms=a2a_elapsed_ms,
3937 metadata={"event": "tool_timeout", "tool_name": name, "a2a_agent": a2a_agent_name, "timeout_seconds": effective_timeout},
3938 )
3940 # Increment timeout counter
3941 try:
3942 # First-Party
3943 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3945 tool_timeout_counter.labels(tool_name=name).inc()
3946 except Exception as exc:
3947 logger.debug("Failed to increment tool_timeout_counter for %s: %s", name, exc, exc_info=True)
3949 # Trigger circuit breaker on timeout
3950 if self._plugin_manager:
3951 if context_table:
3952 for ctx in context_table.values():
3953 ctx.set_state("cb_timeout_failure", True)
3955 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3956 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3957 await self._plugin_manager.invoke_hook(
3958 ToolHookType.TOOL_POST_INVOKE,
3959 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3960 global_context=global_context,
3961 local_contexts=context_table,
3962 violations_as_exceptions=False,
3963 )
3965 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3967 if http_response.status_code == 200:
3968 response_data = http_response.json()
3969 if isinstance(response_data, dict) and "response" in response_data:
3970 val = response_data["response"]
3971 content = [TextContent(type="text", text=val if isinstance(val, str) else orjson.dumps(val).decode())]
3972 else:
3973 content = [TextContent(type="text", text=response_data if isinstance(response_data, str) else orjson.dumps(response_data).decode())]
3974 tool_result = ToolResult(content=content, is_error=False)
3975 success = True
3976 else:
3977 error_message = f"HTTP {http_response.status_code}: {http_response.text}"
3978 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")]
3979 tool_result = ToolResult(content=content, is_error=True)
3980 else:
3981 tool_result = ToolResult(content=[TextContent(type="text", text="Invalid tool type")], is_error=True)
3983 # Plugin hook: tool post-invoke
3984 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3985 post_result, _ = await self._plugin_manager.invoke_hook(
3986 ToolHookType.TOOL_POST_INVOKE,
3987 payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)),
3988 global_context=global_context,
3989 local_contexts=context_table,
3990 violations_as_exceptions=True,
3991 )
3992 # Use modified payload if provided
3993 if post_result.modified_payload:
3994 # Reconstruct ToolResult from modified result
3995 modified_result = post_result.modified_payload.result
3996 if isinstance(modified_result, dict) and "content" in modified_result:
3997 # Safely obtain structured content using .get() to avoid KeyError when
3998 # plugins provide only the content without structured content fields.
3999 structured = modified_result.get("structuredContent") if "structuredContent" in modified_result else modified_result.get("structured_content")
4001 tool_result = ToolResult(content=modified_result["content"], structured_content=structured)
4002 else:
4003 # If result is not in expected format, convert it to text content
4004 try:
4005 tool_result = ToolResult(content=[TextContent(type="text", text=modified_result if isinstance(modified_result, str) else orjson.dumps(modified_result).decode())])
4006 except Exception:
4007 tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))])
4009 return tool_result
4010 except (PluginError, PluginViolationError):
4011 raise
4012 except ToolTimeoutError as e:
4013 # ToolTimeoutError is raised by timeout handlers which already called tool_post_invoke
4014 # Re-raise without calling post_invoke again to avoid double-counting failures
4015 # But DO set error_message and span attributes for observability
4016 error_message = str(e)
4017 if span:
4018 span.set_attribute("error", True)
4019 span.set_attribute("error.message", error_message)
4020 raise
4021 except BaseException as e:
4022 # Extract root cause from ExceptionGroup (Python 3.11+)
4023 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
4024 root_cause = e
4025 if isinstance(e, BaseExceptionGroup):
4026 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
4027 root_cause = root_cause.exceptions[0]
4028 error_message = str(root_cause)
4029 # Set span error status
4030 if span:
4031 span.set_attribute("error", True)
4032 span.set_attribute("error.message", error_message)
4034 # Notify plugins of the failure so circuit breaker can track it
4035 # This ensures HTTP 4xx/5xx errors and MCP failures are counted
4036 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
4037 try:
4038 exception_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation failed: {error_message}")], is_error=True)
4039 await self._plugin_manager.invoke_hook(
4040 ToolHookType.TOOL_POST_INVOKE,
4041 payload=ToolPostInvokePayload(name=name, result=exception_error_result.model_dump(by_alias=True)),
4042 global_context=global_context,
4043 local_contexts=context_table,
4044 violations_as_exceptions=False, # Don't let plugin errors mask the original exception
4045 )
4046 except Exception as plugin_exc:
4047 logger.debug("Failed to invoke post-invoke plugins on exception: %s", plugin_exc)
4049 raise ToolInvocationError(f"Tool invocation failed: {error_message}")
4050 finally:
4051 # Calculate duration
4052 duration_ms = (time.monotonic() - start_time) * 1000
4054 # End database span for observability_spans table
4055 # Use commit=False since fresh_db_session() handles commits on exit
4056 if db_span_id and observability_service and not db_span_ended:
4057 try:
4058 with fresh_db_session() as span_db:
4059 observability_service.end_span(
4060 db=span_db,
4061 span_id=db_span_id,
4062 status="ok" if success else "error",
4063 status_message=error_message if error_message else None,
4064 attributes={
4065 "success": success,
4066 "duration_ms": duration_ms,
4067 },
4068 commit=False,
4069 )
4070 db_span_ended = True
4071 logger.debug(f"✓ Ended tool.invoke span: {db_span_id}")
4072 except Exception as e:
4073 logger.warning(f"Failed to end observability span for tool invocation: {e}")
4075 # Add final span attributes for OpenTelemetry
4076 if span:
4077 span.set_attribute("success", success)
4078 span.set_attribute("duration.ms", duration_ms)
4080 # ═══════════════════════════════════════════════════════════════════════════
4081 # PHASE 4: Record metrics via buffered service (batches writes for performance)
4082 # ═══════════════════════════════════════════════════════════════════════════
4083 # Only record metrics if tool_id is valid (skip for direct_proxy mode)
4084 if tool_id:
4085 try:
4086 metrics_buffer.record_tool_metric(
4087 tool_id=tool_id,
4088 start_time=start_time,
4089 success=success,
4090 error_message=error_message,
4091 )
4092 except Exception as metric_error:
4093 logger.warning(f"Failed to record tool metric: {metric_error}")
4095 # Record server metrics ONLY when invoked through a specific virtual server
4096 # When server_id is provided, it means the tool was called via a virtual server endpoint
4097 # Direct tool calls via /rpc should NOT populate server metrics
4098 if tool_id and server_id:
4099 try:
4100 # Record server metric only for the specific virtual server being accessed
4101 metrics_buffer.record_server_metric(
4102 server_id=server_id,
4103 start_time=start_time,
4104 success=success,
4105 error_message=error_message,
4106 )
4107 except Exception as metric_error:
4108 logger.warning(f"Failed to record server metric: {metric_error}")
4110 # Log structured message with performance tracking (using local variables)
4111 if success:
4112 structured_logger.info(
4113 f"Tool '{name}' invoked successfully",
4114 user_id=app_user_email,
4115 resource_type="tool",
4116 resource_id=tool_id,
4117 resource_action="invoke",
4118 duration_ms=duration_ms,
4119 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "arguments_count": len(arguments) if arguments else 0},
4120 )
4121 else:
4122 structured_logger.error(
4123 f"Tool '{name}' invocation failed",
4124 error=Exception(error_message) if error_message else None,
4125 user_id=app_user_email,
4126 resource_type="tool",
4127 resource_id=tool_id,
4128 resource_action="invoke",
4129 duration_ms=duration_ms,
4130 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "error_message": error_message},
4131 )
4133 # Track performance with threshold checking
4134 with perf_tracker.track_operation("tool_invocation", name):
4135 pass # Duration already captured above
4137 @staticmethod
4138 def _check_tool_name_conflict(db: Session, custom_name: str, visibility: str, tool_id: str, team_id: Optional[str] = None, owner_email: Optional[str] = None) -> None:
4139 """Raise ToolNameConflictError if another tool with the same name exists in the target visibility scope.
4141 Args:
4142 db: The SQLAlchemy database session.
4143 custom_name: The custom name to check for conflicts.
4144 visibility: The target visibility scope (``public``, ``team``, or ``private``).
4145 tool_id: The ID of the tool being updated (excluded from the conflict search).
4146 team_id: Required when *visibility* is ``team``; scopes the uniqueness check to this team.
4147 owner_email: Required when *visibility* is ``private``; scopes the uniqueness check to this owner.
4149 Raises:
4150 ToolNameConflictError: If a conflicting tool already exists in the target scope.
4151 """
4152 if visibility == "public":
4153 existing_tool = get_for_update(
4154 db,
4155 DbTool,
4156 where=and_(
4157 DbTool.custom_name == custom_name,
4158 DbTool.visibility == "public",
4159 DbTool.id != tool_id,
4160 ),
4161 )
4162 elif visibility == "team" and team_id:
4163 existing_tool = get_for_update(
4164 db,
4165 DbTool,
4166 where=and_(
4167 DbTool.custom_name == custom_name,
4168 DbTool.visibility == "team",
4169 DbTool.team_id == team_id,
4170 DbTool.id != tool_id,
4171 ),
4172 )
4173 elif visibility == "private" and owner_email:
4174 existing_tool = get_for_update(
4175 db,
4176 DbTool,
4177 where=and_(
4178 DbTool.custom_name == custom_name,
4179 DbTool.visibility == "private",
4180 DbTool.owner_email == owner_email,
4181 DbTool.id != tool_id,
4182 ),
4183 )
4184 else:
4185 logger.warning("Skipping conflict check for tool %s: visibility=%r requires %s but none provided", tool_id, visibility, "team_id" if visibility == "team" else "owner_email")
4186 return
4187 if existing_tool:
4188 raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
4190 async def update_tool(
4191 self,
4192 db: Session,
4193 tool_id: str,
4194 tool_update: ToolUpdate,
4195 modified_by: Optional[str] = None,
4196 modified_from_ip: Optional[str] = None,
4197 modified_via: Optional[str] = None,
4198 modified_user_agent: Optional[str] = None,
4199 user_email: Optional[str] = None,
4200 ) -> ToolRead:
4201 """
4202 Update an existing tool.
4204 Args:
4205 db (Session): The SQLAlchemy database session.
4206 tool_id (str): The unique identifier of the tool.
4207 tool_update (ToolUpdate): Tool update schema with new data.
4208 modified_by (Optional[str]): Username who modified this tool.
4209 modified_from_ip (Optional[str]): IP address of modifier.
4210 modified_via (Optional[str]): Modification method (ui, api).
4211 modified_user_agent (Optional[str]): User agent of modification request.
4212 user_email (Optional[str]): Email of user performing update (for ownership check).
4214 Returns:
4215 The updated ToolRead object.
4217 Raises:
4218 ToolNotFoundError: If the tool is not found.
4219 PermissionError: If user doesn't own the tool.
4220 IntegrityError: If there is a database integrity error.
4221 ToolNameConflictError: If a tool with the same name already exists.
4222 ToolError: For other update errors.
4224 Examples:
4225 >>> from mcpgateway.services.tool_service import ToolService
4226 >>> from unittest.mock import MagicMock, AsyncMock
4227 >>> from mcpgateway.schemas import ToolRead
4228 >>> service = ToolService()
4229 >>> db = MagicMock()
4230 >>> tool = MagicMock()
4231 >>> db.get.return_value = tool
4232 >>> db.commit = MagicMock()
4233 >>> db.refresh = MagicMock()
4234 >>> db.execute.return_value.scalar_one_or_none.return_value = None
4235 >>> service._notify_tool_updated = AsyncMock()
4236 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
4237 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
4238 >>> import asyncio
4239 >>> asyncio.run(service.update_tool(db, 'tool_id', MagicMock()))
4240 'tool_read'
4241 """
4242 try:
4243 tool = get_for_update(db, DbTool, tool_id)
4245 if not tool:
4246 raise ToolNotFoundError(f"Tool not found: {tool_id}")
4248 old_tool_name = tool.name
4249 old_gateway_id = tool.gateway_id
4251 # Check ownership if user_email provided
4252 if user_email:
4253 # First-Party
4254 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
4256 permission_service = PermissionService(db)
4257 if not await permission_service.check_resource_ownership(user_email, tool):
4258 raise PermissionError("Only the owner can update this tool")
4260 # Track whether a name change occurred (before tool.name is mutated)
4261 name_is_changing = bool(tool_update.name and tool_update.name != tool.name)
4263 # Check for name change and ensure uniqueness
4264 if name_is_changing:
4265 # Always derive ownership fields from the DB record — never trust client-provided team_id/owner_email
4266 tool_visibility_ref = tool.visibility if tool_update.visibility is None else tool_update.visibility.lower()
4267 if tool_update.custom_name is not None:
4268 custom_name_ref = tool_update.custom_name
4269 elif tool.name == tool.custom_name:
4270 custom_name_ref = tool_update.name # custom_name will track the rename
4271 else:
4272 custom_name_ref = tool.custom_name # custom_name stays unchanged
4273 self._check_tool_name_conflict(db, custom_name_ref, tool_visibility_ref, tool.id, team_id=tool.team_id, owner_email=tool.owner_email)
4274 if tool_update.custom_name is None and tool.name == tool.custom_name:
4275 tool.custom_name = tool_update.name
4276 tool.name = tool_update.name
4278 # Check for conflicts when visibility changes without a name change
4279 if tool_update.visibility is not None and tool_update.visibility.lower() != tool.visibility and not name_is_changing:
4280 new_visibility = tool_update.visibility.lower()
4281 self._check_tool_name_conflict(db, tool.custom_name, new_visibility, tool.id, team_id=tool.team_id, owner_email=tool.owner_email)
4283 if tool_update.custom_name is not None:
4284 tool.custom_name = tool_update.custom_name
4285 if tool_update.displayName is not None:
4286 tool.display_name = tool_update.displayName
4287 if tool_update.url is not None:
4288 tool.url = str(tool_update.url)
4289 if tool_update.description is not None:
4290 tool.description = tool_update.description
4291 if tool_update.integration_type is not None:
4292 tool.integration_type = tool_update.integration_type
4293 if tool_update.request_type is not None:
4294 tool.request_type = tool_update.request_type
4295 if tool_update.headers is not None:
4296 tool.headers = _protect_tool_headers_for_storage(tool_update.headers, existing_headers=tool.headers)
4297 if tool_update.input_schema is not None:
4298 tool.input_schema = tool_update.input_schema
4299 if tool_update.output_schema is not None:
4300 tool.output_schema = tool_update.output_schema
4301 if tool_update.annotations is not None:
4302 tool.annotations = tool_update.annotations
4303 if tool_update.jsonpath_filter is not None:
4304 tool.jsonpath_filter = tool_update.jsonpath_filter
4305 if tool_update.visibility is not None:
4306 tool.visibility = tool_update.visibility
4308 if tool_update.auth is not None:
4309 if tool_update.auth.auth_type is not None:
4310 tool.auth_type = tool_update.auth.auth_type
4311 if tool_update.auth.auth_value is not None:
4312 tool.auth_value = tool_update.auth.auth_value
4314 # Update tags if provided
4315 if tool_update.tags is not None:
4316 tool.tags = tool_update.tags
4318 # Update modification metadata
4319 if modified_by is not None:
4320 tool.modified_by = modified_by
4321 if modified_from_ip is not None:
4322 tool.modified_from_ip = modified_from_ip
4323 if modified_via is not None:
4324 tool.modified_via = modified_via
4325 if modified_user_agent is not None:
4326 tool.modified_user_agent = modified_user_agent
4328 # Increment version
4329 if hasattr(tool, "version") and tool.version is not None:
4330 tool.version += 1
4331 else:
4332 tool.version = 1
4333 logger.info(f"Update tool: {tool.name} (output_schema: {tool.output_schema})")
4335 tool.updated_at = datetime.now(timezone.utc)
4336 db.commit()
4337 db.refresh(tool)
4338 await self._notify_tool_updated(tool)
4339 logger.info(f"Updated tool: {tool.name}")
4341 # Structured logging: Audit trail for tool update
4342 changes = []
4343 if tool_update.name:
4344 changes.append(f"name: {tool_update.name}")
4345 if tool_update.visibility:
4346 changes.append(f"visibility: {tool_update.visibility}")
4347 if tool_update.description:
4348 changes.append("description updated")
4350 audit_trail.log_action(
4351 user_id=user_email or modified_by or "system",
4352 action="update_tool",
4353 resource_type="tool",
4354 resource_id=tool.id,
4355 resource_name=tool.name,
4356 user_email=user_email,
4357 team_id=tool.team_id,
4358 client_ip=modified_from_ip,
4359 user_agent=modified_user_agent,
4360 new_values={
4361 "name": tool.name,
4362 "display_name": tool.display_name,
4363 "version": tool.version,
4364 },
4365 context={
4366 "modified_via": modified_via,
4367 "changes": ", ".join(changes) if changes else "metadata only",
4368 },
4369 db=db,
4370 )
4372 # Structured logging: Log successful tool update
4373 structured_logger.log(
4374 level="INFO",
4375 message="Tool updated successfully",
4376 event_type="tool_updated",
4377 component="tool_service",
4378 user_id=modified_by,
4379 user_email=user_email,
4380 team_id=tool.team_id,
4381 resource_type="tool",
4382 resource_id=tool.id,
4383 custom_fields={
4384 "tool_name": tool.name,
4385 "version": tool.version,
4386 },
4387 )
4389 # Invalidate cache after successful update
4390 cache = _get_registry_cache()
4391 await cache.invalidate_tools()
4392 tool_lookup_cache = _get_tool_lookup_cache()
4393 await tool_lookup_cache.invalidate(old_tool_name, gateway_id=str(old_gateway_id) if old_gateway_id else None)
4394 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
4395 # Also invalidate tags cache since tool tags may have changed
4396 # First-Party
4397 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
4399 await admin_stats_cache.invalidate_tags()
4401 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None))
4402 except PermissionError as pe:
4403 db.rollback()
4405 # Structured logging: Log permission error
4406 structured_logger.log(
4407 level="WARNING",
4408 message="Tool update failed due to permission error",
4409 event_type="tool_update_permission_denied",
4410 component="tool_service",
4411 user_email=user_email,
4412 resource_type="tool",
4413 resource_id=tool_id,
4414 error=pe,
4415 )
4416 raise
4417 except IntegrityError as ie:
4418 db.rollback()
4419 logger.error(f"IntegrityError during tool update: {ie}")
4421 # Structured logging: Log database integrity error
4422 structured_logger.log(
4423 level="ERROR",
4424 message="Tool update failed due to database integrity error",
4425 event_type="tool_update_failed",
4426 component="tool_service",
4427 user_id=modified_by,
4428 user_email=user_email,
4429 resource_type="tool",
4430 resource_id=tool_id,
4431 error=ie,
4432 )
4433 raise ie
4434 except ToolNotFoundError as tnfe:
4435 db.rollback()
4436 logger.error(f"Tool not found during update: {tnfe}")
4438 # Structured logging: Log not found error
4439 structured_logger.log(
4440 level="ERROR",
4441 message="Tool update failed - tool not found",
4442 event_type="tool_not_found",
4443 component="tool_service",
4444 user_email=user_email,
4445 resource_type="tool",
4446 resource_id=tool_id,
4447 error=tnfe,
4448 )
4449 raise tnfe
4450 except ToolNameConflictError as tnce:
4451 db.rollback()
4452 logger.error(f"Tool name conflict during update: {tnce}")
4454 # Structured logging: Log name conflict error
4455 structured_logger.log(
4456 level="WARNING",
4457 message="Tool update failed due to name conflict",
4458 event_type="tool_name_conflict",
4459 component="tool_service",
4460 user_id=modified_by,
4461 user_email=user_email,
4462 resource_type="tool",
4463 resource_id=tool_id,
4464 error=tnce,
4465 )
4466 raise tnce
4467 except Exception as ex:
4468 db.rollback()
4470 # Structured logging: Log generic tool update failure
4471 structured_logger.log(
4472 level="ERROR",
4473 message="Tool update failed",
4474 event_type="tool_update_failed",
4475 component="tool_service",
4476 user_id=modified_by,
4477 user_email=user_email,
4478 resource_type="tool",
4479 resource_id=tool_id,
4480 error=ex,
4481 )
4482 raise ToolError(f"Failed to update tool: {str(ex)}")
4484 async def _notify_tool_updated(self, tool: DbTool) -> None:
4485 """
4486 Notify subscribers of tool update.
4488 Args:
4489 tool: Tool updated
4490 """
4491 event = {
4492 "type": "tool_updated",
4493 "data": {"id": tool.id, "name": tool.name, "url": tool.url, "description": tool.description, "enabled": tool.enabled},
4494 "timestamp": datetime.now(timezone.utc).isoformat(),
4495 }
4496 await self._publish_event(event)
4498 async def _notify_tool_activated(self, tool: DbTool) -> None:
4499 """
4500 Notify subscribers of tool activation.
4502 Args:
4503 tool: Tool activated
4504 """
4505 event = {
4506 "type": "tool_activated",
4507 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable},
4508 "timestamp": datetime.now(timezone.utc).isoformat(),
4509 }
4510 await self._publish_event(event)
4512 async def _notify_tool_deactivated(self, tool: DbTool) -> None:
4513 """
4514 Notify subscribers of tool deactivation.
4516 Args:
4517 tool: Tool deactivated
4518 """
4519 event = {
4520 "type": "tool_deactivated",
4521 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable},
4522 "timestamp": datetime.now(timezone.utc).isoformat(),
4523 }
4524 await self._publish_event(event)
4526 async def _notify_tool_offline(self, tool: DbTool) -> None:
4527 """
4528 Notify subscribers that tool is offline.
4530 Args:
4531 tool: Tool database object
4532 """
4533 event = {
4534 "type": "tool_offline",
4535 "data": {
4536 "id": tool.id,
4537 "name": tool.name,
4538 "enabled": True,
4539 "reachable": False,
4540 },
4541 "timestamp": datetime.now(timezone.utc).isoformat(),
4542 }
4543 await self._publish_event(event)
4545 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None:
4546 """
4547 Notify subscribers of tool deletion.
4549 Args:
4550 tool_info: Dictionary on tool deleted
4551 """
4552 event = {
4553 "type": "tool_deleted",
4554 "data": tool_info,
4555 "timestamp": datetime.now(timezone.utc).isoformat(),
4556 }
4557 await self._publish_event(event)
4559 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
4560 """Subscribe to tool events via the EventService.
4562 Yields:
4563 Tool event messages.
4564 """
4565 async for event in self._event_service.subscribe_events():
4566 yield event
4568 async def _notify_tool_added(self, tool: DbTool) -> None:
4569 """
4570 Notify subscribers of tool addition.
4572 Args:
4573 tool: Tool added
4574 """
4575 event = {
4576 "type": "tool_added",
4577 "data": {
4578 "id": tool.id,
4579 "name": tool.name,
4580 "url": tool.url,
4581 "description": tool.description,
4582 "enabled": tool.enabled,
4583 },
4584 "timestamp": datetime.now(timezone.utc).isoformat(),
4585 }
4586 await self._publish_event(event)
4588 async def _notify_tool_removed(self, tool: DbTool) -> None:
4589 """
4590 Notify subscribers of tool removal (soft delete/deactivation).
4592 Args:
4593 tool: Tool removed
4594 """
4595 event = {
4596 "type": "tool_removed",
4597 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled},
4598 "timestamp": datetime.now(timezone.utc).isoformat(),
4599 }
4600 await self._publish_event(event)
4602 async def _publish_event(self, event: Dict[str, Any]) -> None:
4603 """
4604 Publish event to all subscribers via the EventService.
4606 Args:
4607 event: Event to publish
4608 """
4609 await self._event_service.publish_event(event)
4611 async def _validate_tool_url(self, url: str) -> None:
4612 """Validate tool URL is accessible.
4614 Args:
4615 url: URL to validate.
4617 Raises:
4618 ToolValidationError: If URL validation fails.
4619 """
4620 try:
4621 response = await self._http_client.get(url)
4622 response.raise_for_status()
4623 except Exception as e:
4624 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}")
4626 async def _check_tool_health(self, tool: DbTool) -> bool:
4627 """Check if tool endpoint is healthy.
4629 Args:
4630 tool: Tool to check.
4632 Returns:
4633 True if tool is healthy.
4634 """
4635 try:
4636 response = await self._http_client.get(tool.url)
4637 return response.is_success
4638 except Exception:
4639 return False
4641 # async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]:
4642 # """Generate tool events for SSE.
4644 # Yields:
4645 # Tool events.
4646 # """
4647 # queue: asyncio.Queue = asyncio.Queue()
4648 # self._event_subscribers.append(queue)
4649 # try:
4650 # while True:
4651 # event = await queue.get()
4652 # yield event
4653 # finally:
4654 # self._event_subscribers.remove(queue)
4656 # --- Metrics ---
4657 async def aggregate_metrics(self, db: Session) -> ToolMetrics:
4658 """
4659 Aggregate metrics for all tool invocations across all tools.
4661 Combines recent raw metrics (within retention period) with historical
4662 hourly rollups for complete historical coverage. Uses in-memory caching
4663 (10s TTL) to reduce database load under high request rates.
4665 Args:
4666 db: Database session
4668 Returns:
4669 ToolMetrics: Aggregated metrics computed from raw ToolMetric + ToolMetricsHourly.
4671 Examples:
4672 >>> from mcpgateway.services.tool_service import ToolService
4673 >>> service = ToolService()
4674 >>> # Method exists and is callable
4675 >>> callable(service.aggregate_metrics)
4676 True
4677 """
4678 # Check cache first (if enabled)
4679 # First-Party
4680 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
4682 if is_cache_enabled():
4683 cached = metrics_cache.get("tools")
4684 if cached is not None:
4685 return ToolMetrics(**cached)
4687 # Use combined raw + rollup query for full historical coverage
4688 # First-Party
4689 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel
4691 result = aggregate_metrics_combined(db, "tool")
4692 metrics = ToolMetrics(
4693 total_executions=result.total_executions,
4694 successful_executions=result.successful_executions,
4695 failed_executions=result.failed_executions,
4696 failure_rate=result.failure_rate,
4697 min_response_time=result.min_response_time,
4698 max_response_time=result.max_response_time,
4699 avg_response_time=result.avg_response_time,
4700 last_execution_time=result.last_execution_time,
4701 )
4703 # Cache the result as dict for serialization compatibility (if enabled)
4704 if is_cache_enabled():
4705 metrics_cache.set("tools", metrics.model_dump())
4707 return metrics
4709 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None:
4710 """
4711 Reset all tool metrics by deleting raw and hourly rollup records.
4713 Args:
4714 db: Database session
4715 tool_id: Optional tool ID to reset metrics for a specific tool
4717 Examples:
4718 >>> from mcpgateway.services.tool_service import ToolService
4719 >>> from unittest.mock import MagicMock
4720 >>> service = ToolService()
4721 >>> db = MagicMock()
4722 >>> db.execute = MagicMock()
4723 >>> db.commit = MagicMock()
4724 >>> import asyncio
4725 >>> asyncio.run(service.reset_metrics(db))
4726 """
4728 if tool_id:
4729 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id))
4730 db.execute(delete(ToolMetricsHourly).where(ToolMetricsHourly.tool_id == tool_id))
4731 else:
4732 db.execute(delete(ToolMetric))
4733 db.execute(delete(ToolMetricsHourly))
4734 db.commit()
4736 # Invalidate metrics cache
4737 # First-Party
4738 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
4740 metrics_cache.invalidate("tools")
4741 metrics_cache.invalidate_prefix("top_tools:")
4743 async def create_tool_from_a2a_agent(
4744 self,
4745 db: Session,
4746 agent: DbA2AAgent,
4747 created_by: Optional[str] = None,
4748 created_from_ip: Optional[str] = None,
4749 created_via: Optional[str] = None,
4750 created_user_agent: Optional[str] = None,
4751 ) -> DbTool:
4752 """Create a tool entry from an A2A agent for virtual server integration.
4754 Args:
4755 db: Database session.
4756 agent: A2A agent to create tool from.
4757 created_by: Username who created this tool.
4758 created_from_ip: IP address of creator.
4759 created_via: Creation method.
4760 created_user_agent: User agent of creation request.
4762 Returns:
4763 The created tool database object.
4765 Raises:
4766 ToolNameConflictError: If a tool with the same name already exists.
4767 """
4768 # Check if tool already exists for this agent
4769 tool_name = f"a2a_{agent.slug}"
4770 existing_query = select(DbTool).where(DbTool.original_name == tool_name)
4771 existing_tool = db.execute(existing_query).scalar_one_or_none()
4773 if existing_tool:
4774 # Tool already exists, return it
4775 return existing_tool
4777 # Create tool entry for the A2A agent
4778 logger.debug(f"agent.tags: {agent.tags} for agent: {agent.name} (ID: {agent.id})")
4780 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..},
4781 # extract the human-friendly label. If tags are already strings, keep them.
4782 normalized_tags: list[str] = []
4783 for t in agent.tags or []:
4784 if isinstance(t, dict):
4785 # Prefer 'label', fall back to 'id' or stringified dict
4786 normalized_tags.append(t.get("label") or t.get("id") or str(t))
4787 elif hasattr(t, "label"):
4788 normalized_tags.append(getattr(t, "label"))
4789 else:
4790 normalized_tags.append(str(t))
4792 # Ensure we include identifying A2A tags
4793 normalized_tags = normalized_tags + ["a2a", "agent"]
4795 tool_data = ToolCreate(
4796 name=tool_name,
4797 displayName=generate_display_name(agent.name),
4798 url=agent.endpoint_url,
4799 description=f"A2A Agent: {agent.description or agent.name}",
4800 integration_type="A2A", # Special integration type for A2A agents
4801 request_type="POST",
4802 input_schema={
4803 "type": "object",
4804 "properties": {
4805 "query": {"type": "string", "description": "User query", "default": "Hello from ContextForge Admin UI test!"},
4806 },
4807 "required": ["query"],
4808 },
4809 allow_auto=True,
4810 annotations={
4811 "title": f"A2A Agent: {agent.name}",
4812 "a2a_agent_id": agent.id,
4813 "a2a_agent_type": agent.agent_type,
4814 },
4815 auth_type=agent.auth_type,
4816 auth_value=agent.auth_value,
4817 tags=normalized_tags,
4818 )
4820 # Default to "public" visibility if agent visibility is not set
4821 # This ensures A2A tools are visible in the Global Tools Tab
4822 tool_visibility = agent.visibility or "public"
4824 tool_read = await self.register_tool(
4825 db,
4826 tool_data,
4827 created_by=created_by,
4828 created_from_ip=created_from_ip,
4829 created_via=created_via or "a2a_integration",
4830 created_user_agent=created_user_agent,
4831 team_id=agent.team_id,
4832 owner_email=agent.owner_email,
4833 visibility=tool_visibility,
4834 )
4836 # Return the DbTool object for relationship assignment
4837 tool_db = db.get(DbTool, tool_read.id)
4838 return tool_db
4840 async def update_tool_from_a2a_agent(
4841 self,
4842 db: Session,
4843 agent: DbA2AAgent,
4844 modified_by: Optional[str] = None,
4845 modified_from_ip: Optional[str] = None,
4846 modified_via: Optional[str] = None,
4847 modified_user_agent: Optional[str] = None,
4848 ) -> Optional[ToolRead]:
4849 """Update the tool associated with an A2A agent when the agent is updated.
4851 Args:
4852 db: Database session.
4853 agent: Updated A2A agent.
4854 modified_by: Username who modified this tool.
4855 modified_from_ip: IP address of modifier.
4856 modified_via: Modification method.
4857 modified_user_agent: User agent of modification request.
4859 Returns:
4860 The updated tool, or None if no associated tool exists.
4861 """
4862 # Use the tool_id from the agent for efficient lookup
4863 if not agent.tool_id:
4864 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool update")
4865 return None
4867 tool = db.get(DbTool, agent.tool_id)
4868 if not tool:
4869 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}, resetting tool_id")
4870 agent.tool_id = None
4871 db.commit()
4872 return None
4874 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..},
4875 # extract the human-friendly label. If tags are already strings, keep them.
4876 normalized_tags: list[str] = []
4877 for t in agent.tags or []:
4878 if isinstance(t, dict):
4879 # Prefer 'label', fall back to 'id' or stringified dict
4880 normalized_tags.append(t.get("label") or t.get("id") or str(t))
4881 elif hasattr(t, "label"):
4882 normalized_tags.append(getattr(t, "label"))
4883 else:
4884 normalized_tags.append(str(t))
4886 # Ensure we include identifying A2A tags
4887 normalized_tags = normalized_tags + ["a2a", "agent"]
4889 # Prepare update data matching the agent's current state
4890 # IMPORTANT: Preserve the existing tool's visibility to avoid unintentionally
4891 # making private/team tools public (ToolUpdate defaults to "public")
4892 # Note: team_id is not a field on ToolUpdate schema, so team assignment is preserved
4893 # implicitly by not changing visibility (team tools stay team-scoped)
4894 new_tool_name = f"a2a_{agent.slug}"
4895 tool_update = ToolUpdate(
4896 name=new_tool_name,
4897 custom_name=new_tool_name, # Also set custom_name to ensure name update works
4898 displayName=generate_display_name(agent.name),
4899 url=agent.endpoint_url,
4900 description=f"A2A Agent: {agent.description or agent.name}",
4901 auth=AuthenticationValues(auth_type=agent.auth_type, auth_value=agent.auth_value) if agent.auth_type else None,
4902 tags=normalized_tags,
4903 visibility=tool.visibility, # Preserve existing visibility
4904 )
4906 # Update the tool
4907 return await self.update_tool(
4908 db=db,
4909 tool_id=tool.id,
4910 tool_update=tool_update,
4911 modified_by=modified_by,
4912 modified_from_ip=modified_from_ip,
4913 modified_via=modified_via or "a2a_sync",
4914 modified_user_agent=modified_user_agent,
4915 )
4917 async def delete_tool_from_a2a_agent(self, db: Session, agent: DbA2AAgent, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
4918 """Delete the tool associated with an A2A agent when the agent is deleted.
4920 Args:
4921 db: Database session.
4922 agent: The A2A agent being deleted.
4923 user_email: Email of user performing delete (for ownership check).
4924 purge_metrics: If True, delete raw + rollup metrics for this tool.
4925 """
4926 # Use the tool_id from the agent for efficient lookup
4927 if not agent.tool_id:
4928 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool deletion")
4929 return
4931 tool = db.get(DbTool, agent.tool_id)
4932 if not tool:
4933 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}")
4934 return
4936 # Delete the tool
4937 await self.delete_tool(db=db, tool_id=tool.id, user_email=user_email, purge_metrics=purge_metrics)
4938 logger.info(f"Deleted tool {tool.id} associated with A2A agent {agent.id}")
4940 async def _invoke_a2a_tool(self, db: Session, tool: DbTool, arguments: Dict[str, Any]) -> ToolResult:
4941 """Invoke an A2A agent through its corresponding tool.
4943 Args:
4944 db: Database session.
4945 tool: The tool record that represents the A2A agent.
4946 arguments: Tool arguments.
4948 Returns:
4949 Tool result from A2A agent invocation.
4951 Raises:
4952 ToolNotFoundError: If the A2A agent is not found.
4953 """
4955 # Extract A2A agent ID from tool annotations
4956 agent_id = tool.annotations.get("a2a_agent_id")
4957 if not agent_id:
4958 raise ToolNotFoundError(f"A2A tool '{tool.name}' missing agent ID in annotations")
4960 # Get the A2A agent
4961 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
4962 agent = db.execute(agent_query).scalar_one_or_none()
4964 if not agent:
4965 raise ToolNotFoundError(f"A2A agent not found for tool '{tool.name}' (agent ID: {agent_id})")
4967 if not agent.enabled:
4968 raise ToolNotFoundError(f"A2A agent '{agent.name}' is disabled")
4970 # Force-load all attributes needed by _call_a2a_agent before detaching
4971 # (accessing them ensures they're loaded into the object's __dict__)
4972 _ = (agent.name, agent.endpoint_url, agent.agent_type, agent.protocol_version, agent.auth_type, agent.auth_value, agent.auth_query_params)
4974 # Detach agent from session so its loaded data remains accessible after close
4975 db.expunge(agent)
4977 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
4978 # This prevents "idle in transaction" connection pool exhaustion under load
4979 db.commit()
4980 db.close()
4982 # Prepare parameters for A2A invocation
4983 try:
4984 # Make the A2A agent call (agent is now detached but data is loaded)
4985 response_data = await self._call_a2a_agent(agent, arguments)
4987 # Convert A2A response to MCP ToolResult format
4988 if isinstance(response_data, dict) and "response" in response_data:
4989 val = response_data["response"]
4990 content = [TextContent(type="text", text=val if isinstance(val, str) else orjson.dumps(val).decode())]
4991 else:
4992 content = [TextContent(type="text", text=response_data if isinstance(response_data, str) else orjson.dumps(response_data).decode())]
4994 result = ToolResult(content=content, is_error=False)
4996 except Exception as e:
4997 error_message = str(e)
4998 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")]
4999 result = ToolResult(content=content, is_error=True)
5001 # Note: Metrics are recorded by the calling invoke_tool method, not here
5002 return result
5004 async def _call_a2a_agent(self, agent: DbA2AAgent, parameters: Dict[str, Any]):
5005 """Call an A2A agent directly.
5007 Args:
5008 agent: The A2A agent to call.
5009 parameters: Parameters for the interaction.
5011 Returns:
5012 Response from the A2A agent.
5014 Raises:
5015 Exception: If the call fails.
5016 """
5017 logger.info(f"Calling A2A agent '{agent.name}' at {agent.endpoint_url} with arguments: {parameters}")
5019 # Build request data based on agent type
5020 if agent.agent_type in ["generic", "jsonrpc"] or agent.endpoint_url.endswith("/"):
5021 # JSONRPC agents: Convert flat query to nested message structure
5022 params = None
5023 if isinstance(parameters, dict) and "query" in parameters and isinstance(parameters["query"], str):
5024 # Build the nested message object for JSONRPC protocol
5025 message_id = f"admin-test-{int(time.time())}"
5026 # A2A v0.3.x: message.parts use "kind" (not "type").
5027 params = {
5028 "message": {
5029 "kind": "message",
5030 "messageId": message_id,
5031 "role": "user",
5032 "parts": [{"kind": "text", "text": parameters["query"]}],
5033 }
5034 }
5035 method = parameters.get("method", "message/send")
5036 else:
5037 # Already in correct format or unknown, pass through
5038 params = parameters.get("params", parameters)
5039 method = parameters.get("method", "message/send")
5041 try:
5042 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1}
5043 logger.info(f"invoke tool JSONRPC request_data prepared: {request_data}")
5044 except Exception as e:
5045 logger.error(f"Error preparing JSONRPC request data: {e}")
5046 raise
5047 else:
5048 # Custom agents: Pass parameters directly without JSONRPC message conversion
5049 # Custom agents expect flat fields like {"query": "...", "message": "..."}
5050 params = parameters if isinstance(parameters, dict) else {}
5051 logger.info(f"invoke tool Using custom A2A format for A2A agent '{params}'")
5052 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": agent.protocol_version}
5053 logger.info(f"invoke tool request_data prepared: {request_data}")
5054 # Make HTTP request to the agent endpoint using shared HTTP client
5055 # First-Party
5056 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
5058 client = await get_http_client()
5059 headers = {"Content-Type": "application/json"}
5061 # Determine the endpoint URL (may be modified for query_param auth)
5062 endpoint_url = agent.endpoint_url
5064 # Add authentication if configured
5065 if agent.auth_type == "api_key" and agent.auth_value:
5066 headers["Authorization"] = f"Bearer {agent.auth_value}"
5067 elif agent.auth_type == "bearer" and agent.auth_value:
5068 headers["Authorization"] = f"Bearer {agent.auth_value}"
5069 elif agent.auth_type == "query_param" and agent.auth_query_params:
5070 # Handle query parameter authentication (imports at top: decode_auth, apply_query_param_auth, sanitize_url_for_logging)
5071 auth_query_params_decrypted: dict[str, str] = {}
5072 for param_key, encrypted_value in agent.auth_query_params.items():
5073 if encrypted_value:
5074 try:
5075 decrypted = decode_auth(encrypted_value)
5076 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
5077 except Exception:
5078 logger.debug(f"Failed to decrypt query param for key '{param_key}'")
5079 if auth_query_params_decrypted:
5080 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted)
5081 # Log sanitized URL to avoid credential leakage
5082 sanitized_url = sanitize_url_for_logging(endpoint_url, auth_query_params_decrypted)
5083 logger.debug(f"Applied query param auth to A2A agent endpoint: {sanitized_url}")
5085 http_response = await client.post(endpoint_url, json=request_data, headers=headers)
5087 if http_response.status_code == 200:
5088 return http_response.json()
5090 raise Exception(f"HTTP {http_response.status_code}: {http_response.text}")
5093# Lazy singleton - created on first access, not at module import time.
5094# This avoids instantiation when only exception classes are imported.
5095_tool_service_instance = None # pylint: disable=invalid-name
5098def __getattr__(name: str):
5099 """Module-level __getattr__ for lazy singleton creation.
5101 Args:
5102 name: The attribute name being accessed.
5104 Returns:
5105 The tool_service singleton instance if name is "tool_service".
5107 Raises:
5108 AttributeError: If the attribute name is not "tool_service".
5109 """
5110 global _tool_service_instance # pylint: disable=global-statement
5111 if name == "tool_service":
5112 if _tool_service_instance is None:
5113 _tool_service_instance = ToolService()
5114 return _tool_service_instance
5115 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")