Coverage for mcpgateway / services / tool_service.py: 99%
1611 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/tool_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Tool Service Implementation.
8This module implements tool management and invocation according to the MCP specification.
9It handles:
10- Tool registration and validation
11- Tool invocation with schema validation
12- Tool federation across gateways
13- Event notifications for tool changes
14- Active/inactive tool management
15"""
17# Standard
18import asyncio
19import base64
20import binascii
21from datetime import datetime, timezone
22from functools import lru_cache
23import os
24import re
25import ssl
26import time
27from types import SimpleNamespace
28from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
29from urllib.parse import parse_qs, urlparse
30import uuid
32# Third-Party
33import httpx
34import jq
35import jsonschema
36from jsonschema import Draft4Validator, Draft6Validator, Draft7Validator, validators
37from mcp import ClientSession
38from mcp.client.sse import sse_client
39from mcp.client.streamable_http import streamablehttp_client
40import orjson
41from pydantic import ValidationError
42from sqlalchemy import and_, delete, desc, or_, select
43from sqlalchemy.exc import IntegrityError, OperationalError
44from sqlalchemy.orm import joinedload, selectinload, Session
46# First-Party
47from mcpgateway.cache.global_config_cache import global_config_cache
48from mcpgateway.common.models import Gateway as PydanticGateway
49from mcpgateway.common.models import TextContent
50from mcpgateway.common.models import Tool as PydanticTool
51from mcpgateway.common.models import ToolResult
52from mcpgateway.config import settings
53from mcpgateway.db import A2AAgent as DbA2AAgent
54from mcpgateway.db import fresh_db_session
55from mcpgateway.db import Gateway as DbGateway
56from mcpgateway.db import get_for_update, server_tool_association
57from mcpgateway.db import Tool as DbTool
58from mcpgateway.db import ToolMetric, ToolMetricsHourly
59from mcpgateway.observability import create_span
60from mcpgateway.plugins.framework import (
61 GlobalContext,
62 HttpHeaderPayload,
63 PluginContextTable,
64 PluginError,
65 PluginManager,
66 PluginViolationError,
67 ToolHookType,
68 ToolPostInvokePayload,
69 ToolPreInvokePayload,
70)
71from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA
72from mcpgateway.schemas import AuthenticationValues, ToolCreate, ToolRead, ToolUpdate, TopPerformer
73from mcpgateway.services.audit_trail_service import get_audit_trail_service
74from mcpgateway.services.event_service import EventService
75from mcpgateway.services.logging_service import LoggingService
76from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType
77from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge
78from mcpgateway.services.metrics_query_service import get_top_performers_combined
79from mcpgateway.services.oauth_manager import OAuthManager
80from mcpgateway.services.observability_service import current_trace_id, ObservabilityService
81from mcpgateway.services.performance_tracker import get_performance_tracker
82from mcpgateway.services.structured_logger import get_structured_logger
83from mcpgateway.services.team_management_service import TeamManagementService
84from mcpgateway.utils.correlation_id import get_correlation_id
85from mcpgateway.utils.create_slug import slugify
86from mcpgateway.utils.display_name import generate_display_name
87from mcpgateway.utils.metrics_common import build_top_performers
88from mcpgateway.utils.pagination import decode_cursor, encode_cursor, unified_paginate
89from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
90from mcpgateway.utils.retry_manager import ResilientHttpClient
91from mcpgateway.utils.services_auth import decode_auth
92from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
93from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context
94from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging
95from mcpgateway.utils.validate_signature import validate_signature
97# Cache import (lazy to avoid circular dependencies)
98_REGISTRY_CACHE = None
99_TOOL_LOOKUP_CACHE = None
102def _get_registry_cache():
103 """Get registry cache singleton lazily.
105 Returns:
106 RegistryCache instance.
107 """
108 global _REGISTRY_CACHE # pylint: disable=global-statement
109 if _REGISTRY_CACHE is None:
110 # First-Party
111 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
113 _REGISTRY_CACHE = registry_cache
114 return _REGISTRY_CACHE
117def _get_tool_lookup_cache():
118 """Get tool lookup cache singleton lazily.
120 Returns:
121 ToolLookupCache instance.
122 """
123 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement
124 if _TOOL_LOOKUP_CACHE is None:
125 # First-Party
126 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
128 _TOOL_LOOKUP_CACHE = tool_lookup_cache
129 return _TOOL_LOOKUP_CACHE
132# Initialize logging service first
133logging_service = LoggingService()
134logger = logging_service.get_logger(__name__)
136# Initialize performance tracker, structured logger, and audit trail for tool operations
137perf_tracker = get_performance_tracker()
138structured_logger = get_structured_logger("tool_service")
139audit_trail = get_audit_trail_service()
142@lru_cache(maxsize=256)
143def _compile_jq_filter(jq_filter: str):
144 """Cache compiled jq filter program.
146 Args:
147 jq_filter: The jq filter string to compile.
149 Returns:
150 Compiled jq program object.
152 Raises:
153 ValueError: If the jq filter is invalid.
154 """
155 # pylint: disable=c-extension-no-member
156 return jq.compile(jq_filter)
159@lru_cache(maxsize=128)
160def _get_validator_class_and_check(schema_json: str) -> Tuple[type, dict]:
161 """Cache schema validation and validator class selection.
163 This caches the expensive operations:
164 1. Deserializing the schema
165 2. Selecting the appropriate validator class based on $schema
166 3. Checking the schema is valid
168 Supports multiple JSON Schema drafts by using fallback validators when the
169 auto-detected validator fails. This handles schemas using older draft features
170 (e.g., Draft 4 style exclusiveMinimum: true) that are invalid in newer drafts.
172 Args:
173 schema_json: Canonical JSON string of the schema (used as cache key).
175 Returns:
176 Tuple of (validator_class, schema_dict) ready for instantiation.
177 """
178 schema = orjson.loads(schema_json)
180 # First try auto-detection based on $schema
181 validator_cls = validators.validator_for(schema)
182 try:
183 validator_cls.check_schema(schema)
184 return validator_cls, schema
185 except jsonschema.exceptions.SchemaError:
186 pass
188 # Fallback: try older drafts that may accept schemas with legacy features
189 # (e.g., Draft 4/6 style boolean exclusiveMinimum/exclusiveMaximum)
190 for fallback_cls in [Draft7Validator, Draft6Validator, Draft4Validator]:
191 try:
192 fallback_cls.check_schema(schema)
193 return fallback_cls, schema
194 except jsonschema.exceptions.SchemaError:
195 continue
197 # If no validator accepts the schema, use the original and let it fail
198 # with a clear error message during validation
199 validator_cls.check_schema(schema)
200 return validator_cls, schema
203def _canonicalize_schema(schema: dict) -> str:
204 """Create a canonical JSON string of a schema for use as a cache key.
206 Args:
207 schema: The JSON Schema dictionary.
209 Returns:
210 Canonical JSON string with sorted keys.
211 """
212 return orjson.dumps(schema, option=orjson.OPT_SORT_KEYS).decode()
215def _validate_with_cached_schema(instance: Any, schema: dict) -> None:
216 # noqa: DAR401
217 """Validate instance against schema using cached validator class.
219 Creates a fresh validator instance for thread safety, but reuses
220 the cached validator class and schema check. Uses best_match to
221 preserve jsonschema.validate() error selection semantics.
223 Args:
224 instance: The data to validate.
225 schema: The JSON Schema to validate against.
227 Raises:
228 jsonschema.exceptions.ValidationError: If validation fails.
229 """
230 schema_json = _canonicalize_schema(schema)
231 validator_cls, checked_schema = _get_validator_class_and_check(schema_json)
232 # Create fresh validator instance for thread safety
233 validator = validator_cls(checked_schema)
234 # Use best_match to match jsonschema.validate() error selection behavior
235 error = jsonschema.exceptions.best_match(validator.iter_errors(instance))
236 if error is not None:
237 raise error
240def extract_using_jq(data, jq_filter=""):
241 """
242 Extracts data from a given input (string, dict, or list) using a jq filter string.
244 Uses cached compiled jq programs for performance.
246 Args:
247 data (str, dict, list): The input JSON data. Can be a string, dict, or list.
248 jq_filter (str): The jq filter string to extract the desired data.
250 Returns:
251 The result of applying the jq filter to the input data.
253 Examples:
254 >>> extract_using_jq('{"a": 1, "b": 2}', '.a')
255 [1]
256 >>> extract_using_jq({'a': 1, 'b': 2}, '.b')
257 [2]
258 >>> extract_using_jq('[{"a": 1}, {"a": 2}]', '.[].a')
259 [1, 2]
260 >>> extract_using_jq('not a json', '.a')
261 ['Invalid JSON string provided.']
262 >>> extract_using_jq({'a': 1}, '')
263 {'a': 1}
264 """
265 if jq_filter == "":
266 return data
268 # Track if input was originally a string (for error handling)
269 was_string = isinstance(data, str)
271 if was_string:
272 # If the input is a string, parse it as JSON
273 try:
274 data = orjson.loads(data)
275 except orjson.JSONDecodeError:
276 return ["Invalid JSON string provided."]
277 elif not isinstance(data, (dict, list)):
278 # If the input is not a string, dict, or list, raise an error
279 return ["Input data must be a JSON string, dictionary, or list."]
281 # Apply the jq filter to the data using cached compiled program
282 try:
283 program = _compile_jq_filter(jq_filter)
284 result = program.input(data).all()
285 if result == [None]:
286 result = "Error applying jsonpath filter"
287 except Exception as e:
288 message = "Error applying jsonpath filter: " + str(e)
289 return message
291 return result
294class ToolError(Exception):
295 """Base class for tool-related errors.
297 Examples:
298 >>> from mcpgateway.services.tool_service import ToolError
299 >>> err = ToolError("Something went wrong")
300 >>> str(err)
301 'Something went wrong'
302 """
305class ToolNotFoundError(ToolError):
306 """Raised when a requested tool is not found.
308 Examples:
309 >>> from mcpgateway.services.tool_service import ToolNotFoundError
310 >>> err = ToolNotFoundError("Tool xyz not found")
311 >>> str(err)
312 'Tool xyz not found'
313 >>> isinstance(err, ToolError)
314 True
315 """
318class ToolNameConflictError(ToolError):
319 """Raised when a tool name conflicts with existing (active or inactive) tool."""
321 def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = None, visibility: str = "public"):
322 """Initialize the error with tool information.
324 Args:
325 name: The conflicting tool name.
326 enabled: Whether the existing tool is enabled or not.
327 tool_id: ID of the existing tool if available.
328 visibility: The visibility of the tool ("public" or "team").
330 Examples:
331 >>> from mcpgateway.services.tool_service import ToolNameConflictError
332 >>> err = ToolNameConflictError('test_tool', enabled=False, tool_id=123)
333 >>> str(err)
334 'Public Tool already exists with name: test_tool (currently inactive, ID: 123)'
335 >>> err.name
336 'test_tool'
337 >>> err.enabled
338 False
339 >>> err.tool_id
340 123
341 """
342 self.name = name
343 self.enabled = enabled
344 self.tool_id = tool_id
345 if visibility == "team":
346 vis_label = "Team-level"
347 else:
348 vis_label = "Public"
349 message = f"{vis_label} Tool already exists with name: {name}"
350 if not enabled:
351 message += f" (currently inactive, ID: {tool_id})"
352 super().__init__(message)
355class ToolLockConflictError(ToolError):
356 """Raised when a tool row is locked by another transaction."""
359class ToolValidationError(ToolError):
360 """Raised when tool validation fails.
362 Examples:
363 >>> from mcpgateway.services.tool_service import ToolValidationError
364 >>> err = ToolValidationError("Invalid tool configuration")
365 >>> str(err)
366 'Invalid tool configuration'
367 >>> isinstance(err, ToolError)
368 True
369 """
372class ToolInvocationError(ToolError):
373 """Raised when tool invocation fails.
375 Examples:
376 >>> from mcpgateway.services.tool_service import ToolInvocationError
377 >>> err = ToolInvocationError("Tool execution failed")
378 >>> str(err)
379 'Tool execution failed'
380 >>> isinstance(err, ToolError)
381 True
382 >>> # Test with detailed error
383 >>> detailed_err = ToolInvocationError("Network timeout after 30 seconds")
384 >>> "timeout" in str(detailed_err)
385 True
386 >>> isinstance(err, ToolError)
387 True
388 """
391class ToolTimeoutError(ToolInvocationError):
392 """Raised when tool invocation times out.
394 This subclass is used to distinguish timeout errors from other invocation errors.
395 Timeout handlers call tool_post_invoke before raising this, so the generic exception
396 handler should skip calling post_invoke again to avoid double-counting failures.
397 """
400class ToolService:
401 """Service for managing and invoking tools.
403 Handles:
404 - Tool registration and deregistration.
405 - Tool invocation and validation.
406 - Tool federation.
407 - Event notifications.
408 - Active/inactive tool management.
409 """
411 def __init__(self) -> None:
412 """Initialize the tool service.
414 Examples:
415 >>> from mcpgateway.services.tool_service import ToolService
416 >>> service = ToolService()
417 >>> isinstance(service._event_service, EventService)
418 True
419 >>> hasattr(service, '_http_client')
420 True
421 """
422 self._event_service = EventService(channel_name="mcpgateway:tool_events")
423 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
424 # Initialize plugin manager with env overrides to ease testing
425 env_flag = os.getenv("PLUGINS_ENABLED")
426 if env_flag is not None:
427 env_enabled = env_flag.strip().lower() in {"1", "true", "yes", "on"}
428 plugins_enabled = env_enabled
429 else:
430 plugins_enabled = settings.plugins_enabled
431 config_file = os.getenv("PLUGIN_CONFIG_FILE", getattr(settings, "plugin_config_file", "plugins/config.yaml"))
432 self._plugin_manager: PluginManager | None = PluginManager(config_file) if plugins_enabled else None
433 self.oauth_manager = OAuthManager(
434 request_timeout=int(settings.oauth_request_timeout if hasattr(settings, "oauth_request_timeout") else 30),
435 max_retries=int(settings.oauth_max_retries if hasattr(settings, "oauth_max_retries") else 3),
436 )
438 async def initialize(self) -> None:
439 """Initialize the service.
441 Examples:
442 >>> from mcpgateway.services.tool_service import ToolService
443 >>> service = ToolService()
444 >>> import asyncio
445 >>> asyncio.run(service.initialize()) # Should log "Initializing tool service"
446 """
447 logger.info("Initializing tool service")
448 await self._event_service.initialize()
450 async def shutdown(self) -> None:
451 """Shutdown the service.
453 Examples:
454 >>> from mcpgateway.services.tool_service import ToolService
455 >>> service = ToolService()
456 >>> import asyncio
457 >>> asyncio.run(service.shutdown()) # Should log "Tool service shutdown complete"
458 """
459 await self._http_client.aclose()
460 await self._event_service.shutdown()
461 logger.info("Tool service shutdown complete")
463 async def get_top_tools(self, db: Session, limit: Optional[int] = 5, include_deleted: bool = False) -> List[TopPerformer]:
464 """Retrieve the top-performing tools based on execution count.
466 Queries the database to get tools with their metrics, ordered by the number of executions
467 in descending order. Returns a list of TopPerformer objects containing tool details and
468 performance metrics. Results are cached for performance.
470 Args:
471 db (Session): Database session for querying tool metrics.
472 limit (Optional[int]): Maximum number of tools to return. Defaults to 5.
473 include_deleted (bool): Whether to include deleted tools from rollups.
475 Returns:
476 List[TopPerformer]: A list of TopPerformer objects, each containing:
477 - id: Tool ID.
478 - name: Tool name.
479 - execution_count: Total number of executions.
480 - avg_response_time: Average response time in seconds, or None if no metrics.
481 - success_rate: Success rate percentage, or None if no metrics.
482 - last_execution: Timestamp of the last execution, or None if no metrics.
483 """
484 # Check cache first (if enabled)
485 # First-Party
486 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
488 effective_limit = limit or 5
489 cache_key = f"top_tools:{effective_limit}:include_deleted={include_deleted}"
491 if is_cache_enabled():
492 cached = metrics_cache.get(cache_key)
493 if cached is not None:
494 return cached
496 # Use combined query that includes both raw metrics and rollup data
497 results = get_top_performers_combined(
498 db=db,
499 metric_type="tool",
500 entity_model=DbTool,
501 limit=effective_limit,
502 include_deleted=include_deleted,
503 )
504 top_performers = build_top_performers(results)
506 # Cache the result (if enabled)
507 if is_cache_enabled():
508 metrics_cache.set(cache_key, top_performers)
510 return top_performers
512 def _build_tool_cache_payload(self, tool: DbTool, gateway: Optional[DbGateway]) -> Dict[str, Any]:
513 """Build cache payload for tool lookup by name.
515 Args:
516 tool: Tool ORM instance.
517 gateway: Optional gateway ORM instance.
519 Returns:
520 Cache payload dict for tool lookup.
521 """
522 tool_payload = {
523 "id": str(tool.id),
524 "name": tool.name,
525 "original_name": tool.original_name,
526 "url": tool.url,
527 "description": tool.description,
528 "integration_type": tool.integration_type,
529 "request_type": tool.request_type,
530 "headers": tool.headers or {},
531 "input_schema": tool.input_schema or {"type": "object", "properties": {}},
532 "output_schema": tool.output_schema,
533 "annotations": tool.annotations or {},
534 "auth_type": tool.auth_type,
535 "auth_value": tool.auth_value,
536 "oauth_config": getattr(tool, "oauth_config", None),
537 "jsonpath_filter": tool.jsonpath_filter,
538 "custom_name": tool.custom_name,
539 "custom_name_slug": tool.custom_name_slug,
540 "display_name": tool.display_name,
541 "gateway_id": str(tool.gateway_id) if tool.gateway_id else None,
542 "enabled": bool(tool.enabled),
543 "reachable": bool(tool.reachable),
544 "tags": tool.tags or [],
545 "team_id": tool.team_id,
546 "owner_email": tool.owner_email,
547 "visibility": tool.visibility,
548 }
550 gateway_payload = None
551 if gateway:
552 gateway_payload = {
553 "id": str(gateway.id),
554 "name": gateway.name,
555 "url": gateway.url,
556 "description": gateway.description,
557 "slug": gateway.slug,
558 "transport": gateway.transport,
559 "capabilities": gateway.capabilities or {},
560 "passthrough_headers": gateway.passthrough_headers or [],
561 "auth_type": gateway.auth_type,
562 "auth_value": gateway.auth_value,
563 "auth_query_params": getattr(gateway, "auth_query_params", None), # Query param auth
564 "oauth_config": getattr(gateway, "oauth_config", None),
565 "ca_certificate": getattr(gateway, "ca_certificate", None),
566 "ca_certificate_sig": getattr(gateway, "ca_certificate_sig", None),
567 "enabled": bool(gateway.enabled),
568 "reachable": bool(gateway.reachable),
569 "team_id": gateway.team_id,
570 "owner_email": gateway.owner_email,
571 "visibility": gateway.visibility,
572 "tags": gateway.tags or [],
573 }
575 return {"status": "active", "tool": tool_payload, "gateway": gateway_payload}
577 def _pydantic_tool_from_payload(self, tool_payload: Dict[str, Any]) -> Optional[PydanticTool]:
578 """Build Pydantic tool metadata from cache payload.
580 Args:
581 tool_payload: Cached tool payload dict.
583 Returns:
584 Pydantic tool metadata or None if validation fails.
585 """
586 try:
587 return PydanticTool.model_validate(tool_payload)
588 except Exception as exc:
589 logger.debug("Failed to build PydanticTool from cache payload: %s", exc)
590 return None
592 def _pydantic_gateway_from_payload(self, gateway_payload: Dict[str, Any]) -> Optional[PydanticGateway]:
593 """Build Pydantic gateway metadata from cache payload.
595 Args:
596 gateway_payload: Cached gateway payload dict.
598 Returns:
599 Pydantic gateway metadata or None if validation fails.
600 """
601 try:
602 return PydanticGateway.model_validate(gateway_payload)
603 except Exception as exc:
604 logger.debug("Failed to build PydanticGateway from cache payload: %s", exc)
605 return None
607 async def _check_tool_access(
608 self,
609 db: Session,
610 tool_payload: Dict[str, Any],
611 user_email: Optional[str],
612 token_teams: Optional[List[str]],
613 ) -> bool:
614 """Check if user has access to a tool based on visibility rules.
616 Implements the same access control logic as list_tools() for consistency.
618 Access Rules:
619 - Public tools: Accessible by all authenticated users
620 - Team tools: Accessible by team members (team_id in user's teams)
621 - Private tools: Accessible only by owner (owner_email matches)
623 Args:
624 db: Database session for team membership lookup if needed.
625 tool_payload: Tool data dict with visibility, team_id, owner_email.
626 user_email: Email of the requesting user (None = unauthenticated).
627 token_teams: List of team IDs from token.
628 - None = unrestricted admin access
629 - [] = public-only token
630 - [...] = team-scoped token
632 Returns:
633 True if access is allowed, False otherwise.
634 """
635 visibility = tool_payload.get("visibility", "public")
636 tool_team_id = tool_payload.get("team_id")
637 tool_owner_email = tool_payload.get("owner_email")
639 # Public tools are accessible by everyone
640 if visibility == "public":
641 return True
643 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin
644 # This happens when is_admin=True and no team scoping in token
645 if token_teams is None and user_email is None:
646 return True
648 # No user context (but not admin) = deny access to non-public tools
649 if not user_email:
650 return False
652 # Public-only tokens (empty teams array) can ONLY access public tools
653 is_public_only_token = token_teams is not None and len(token_teams) == 0
654 if is_public_only_token:
655 return False # Already checked public above
657 # Owner can always access their own tools
658 if tool_owner_email and tool_owner_email == user_email:
659 return True
661 # Team tools: check team membership (matches list_tools behavior)
662 if tool_team_id:
663 # Use token_teams if provided, otherwise look up from DB
664 if token_teams is not None:
665 team_ids = token_teams
666 else:
667 team_service = TeamManagementService(db)
668 user_teams = await team_service.get_user_teams(user_email)
669 team_ids = [team.id for team in user_teams]
671 # Team/public visibility allows access if user is in the team
672 if visibility in ["team", "public"] and tool_team_id in team_ids:
673 return True
675 return False
677 def convert_tool_to_read(
678 self,
679 tool: DbTool,
680 include_metrics: bool = False,
681 include_auth: bool = True,
682 requesting_user_email: Optional[str] = None,
683 requesting_user_is_admin: bool = False,
684 requesting_user_team_roles: Optional[Dict[str, str]] = None,
685 ) -> ToolRead:
686 """Converts a DbTool instance into a ToolRead model, including aggregated metrics and
687 new API gateway fields: request_type and authentication credentials (masked).
689 Args:
690 tool (DbTool): The ORM instance of the tool.
691 include_metrics (bool): Whether to include metrics in the result. Defaults to False.
692 include_auth (bool): Whether to decode and include auth details. Defaults to True.
693 When False, skips expensive AES-GCM decryption and returns minimal auth info.
694 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
695 requesting_user_is_admin (bool): Whether the requester is an admin.
696 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
698 Returns:
699 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields.
700 """
701 # NOTE: This serves two purposes:
702 # 1. It determines whether to decode auth (used later)
703 # 2. It forces the tool object to lazily evaluate (required before copy)
704 has_encrypted_auth = tool.auth_type and tool.auth_value
706 # Copy the dict from the tool
707 tool_dict = tool.__dict__.copy()
708 tool_dict.pop("_sa_instance_state", None)
710 # Compute metrics in a single pass (matches server/resource/prompt service pattern)
711 if include_metrics:
712 metrics = tool.metrics_summary # Single-pass computation
713 tool_dict["metrics"] = metrics
714 tool_dict["execution_count"] = metrics["total_executions"]
715 else:
716 tool_dict["metrics"] = None
717 tool_dict["execution_count"] = None
719 tool_dict["request_type"] = tool.request_type
720 tool_dict["annotations"] = tool.annotations or {}
722 # Only decode auth if include_auth=True AND we have encrypted credentials
723 if include_auth and has_encrypted_auth:
724 decoded_auth_value = decode_auth(tool.auth_value)
725 if tool.auth_type == "basic":
726 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1])
727 username, password = decoded_bytes.decode("utf-8").split(":")
728 tool_dict["auth"] = {
729 "auth_type": "basic",
730 "username": username,
731 "password": settings.masked_auth_value if password else None,
732 }
733 elif tool.auth_type == "bearer":
734 tool_dict["auth"] = {
735 "auth_type": "bearer",
736 "token": settings.masked_auth_value if decoded_auth_value["Authorization"] else None,
737 }
738 elif tool.auth_type == "authheaders":
739 # Get first key
740 if decoded_auth_value:
741 first_key = next(iter(decoded_auth_value))
742 tool_dict["auth"] = {
743 "auth_type": "authheaders",
744 "auth_header_key": first_key,
745 "auth_header_value": settings.masked_auth_value if decoded_auth_value[first_key] else None,
746 }
747 else:
748 tool_dict["auth"] = None
749 else:
750 tool_dict["auth"] = None
751 elif not include_auth and has_encrypted_auth:
752 # LIST VIEW: Minimal auth info without decryption
753 # Only show auth_type for tools that have encrypted credentials
754 tool_dict["auth"] = {"auth_type": tool.auth_type}
755 else:
756 # No encrypted auth (includes OAuth tools where auth_value=None)
757 # Behavior unchanged from current implementation
758 tool_dict["auth"] = None
760 tool_dict["name"] = tool.name
761 # Handle displayName with fallback and None checks
762 display_name = getattr(tool, "display_name", None)
763 custom_name = getattr(tool, "custom_name", tool.original_name)
764 tool_dict["displayName"] = display_name or custom_name
765 tool_dict["custom_name"] = custom_name
766 tool_dict["gateway_slug"] = getattr(tool, "gateway_slug", "") or ""
767 tool_dict["custom_name_slug"] = getattr(tool, "custom_name_slug", "") or ""
768 tool_dict["tags"] = getattr(tool, "tags", []) or []
769 tool_dict["team"] = getattr(tool, "team", None)
771 # Mask custom headers unless the requester is allowed to modify this tool.
772 # Safe default: if no requester context is provided, mask everything.
773 headers = tool_dict.get("headers")
774 if headers:
775 can_view = requesting_user_is_admin
776 if not can_view and getattr(tool, "owner_email", None) == requesting_user_email:
777 can_view = True
778 if (
779 not can_view
780 and getattr(tool, "visibility", None) == "team"
781 and getattr(tool, "team_id", None) is not None
782 and requesting_user_team_roles
783 and requesting_user_team_roles.get(str(tool.team_id)) == "owner"
784 ):
785 can_view = True
786 if not can_view:
787 tool_dict["headers"] = {k: settings.masked_auth_value for k in headers}
789 return ToolRead.model_validate(tool_dict)
791 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None:
792 """
793 Records a metric for a tool invocation.
795 This function calculates the response time using the provided start time and records
796 the metric details (including whether the invocation was successful and any error message)
797 into the database. The metric is then committed to the database.
799 Args:
800 db (Session): The SQLAlchemy database session.
801 tool (DbTool): The tool that was invoked.
802 start_time (float): The monotonic start time of the invocation.
803 success (bool): True if the invocation succeeded; otherwise, False.
804 error_message (Optional[str]): The error message if the invocation failed, otherwise None.
805 """
806 end_time = time.monotonic()
807 response_time = end_time - start_time
808 metric = ToolMetric(
809 tool_id=tool.id,
810 response_time=response_time,
811 is_success=success,
812 error_message=error_message,
813 )
814 db.add(metric)
815 db.commit()
817 def _record_tool_metric_by_id(
818 self,
819 db: Session,
820 tool_id: str,
821 start_time: float,
822 success: bool,
823 error_message: Optional[str],
824 ) -> None:
825 """Record tool metric using tool ID instead of ORM object.
827 This method is designed to be used with a fresh database session after the main
828 request session has been released. It avoids requiring the ORM tool object,
829 which may have been detached from the session.
831 Args:
832 db: A fresh database session (not the request session).
833 tool_id: The UUID string of the tool.
834 start_time: The monotonic start time of the invocation.
835 success: True if the invocation succeeded; otherwise, False.
836 error_message: The error message if the invocation failed, otherwise None.
837 """
838 end_time = time.monotonic()
839 response_time = end_time - start_time
840 metric = ToolMetric(
841 tool_id=tool_id,
842 response_time=response_time,
843 is_success=success,
844 error_message=error_message,
845 )
846 db.add(metric)
847 db.commit()
849 def _record_tool_metric_sync(
850 self,
851 tool_id: str,
852 start_time: float,
853 success: bool,
854 error_message: Optional[str],
855 ) -> None:
856 """Synchronous helper to record tool metrics with its own session.
858 This method creates a fresh database session, records the metric, and closes
859 the session. Designed to be called via asyncio.to_thread() to avoid blocking
860 the event loop.
862 Args:
863 tool_id: The UUID string of the tool.
864 start_time: The monotonic start time of the invocation.
865 success: True if the invocation succeeded; otherwise, False.
866 error_message: The error message if the invocation failed, otherwise None.
867 """
868 with fresh_db_session() as db_metrics:
869 self._record_tool_metric_by_id(
870 db_metrics,
871 tool_id=tool_id,
872 start_time=start_time,
873 success=success,
874 error_message=error_message,
875 )
877 def _extract_and_validate_structured_content(self, tool: DbTool, tool_result: "ToolResult", candidate: Optional[Any] = None) -> bool:
878 """
879 Extract structured content (if any) and validate it against ``tool.output_schema``.
881 Args:
882 tool: The tool with an optional output schema to validate against.
883 tool_result: The tool result containing content to validate.
884 candidate: Optional structured payload to validate. If not provided, will attempt
885 to parse the first TextContent item as JSON.
887 Behavior:
888 - If ``candidate`` is provided it is used as the structured payload to validate.
889 - Otherwise the method will try to parse the first ``TextContent`` item in
890 ``tool_result.content`` as JSON and use that as the candidate.
891 - If no output schema is declared on the tool the method returns True (nothing to validate).
892 - On successful validation the parsed value is attached to ``tool_result.structured_content``.
893 When structured content is present and valid callers may drop textual ``content`` in favour
894 of the structured payload.
895 - On validation failure the method sets ``tool_result.content`` to a single ``TextContent``
896 containing a compact JSON object describing the validation error, sets
897 ``tool_result.is_error = True`` and returns False.
899 Returns:
900 True when the structured content is valid or when no schema is declared.
901 False when validation fails.
903 Examples:
904 >>> from mcpgateway.services.tool_service import ToolService
905 >>> from mcpgateway.common.models import TextContent, ToolResult
906 >>> import json
907 >>> service = ToolService()
908 >>> # No schema declared -> nothing to validate
909 >>> tool = type("T", (object,), {"output_schema": None})()
910 >>> r = ToolResult(content=[TextContent(type="text", text='{"a":1}')])
911 >>> service._extract_and_validate_structured_content(tool, r)
912 True
914 >>> # Valid candidate provided -> attaches structured_content and returns True
915 >>> tool = type(
916 ... "T",
917 ... (object,),
918 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}},
919 ... )()
920 >>> r = ToolResult(content=[])
921 >>> service._extract_and_validate_structured_content(tool, r, candidate={"foo": "bar"})
922 True
923 >>> r.structured_content == {"foo": "bar"}
924 True
926 >>> # Invalid candidate -> returns False, marks result as error and emits details
927 >>> tool = type(
928 ... "T",
929 ... (object,),
930 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}},
931 ... )()
932 >>> r = ToolResult(content=[])
933 >>> ok = service._extract_and_validate_structured_content(tool, r, candidate={"foo": 123})
934 >>> ok
935 False
936 >>> r.is_error
937 True
938 >>> details = orjson.loads(r.content[0].text)
939 >>> "received" in details
940 True
941 """
942 try:
943 output_schema = getattr(tool, "output_schema", None)
944 # Nothing to do if the tool doesn't declare a schema
945 if not output_schema:
946 return True
948 structured: Optional[Any] = None
949 # Prefer explicit candidate
950 if candidate is not None:
951 structured = candidate
952 else:
953 # Try to parse first TextContent text payload as JSON
954 for c in getattr(tool_result, "content", []) or []:
955 try:
956 if isinstance(c, dict) and "type" in c and c.get("type") == "text" and "text" in c:
957 structured = orjson.loads(c.get("text") or "null")
958 break
959 except (orjson.JSONDecodeError, TypeError, ValueError):
960 # ignore JSON parse errors and continue
961 continue
963 # If no structured data found, treat as valid (nothing to validate)
964 if structured is None:
965 return True
967 # Try to normalize common wrapper shapes to match schema expectations
968 schema_type = None
969 try:
970 if isinstance(output_schema, dict):
971 schema_type = output_schema.get("type")
972 except Exception:
973 schema_type = None
975 # Unwrap single-element list wrappers when schema expects object
976 if isinstance(structured, list) and len(structured) == 1 and schema_type == "object":
977 inner = structured[0]
978 # If inner is a TextContent-like dict with 'text' JSON string, parse it
979 if isinstance(inner, dict) and "text" in inner and "type" in inner and inner.get("type") == "text":
980 try:
981 structured = orjson.loads(inner.get("text") or "null")
982 except Exception:
983 # leave as-is if parsing fails
984 structured = inner
985 else:
986 structured = inner
988 # Attach structured content
989 try:
990 setattr(tool_result, "structured_content", structured)
991 except Exception:
992 logger.debug("Failed to set structured_content on ToolResult")
994 # Validate using cached schema validator
995 try:
996 _validate_with_cached_schema(structured, output_schema)
997 return True
998 except jsonschema.exceptions.ValidationError as e:
999 details = {
1000 "code": getattr(e, "validator", "validation_error"),
1001 "expected": e.schema.get("type") if isinstance(e.schema, dict) and "type" in e.schema else None,
1002 "received": type(e.instance).__name__.lower() if e.instance is not None else None,
1003 "path": list(e.absolute_path) if hasattr(e, "absolute_path") else list(e.path or []),
1004 "message": e.message,
1005 }
1006 try:
1007 tool_result.content = [TextContent(type="text", text=orjson.dumps(details).decode())]
1008 except Exception:
1009 tool_result.content = [TextContent(type="text", text=str(details))]
1010 tool_result.is_error = True
1011 logger.debug(f"structured_content validation failed for tool {getattr(tool, 'name', '<unknown>')}: {details}")
1012 return False
1013 except Exception as exc: # pragma: no cover - defensive
1014 logger.error(f"Error extracting/validating structured_content: {exc}")
1015 return False
1017 async def register_tool(
1018 self,
1019 db: Session,
1020 tool: ToolCreate,
1021 created_by: Optional[str] = None,
1022 created_from_ip: Optional[str] = None,
1023 created_via: Optional[str] = None,
1024 created_user_agent: Optional[str] = None,
1025 import_batch_id: Optional[str] = None,
1026 federation_source: Optional[str] = None,
1027 team_id: Optional[str] = None,
1028 owner_email: Optional[str] = None,
1029 visibility: str = None,
1030 ) -> ToolRead:
1031 """Register a new tool with team support.
1033 Args:
1034 db: Database session.
1035 tool: Tool creation schema.
1036 created_by: Username who created this tool.
1037 created_from_ip: IP address of creator.
1038 created_via: Creation method (ui, api, import, federation).
1039 created_user_agent: User agent of creation request.
1040 import_batch_id: UUID for bulk import operations.
1041 federation_source: Source gateway for federated tools.
1042 team_id: Optional team ID to assign tool to.
1043 owner_email: Optional owner email for tool ownership.
1044 visibility: Tool visibility (private, team, public).
1046 Returns:
1047 Created tool information.
1049 Raises:
1050 IntegrityError: If there is a database integrity error.
1051 ToolNameConflictError: If a tool with the same name and visibility public exists.
1052 ToolError: For other tool registration errors.
1054 Examples:
1055 >>> from mcpgateway.services.tool_service import ToolService
1056 >>> from unittest.mock import MagicMock, AsyncMock
1057 >>> from mcpgateway.schemas import ToolRead
1058 >>> service = ToolService()
1059 >>> db = MagicMock()
1060 >>> tool = MagicMock()
1061 >>> tool.name = 'test'
1062 >>> db.execute.return_value.scalar_one_or_none.return_value = None
1063 >>> mock_gateway = MagicMock()
1064 >>> mock_gateway.name = 'test_gateway'
1065 >>> db.add = MagicMock()
1066 >>> db.commit = MagicMock()
1067 >>> def mock_refresh(obj):
1068 ... obj.gateway = mock_gateway
1069 >>> db.refresh = MagicMock(side_effect=mock_refresh)
1070 >>> service._notify_tool_added = AsyncMock()
1071 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
1072 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
1073 >>> import asyncio
1074 >>> asyncio.run(service.register_tool(db, tool))
1075 'tool_read'
1076 """
1077 try:
1078 if tool.auth is None:
1079 auth_type = None
1080 auth_value = None
1081 else:
1082 auth_type = tool.auth.auth_type
1083 auth_value = tool.auth.auth_value
1085 if team_id is None:
1086 team_id = tool.team_id
1088 if owner_email is None:
1089 owner_email = tool.owner_email
1091 if visibility is None:
1092 visibility = tool.visibility or "public"
1093 # Check for existing tool with the same name and visibility
1094 if visibility.lower() == "public":
1095 # Check for existing public tool with the same name
1096 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "public")).scalar_one_or_none()
1097 if existing_tool:
1098 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
1099 elif visibility.lower() == "team" and team_id:
1100 # Check for existing team tool with the same name, team_id
1101 existing_tool = db.execute(
1102 select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "team", DbTool.team_id == team_id) # pylint: disable=comparison-with-callable
1103 ).scalar_one_or_none()
1104 if existing_tool:
1105 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
1107 db_tool = DbTool(
1108 original_name=tool.name,
1109 custom_name=tool.name,
1110 custom_name_slug=slugify(tool.name),
1111 display_name=tool.displayName or tool.name,
1112 url=str(tool.url),
1113 description=tool.description,
1114 integration_type=tool.integration_type,
1115 request_type=tool.request_type,
1116 headers=tool.headers,
1117 input_schema=tool.input_schema,
1118 output_schema=tool.output_schema,
1119 annotations=tool.annotations,
1120 jsonpath_filter=tool.jsonpath_filter,
1121 auth_type=auth_type,
1122 auth_value=auth_value,
1123 gateway_id=tool.gateway_id,
1124 tags=tool.tags or [],
1125 # Metadata fields
1126 created_by=created_by,
1127 created_from_ip=created_from_ip,
1128 created_via=created_via,
1129 created_user_agent=created_user_agent,
1130 import_batch_id=import_batch_id,
1131 federation_source=federation_source,
1132 version=1,
1133 # Team scoping fields
1134 team_id=team_id,
1135 owner_email=owner_email or created_by,
1136 visibility=visibility,
1137 # passthrough REST tools fields
1138 base_url=tool.base_url if tool.integration_type == "REST" else None,
1139 path_template=tool.path_template if tool.integration_type == "REST" else None,
1140 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None,
1141 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None,
1142 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None,
1143 expose_passthrough=(tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None,
1144 allowlist=tool.allowlist if tool.integration_type == "REST" else None,
1145 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None,
1146 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None,
1147 )
1148 db.add(db_tool)
1149 db.commit()
1150 db.refresh(db_tool)
1151 await self._notify_tool_added(db_tool)
1153 # Structured logging: Audit trail for tool creation
1154 audit_trail.log_action(
1155 user_id=created_by or "system",
1156 action="create_tool",
1157 resource_type="tool",
1158 resource_id=db_tool.id,
1159 resource_name=db_tool.name,
1160 user_email=owner_email,
1161 team_id=team_id,
1162 client_ip=created_from_ip,
1163 user_agent=created_user_agent,
1164 new_values={
1165 "name": db_tool.name,
1166 "display_name": db_tool.display_name,
1167 "visibility": visibility,
1168 "integration_type": db_tool.integration_type,
1169 },
1170 context={
1171 "created_via": created_via,
1172 "import_batch_id": import_batch_id,
1173 "federation_source": federation_source,
1174 },
1175 db=db,
1176 )
1178 # Structured logging: Log successful tool creation
1179 structured_logger.log(
1180 level="INFO",
1181 message="Tool created successfully",
1182 event_type="tool_created",
1183 component="tool_service",
1184 user_id=created_by,
1185 user_email=owner_email,
1186 team_id=team_id,
1187 resource_type="tool",
1188 resource_id=db_tool.id,
1189 custom_fields={
1190 "tool_name": db_tool.name,
1191 "visibility": visibility,
1192 "integration_type": db_tool.integration_type,
1193 },
1194 db=db,
1195 )
1197 # Refresh db_tool after logging commits (they expire the session objects)
1198 db.refresh(db_tool)
1200 # Invalidate cache after successful creation
1201 cache = _get_registry_cache()
1202 await cache.invalidate_tools()
1203 tool_lookup_cache = _get_tool_lookup_cache()
1204 await tool_lookup_cache.invalidate(db_tool.name, gateway_id=str(db_tool.gateway_id) if db_tool.gateway_id else None)
1205 # Also invalidate tags cache since tool tags may have changed
1206 # First-Party
1207 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1209 await admin_stats_cache.invalidate_tags()
1211 return self.convert_tool_to_read(db_tool, requesting_user_email=getattr(db_tool, "owner_email", None))
1212 except IntegrityError as ie:
1213 db.rollback()
1214 logger.error(f"IntegrityError during tool registration: {ie}")
1216 # Structured logging: Log database integrity error
1217 structured_logger.log(
1218 level="ERROR",
1219 message="Tool creation failed due to database integrity error",
1220 event_type="tool_creation_failed",
1221 component="tool_service",
1222 user_id=created_by,
1223 user_email=owner_email,
1224 error=ie,
1225 custom_fields={
1226 "tool_name": tool.name,
1227 },
1228 db=db,
1229 )
1230 raise ie
1231 except ToolNameConflictError as tnce:
1232 db.rollback()
1233 logger.error(f"ToolNameConflictError during tool registration: {tnce}")
1235 # Structured logging: Log name conflict error
1236 structured_logger.log(
1237 level="WARNING",
1238 message="Tool creation failed due to name conflict",
1239 event_type="tool_name_conflict",
1240 component="tool_service",
1241 user_id=created_by,
1242 user_email=owner_email,
1243 custom_fields={
1244 "tool_name": tool.name,
1245 "visibility": visibility,
1246 },
1247 db=db,
1248 )
1249 raise tnce
1250 except Exception as e:
1251 db.rollback()
1253 # Structured logging: Log generic tool creation failure
1254 structured_logger.log(
1255 level="ERROR",
1256 message="Tool creation failed",
1257 event_type="tool_creation_failed",
1258 component="tool_service",
1259 user_id=created_by,
1260 user_email=owner_email,
1261 error=e,
1262 custom_fields={
1263 "tool_name": tool.name,
1264 },
1265 db=db,
1266 )
1267 raise ToolError(f"Failed to register tool: {str(e)}")
1269 async def register_tools_bulk(
1270 self,
1271 db: Session,
1272 tools: List[ToolCreate],
1273 created_by: Optional[str] = None,
1274 created_from_ip: Optional[str] = None,
1275 created_via: Optional[str] = None,
1276 created_user_agent: Optional[str] = None,
1277 import_batch_id: Optional[str] = None,
1278 federation_source: Optional[str] = None,
1279 team_id: Optional[str] = None,
1280 owner_email: Optional[str] = None,
1281 visibility: Optional[str] = "public",
1282 conflict_strategy: str = "skip",
1283 ) -> Dict[str, Any]:
1284 """Register multiple tools in bulk with a single commit.
1286 This method provides significant performance improvements over individual
1287 tool registration by:
1288 - Using db.add_all() instead of individual db.add() calls
1289 - Performing a single commit for all tools
1290 - Batch conflict detection
1291 - Chunking for very large imports (>500 items)
1293 Args:
1294 db: Database session
1295 tools: List of tool creation schemas
1296 created_by: Username who created these tools
1297 created_from_ip: IP address of creator
1298 created_via: Creation method (ui, api, import, federation)
1299 created_user_agent: User agent of creation request
1300 import_batch_id: UUID for bulk import operations
1301 federation_source: Source gateway for federated tools
1302 team_id: Team ID to assign the tools to
1303 owner_email: Email of the user who owns these tools
1304 visibility: Tool visibility level (private, team, public)
1305 conflict_strategy: How to handle conflicts (skip, update, rename, fail)
1307 Returns:
1308 Dict with statistics:
1309 - created: Number of tools created
1310 - updated: Number of tools updated
1311 - skipped: Number of tools skipped
1312 - failed: Number of tools that failed
1313 - errors: List of error messages
1315 Raises:
1316 ToolError: If bulk registration fails critically
1318 Examples:
1319 >>> from mcpgateway.services.tool_service import ToolService
1320 >>> from unittest.mock import MagicMock
1321 >>> service = ToolService()
1322 >>> db = MagicMock()
1323 >>> tools = [MagicMock(), MagicMock()]
1324 >>> import asyncio
1325 >>> try:
1326 ... result = asyncio.run(service.register_tools_bulk(db, tools))
1327 ... except Exception:
1328 ... pass
1329 """
1330 if not tools:
1331 return {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1333 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1335 # Process in chunks to avoid memory issues and SQLite parameter limits
1336 chunk_size = 500
1338 for chunk_start in range(0, len(tools), chunk_size):
1339 chunk = tools[chunk_start : chunk_start + chunk_size]
1340 chunk_stats = self._process_tool_chunk(
1341 db=db,
1342 chunk=chunk,
1343 conflict_strategy=conflict_strategy,
1344 visibility=visibility,
1345 team_id=team_id,
1346 owner_email=owner_email,
1347 created_by=created_by,
1348 created_from_ip=created_from_ip,
1349 created_via=created_via,
1350 created_user_agent=created_user_agent,
1351 import_batch_id=import_batch_id,
1352 federation_source=federation_source,
1353 )
1355 # Aggregate stats
1356 for key, value in chunk_stats.items():
1357 if key == "errors":
1358 stats[key].extend(value)
1359 else:
1360 stats[key] += value
1362 if chunk_stats["created"] or chunk_stats["updated"]:
1363 cache = _get_registry_cache()
1364 await cache.invalidate_tools()
1365 tool_lookup_cache = _get_tool_lookup_cache()
1366 tool_name_map: Dict[str, Optional[str]] = {}
1367 for tool in chunk:
1368 name = getattr(tool, "name", None)
1369 if not name:
1370 continue
1371 gateway_id = getattr(tool, "gateway_id", None)
1372 tool_name_map[name] = str(gateway_id) if gateway_id else tool_name_map.get(name)
1373 for tool_name, gateway_id in tool_name_map.items():
1374 await tool_lookup_cache.invalidate(tool_name, gateway_id=gateway_id)
1375 # Also invalidate tags cache since tool tags may have changed
1376 # First-Party
1377 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1379 await admin_stats_cache.invalidate_tags()
1381 return stats
1383 def _process_tool_chunk(
1384 self,
1385 db: Session,
1386 chunk: List[ToolCreate],
1387 conflict_strategy: str,
1388 visibility: str,
1389 team_id: Optional[int],
1390 owner_email: Optional[str],
1391 created_by: str,
1392 created_from_ip: Optional[str],
1393 created_via: Optional[str],
1394 created_user_agent: Optional[str],
1395 import_batch_id: Optional[str],
1396 federation_source: Optional[str],
1397 ) -> dict:
1398 """Process a chunk of tools for bulk import.
1400 Args:
1401 db: The SQLAlchemy database session.
1402 chunk: List of ToolCreate objects to process.
1403 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail").
1404 visibility: Tool visibility level ("public", "team", or "private").
1405 team_id: Team ID for team-scoped tools.
1406 owner_email: Email of the tool owner.
1407 created_by: Email of the user creating the tools.
1408 created_from_ip: IP address of the request origin.
1409 created_via: Source of the creation (e.g., "api", "ui").
1410 created_user_agent: User agent string from the request.
1411 import_batch_id: Batch identifier for bulk imports.
1412 federation_source: Source identifier for federated tools.
1414 Returns:
1415 dict: Statistics dictionary with keys "created", "updated", "skipped", "failed", and "errors".
1416 """
1417 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
1419 try:
1420 # Batch check for existing tools to detect conflicts
1421 tool_names = [tool.name for tool in chunk]
1423 if visibility.lower() == "public":
1424 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "public")
1425 elif visibility.lower() == "team" and team_id:
1426 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "team", DbTool.team_id == team_id)
1427 else:
1428 # Private tools - check by owner
1429 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "private", DbTool.owner_email == (owner_email or created_by))
1431 existing_tools = db.execute(existing_tools_query).scalars().all()
1432 existing_tools_map = {tool.name: tool for tool in existing_tools}
1434 tools_to_add = []
1435 tools_to_update = []
1437 for tool in chunk:
1438 result = self._process_single_tool_for_bulk(
1439 tool=tool,
1440 existing_tools_map=existing_tools_map,
1441 conflict_strategy=conflict_strategy,
1442 visibility=visibility,
1443 team_id=team_id,
1444 owner_email=owner_email,
1445 created_by=created_by,
1446 created_from_ip=created_from_ip,
1447 created_via=created_via,
1448 created_user_agent=created_user_agent,
1449 import_batch_id=import_batch_id,
1450 federation_source=federation_source,
1451 )
1453 if result["status"] == "add":
1454 tools_to_add.append(result["tool"])
1455 stats["created"] += 1
1456 elif result["status"] == "update":
1457 tools_to_update.append(result["tool"])
1458 stats["updated"] += 1
1459 elif result["status"] == "skip":
1460 stats["skipped"] += 1
1461 elif result["status"] == "fail":
1462 stats["failed"] += 1
1463 stats["errors"].append(result["error"])
1465 # Bulk add new tools
1466 if tools_to_add:
1467 db.add_all(tools_to_add)
1469 # Commit the chunk
1470 db.commit()
1472 # Refresh tools for notifications and audit trail
1473 for db_tool in tools_to_add:
1474 db.refresh(db_tool)
1475 # Notify subscribers (sync call in async context handled by caller)
1477 # Log bulk audit trail entry
1478 if tools_to_add or tools_to_update:
1479 audit_trail.log_action(
1480 user_id=created_by or "system",
1481 action="bulk_create_tools" if tools_to_add else "bulk_update_tools",
1482 resource_type="tool",
1483 resource_id=None,
1484 details={"count": len(tools_to_add) + len(tools_to_update), "import_batch_id": import_batch_id},
1485 db=db,
1486 )
1488 except Exception as e:
1489 db.rollback()
1490 logger.error(f"Failed to process tool chunk: {str(e)}")
1491 stats["failed"] += len(chunk)
1492 stats["errors"].append(f"Chunk processing failed: {str(e)}")
1494 return stats
1496 def _process_single_tool_for_bulk(
1497 self,
1498 tool: ToolCreate,
1499 existing_tools_map: dict,
1500 conflict_strategy: str,
1501 visibility: str,
1502 team_id: Optional[int],
1503 owner_email: Optional[str],
1504 created_by: str,
1505 created_from_ip: Optional[str],
1506 created_via: Optional[str],
1507 created_user_agent: Optional[str],
1508 import_batch_id: Optional[str],
1509 federation_source: Optional[str],
1510 ) -> dict:
1511 """Process a single tool for bulk import.
1513 Args:
1514 tool: ToolCreate object to process.
1515 existing_tools_map: Dictionary mapping tool names to existing DbTool objects.
1516 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail").
1517 visibility: Tool visibility level ("public", "team", or "private").
1518 team_id: Team ID for team-scoped tools.
1519 owner_email: Email of the tool owner.
1520 created_by: Email of the user creating the tool.
1521 created_from_ip: IP address of the request origin.
1522 created_via: Source of the creation (e.g., "api", "ui").
1523 created_user_agent: User agent string from the request.
1524 import_batch_id: Batch identifier for bulk imports.
1525 federation_source: Source identifier for federated tools.
1527 Returns:
1528 dict: Result dictionary with "status" key ("add", "update", "skip", or "fail")
1529 and either "tool" (DbTool object) or "error" (error message).
1530 """
1531 try:
1532 # Extract auth information
1533 if tool.auth is None:
1534 auth_type = None
1535 auth_value = None
1536 else:
1537 auth_type = tool.auth.auth_type
1538 auth_value = tool.auth.auth_value
1540 # Use provided parameters or schema values
1541 tool_team_id = team_id if team_id is not None else getattr(tool, "team_id", None)
1542 tool_owner_email = owner_email or getattr(tool, "owner_email", None) or created_by
1543 tool_visibility = visibility if visibility is not None else getattr(tool, "visibility", "public")
1545 existing_tool = existing_tools_map.get(tool.name)
1547 if existing_tool:
1548 # Handle conflict based on strategy
1549 if conflict_strategy == "skip":
1550 return {"status": "skip"}
1551 if conflict_strategy == "update":
1552 # Update existing tool
1553 existing_tool.display_name = tool.displayName or tool.name
1554 existing_tool.url = str(tool.url)
1555 existing_tool.description = tool.description
1556 existing_tool.integration_type = tool.integration_type
1557 existing_tool.request_type = tool.request_type
1558 existing_tool.headers = tool.headers
1559 existing_tool.input_schema = tool.input_schema
1560 existing_tool.output_schema = tool.output_schema
1561 existing_tool.annotations = tool.annotations
1562 existing_tool.jsonpath_filter = tool.jsonpath_filter
1563 existing_tool.auth_type = auth_type
1564 existing_tool.auth_value = auth_value
1565 existing_tool.tags = tool.tags or []
1566 existing_tool.modified_by = created_by
1567 existing_tool.modified_from_ip = created_from_ip
1568 existing_tool.modified_via = created_via
1569 existing_tool.modified_user_agent = created_user_agent
1570 existing_tool.updated_at = datetime.now(timezone.utc)
1571 existing_tool.version = (existing_tool.version or 1) + 1
1573 # Update REST-specific fields if applicable
1574 if tool.integration_type == "REST":
1575 existing_tool.base_url = tool.base_url
1576 existing_tool.path_template = tool.path_template
1577 existing_tool.query_mapping = tool.query_mapping
1578 existing_tool.header_mapping = tool.header_mapping
1579 existing_tool.timeout_ms = tool.timeout_ms
1580 existing_tool.expose_passthrough = tool.expose_passthrough if tool.expose_passthrough is not None else True
1581 existing_tool.allowlist = tool.allowlist
1582 existing_tool.plugin_chain_pre = tool.plugin_chain_pre
1583 existing_tool.plugin_chain_post = tool.plugin_chain_post
1585 return {"status": "update", "tool": existing_tool}
1587 if conflict_strategy == "rename":
1588 # Create with renamed tool
1589 new_name = f"{tool.name}_imported_{int(datetime.now().timestamp())}"
1590 db_tool = self._create_tool_object(
1591 tool,
1592 new_name,
1593 auth_type,
1594 auth_value,
1595 tool_team_id,
1596 tool_owner_email,
1597 tool_visibility,
1598 created_by,
1599 created_from_ip,
1600 created_via,
1601 created_user_agent,
1602 import_batch_id,
1603 federation_source,
1604 )
1605 return {"status": "add", "tool": db_tool}
1607 if conflict_strategy == "fail":
1608 return {"status": "fail", "error": f"Tool name conflict: {tool.name}"}
1610 # Create new tool
1611 db_tool = self._create_tool_object(
1612 tool,
1613 tool.name,
1614 auth_type,
1615 auth_value,
1616 tool_team_id,
1617 tool_owner_email,
1618 tool_visibility,
1619 created_by,
1620 created_from_ip,
1621 created_via,
1622 created_user_agent,
1623 import_batch_id,
1624 federation_source,
1625 )
1626 return {"status": "add", "tool": db_tool}
1628 except Exception as e:
1629 logger.warning(f"Failed to process tool {tool.name} in bulk operation: {str(e)}")
1630 return {"status": "fail", "error": f"Failed to process tool {tool.name}: {str(e)}"}
1632 def _create_tool_object(
1633 self,
1634 tool: ToolCreate,
1635 name: str,
1636 auth_type: Optional[str],
1637 auth_value: Optional[str],
1638 tool_team_id: Optional[int],
1639 tool_owner_email: Optional[str],
1640 tool_visibility: str,
1641 created_by: str,
1642 created_from_ip: Optional[str],
1643 created_via: Optional[str],
1644 created_user_agent: Optional[str],
1645 import_batch_id: Optional[str],
1646 federation_source: Optional[str],
1647 ) -> DbTool:
1648 """Create a DbTool object from ToolCreate schema.
1650 Args:
1651 tool: ToolCreate schema object containing tool data.
1652 name: Name of the tool.
1653 auth_type: Authentication type for the tool.
1654 auth_value: Authentication value/credentials for the tool.
1655 tool_team_id: Team ID for team-scoped tools.
1656 tool_owner_email: Email of the tool owner.
1657 tool_visibility: Tool visibility level ("public", "team", or "private").
1658 created_by: Email of the user creating the tool.
1659 created_from_ip: IP address of the request origin.
1660 created_via: Source of the creation (e.g., "api", "ui").
1661 created_user_agent: User agent string from the request.
1662 import_batch_id: Batch identifier for bulk imports.
1663 federation_source: Source identifier for federated tools.
1665 Returns:
1666 DbTool: Database model instance ready to be added to the session.
1667 """
1668 return DbTool(
1669 original_name=name,
1670 custom_name=name,
1671 custom_name_slug=slugify(name),
1672 display_name=tool.displayName or name,
1673 url=str(tool.url),
1674 description=tool.description,
1675 integration_type=tool.integration_type,
1676 request_type=tool.request_type,
1677 headers=tool.headers,
1678 input_schema=tool.input_schema,
1679 output_schema=tool.output_schema,
1680 annotations=tool.annotations,
1681 jsonpath_filter=tool.jsonpath_filter,
1682 auth_type=auth_type,
1683 auth_value=auth_value,
1684 gateway_id=tool.gateway_id,
1685 tags=tool.tags or [],
1686 created_by=created_by,
1687 created_from_ip=created_from_ip,
1688 created_via=created_via,
1689 created_user_agent=created_user_agent,
1690 import_batch_id=import_batch_id,
1691 federation_source=federation_source,
1692 version=1,
1693 team_id=tool_team_id,
1694 owner_email=tool_owner_email,
1695 visibility=tool_visibility,
1696 base_url=tool.base_url if tool.integration_type == "REST" else None,
1697 path_template=tool.path_template if tool.integration_type == "REST" else None,
1698 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None,
1699 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None,
1700 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None,
1701 expose_passthrough=((tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None),
1702 allowlist=tool.allowlist if tool.integration_type == "REST" else None,
1703 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None,
1704 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None,
1705 )
1707 async def list_tools(
1708 self,
1709 db: Session,
1710 include_inactive: bool = False,
1711 cursor: Optional[str] = None,
1712 tags: Optional[List[str]] = None,
1713 gateway_id: Optional[str] = None,
1714 limit: Optional[int] = None,
1715 page: Optional[int] = None,
1716 per_page: Optional[int] = None,
1717 user_email: Optional[str] = None,
1718 team_id: Optional[str] = None,
1719 visibility: Optional[str] = None,
1720 token_teams: Optional[List[str]] = None,
1721 _request_headers: Optional[Dict[str, str]] = None,
1722 requesting_user_email: Optional[str] = None,
1723 requesting_user_is_admin: bool = False,
1724 requesting_user_team_roles: Optional[Dict[str, str]] = None,
1725 ) -> Union[tuple[List[ToolRead], Optional[str]], Dict[str, Any]]:
1726 """
1727 Retrieve a list of registered tools from the database with pagination support.
1729 Args:
1730 db (Session): The SQLAlchemy database session.
1731 include_inactive (bool): If True, include inactive tools in the result.
1732 Defaults to False.
1733 cursor (Optional[str], optional): An opaque cursor token for pagination.
1734 Opaque base64-encoded string containing last item's ID.
1735 tags (Optional[List[str]]): Filter tools by tags. If provided, only tools with at least one matching tag will be returned.
1736 gateway_id (Optional[str]): Filter tools by gateway ID. Accepts the literal value 'null' to match NULL gateway_id.
1737 limit (Optional[int]): Maximum number of tools to return. Use 0 for all tools (no limit).
1738 If not specified, uses pagination_default_page_size.
1739 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1740 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1741 user_email (Optional[str]): User email for team-based access control. If None, no access control is applied.
1742 team_id (Optional[str]): Filter by specific team ID. Requires user_email for access validation.
1743 visibility (Optional[str]): Filter by visibility (private, team, public).
1744 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API token access
1745 where the token scope should be respected instead of the user's full team memberships.
1746 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
1747 Currently unused but kept for API consistency. Defaults to None.
1748 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
1749 requesting_user_is_admin (bool): Whether the requester is an admin.
1750 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
1752 Returns:
1753 tuple[List[ToolRead], Optional[str]]: Tuple containing:
1754 - List of tools for current page
1755 - Next cursor token if more results exist, None otherwise
1757 Examples:
1758 >>> from mcpgateway.services.tool_service import ToolService
1759 >>> from unittest.mock import MagicMock
1760 >>> service = ToolService()
1761 >>> db = MagicMock()
1762 >>> tool_read = MagicMock()
1763 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read)
1764 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
1765 >>> import asyncio
1766 >>> tools, next_cursor = asyncio.run(service.list_tools(db))
1767 >>> isinstance(tools, list)
1768 True
1769 """
1770 # Check cache for first page only (cursor=None)
1771 # Skip caching when:
1772 # - user_email is provided (team-filtered results are user-specific)
1773 # - token_teams is set (scoped access, e.g., public-only or team-scoped tokens)
1774 # - page-based pagination is used
1775 # This prevents cache poisoning where admin results could leak to public-only requests
1776 cache = _get_registry_cache()
1777 filters_hash = None
1778 # Only use the cache when using the real converter. In unit tests we often patch
1779 # convert_tool_to_read() to exercise error handling, and a warm cache would bypass it.
1780 try:
1781 converter_is_default = self.convert_tool_to_read.__func__ is ToolService.convert_tool_to_read # type: ignore[attr-defined]
1782 except Exception:
1783 converter_is_default = False
1785 if cursor is None and user_email is None and token_teams is None and page is None and converter_is_default:
1786 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None, gateway_id=gateway_id, limit=limit)
1787 cached = await cache.get("tools", filters_hash)
1788 if cached is not None:
1789 # Reconstruct ToolRead objects from cached dicts
1790 cached_tools = [ToolRead.model_validate(t) for t in cached["tools"]]
1791 return (cached_tools, cached.get("next_cursor"))
1793 # Build base query with ordering and eager load gateway + email_team to avoid N+1
1794 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)).order_by(desc(DbTool.created_at), desc(DbTool.id))
1796 # Apply active/inactive filter
1797 if not include_inactive:
1798 query = query.where(DbTool.enabled)
1799 # Apply team-based access control if user_email is provided OR token_teams is explicitly set
1800 # This ensures unauthenticated requests with token_teams=[] only see public tools
1801 if user_email is not None or token_teams is not None: # empty-string user_email -> public-only filtering (secure default)
1802 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB
1803 if token_teams is not None:
1804 team_ids = token_teams
1805 elif user_email:
1806 team_service = TeamManagementService(db)
1807 user_teams = await team_service.get_user_teams(user_email)
1808 team_ids = [team.id for team in user_teams]
1809 else:
1810 team_ids = []
1812 # Check if this is a public-only token (empty teams array)
1813 # Public-only tokens can ONLY see public resources - no owner access
1814 is_public_only_token = token_teams is not None and len(token_teams) == 0
1816 if team_id:
1817 # User requesting specific team - verify access
1818 if team_id not in team_ids:
1819 return ([], None)
1820 access_conditions = [
1821 and_(DbTool.team_id == team_id, DbTool.visibility.in_(["team", "public"])),
1822 ]
1823 # Only include owner access for non-public-only tokens
1824 if not is_public_only_token and user_email:
1825 access_conditions.append(and_(DbTool.team_id == team_id, DbTool.owner_email == user_email))
1826 query = query.where(or_(*access_conditions))
1827 else:
1828 # General access: public tools + team tools (+ owner tools if not public-only token)
1829 access_conditions = [
1830 DbTool.visibility == "public",
1831 ]
1832 # Only include owner access for non-public-only tokens with user_email
1833 if not is_public_only_token and user_email:
1834 access_conditions.append(DbTool.owner_email == user_email)
1835 if team_ids:
1836 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
1837 query = query.where(or_(*access_conditions))
1839 if visibility:
1840 query = query.where(DbTool.visibility == visibility)
1842 # Add gateway_id filtering if provided
1843 if gateway_id:
1844 if gateway_id.lower() == "null":
1845 query = query.where(DbTool.gateway_id.is_(None))
1846 else:
1847 query = query.where(DbTool.gateway_id == gateway_id)
1849 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
1850 if tags:
1851 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True))
1853 # Use unified pagination helper - handles both page and cursor pagination
1854 pag_result = await unified_paginate(
1855 db=db,
1856 query=query,
1857 page=page,
1858 per_page=per_page,
1859 cursor=cursor,
1860 limit=limit,
1861 base_url="/admin/tools", # Used for page-based links
1862 query_params={"include_inactive": include_inactive} if include_inactive else {},
1863 )
1865 next_cursor = None
1866 # Extract servers based on pagination type
1867 if page is not None:
1868 # Page-based: pag_result is a dict
1869 tools_db = pag_result["data"]
1870 else:
1871 # Cursor-based: pag_result is a tuple
1872 tools_db, next_cursor = pag_result
1874 db.commit() # Release transaction to avoid idle-in-transaction
1876 # Convert to ToolRead (common for both pagination types)
1877 # Team names are loaded via joinedload(DbTool.email_team)
1878 result = []
1879 for s in tools_db:
1880 try:
1881 result.append(
1882 self.convert_tool_to_read(
1883 s,
1884 include_metrics=False,
1885 include_auth=False,
1886 requesting_user_email=requesting_user_email,
1887 requesting_user_is_admin=requesting_user_is_admin,
1888 requesting_user_team_roles=requesting_user_team_roles,
1889 )
1890 )
1891 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1892 logger.exception(f"Failed to convert tool {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
1893 # Continue with remaining tools instead of failing completely
1895 # Return appropriate format based on pagination type
1896 if page is not None:
1897 # Page-based format
1898 return {
1899 "data": result,
1900 "pagination": pag_result["pagination"],
1901 "links": pag_result["links"],
1902 }
1904 # Cursor-based format
1906 # Cache first page results - only for non-user-specific/non-scoped queries
1907 # Must match the same conditions as cache lookup to prevent cache poisoning
1908 if filters_hash is not None and cursor is None and user_email is None and token_teams is None and page is None and converter_is_default:
1909 try:
1910 cache_data = {"tools": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
1911 await cache.set("tools", cache_data, filters_hash)
1912 except AttributeError:
1913 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
1915 return (result, next_cursor)
1917 async def list_server_tools(
1918 self,
1919 db: Session,
1920 server_id: str,
1921 include_inactive: bool = False,
1922 include_metrics: bool = False,
1923 cursor: Optional[str] = None,
1924 user_email: Optional[str] = None,
1925 token_teams: Optional[List[str]] = None,
1926 _request_headers: Optional[Dict[str, str]] = None,
1927 requesting_user_email: Optional[str] = None,
1928 requesting_user_is_admin: bool = False,
1929 requesting_user_team_roles: Optional[Dict[str, str]] = None,
1930 ) -> List[ToolRead]:
1931 """
1932 Retrieve a list of registered tools from the database.
1934 Args:
1935 db (Session): The SQLAlchemy database session.
1936 server_id (str): Server ID
1937 include_inactive (bool): If True, include inactive tools in the result.
1938 Defaults to False.
1939 include_metrics (bool): If True, all tool metrics included in result otherwise null.
1940 Defaults to False.
1941 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
1942 this parameter is ignored. Defaults to None.
1943 user_email (Optional[str]): User email for visibility filtering. If None, no filtering applied.
1944 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API
1945 token access where the token scope should be respected.
1946 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
1947 Currently unused but kept for API consistency. Defaults to None.
1948 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
1949 requesting_user_is_admin (bool): Whether the requester is an admin.
1950 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
1952 Returns:
1953 List[ToolRead]: A list of registered tools represented as ToolRead objects.
1955 Examples:
1956 >>> from mcpgateway.services.tool_service import ToolService
1957 >>> from unittest.mock import MagicMock
1958 >>> service = ToolService()
1959 >>> db = MagicMock()
1960 >>> tool_read = MagicMock()
1961 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read)
1962 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
1963 >>> import asyncio
1964 >>> result = asyncio.run(service.list_server_tools(db, 'server1'))
1965 >>> isinstance(result, list)
1966 True
1967 """
1969 if include_metrics:
1970 query = (
1971 select(DbTool)
1972 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
1973 .options(selectinload(DbTool.metrics))
1974 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
1975 .where(server_tool_association.c.server_id == server_id)
1976 )
1977 else:
1978 query = (
1979 select(DbTool)
1980 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
1981 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id)
1982 .where(server_tool_association.c.server_id == server_id)
1983 )
1985 cursor = None # Placeholder for pagination; ignore for now
1986 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}")
1988 if not include_inactive:
1989 query = query.where(DbTool.enabled)
1991 # Add visibility filtering if user context OR token_teams provided
1992 # This ensures unauthenticated requests with token_teams=[] only see public tools
1993 if user_email is not None or token_teams is not None: # empty-string user_email -> public-only filtering (secure default)
1994 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB
1995 if token_teams is not None:
1996 team_ids = token_teams
1997 elif user_email:
1998 team_service = TeamManagementService(db)
1999 user_teams = await team_service.get_user_teams(user_email)
2000 team_ids = [team.id for team in user_teams]
2001 else:
2002 team_ids = []
2004 # Check if this is a public-only token (empty teams array)
2005 # Public-only tokens can ONLY see public resources - no owner access
2006 is_public_only_token = token_teams is not None and len(token_teams) == 0
2008 access_conditions = [
2009 DbTool.visibility == "public",
2010 ]
2011 # Only include owner access for non-public-only tokens with user_email
2012 if not is_public_only_token and user_email:
2013 access_conditions.append(DbTool.owner_email == user_email)
2014 if team_ids:
2015 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2016 query = query.where(or_(*access_conditions))
2018 # Execute the query - team names are loaded via joinedload(DbTool.email_team)
2019 tools = db.execute(query).scalars().all()
2021 db.commit() # Release transaction to avoid idle-in-transaction
2023 result = []
2024 for tool in tools:
2025 try:
2026 result.append(
2027 self.convert_tool_to_read(
2028 tool,
2029 include_metrics=include_metrics,
2030 include_auth=False,
2031 requesting_user_email=requesting_user_email,
2032 requesting_user_is_admin=requesting_user_is_admin,
2033 requesting_user_team_roles=requesting_user_team_roles,
2034 )
2035 )
2036 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2037 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}")
2038 # Continue with remaining tools instead of failing completely
2040 return result
2042 async def list_tools_for_user(
2043 self,
2044 db: Session,
2045 user_email: str,
2046 team_id: Optional[str] = None,
2047 visibility: Optional[str] = None,
2048 include_inactive: bool = False,
2049 _skip: int = 0,
2050 _limit: int = 100,
2051 *,
2052 cursor: Optional[str] = None,
2053 gateway_id: Optional[str] = None,
2054 tags: Optional[List[str]] = None,
2055 limit: Optional[int] = None,
2056 ) -> tuple[List[ToolRead], Optional[str]]:
2057 """
2058 DEPRECATED: Use list_tools() with user_email parameter instead.
2060 List tools user has access to with team filtering and cursor pagination.
2062 This method is maintained for backward compatibility but is no longer used.
2063 New code should call list_tools() with user_email, team_id, and visibility parameters.
2065 Args:
2066 db: Database session
2067 user_email: Email of the user requesting tools
2068 team_id: Optional team ID to filter by specific team
2069 visibility: Optional visibility filter (private, team, public)
2070 include_inactive: Whether to include inactive tools
2071 _skip: Number of tools to skip for pagination (deprecated)
2072 _limit: Maximum number of tools to return (deprecated)
2073 cursor: Opaque cursor token for pagination
2074 gateway_id: Filter tools by gateway ID. Accepts literal 'null' for NULL gateway_id.
2075 tags: Filter tools by tags (match any)
2076 limit: Maximum number of tools to return. Use 0 for all tools (no limit).
2077 If not specified, uses pagination_default_page_size.
2079 Returns:
2080 tuple[List[ToolRead], Optional[str]]: Tools the user has access to and optional next_cursor
2081 """
2082 # Determine page size based on limit parameter
2083 # limit=None: use default, limit=0: no limit (all), limit>0: use specified (capped)
2084 if limit is None:
2085 page_size = settings.pagination_default_page_size
2086 elif limit == 0:
2087 page_size = None # No limit - fetch all
2088 else:
2089 page_size = min(limit, settings.pagination_max_page_size)
2091 # Decode cursor to get last_id if provided
2092 last_id = None
2093 if cursor:
2094 try:
2095 cursor_data = decode_cursor(cursor)
2096 last_id = cursor_data.get("id")
2097 logger.debug(f"Decoded cursor: last_id={last_id}")
2098 except ValueError as e:
2099 logger.warning(f"Invalid cursor, ignoring: {e}")
2101 # Build query following existing patterns from list_tools()
2102 team_service = TeamManagementService(db)
2103 user_teams = await team_service.get_user_teams(user_email)
2104 team_ids = [team.id for team in user_teams]
2106 # Eager load gateway and email_team to avoid N+1 when accessing gateway_slug and team name
2107 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team))
2109 # Apply active/inactive filter
2110 if not include_inactive:
2111 query = query.where(DbTool.enabled.is_(True))
2113 if team_id:
2114 if team_id not in team_ids:
2115 return ([], None) # No access to team
2117 access_conditions = [
2118 and_(DbTool.team_id == team_id, DbTool.visibility.in_(["team", "public"])),
2119 and_(DbTool.team_id == team_id, DbTool.owner_email == user_email),
2120 ]
2121 query = query.where(or_(*access_conditions))
2122 else:
2123 access_conditions = [
2124 DbTool.owner_email == user_email,
2125 DbTool.visibility == "public",
2126 ]
2127 if team_ids:
2128 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"])))
2130 query = query.where(or_(*access_conditions))
2132 # Apply visibility filter if specified
2133 if visibility:
2134 query = query.where(DbTool.visibility == visibility)
2136 if gateway_id:
2137 if gateway_id.lower() == "null":
2138 query = query.where(DbTool.gateway_id.is_(None))
2139 else:
2140 query = query.where(DbTool.gateway_id == gateway_id)
2142 if tags:
2143 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True))
2145 # Apply cursor filter (WHERE id > last_id)
2146 if last_id:
2147 query = query.where(DbTool.id > last_id)
2149 # Execute query - team names are loaded via joinedload(DbTool.email_team)
2150 if page_size is not None:
2151 tools = db.execute(query.limit(page_size + 1)).scalars().all()
2152 else:
2153 tools = db.execute(query).scalars().all()
2155 db.commit() # Release transaction to avoid idle-in-transaction
2157 # Check if there are more results (only when paginating)
2158 has_more = page_size is not None and len(tools) > page_size
2159 if has_more:
2160 tools = tools[:page_size]
2162 # Convert to ToolRead objects
2163 result = []
2164 for tool in tools:
2165 try:
2166 result.append(self.convert_tool_to_read(tool, include_metrics=False, include_auth=False, requesting_user_email=user_email, requesting_user_is_admin=False))
2167 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
2168 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}")
2169 # Continue with remaining tools instead of failing completely
2171 next_cursor = None
2172 # Generate cursor if there are more results (cursor-based pagination)
2173 if has_more and tools:
2174 last_tool = tools[-1]
2175 next_cursor = encode_cursor({"created_at": last_tool.created_at.isoformat(), "id": last_tool.id})
2177 return (result, next_cursor)
2179 async def get_tool(
2180 self,
2181 db: Session,
2182 tool_id: str,
2183 requesting_user_email: Optional[str] = None,
2184 requesting_user_is_admin: bool = False,
2185 requesting_user_team_roles: Optional[Dict[str, str]] = None,
2186 ) -> ToolRead:
2187 """
2188 Retrieve a tool by its ID.
2190 Args:
2191 db (Session): The SQLAlchemy database session.
2192 tool_id (str): The unique identifier of the tool.
2193 requesting_user_email (Optional[str]): Email of the requesting user for header masking.
2194 requesting_user_is_admin (bool): Whether the requester is an admin.
2195 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester.
2197 Returns:
2198 ToolRead: The tool object.
2200 Raises:
2201 ToolNotFoundError: If the tool is not found.
2203 Examples:
2204 >>> from mcpgateway.services.tool_service import ToolService
2205 >>> from unittest.mock import MagicMock
2206 >>> service = ToolService()
2207 >>> db = MagicMock()
2208 >>> tool = MagicMock()
2209 >>> db.get.return_value = tool
2210 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
2211 >>> import asyncio
2212 >>> asyncio.run(service.get_tool(db, 'tool_id'))
2213 'tool_read'
2214 """
2215 tool = db.get(DbTool, tool_id)
2216 if not tool:
2217 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2219 tool_read = self.convert_tool_to_read(
2220 tool,
2221 requesting_user_email=requesting_user_email,
2222 requesting_user_is_admin=requesting_user_is_admin,
2223 requesting_user_team_roles=requesting_user_team_roles,
2224 )
2226 structured_logger.log(
2227 level="INFO",
2228 message="Tool retrieved successfully",
2229 event_type="tool_viewed",
2230 component="tool_service",
2231 team_id=getattr(tool, "team_id", None),
2232 resource_type="tool",
2233 resource_id=str(tool.id),
2234 custom_fields={
2235 "tool_name": tool.name,
2236 "include_metrics": bool(getattr(tool_read, "metrics", {})),
2237 },
2238 db=db,
2239 )
2241 return tool_read
2243 async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
2244 """
2245 Delete a tool by its ID.
2247 Args:
2248 db (Session): The SQLAlchemy database session.
2249 tool_id (str): The unique identifier of the tool.
2250 user_email (Optional[str]): Email of user performing delete (for ownership check).
2251 purge_metrics (bool): If True, delete raw + rollup metrics for this tool.
2253 Raises:
2254 ToolNotFoundError: If the tool is not found.
2255 PermissionError: If user doesn't own the tool.
2256 ToolError: For other deletion errors.
2258 Examples:
2259 >>> from mcpgateway.services.tool_service import ToolService
2260 >>> from unittest.mock import MagicMock, AsyncMock
2261 >>> service = ToolService()
2262 >>> db = MagicMock()
2263 >>> tool = MagicMock()
2264 >>> db.get.return_value = tool
2265 >>> db.delete = MagicMock()
2266 >>> db.commit = MagicMock()
2267 >>> service._notify_tool_deleted = AsyncMock()
2268 >>> import asyncio
2269 >>> asyncio.run(service.delete_tool(db, 'tool_id'))
2270 """
2271 try:
2272 tool = db.get(DbTool, tool_id)
2273 if not tool:
2274 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2276 # Check ownership if user_email provided
2277 if user_email:
2278 # First-Party
2279 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2281 permission_service = PermissionService(db)
2282 if not await permission_service.check_resource_ownership(user_email, tool):
2283 raise PermissionError("Only the owner can delete this tool")
2285 tool_info = {"id": tool.id, "name": tool.name}
2286 tool_name = tool.name
2287 tool_team_id = tool.team_id
2289 if purge_metrics:
2290 with pause_rollup_during_purge(reason=f"purge_tool:{tool_id}"):
2291 delete_metrics_in_batches(db, ToolMetric, ToolMetric.tool_id, tool_id)
2292 delete_metrics_in_batches(db, ToolMetricsHourly, ToolMetricsHourly.tool_id, tool_id)
2294 # Use DELETE with rowcount check for database-agnostic atomic delete
2295 # (RETURNING is not supported on MySQL/MariaDB)
2296 stmt = delete(DbTool).where(DbTool.id == tool_id)
2297 result = db.execute(stmt)
2298 if result.rowcount == 0:
2299 # Tool was already deleted by another concurrent request
2300 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2302 db.commit()
2303 await self._notify_tool_deleted(tool_info)
2304 logger.info(f"Permanently deleted tool: {tool_info['name']}")
2306 # Structured logging: Audit trail for tool deletion
2307 audit_trail.log_action(
2308 user_id=user_email or "system",
2309 action="delete_tool",
2310 resource_type="tool",
2311 resource_id=tool_info["id"],
2312 resource_name=tool_name,
2313 user_email=user_email,
2314 team_id=tool_team_id,
2315 old_values={
2316 "name": tool_name,
2317 },
2318 db=db,
2319 )
2321 # Structured logging: Log successful tool deletion
2322 structured_logger.log(
2323 level="INFO",
2324 message="Tool deleted successfully",
2325 event_type="tool_deleted",
2326 component="tool_service",
2327 user_email=user_email,
2328 team_id=tool_team_id,
2329 resource_type="tool",
2330 resource_id=tool_info["id"],
2331 custom_fields={
2332 "tool_name": tool_name,
2333 "purge_metrics": purge_metrics,
2334 },
2335 db=db,
2336 )
2338 # Invalidate cache after successful deletion
2339 cache = _get_registry_cache()
2340 await cache.invalidate_tools()
2341 tool_lookup_cache = _get_tool_lookup_cache()
2342 await tool_lookup_cache.invalidate(tool_name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
2343 # Also invalidate tags cache since tool tags may have changed
2344 # First-Party
2345 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2347 await admin_stats_cache.invalidate_tags()
2348 # Invalidate top performers cache
2349 # First-Party
2350 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
2352 metrics_cache.invalidate_prefix("top_tools:")
2353 metrics_cache.invalidate("tools")
2354 except PermissionError as pe:
2355 db.rollback()
2357 # Structured logging: Log permission error
2358 structured_logger.log(
2359 level="WARNING",
2360 message="Tool deletion failed due to permission error",
2361 event_type="tool_delete_permission_denied",
2362 component="tool_service",
2363 user_email=user_email,
2364 resource_type="tool",
2365 resource_id=tool_id,
2366 error=pe,
2367 db=db,
2368 )
2369 raise
2370 except Exception as e:
2371 db.rollback()
2373 # Structured logging: Log generic tool deletion failure
2374 structured_logger.log(
2375 level="ERROR",
2376 message="Tool deletion failed",
2377 event_type="tool_deletion_failed",
2378 component="tool_service",
2379 user_email=user_email,
2380 resource_type="tool",
2381 resource_id=tool_id,
2382 error=e,
2383 db=db,
2384 )
2385 raise ToolError(f"Failed to delete tool: {str(e)}")
2387 async def set_tool_state(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead:
2388 """
2389 Set the activation status of a tool.
2391 Args:
2392 db (Session): The SQLAlchemy database session.
2393 tool_id (str): The unique identifier of the tool.
2394 activate (bool): True to activate, False to deactivate.
2395 reachable (bool): True if the tool is reachable.
2396 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2397 skip_cache_invalidation: If True, skip cache invalidation (used for batch operations).
2399 Returns:
2400 ToolRead: The updated tool object.
2402 Raises:
2403 ToolNotFoundError: If the tool is not found.
2404 ToolLockConflictError: If the tool row is locked by another transaction.
2405 ToolError: For other errors.
2406 PermissionError: If user doesn't own the agent.
2408 Examples:
2409 >>> from mcpgateway.services.tool_service import ToolService
2410 >>> from unittest.mock import MagicMock, AsyncMock
2411 >>> from mcpgateway.schemas import ToolRead
2412 >>> service = ToolService()
2413 >>> db = MagicMock()
2414 >>> tool = MagicMock()
2415 >>> db.get.return_value = tool
2416 >>> db.commit = MagicMock()
2417 >>> db.refresh = MagicMock()
2418 >>> service._notify_tool_activated = AsyncMock()
2419 >>> service._notify_tool_deactivated = AsyncMock()
2420 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
2421 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
2422 >>> import asyncio
2423 >>> asyncio.run(service.set_tool_state(db, 'tool_id', True, True))
2424 'tool_read'
2425 """
2426 try:
2427 # Use nowait=True to fail fast if row is locked, preventing lock contention under high load
2428 try:
2429 tool = get_for_update(db, DbTool, tool_id, nowait=True)
2430 except OperationalError as lock_err:
2431 # Row is locked by another transaction - fail fast with 409
2432 db.rollback()
2433 raise ToolLockConflictError(f"Tool {tool_id} is currently being modified by another request") from lock_err
2434 if not tool:
2435 raise ToolNotFoundError(f"Tool not found: {tool_id}")
2437 if user_email:
2438 # First-Party
2439 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2441 permission_service = PermissionService(db)
2442 if not await permission_service.check_resource_ownership(user_email, tool):
2443 raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool")
2445 is_activated = is_reachable = False
2446 if tool.enabled != activate:
2447 tool.enabled = activate
2448 is_activated = True
2450 if tool.reachable != reachable:
2451 tool.reachable = reachable
2452 is_reachable = True
2454 if is_activated or is_reachable:
2455 tool.updated_at = datetime.now(timezone.utc)
2457 db.commit()
2458 db.refresh(tool)
2460 # Invalidate cache after status change (skip for batch operations)
2461 if not skip_cache_invalidation:
2462 cache = _get_registry_cache()
2463 await cache.invalidate_tools()
2464 tool_lookup_cache = _get_tool_lookup_cache()
2465 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
2467 if not tool.enabled:
2468 # Inactive
2469 await self._notify_tool_deactivated(tool)
2470 elif tool.enabled and not tool.reachable:
2471 # Offline
2472 await self._notify_tool_offline(tool)
2473 else:
2474 # Active
2475 await self._notify_tool_activated(tool)
2477 logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}")
2479 # Structured logging: Audit trail for tool state change
2480 audit_trail.log_action(
2481 user_id=user_email or "system",
2482 action="set_tool_state",
2483 resource_type="tool",
2484 resource_id=tool.id,
2485 resource_name=tool.name,
2486 user_email=user_email,
2487 team_id=tool.team_id,
2488 new_values={
2489 "enabled": tool.enabled,
2490 "reachable": tool.reachable,
2491 },
2492 context={
2493 "action": "activate" if activate else "deactivate",
2494 },
2495 db=db,
2496 )
2498 # Structured logging: Log successful tool state change
2499 structured_logger.log(
2500 level="INFO",
2501 message=f"Tool {'activated' if activate else 'deactivated'} successfully",
2502 event_type="tool_state_changed",
2503 component="tool_service",
2504 user_email=user_email,
2505 team_id=tool.team_id,
2506 resource_type="tool",
2507 resource_id=tool.id,
2508 custom_fields={
2509 "tool_name": tool.name,
2510 "enabled": tool.enabled,
2511 "reachable": tool.reachable,
2512 },
2513 db=db,
2514 )
2516 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None))
2517 except PermissionError as e:
2518 # Structured logging: Log permission error
2519 structured_logger.log(
2520 level="WARNING",
2521 message="Tool state change failed due to permission error",
2522 event_type="tool_state_change_permission_denied",
2523 component="tool_service",
2524 user_email=user_email,
2525 resource_type="tool",
2526 resource_id=tool_id,
2527 error=e,
2528 db=db,
2529 )
2530 raise e
2531 except ToolLockConflictError:
2532 # Re-raise lock conflicts without wrapping - allows 409 response
2533 raise
2534 except ToolNotFoundError:
2535 # Re-raise not found without wrapping - allows 404 response
2536 raise
2537 except Exception as e:
2538 db.rollback()
2540 # Structured logging: Log generic tool state change failure
2541 structured_logger.log(
2542 level="ERROR",
2543 message="Tool state change failed",
2544 event_type="tool_state_change_failed",
2545 component="tool_service",
2546 user_email=user_email,
2547 resource_type="tool",
2548 resource_id=tool_id,
2549 error=e,
2550 db=db,
2551 )
2552 raise ToolError(f"Failed to set tool state: {str(e)}")
2554 async def invoke_tool(
2555 self,
2556 db: Session,
2557 name: str,
2558 arguments: Dict[str, Any],
2559 request_headers: Optional[Dict[str, str]] = None,
2560 app_user_email: Optional[str] = None,
2561 user_email: Optional[str] = None,
2562 token_teams: Optional[List[str]] = None,
2563 server_id: Optional[str] = None,
2564 plugin_context_table: Optional[PluginContextTable] = None,
2565 plugin_global_context: Optional[GlobalContext] = None,
2566 meta_data: Optional[Dict[str, Any]] = None,
2567 ) -> ToolResult:
2568 """
2569 Invoke a registered tool and record execution metrics.
2571 Args:
2572 db: Database session.
2573 name: Name of tool to invoke.
2574 arguments: Tool arguments.
2575 request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through.
2576 Defaults to None.
2577 app_user_email (Optional[str], optional): MCP Gateway user email for OAuth token retrieval.
2578 Required for OAuth-protected gateways.
2579 user_email (Optional[str], optional): User email for authorization checks.
2580 None = unauthenticated request.
2581 token_teams (Optional[List[str]], optional): Team IDs from JWT token for authorization.
2582 None = unrestricted admin, [] = public-only, [...] = team-scoped.
2583 server_id (Optional[str], optional): Virtual server ID for server scoping enforcement.
2584 If provided, tool must be attached to this server.
2585 plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing.
2586 plugin_global_context: Optional global context from middleware for consistency across hooks.
2587 meta_data: Optional metadata dictionary for additional context (e.g., request ID).
2589 Returns:
2590 Tool invocation result.
2592 Raises:
2593 ToolNotFoundError: If tool not found or access denied.
2594 ToolInvocationError: If invocation fails.
2595 ToolTimeoutError: If tool invocation times out.
2596 PluginViolationError: If plugin blocks tool invocation.
2597 PluginError: If encounters issue with plugin
2599 Examples:
2600 >>> # Note: This method requires extensive mocking of SQLAlchemy models,
2601 >>> # database relationships, and caching infrastructure, which is not
2602 >>> # suitable for doctests. See tests/unit/mcpgateway/services/test_tool_service.py
2603 >>> pass # doctest: +SKIP
2604 """
2605 # pylint: disable=comparison-with-callable
2606 logger.info(f"Invoking tool: {name} with arguments: {arguments.keys() if arguments else None} and headers: {request_headers.keys() if request_headers else None}")
2608 # ═══════════════════════════════════════════════════════════════════════════
2609 # PHASE 1: Fetch all required data with eager loading to minimize DB queries
2610 # ═══════════════════════════════════════════════════════════════════════════
2611 tool = None
2612 gateway = None
2613 tool_payload: Dict[str, Any] = {}
2614 gateway_payload: Optional[Dict[str, Any]] = None
2616 tool_lookup_cache = _get_tool_lookup_cache()
2617 cached_payload = await tool_lookup_cache.get(name) if tool_lookup_cache.enabled else None
2618 if cached_payload:
2619 status = cached_payload.get("status", "active")
2620 if status == "missing":
2621 raise ToolNotFoundError(f"Tool not found: {name}")
2622 if status == "inactive":
2623 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
2624 if status == "offline":
2625 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
2626 tool_payload = cached_payload.get("tool") or {}
2627 gateway_payload = cached_payload.get("gateway")
2629 if not tool_payload:
2630 # Eager load tool WITH gateway in single query to prevent lazy load N+1
2631 # Use a single query to avoid a race between separate enabled/inactive lookups.
2632 tool = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.name == name)).scalar_one_or_none()
2633 if not tool:
2634 raise ToolNotFoundError(f"Tool not found: {name}")
2635 if not tool.enabled:
2636 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
2638 if not tool.reachable:
2639 await tool_lookup_cache.set_negative(name, "offline")
2640 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
2642 gateway = tool.gateway
2643 cache_payload = self._build_tool_cache_payload(tool, gateway)
2644 tool_payload = cache_payload.get("tool") or {}
2645 gateway_payload = cache_payload.get("gateway")
2646 await tool_lookup_cache.set(name, cache_payload, gateway_id=tool_payload.get("gateway_id"))
2648 if tool_payload.get("enabled") is False:
2649 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
2650 if tool_payload.get("reachable") is False:
2651 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
2653 # ═══════════════════════════════════════════════════════════════════════════
2654 # SECURITY: Check tool access based on visibility and team membership
2655 # This enforces the same access control rules as list_tools()
2656 # ═══════════════════════════════════════════════════════════════════════════
2657 if not await self._check_tool_access(db, tool_payload, user_email, token_teams):
2658 # Don't reveal tool existence - return generic "not found"
2659 raise ToolNotFoundError(f"Tool not found: {name}")
2661 # ═══════════════════════════════════════════════════════════════════════════
2662 # SECURITY: Enforce server scoping if server_id is provided
2663 # Tool must be attached to the specified virtual server
2664 # ═══════════════════════════════════════════════════════════════════════════
2665 if server_id:
2666 tool_id_for_check = tool_payload.get("id")
2667 if not tool_id_for_check:
2668 # Cannot verify server membership without tool ID - deny access
2669 # This should not happen with properly cached tools, but fail safe
2670 logger.warning(f"Tool '{name}' has no ID in payload, cannot verify server membership")
2671 raise ToolNotFoundError(f"Tool not found: {name}")
2673 server_match = db.execute(
2674 select(server_tool_association.c.tool_id).where(
2675 server_tool_association.c.server_id == server_id,
2676 server_tool_association.c.tool_id == tool_id_for_check,
2677 )
2678 ).first()
2679 if not server_match:
2680 raise ToolNotFoundError(f"Tool not found: {name}")
2682 # Extract A2A-related data from annotations (will be used after db.close() if A2A tool)
2683 tool_annotations = tool_payload.get("annotations") or {}
2684 tool_integration_type = tool_payload.get("integration_type")
2686 # Get passthrough headers from in-memory cache (Issue #1715)
2687 # This eliminates 42,000+ redundant DB queries under load
2688 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
2690 # Access gateway now (already eager-loaded) to prevent later lazy load
2691 if tool is not None:
2692 gateway = tool.gateway
2694 # ═══════════════════════════════════════════════════════════════════════════
2695 # PHASE 2: Extract all needed data to local variables before network I/O
2696 # This allows us to release the DB session before making HTTP calls
2697 # ═══════════════════════════════════════════════════════════════════════════
2698 tool_id = tool_payload.get("id") or (str(tool.id) if tool else "")
2699 tool_name_original = tool_payload.get("original_name") or tool_payload.get("name") or name
2700 tool_name_computed = tool_payload.get("name") or name
2701 tool_url = tool_payload.get("url")
2702 tool_integration_type = tool_payload.get("integration_type")
2703 tool_request_type = tool_payload.get("request_type")
2704 tool_headers = dict(tool_payload.get("headers") or {})
2705 tool_auth_type = tool_payload.get("auth_type")
2706 tool_auth_value = tool_payload.get("auth_value")
2707 tool_jsonpath_filter = tool_payload.get("jsonpath_filter")
2708 tool_output_schema = tool_payload.get("output_schema")
2709 tool_oauth_config = tool_payload.get("oauth_config")
2710 tool_gateway_id = tool_payload.get("gateway_id")
2712 # Get effective timeout: per-tool timeout_ms (in seconds) or global fallback
2713 # timeout_ms is stored in milliseconds, convert to seconds
2714 tool_timeout_ms = tool_payload.get("timeout_ms")
2715 effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout
2717 # Save gateway existence as local boolean BEFORE db.close()
2718 # to avoid checking ORM object truthiness after session is closed
2719 has_gateway = gateway_payload is not None
2720 gateway_url = gateway_payload.get("url") if has_gateway else None
2721 gateway_name = gateway_payload.get("name") if has_gateway else None
2722 gateway_auth_type = gateway_payload.get("auth_type") if has_gateway else None
2723 gateway_auth_value = gateway_payload.get("auth_value") if has_gateway else None
2724 gateway_auth_query_params = gateway_payload.get("auth_query_params") if has_gateway else None
2725 gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway else None
2726 gateway_ca_cert = gateway_payload.get("ca_certificate") if has_gateway else None
2727 gateway_ca_cert_sig = gateway_payload.get("ca_certificate_sig") if has_gateway else None
2728 gateway_passthrough = gateway_payload.get("passthrough_headers") if has_gateway else None
2729 gateway_id_str = gateway_payload.get("id") if has_gateway else None
2731 # Decrypt and apply query param auth to URL if applicable
2732 gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None
2733 if gateway_auth_type == "query_param" and gateway_auth_query_params:
2734 # Decrypt the query param values
2735 gateway_auth_query_params_decrypted = {}
2736 for param_key, encrypted_value in gateway_auth_query_params.items():
2737 if encrypted_value:
2738 try:
2739 decrypted = decode_auth(encrypted_value)
2740 gateway_auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
2741 except Exception: # noqa: S110 - intentionally skip failed decryptions
2742 # Silently skip params that fail decryption (may be corrupted or use old key)
2743 logger.debug(f"Failed to decrypt query param '{param_key}' for tool invocation")
2744 # Apply query params to gateway URL
2745 if gateway_auth_query_params_decrypted and gateway_url:
2746 gateway_url = apply_query_param_auth(gateway_url, gateway_auth_query_params_decrypted)
2748 # Create Pydantic models for plugins BEFORE HTTP calls (use ORM objects while still valid)
2749 # This prevents lazy loading during HTTP calls
2750 tool_metadata: Optional[PydanticTool] = None
2751 gateway_metadata: Optional[PydanticGateway] = None
2752 if self._plugin_manager:
2753 if tool is not None:
2754 tool_metadata = PydanticTool.model_validate(tool)
2755 if has_gateway and gateway is not None:
2756 gateway_metadata = PydanticGateway.model_validate(gateway)
2757 else:
2758 tool_metadata = self._pydantic_tool_from_payload(tool_payload)
2759 if has_gateway and gateway_payload:
2760 gateway_metadata = self._pydantic_gateway_from_payload(gateway_payload)
2762 tool_for_validation = tool if tool is not None else SimpleNamespace(output_schema=tool_output_schema, name=tool_name_computed)
2764 # ═══════════════════════════════════════════════════════════════════════════
2765 # A2A Agent Data Extraction (must happen before db.close())
2766 # Extract all A2A agent data to local variables so HTTP call can happen after db.close()
2767 # ═══════════════════════════════════════════════════════════════════════════
2768 a2a_agent_name: Optional[str] = None
2769 a2a_agent_endpoint_url: Optional[str] = None
2770 a2a_agent_type: Optional[str] = None
2771 a2a_agent_protocol_version: Optional[str] = None
2772 a2a_agent_auth_type: Optional[str] = None
2773 a2a_agent_auth_value: Optional[str] = None
2774 a2a_agent_auth_query_params: Optional[Dict[str, str]] = None
2776 if tool_integration_type == "A2A" and "a2a_agent_id" in tool_annotations:
2777 a2a_agent_id = tool_annotations.get("a2a_agent_id")
2778 if not a2a_agent_id:
2779 raise ToolNotFoundError(f"A2A tool '{name}' missing agent ID in annotations")
2781 # Query for the A2A agent
2782 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == a2a_agent_id)
2783 a2a_agent = db.execute(agent_query).scalar_one_or_none()
2785 if not a2a_agent:
2786 raise ToolNotFoundError(f"A2A agent not found for tool '{name}' (agent ID: {a2a_agent_id})")
2788 if not a2a_agent.enabled:
2789 raise ToolNotFoundError(f"A2A agent '{a2a_agent.name}' is disabled")
2791 # Extract all needed data to local variables before db.close()
2792 a2a_agent_name = a2a_agent.name
2793 a2a_agent_endpoint_url = a2a_agent.endpoint_url
2794 a2a_agent_type = a2a_agent.agent_type
2795 a2a_agent_protocol_version = a2a_agent.protocol_version
2796 a2a_agent_auth_type = a2a_agent.auth_type
2797 a2a_agent_auth_value = a2a_agent.auth_value
2798 a2a_agent_auth_query_params = a2a_agent.auth_query_params
2800 # ═══════════════════════════════════════════════════════════════════════════
2801 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
2802 # This prevents connection pool exhaustion during slow upstream requests.
2803 # All needed data has been extracted to local variables above.
2804 # The session will be closed again by FastAPI's get_db() finally block (safe no-op).
2805 # ═══════════════════════════════════════════════════════════════════════════
2806 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats)
2807 db.close()
2809 # Plugin hook: tool pre-invoke
2810 # Use existing context_table from previous hooks if available
2811 context_table = plugin_context_table
2813 # Reuse existing global_context from middleware or create new one
2814 # IMPORTANT: Use local variables (tool_gateway_id) instead of ORM object access
2815 if plugin_global_context:
2816 global_context = plugin_global_context
2817 # Update server_id using local variable (not ORM access)
2818 if tool_gateway_id and isinstance(tool_gateway_id, str):
2819 global_context.server_id = tool_gateway_id
2820 # Propagate user email to global context for plugin access
2821 if not plugin_global_context.user and app_user_email and isinstance(app_user_email, str):
2822 global_context.user = app_user_email
2823 else:
2824 # Create new context (fallback when middleware didn't run)
2825 # Use correlation ID from context if available, otherwise generate new one
2826 request_id = get_correlation_id() or uuid.uuid4().hex
2827 server_id = tool_gateway_id if tool_gateway_id and isinstance(tool_gateway_id, str) else "unknown"
2828 global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None, user=app_user_email)
2830 start_time = time.monotonic()
2831 success = False
2832 error_message = None
2834 # Get trace_id from context for database span creation
2835 trace_id = current_trace_id.get()
2836 db_span_id = None
2837 db_span_ended = False
2838 observability_service = ObservabilityService() if trace_id else None
2840 # Create database span for observability_spans table
2841 if trace_id and observability_service:
2842 try:
2843 # Re-open database session for span creation (original was closed at line 2285)
2844 # Use commit=False since fresh_db_session() handles commits on exit
2845 with fresh_db_session() as span_db:
2846 db_span_id = observability_service.start_span(
2847 db=span_db,
2848 trace_id=trace_id,
2849 name="tool.invoke",
2850 kind="client",
2851 resource_type="tool",
2852 resource_name=name,
2853 resource_id=tool_id,
2854 attributes={
2855 "tool.name": name,
2856 "tool.id": tool_id,
2857 "tool.integration_type": tool_integration_type,
2858 "tool.gateway_id": tool_gateway_id,
2859 "arguments_count": len(arguments) if arguments else 0,
2860 "has_headers": bool(request_headers),
2861 },
2862 commit=False,
2863 )
2864 logger.debug(f"✓ Created tool.invoke span: {db_span_id} for tool: {name}")
2865 except Exception as e:
2866 logger.warning(f"Failed to start observability span for tool invocation: {e}")
2867 db_span_id = None
2869 # Create a trace span for OpenTelemetry export (Jaeger, Zipkin, etc.)
2870 with create_span(
2871 "tool.invoke",
2872 {
2873 "tool.name": name,
2874 "tool.id": tool_id,
2875 "tool.integration_type": tool_integration_type,
2876 "tool.gateway_id": tool_gateway_id,
2877 "arguments_count": len(arguments) if arguments else 0,
2878 "has_headers": bool(request_headers),
2879 },
2880 ) as span:
2881 try:
2882 # Get combined headers for the tool including base headers, auth, and passthrough headers
2883 headers = tool_headers.copy()
2884 if tool_integration_type == "REST":
2885 # Handle OAuth authentication for REST tools
2886 if tool_auth_type == "oauth" and tool_oauth_config:
2887 try:
2888 access_token = await self.oauth_manager.get_access_token(tool_oauth_config)
2889 headers["Authorization"] = f"Bearer {access_token}"
2890 except Exception as e:
2891 logger.error(f"Failed to obtain OAuth access token for tool {tool_name_computed}: {e}")
2892 raise ToolInvocationError(f"OAuth authentication failed: {str(e)}")
2893 else:
2894 credentials = decode_auth(tool_auth_value)
2895 # Filter out empty header names/values to avoid "Illegal header name" errors
2896 filtered_credentials = {k: v for k, v in credentials.items() if k and v}
2897 headers.update(filtered_credentials)
2899 # Use cached passthrough headers (no DB query needed)
2900 if request_headers:
2901 headers = compute_passthrough_headers_cached(
2902 request_headers,
2903 headers,
2904 passthrough_allowed,
2905 gateway_auth_type=None,
2906 gateway_passthrough_headers=None, # REST tools don't use gateway auth here
2907 )
2908 # Read MCP-Session-Id from downstream client (MCP protocol header)
2909 # and normalize to x-mcp-session-id for our internal session affinity logic
2910 # The pool will strip this before sending to upstream
2911 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests)
2912 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
2913 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id")
2914 if mcp_session_id:
2915 headers["x-mcp-session-id"] = mcp_session_id
2917 worker_id = str(os.getpid())
2918 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
2919 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity")
2921 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
2922 # Use pre-created Pydantic model from Phase 2 (no ORM access)
2923 if tool_metadata:
2924 global_context.metadata[TOOL_METADATA] = tool_metadata
2925 pre_result, context_table = await self._plugin_manager.invoke_hook(
2926 ToolHookType.TOOL_PRE_INVOKE,
2927 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
2928 global_context=global_context,
2929 local_contexts=context_table, # Pass context from previous hooks
2930 violations_as_exceptions=True,
2931 )
2932 if pre_result.modified_payload:
2933 payload = pre_result.modified_payload
2934 name = payload.name
2935 arguments = payload.args
2936 if payload.headers is not None:
2937 headers = payload.headers.model_dump()
2939 # Build the payload based on integration type
2940 payload = arguments.copy()
2942 # Handle URL path parameter substitution (using local variable)
2943 final_url = tool_url
2944 if "{" in tool_url and "}" in tool_url:
2945 # Extract path parameters from URL template and arguments
2946 url_params = re.findall(r"\{(\w+)\}", tool_url)
2947 url_substitutions = {}
2949 for param in url_params:
2950 if param in payload:
2951 url_substitutions[param] = payload.pop(param) # Remove from payload
2952 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param]))
2953 else:
2954 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments")
2956 # --- Extract query params from URL ---
2957 parsed = urlparse(final_url)
2958 final_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
2960 query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()}
2962 # Merge leftover payload + query params
2963 payload.update(query_params)
2965 # Use the tool's request_type rather than defaulting to POST (using local variable)
2966 method = tool_request_type.upper() if tool_request_type else "POST"
2967 rest_start_time = time.time()
2968 try:
2969 if method == "GET":
2970 response = await asyncio.wait_for(self._http_client.get(final_url, params=payload, headers=headers), timeout=effective_timeout)
2971 else:
2972 response = await asyncio.wait_for(self._http_client.request(method, final_url, json=payload, headers=headers), timeout=effective_timeout)
2973 except (asyncio.TimeoutError, httpx.TimeoutException):
2974 rest_elapsed_ms = (time.time() - rest_start_time) * 1000
2975 structured_logger.log(
2976 level="WARNING",
2977 message=f"REST tool invocation timed out: {tool_name_computed}",
2978 component="tool_service",
2979 correlation_id=get_correlation_id(),
2980 duration_ms=rest_elapsed_ms,
2981 metadata={"event": "tool_timeout", "tool_name": tool_name_computed, "timeout_seconds": effective_timeout},
2982 )
2984 # Manually trigger circuit breaker (or other plugins) on timeout
2985 try:
2986 # First-Party
2987 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
2989 tool_timeout_counter.labels(tool_name=name).inc()
2990 except Exception as exc:
2991 logger.debug(
2992 "Failed to increment tool_timeout_counter for %s: %s",
2993 name,
2994 exc,
2995 exc_info=True,
2996 )
2998 if self._plugin_manager:
2999 if context_table:
3000 for ctx in context_table.values():
3001 ctx.set_state("cb_timeout_failure", True)
3003 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3004 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3005 await self._plugin_manager.invoke_hook(
3006 ToolHookType.TOOL_POST_INVOKE,
3007 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3008 global_context=global_context,
3009 local_contexts=context_table,
3010 violations_as_exceptions=False,
3011 )
3013 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3014 response.raise_for_status()
3016 # Handle 204 No Content responses that have no body
3017 if response.status_code == 204:
3018 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")])
3019 success = True
3020 elif response.status_code not in [200, 201, 202, 206]:
3021 try:
3022 result = response.json()
3023 except orjson.JSONDecodeError:
3024 result = {"response_text": response.text} if response.text else {}
3025 tool_result = ToolResult(
3026 content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")],
3027 is_error=True,
3028 )
3029 # Don't mark as successful for error responses - success remains False
3030 else:
3031 try:
3032 result = response.json()
3033 except orjson.JSONDecodeError:
3034 result = {"response_text": response.text} if response.text else {}
3035 logger.debug(f"REST API tool response: {result}")
3036 filtered_response = extract_using_jq(result, tool_jsonpath_filter)
3037 tool_result = ToolResult(content=[TextContent(type="text", text=orjson.dumps(filtered_response, option=orjson.OPT_INDENT_2).decode())])
3038 success = True
3039 # If output schema is present, validate and attach structured content
3040 if tool_output_schema:
3041 valid = self._extract_and_validate_structured_content(tool_for_validation, tool_result, candidate=filtered_response)
3042 success = bool(valid)
3043 elif tool_integration_type == "MCP":
3044 transport = tool_request_type.lower() if tool_request_type else "sse"
3046 # Handle OAuth authentication for the gateway (using local variables)
3047 # NOTE: Use has_gateway instead of gateway to avoid accessing detached ORM object
3048 if has_gateway and gateway_auth_type == "oauth" and gateway_oauth_config:
3049 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
3051 if grant_type == "authorization_code":
3052 # For Authorization Code flow, try to get stored tokens
3053 # NOTE: Use fresh_db_session() since the original db was closed
3054 try:
3055 # First-Party
3056 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3058 with fresh_db_session() as token_db:
3059 token_storage = TokenStorageService(token_db)
3061 # Get user-specific OAuth token
3062 if not app_user_email:
3063 raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.")
3065 access_token = await token_storage.get_user_token(gateway_id_str, app_user_email)
3067 if access_token:
3068 headers = {"Authorization": f"Bearer {access_token}"}
3069 else:
3070 # User hasn't authorized this gateway yet
3071 raise ToolInvocationError(f"Please authorize {gateway_name} first. Visit /oauth/authorize/{gateway_id_str} to complete OAuth flow.")
3072 except Exception as e:
3073 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
3074 raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}")
3075 else:
3076 # For Client Credentials flow, get token directly (no DB needed)
3077 try:
3078 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
3079 headers = {"Authorization": f"Bearer {access_token}"}
3080 except Exception as e:
3081 logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}")
3082 raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}")
3083 else:
3084 headers = decode_auth(gateway_auth_value)
3086 # Use cached passthrough headers (no DB query needed)
3087 if request_headers:
3088 headers = compute_passthrough_headers_cached(
3089 request_headers, headers, passthrough_allowed, gateway_auth_type=gateway_auth_type, gateway_passthrough_headers=gateway_passthrough
3090 )
3091 # Read MCP-Session-Id from downstream client (MCP protocol header)
3092 # and normalize to x-mcp-session-id for our internal session affinity logic
3093 # The pool will strip this before sending to upstream
3094 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests)
3095 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
3096 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id")
3097 if mcp_session_id:
3098 headers["x-mcp-session-id"] = mcp_session_id
3100 worker_id = str(os.getpid())
3101 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
3102 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity (MCP transport)")
3104 def create_ssl_context(ca_certificate: str) -> ssl.SSLContext:
3105 """Create an SSL context with the provided CA certificate.
3107 Uses caching to avoid repeated SSL context creation for the same certificate.
3109 Args:
3110 ca_certificate: CA certificate in PEM format
3112 Returns:
3113 ssl.SSLContext: Configured SSL context
3114 """
3115 return get_cached_ssl_context(ca_certificate)
3117 def get_httpx_client_factory(
3118 headers: dict[str, str] | None = None,
3119 timeout: httpx.Timeout | None = None,
3120 auth: httpx.Auth | None = None,
3121 ) -> httpx.AsyncClient:
3122 """Factory function to create httpx.AsyncClient with optional CA certificate.
3124 Args:
3125 headers: Optional headers for the client
3126 timeout: Optional timeout for the client
3127 auth: Optional auth for the client
3129 Returns:
3130 httpx.AsyncClient: Configured HTTPX async client
3132 Raises:
3133 Exception: If CA certificate signature is invalid
3134 """
3135 # Use local variables instead of ORM objects (captured from outer scope)
3136 valid = False
3137 if gateway_ca_cert:
3138 if settings.enable_ed25519_signing:
3139 public_key_pem = settings.ed25519_public_key
3140 valid = validate_signature(gateway_ca_cert.encode(), gateway_ca_cert_sig, public_key_pem)
3141 else:
3142 valid = True
3143 # First-Party
3144 from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout # pylint: disable=import-outside-toplevel
3146 if valid:
3147 ctx = create_ssl_context(gateway_ca_cert)
3148 else:
3149 ctx = None
3151 # Use effective_timeout for read operations if not explicitly overridden by caller
3152 # This ensures the underlying client waits at least as long as the tool configuration requires
3153 factory_timeout = timeout if timeout else get_http_timeout(read_timeout=effective_timeout)
3155 return httpx.AsyncClient(
3156 verify=ctx if ctx else get_default_verify(),
3157 follow_redirects=True,
3158 headers=headers,
3159 timeout=factory_timeout,
3160 auth=auth,
3161 limits=httpx.Limits(
3162 max_connections=settings.httpx_max_connections,
3163 max_keepalive_connections=settings.httpx_max_keepalive_connections,
3164 keepalive_expiry=settings.httpx_keepalive_expiry,
3165 ),
3166 )
3168 async def connect_to_sse_server(server_url: str, headers: dict = headers):
3169 """Connect to an MCP server running with SSE transport.
3171 Args:
3172 server_url: MCP Server SSE URL
3173 headers: HTTP headers to include in the request
3175 Returns:
3176 ToolResult: Result of tool call
3178 Raises:
3179 ToolInvocationError: If the tool invocation fails during execution.
3180 ToolTimeoutError: If the tool invocation times out.
3181 BaseException: On connection or communication errors
3183 """
3184 # Get correlation ID for distributed tracing
3185 correlation_id = get_correlation_id()
3187 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions.
3188 # MCP SDK pins headers at transport creation, so adding per-request headers
3189 # would cause the first request's correlation ID to be reused for all
3190 # subsequent requests on the same pooled session. Correlation IDs are
3191 # still logged locally for tracing within the gateway.
3193 # Log MCP call start (using local variables)
3194 # Sanitize server_url to redact sensitive query params from logs
3195 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted)
3196 mcp_start_time = time.time()
3197 structured_logger.log(
3198 level="INFO",
3199 message=f"MCP tool call started: {tool_name_original}",
3200 component="tool_service",
3201 correlation_id=correlation_id,
3202 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "sse"},
3203 )
3205 try:
3206 # Use session pool if enabled for 10-20x latency improvement
3207 use_pool = False
3208 pool = None
3209 if settings.mcp_session_pool_enabled:
3210 try:
3211 pool = get_mcp_session_pool()
3212 use_pool = True
3213 except RuntimeError:
3214 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3215 pass
3217 if use_pool and pool is not None:
3218 # Pooled path: do NOT add per-request headers (they would be pinned)
3219 async with pool.session(
3220 url=server_url,
3221 headers=headers,
3222 transport_type=TransportType.SSE,
3223 httpx_client_factory=get_httpx_client_factory,
3224 user_identity=app_user_email,
3225 gateway_id=gateway_id_str,
3226 ) as pooled:
3227 tool_call_result = await asyncio.wait_for(pooled.session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3228 else:
3229 # Non-pooled path: safe to add per-request headers
3230 if correlation_id and headers:
3231 headers["X-Correlation-ID"] = correlation_id
3232 # Fallback to per-call sessions when pool disabled or not initialized
3233 async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams:
3234 async with ClientSession(*streams) as session:
3235 await session.initialize()
3236 tool_call_result = await asyncio.wait_for(session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3238 # Log successful MCP call
3239 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3240 structured_logger.log(
3241 level="INFO",
3242 message=f"MCP tool call completed: {tool_name_original}",
3243 component="tool_service",
3244 correlation_id=correlation_id,
3245 duration_ms=mcp_duration_ms,
3246 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "success": True},
3247 )
3249 return tool_call_result
3250 except (asyncio.TimeoutError, httpx.TimeoutException):
3251 # Handle timeout specifically - log and raise ToolInvocationError
3252 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3253 structured_logger.log(
3254 level="WARNING",
3255 message=f"MCP SSE tool invocation timed out: {tool_name_original}",
3256 component="tool_service",
3257 correlation_id=correlation_id,
3258 duration_ms=mcp_duration_ms,
3259 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "timeout_seconds": effective_timeout},
3260 )
3262 # Manually trigger circuit breaker (or other plugins) on timeout
3263 try:
3264 # First-Party
3265 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3267 tool_timeout_counter.labels(tool_name=name).inc()
3268 except Exception as exc:
3269 logger.debug(
3270 "Failed to increment tool_timeout_counter for %s: %s",
3271 name,
3272 exc,
3273 exc_info=True,
3274 )
3276 if self._plugin_manager: 3276 ↛ 3291line 3276 didn't jump to line 3291 because the condition on line 3276 was always true
3277 if context_table: 3277 ↛ 3281line 3277 didn't jump to line 3281 because the condition on line 3277 was always true
3278 for ctx in context_table.values():
3279 ctx.set_state("cb_timeout_failure", True)
3281 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 3281 ↛ 3291line 3281 didn't jump to line 3291 because the condition on line 3281 was always true
3282 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3283 await self._plugin_manager.invoke_hook(
3284 ToolHookType.TOOL_POST_INVOKE,
3285 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3286 global_context=global_context,
3287 local_contexts=context_table,
3288 violations_as_exceptions=False,
3289 )
3291 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3292 except BaseException as e:
3293 # Extract root cause from ExceptionGroup (Python 3.11+)
3294 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3295 root_cause = e
3296 if isinstance(e, BaseExceptionGroup): 3296 ↛ 3300line 3296 didn't jump to line 3300 because the condition on line 3296 was always true
3297 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3298 root_cause = root_cause.exceptions[0]
3299 # Log failed MCP call (using local variables)
3300 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3301 # Sanitize error message to prevent URL secrets from leaking in logs
3302 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted)
3303 structured_logger.log(
3304 level="ERROR",
3305 message=f"MCP tool call failed: {tool_name_original}",
3306 component="tool_service",
3307 correlation_id=correlation_id,
3308 duration_ms=mcp_duration_ms,
3309 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error},
3310 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse"},
3311 )
3312 raise
3314 async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers):
3315 """Connect to an MCP server running with Streamable HTTP transport.
3317 Args:
3318 server_url: MCP Server URL
3319 headers: HTTP headers to include in the request
3321 Returns:
3322 ToolResult: Result of tool call
3324 Raises:
3325 ToolInvocationError: If the tool invocation fails during execution.
3326 ToolTimeoutError: If the tool invocation times out.
3327 BaseException: On connection or communication errors
3328 """
3329 # Get correlation ID for distributed tracing
3330 correlation_id = get_correlation_id()
3332 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions.
3333 # MCP SDK pins headers at transport creation, so adding per-request headers
3334 # would cause the first request's correlation ID to be reused for all
3335 # subsequent requests on the same pooled session. Correlation IDs are
3336 # still logged locally for tracing within the gateway.
3338 # Log MCP call start (using local variables)
3339 # Sanitize server_url to redact sensitive query params from logs
3340 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted)
3341 mcp_start_time = time.time()
3342 structured_logger.log(
3343 level="INFO",
3344 message=f"MCP tool call started: {tool_name_original}",
3345 component="tool_service",
3346 correlation_id=correlation_id,
3347 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "streamablehttp"},
3348 )
3350 try:
3351 # Use session pool if enabled for 10-20x latency improvement
3352 use_pool = False
3353 pool = None
3354 if settings.mcp_session_pool_enabled:
3355 try:
3356 pool = get_mcp_session_pool()
3357 use_pool = True
3358 except RuntimeError:
3359 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3360 pass
3362 if use_pool and pool is not None:
3363 # Pooled path: do NOT add per-request headers (they would be pinned)
3364 # Determine transport type based on current transport setting
3365 pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP
3366 async with pool.session(
3367 url=server_url,
3368 headers=headers,
3369 transport_type=pool_transport_type,
3370 httpx_client_factory=get_httpx_client_factory,
3371 user_identity=app_user_email,
3372 gateway_id=gateway_id_str,
3373 ) as pooled:
3374 tool_call_result = await asyncio.wait_for(pooled.session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3375 else:
3376 # Non-pooled path: safe to add per-request headers
3377 if correlation_id and headers:
3378 headers["X-Correlation-ID"] = correlation_id
3379 # Fallback to per-call sessions when pool disabled or not initialized
3380 async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id):
3381 async with ClientSession(read_stream, write_stream) as session:
3382 await session.initialize()
3383 tool_call_result = await asyncio.wait_for(session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout)
3385 # Log successful MCP call
3386 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3387 structured_logger.log(
3388 level="INFO",
3389 message=f"MCP tool call completed: {tool_name_original}",
3390 component="tool_service",
3391 correlation_id=correlation_id,
3392 duration_ms=mcp_duration_ms,
3393 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "success": True},
3394 )
3396 return tool_call_result
3397 except (asyncio.TimeoutError, httpx.TimeoutException):
3398 # Handle timeout specifically - log and raise ToolInvocationError
3399 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3400 structured_logger.log(
3401 level="WARNING",
3402 message=f"MCP StreamableHTTP tool invocation timed out: {tool_name_original}",
3403 component="tool_service",
3404 correlation_id=correlation_id,
3405 duration_ms=mcp_duration_ms,
3406 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "timeout_seconds": effective_timeout},
3407 )
3409 # Manually trigger circuit breaker (or other plugins) on timeout
3410 try:
3411 # First-Party
3412 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3414 tool_timeout_counter.labels(tool_name=name).inc()
3415 except Exception as exc:
3416 logger.debug(
3417 "Failed to increment tool_timeout_counter for %s: %s",
3418 name,
3419 exc,
3420 exc_info=True,
3421 )
3423 if self._plugin_manager: 3423 ↛ 3438line 3423 didn't jump to line 3438 because the condition on line 3423 was always true
3424 if context_table: 3424 ↛ 3425line 3424 didn't jump to line 3425 because the condition on line 3424 was never true
3425 for ctx in context_table.values(): 3425 ↛ 3428line 3425 didn't jump to line 3428 because the loop on line 3425 didn't complete
3426 ctx.set_state("cb_timeout_failure", True)
3428 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 3428 ↛ 3438line 3428 didn't jump to line 3438 because the condition on line 3428 was always true
3429 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3430 await self._plugin_manager.invoke_hook(
3431 ToolHookType.TOOL_POST_INVOKE,
3432 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3433 global_context=global_context,
3434 local_contexts=context_table,
3435 violations_as_exceptions=False,
3436 )
3438 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3439 except BaseException as e:
3440 # Extract root cause from ExceptionGroup (Python 3.11+)
3441 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3442 root_cause = e
3443 if isinstance(e, BaseExceptionGroup): 3443 ↛ 3447line 3443 didn't jump to line 3447 because the condition on line 3443 was always true
3444 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3445 root_cause = root_cause.exceptions[0]
3446 # Log failed MCP call
3447 mcp_duration_ms = (time.time() - mcp_start_time) * 1000
3448 # Sanitize error message to prevent URL secrets from leaking in logs
3449 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted)
3450 structured_logger.log(
3451 level="ERROR",
3452 message=f"MCP tool call failed: {tool_name_original}",
3453 component="tool_service",
3454 correlation_id=correlation_id,
3455 duration_ms=mcp_duration_ms,
3456 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error},
3457 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp"},
3458 )
3459 raise
3461 # REMOVED: Redundant gateway query - gateway already eager-loaded via joinedload
3462 # tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id)...)
3464 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
3465 # Use pre-created Pydantic models from Phase 2 (no ORM access)
3466 if tool_metadata:
3467 global_context.metadata[TOOL_METADATA] = tool_metadata
3468 if gateway_metadata:
3469 global_context.metadata[GATEWAY_METADATA] = gateway_metadata
3470 pre_result, context_table = await self._plugin_manager.invoke_hook(
3471 ToolHookType.TOOL_PRE_INVOKE,
3472 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
3473 global_context=global_context,
3474 local_contexts=None,
3475 violations_as_exceptions=True,
3476 )
3477 if pre_result.modified_payload:
3478 payload = pre_result.modified_payload
3479 name = payload.name
3480 arguments = payload.args
3481 if payload.headers is not None:
3482 headers = payload.headers.model_dump()
3484 tool_call_result = ToolResult(content=[TextContent(text="", type="text")])
3485 if transport == "sse":
3486 tool_call_result = await connect_to_sse_server(gateway_url, headers=headers)
3487 elif transport == "streamablehttp":
3488 tool_call_result = await connect_to_streamablehttp_server(gateway_url, headers=headers)
3489 dump = tool_call_result.model_dump(by_alias=True, mode="json")
3490 logger.debug(f"Tool call result dump: {dump}")
3491 content = dump.get("content", [])
3492 # Accept both alias and pythonic names for structured content
3493 structured = dump.get("structuredContent") or dump.get("structured_content")
3494 filtered_response = extract_using_jq(content, tool_jsonpath_filter)
3496 is_err = getattr(tool_call_result, "is_error", None)
3497 if is_err is None:
3498 is_err = getattr(tool_call_result, "isError", False)
3499 tool_result = ToolResult(content=filtered_response, structured_content=structured, is_error=is_err, meta=getattr(tool_call_result, "meta", None))
3500 success = not is_err
3501 logger.debug(f"Final tool_result: {tool_result}")
3502 elif tool_integration_type == "A2A" and a2a_agent_endpoint_url:
3503 # A2A tool invocation using pre-extracted agent data (extracted in Phase 2 before db.close())
3504 headers = {"Content-Type": "application/json"}
3506 # Plugin hook: tool pre-invoke for A2A
3507 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
3508 if tool_metadata:
3509 global_context.metadata[TOOL_METADATA] = tool_metadata
3510 pre_result, context_table = await self._plugin_manager.invoke_hook(
3511 ToolHookType.TOOL_PRE_INVOKE,
3512 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)),
3513 global_context=global_context,
3514 local_contexts=context_table,
3515 violations_as_exceptions=True,
3516 )
3517 if pre_result.modified_payload: 3517 ↛ 3525line 3517 didn't jump to line 3525 because the condition on line 3517 was always true
3518 payload = pre_result.modified_payload
3519 name = payload.name
3520 arguments = payload.args
3521 if payload.headers is not None: 3521 ↛ 3525line 3521 didn't jump to line 3525 because the condition on line 3521 was always true
3522 headers = payload.headers.model_dump()
3524 # Build request data based on agent type
3525 endpoint_url = a2a_agent_endpoint_url
3526 if a2a_agent_type in ["generic", "jsonrpc"] or endpoint_url.endswith("/"):
3527 # JSONRPC agents: Convert flat query to nested message structure
3528 params = None
3529 if isinstance(arguments, dict) and "query" in arguments and isinstance(arguments["query"], str):
3530 message_id = f"admin-test-{int(time.time())}"
3531 # A2A v0.3.x: message.parts use "kind" (not "type").
3532 params = {
3533 "message": {
3534 "kind": "message",
3535 "messageId": message_id,
3536 "role": "user",
3537 "parts": [{"kind": "text", "text": arguments["query"]}],
3538 }
3539 }
3540 method = arguments.get("method", "message/send")
3541 else:
3542 params = arguments.get("params", arguments) if isinstance(arguments, dict) else arguments
3543 method = arguments.get("method", "message/send") if isinstance(arguments, dict) else "message/send"
3544 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1}
3545 else:
3546 # Custom agents: Pass parameters directly
3547 params = arguments if isinstance(arguments, dict) else {}
3548 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": a2a_agent_protocol_version}
3550 # Add authentication
3551 if a2a_agent_auth_type == "api_key" and a2a_agent_auth_value:
3552 headers["Authorization"] = f"Bearer {a2a_agent_auth_value}"
3553 elif a2a_agent_auth_type == "bearer" and a2a_agent_auth_value:
3554 headers["Authorization"] = f"Bearer {a2a_agent_auth_value}"
3555 elif a2a_agent_auth_type == "query_param" and a2a_agent_auth_query_params:
3556 auth_query_params_decrypted: dict[str, str] = {}
3557 for param_key, encrypted_value in a2a_agent_auth_query_params.items():
3558 if encrypted_value:
3559 try:
3560 decrypted = decode_auth(encrypted_value)
3561 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3562 except Exception:
3563 logger.debug(f"Failed to decrypt query param for key '{param_key}'")
3564 if auth_query_params_decrypted:
3565 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted)
3567 # Make HTTP request with timeout enforcement
3568 logger.info(f"Calling A2A agent '{a2a_agent_name}' at {endpoint_url}")
3569 a2a_start_time = time.time()
3570 try:
3571 http_response = await asyncio.wait_for(self._http_client.post(endpoint_url, json=request_data, headers=headers), timeout=effective_timeout)
3572 except (asyncio.TimeoutError, httpx.TimeoutException):
3573 a2a_elapsed_ms = (time.time() - a2a_start_time) * 1000
3574 structured_logger.log(
3575 level="WARNING",
3576 message=f"A2A tool invocation timed out: {name}",
3577 component="tool_service",
3578 correlation_id=get_correlation_id(),
3579 duration_ms=a2a_elapsed_ms,
3580 metadata={"event": "tool_timeout", "tool_name": name, "a2a_agent": a2a_agent_name, "timeout_seconds": effective_timeout},
3581 )
3583 # Increment timeout counter
3584 try:
3585 # First-Party
3586 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel
3588 tool_timeout_counter.labels(tool_name=name).inc()
3589 except Exception as exc:
3590 logger.debug("Failed to increment tool_timeout_counter for %s: %s", name, exc, exc_info=True)
3592 # Trigger circuit breaker on timeout
3593 if self._plugin_manager:
3594 if context_table: 3594 ↛ 3598line 3594 didn't jump to line 3598 because the condition on line 3594 was always true
3595 for ctx in context_table.values():
3596 ctx.set_state("cb_timeout_failure", True)
3598 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 3598 ↛ 3608line 3598 didn't jump to line 3608 because the condition on line 3598 was always true
3599 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True)
3600 await self._plugin_manager.invoke_hook(
3601 ToolHookType.TOOL_POST_INVOKE,
3602 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)),
3603 global_context=global_context,
3604 local_contexts=context_table,
3605 violations_as_exceptions=False,
3606 )
3608 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s")
3610 if http_response.status_code == 200:
3611 response_data = http_response.json()
3612 if isinstance(response_data, dict) and "response" in response_data:
3613 content = [TextContent(type="text", text=str(response_data["response"]))]
3614 else:
3615 content = [TextContent(type="text", text=str(response_data))]
3616 tool_result = ToolResult(content=content, is_error=False)
3617 success = True
3618 else:
3619 error_message = f"HTTP {http_response.status_code}: {http_response.text}"
3620 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")]
3621 tool_result = ToolResult(content=content, is_error=True)
3622 else:
3623 tool_result = ToolResult(content=[TextContent(type="text", text="Invalid tool type")], is_error=True)
3625 # Plugin hook: tool post-invoke
3626 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3627 post_result, _ = await self._plugin_manager.invoke_hook(
3628 ToolHookType.TOOL_POST_INVOKE,
3629 payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)),
3630 global_context=global_context,
3631 local_contexts=context_table,
3632 violations_as_exceptions=True,
3633 )
3634 # Use modified payload if provided
3635 if post_result.modified_payload:
3636 # Reconstruct ToolResult from modified result
3637 modified_result = post_result.modified_payload.result
3638 if isinstance(modified_result, dict) and "content" in modified_result:
3639 # Safely obtain structured content using .get() to avoid KeyError when
3640 # plugins provide only the content without structured content fields.
3641 structured = modified_result.get("structuredContent") if "structuredContent" in modified_result else modified_result.get("structured_content")
3643 tool_result = ToolResult(content=modified_result["content"], structured_content=structured)
3644 else:
3645 # If result is not in expected format, convert it to text content
3646 tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))])
3648 return tool_result
3649 except (PluginError, PluginViolationError):
3650 raise
3651 except ToolTimeoutError as e:
3652 # ToolTimeoutError is raised by timeout handlers which already called tool_post_invoke
3653 # Re-raise without calling post_invoke again to avoid double-counting failures
3654 # But DO set error_message and span attributes for observability
3655 error_message = str(e)
3656 if span:
3657 span.set_attribute("error", True)
3658 span.set_attribute("error.message", error_message)
3659 raise
3660 except BaseException as e:
3661 # Extract root cause from ExceptionGroup (Python 3.11+)
3662 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3663 root_cause = e
3664 if isinstance(e, BaseExceptionGroup):
3665 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3666 root_cause = root_cause.exceptions[0]
3667 error_message = str(root_cause)
3668 # Set span error status
3669 if span:
3670 span.set_attribute("error", True)
3671 span.set_attribute("error.message", error_message)
3673 # Notify plugins of the failure so circuit breaker can track it
3674 # This ensures HTTP 4xx/5xx errors and MCP failures are counted
3675 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE):
3676 try:
3677 exception_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation failed: {error_message}")], is_error=True)
3678 await self._plugin_manager.invoke_hook(
3679 ToolHookType.TOOL_POST_INVOKE,
3680 payload=ToolPostInvokePayload(name=name, result=exception_error_result.model_dump(by_alias=True)),
3681 global_context=global_context,
3682 local_contexts=context_table,
3683 violations_as_exceptions=False, # Don't let plugin errors mask the original exception
3684 )
3685 except Exception as plugin_exc:
3686 logger.debug("Failed to invoke post-invoke plugins on exception: %s", plugin_exc)
3688 raise ToolInvocationError(f"Tool invocation failed: {error_message}")
3689 finally:
3690 # Calculate duration
3691 duration_ms = (time.monotonic() - start_time) * 1000
3693 # End database span for observability_spans table
3694 # Use commit=False since fresh_db_session() handles commits on exit
3695 if db_span_id and observability_service and not db_span_ended:
3696 try:
3697 with fresh_db_session() as span_db:
3698 observability_service.end_span(
3699 db=span_db,
3700 span_id=db_span_id,
3701 status="ok" if success else "error",
3702 status_message=error_message if error_message else None,
3703 attributes={
3704 "success": success,
3705 "duration_ms": duration_ms,
3706 },
3707 commit=False,
3708 )
3709 db_span_ended = True
3710 logger.debug(f"✓ Ended tool.invoke span: {db_span_id}")
3711 except Exception as e:
3712 logger.warning(f"Failed to end observability span for tool invocation: {e}")
3714 # Add final span attributes for OpenTelemetry
3715 if span:
3716 span.set_attribute("success", success)
3717 span.set_attribute("duration.ms", duration_ms)
3719 # ═══════════════════════════════════════════════════════════════════════════
3720 # PHASE 4: Record metrics via buffered service (batches writes for performance)
3721 # ═══════════════════════════════════════════════════════════════════════════
3722 try:
3723 # First-Party
3724 from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service # pylint: disable=import-outside-toplevel
3726 metrics_buffer = get_metrics_buffer_service()
3727 metrics_buffer.record_tool_metric(
3728 tool_id=tool_id,
3729 start_time=start_time,
3730 success=success,
3731 error_message=error_message,
3732 )
3733 except Exception as metric_error:
3734 logger.warning(f"Failed to record tool metric: {metric_error}")
3736 # Log structured message with performance tracking (using local variables)
3737 if success:
3738 structured_logger.info(
3739 f"Tool '{name}' invoked successfully",
3740 user_id=app_user_email,
3741 resource_type="tool",
3742 resource_id=tool_id,
3743 resource_action="invoke",
3744 duration_ms=duration_ms,
3745 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "arguments_count": len(arguments) if arguments else 0},
3746 )
3747 else:
3748 structured_logger.error(
3749 f"Tool '{name}' invocation failed",
3750 error=Exception(error_message) if error_message else None,
3751 user_id=app_user_email,
3752 resource_type="tool",
3753 resource_id=tool_id,
3754 resource_action="invoke",
3755 duration_ms=duration_ms,
3756 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "error_message": error_message},
3757 )
3759 # Track performance with threshold checking
3760 with perf_tracker.track_operation("tool_invocation", name):
3761 pass # Duration already captured above
3763 async def update_tool(
3764 self,
3765 db: Session,
3766 tool_id: str,
3767 tool_update: ToolUpdate,
3768 modified_by: Optional[str] = None,
3769 modified_from_ip: Optional[str] = None,
3770 modified_via: Optional[str] = None,
3771 modified_user_agent: Optional[str] = None,
3772 user_email: Optional[str] = None,
3773 ) -> ToolRead:
3774 """
3775 Update an existing tool.
3777 Args:
3778 db (Session): The SQLAlchemy database session.
3779 tool_id (str): The unique identifier of the tool.
3780 tool_update (ToolUpdate): Tool update schema with new data.
3781 modified_by (Optional[str]): Username who modified this tool.
3782 modified_from_ip (Optional[str]): IP address of modifier.
3783 modified_via (Optional[str]): Modification method (ui, api).
3784 modified_user_agent (Optional[str]): User agent of modification request.
3785 user_email (Optional[str]): Email of user performing update (for ownership check).
3787 Returns:
3788 The updated ToolRead object.
3790 Raises:
3791 ToolNotFoundError: If the tool is not found.
3792 PermissionError: If user doesn't own the tool.
3793 IntegrityError: If there is a database integrity error.
3794 ToolNameConflictError: If a tool with the same name already exists.
3795 ToolError: For other update errors.
3797 Examples:
3798 >>> from mcpgateway.services.tool_service import ToolService
3799 >>> from unittest.mock import MagicMock, AsyncMock
3800 >>> from mcpgateway.schemas import ToolRead
3801 >>> service = ToolService()
3802 >>> db = MagicMock()
3803 >>> tool = MagicMock()
3804 >>> db.get.return_value = tool
3805 >>> db.commit = MagicMock()
3806 >>> db.refresh = MagicMock()
3807 >>> db.execute.return_value.scalar_one_or_none.return_value = None
3808 >>> service._notify_tool_updated = AsyncMock()
3809 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
3810 >>> ToolRead.model_validate = MagicMock(return_value='tool_read')
3811 >>> import asyncio
3812 >>> asyncio.run(service.update_tool(db, 'tool_id', MagicMock()))
3813 'tool_read'
3814 """
3815 try:
3816 tool = get_for_update(db, DbTool, tool_id)
3818 if not tool:
3819 raise ToolNotFoundError(f"Tool not found: {tool_id}")
3821 old_tool_name = tool.name
3822 old_gateway_id = tool.gateway_id
3824 # Check ownership if user_email provided
3825 if user_email:
3826 # First-Party
3827 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
3829 permission_service = PermissionService(db)
3830 if not await permission_service.check_resource_ownership(user_email, tool):
3831 raise PermissionError("Only the owner can update this tool")
3833 # Check for name change and ensure uniqueness
3834 if tool_update.name and tool_update.name != tool.name:
3835 # Check for existing tool with the same name and visibility
3836 if tool_update.visibility.lower() == "public":
3837 # Check for existing public tool with the same name (row-locked)
3838 existing_tool = get_for_update(
3839 db,
3840 DbTool,
3841 where=and_(
3842 DbTool.custom_name == tool_update.custom_name,
3843 DbTool.visibility == "public",
3844 DbTool.id != tool.id,
3845 ),
3846 )
3847 if existing_tool:
3848 raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
3849 elif tool_update.visibility.lower() == "team" and tool_update.team_id:
3850 # Check for existing team tool with the same name
3851 existing_tool = get_for_update(
3852 db,
3853 DbTool,
3854 where=and_(
3855 DbTool.custom_name == tool_update.custom_name,
3856 DbTool.visibility == "team",
3857 DbTool.team_id == tool_update.team_id,
3858 DbTool.id != tool.id,
3859 ),
3860 )
3861 if existing_tool:
3862 raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility)
3863 if tool_update.custom_name is None and tool.name == tool.custom_name:
3864 tool.custom_name = tool_update.name
3865 tool.name = tool_update.name
3867 if tool_update.custom_name is not None:
3868 tool.custom_name = tool_update.custom_name
3869 if tool_update.displayName is not None:
3870 tool.display_name = tool_update.displayName
3871 if tool_update.url is not None:
3872 tool.url = str(tool_update.url)
3873 if tool_update.description is not None:
3874 tool.description = tool_update.description
3875 if tool_update.integration_type is not None:
3876 tool.integration_type = tool_update.integration_type
3877 if tool_update.request_type is not None:
3878 tool.request_type = tool_update.request_type
3879 if tool_update.headers is not None:
3880 tool.headers = tool_update.headers
3881 if tool_update.input_schema is not None:
3882 tool.input_schema = tool_update.input_schema
3883 if tool_update.output_schema is not None:
3884 tool.output_schema = tool_update.output_schema
3885 if tool_update.annotations is not None:
3886 tool.annotations = tool_update.annotations
3887 if tool_update.jsonpath_filter is not None:
3888 tool.jsonpath_filter = tool_update.jsonpath_filter
3889 if tool_update.visibility is not None:
3890 tool.visibility = tool_update.visibility
3892 if tool_update.auth is not None:
3893 if tool_update.auth.auth_type is not None:
3894 tool.auth_type = tool_update.auth.auth_type
3895 if tool_update.auth.auth_value is not None:
3896 tool.auth_value = tool_update.auth.auth_value
3897 else:
3898 tool.auth_type = None
3900 # Update tags if provided
3901 if tool_update.tags is not None:
3902 tool.tags = tool_update.tags
3904 # Update modification metadata
3905 if modified_by is not None:
3906 tool.modified_by = modified_by
3907 if modified_from_ip is not None:
3908 tool.modified_from_ip = modified_from_ip
3909 if modified_via is not None:
3910 tool.modified_via = modified_via
3911 if modified_user_agent is not None:
3912 tool.modified_user_agent = modified_user_agent
3914 # Increment version
3915 if hasattr(tool, "version") and tool.version is not None:
3916 tool.version += 1
3917 else:
3918 tool.version = 1
3919 logger.info(f"Update tool: {tool.name} (output_schema: {tool.output_schema})")
3921 tool.updated_at = datetime.now(timezone.utc)
3922 db.commit()
3923 db.refresh(tool)
3924 await self._notify_tool_updated(tool)
3925 logger.info(f"Updated tool: {tool.name}")
3927 # Structured logging: Audit trail for tool update
3928 changes = []
3929 if tool_update.name:
3930 changes.append(f"name: {tool_update.name}")
3931 if tool_update.visibility:
3932 changes.append(f"visibility: {tool_update.visibility}")
3933 if tool_update.description:
3934 changes.append("description updated")
3936 audit_trail.log_action(
3937 user_id=user_email or modified_by or "system",
3938 action="update_tool",
3939 resource_type="tool",
3940 resource_id=tool.id,
3941 resource_name=tool.name,
3942 user_email=user_email,
3943 team_id=tool.team_id,
3944 client_ip=modified_from_ip,
3945 user_agent=modified_user_agent,
3946 new_values={
3947 "name": tool.name,
3948 "display_name": tool.display_name,
3949 "version": tool.version,
3950 },
3951 context={
3952 "modified_via": modified_via,
3953 "changes": ", ".join(changes) if changes else "metadata only",
3954 },
3955 db=db,
3956 )
3958 # Structured logging: Log successful tool update
3959 structured_logger.log(
3960 level="INFO",
3961 message="Tool updated successfully",
3962 event_type="tool_updated",
3963 component="tool_service",
3964 user_id=modified_by,
3965 user_email=user_email,
3966 team_id=tool.team_id,
3967 resource_type="tool",
3968 resource_id=tool.id,
3969 custom_fields={
3970 "tool_name": tool.name,
3971 "version": tool.version,
3972 },
3973 db=db,
3974 )
3976 # Invalidate cache after successful update
3977 cache = _get_registry_cache()
3978 await cache.invalidate_tools()
3979 tool_lookup_cache = _get_tool_lookup_cache()
3980 await tool_lookup_cache.invalidate(old_tool_name, gateway_id=str(old_gateway_id) if old_gateway_id else None)
3981 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None)
3982 # Also invalidate tags cache since tool tags may have changed
3983 # First-Party
3984 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
3986 await admin_stats_cache.invalidate_tags()
3988 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None))
3989 except PermissionError as pe:
3990 db.rollback()
3992 # Structured logging: Log permission error
3993 structured_logger.log(
3994 level="WARNING",
3995 message="Tool update failed due to permission error",
3996 event_type="tool_update_permission_denied",
3997 component="tool_service",
3998 user_email=user_email,
3999 resource_type="tool",
4000 resource_id=tool_id,
4001 error=pe,
4002 db=db,
4003 )
4004 raise
4005 except IntegrityError as ie:
4006 db.rollback()
4007 logger.error(f"IntegrityError during tool update: {ie}")
4009 # Structured logging: Log database integrity error
4010 structured_logger.log(
4011 level="ERROR",
4012 message="Tool update failed due to database integrity error",
4013 event_type="tool_update_failed",
4014 component="tool_service",
4015 user_id=modified_by,
4016 user_email=user_email,
4017 resource_type="tool",
4018 resource_id=tool_id,
4019 error=ie,
4020 db=db,
4021 )
4022 raise ie
4023 except ToolNotFoundError as tnfe:
4024 db.rollback()
4025 logger.error(f"Tool not found during update: {tnfe}")
4027 # Structured logging: Log not found error
4028 structured_logger.log(
4029 level="ERROR",
4030 message="Tool update failed - tool not found",
4031 event_type="tool_not_found",
4032 component="tool_service",
4033 user_email=user_email,
4034 resource_type="tool",
4035 resource_id=tool_id,
4036 error=tnfe,
4037 db=db,
4038 )
4039 raise tnfe
4040 except ToolNameConflictError as tnce:
4041 db.rollback()
4042 logger.error(f"Tool name conflict during update: {tnce}")
4044 # Structured logging: Log name conflict error
4045 structured_logger.log(
4046 level="WARNING",
4047 message="Tool update failed due to name conflict",
4048 event_type="tool_name_conflict",
4049 component="tool_service",
4050 user_id=modified_by,
4051 user_email=user_email,
4052 resource_type="tool",
4053 resource_id=tool_id,
4054 error=tnce,
4055 db=db,
4056 )
4057 raise tnce
4058 except Exception as ex:
4059 db.rollback()
4061 # Structured logging: Log generic tool update failure
4062 structured_logger.log(
4063 level="ERROR",
4064 message="Tool update failed",
4065 event_type="tool_update_failed",
4066 component="tool_service",
4067 user_id=modified_by,
4068 user_email=user_email,
4069 resource_type="tool",
4070 resource_id=tool_id,
4071 error=ex,
4072 db=db,
4073 )
4074 raise ToolError(f"Failed to update tool: {str(ex)}")
4076 async def _notify_tool_updated(self, tool: DbTool) -> None:
4077 """
4078 Notify subscribers of tool update.
4080 Args:
4081 tool: Tool updated
4082 """
4083 event = {
4084 "type": "tool_updated",
4085 "data": {"id": tool.id, "name": tool.name, "url": tool.url, "description": tool.description, "enabled": tool.enabled},
4086 "timestamp": datetime.now(timezone.utc).isoformat(),
4087 }
4088 await self._publish_event(event)
4090 async def _notify_tool_activated(self, tool: DbTool) -> None:
4091 """
4092 Notify subscribers of tool activation.
4094 Args:
4095 tool: Tool activated
4096 """
4097 event = {
4098 "type": "tool_activated",
4099 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable},
4100 "timestamp": datetime.now(timezone.utc).isoformat(),
4101 }
4102 await self._publish_event(event)
4104 async def _notify_tool_deactivated(self, tool: DbTool) -> None:
4105 """
4106 Notify subscribers of tool deactivation.
4108 Args:
4109 tool: Tool deactivated
4110 """
4111 event = {
4112 "type": "tool_deactivated",
4113 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable},
4114 "timestamp": datetime.now(timezone.utc).isoformat(),
4115 }
4116 await self._publish_event(event)
4118 async def _notify_tool_offline(self, tool: DbTool) -> None:
4119 """
4120 Notify subscribers that tool is offline.
4122 Args:
4123 tool: Tool database object
4124 """
4125 event = {
4126 "type": "tool_offline",
4127 "data": {
4128 "id": tool.id,
4129 "name": tool.name,
4130 "enabled": True,
4131 "reachable": False,
4132 },
4133 "timestamp": datetime.now(timezone.utc).isoformat(),
4134 }
4135 await self._publish_event(event)
4137 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None:
4138 """
4139 Notify subscribers of tool deletion.
4141 Args:
4142 tool_info: Dictionary on tool deleted
4143 """
4144 event = {
4145 "type": "tool_deleted",
4146 "data": tool_info,
4147 "timestamp": datetime.now(timezone.utc).isoformat(),
4148 }
4149 await self._publish_event(event)
4151 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
4152 """Subscribe to tool events via the EventService.
4154 Yields:
4155 Tool event messages.
4156 """
4157 async for event in self._event_service.subscribe_events():
4158 yield event
4160 async def _notify_tool_added(self, tool: DbTool) -> None:
4161 """
4162 Notify subscribers of tool addition.
4164 Args:
4165 tool: Tool added
4166 """
4167 event = {
4168 "type": "tool_added",
4169 "data": {
4170 "id": tool.id,
4171 "name": tool.name,
4172 "url": tool.url,
4173 "description": tool.description,
4174 "enabled": tool.enabled,
4175 },
4176 "timestamp": datetime.now(timezone.utc).isoformat(),
4177 }
4178 await self._publish_event(event)
4180 async def _notify_tool_removed(self, tool: DbTool) -> None:
4181 """
4182 Notify subscribers of tool removal (soft delete/deactivation).
4184 Args:
4185 tool: Tool removed
4186 """
4187 event = {
4188 "type": "tool_removed",
4189 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled},
4190 "timestamp": datetime.now(timezone.utc).isoformat(),
4191 }
4192 await self._publish_event(event)
4194 async def _publish_event(self, event: Dict[str, Any]) -> None:
4195 """
4196 Publish event to all subscribers via the EventService.
4198 Args:
4199 event: Event to publish
4200 """
4201 await self._event_service.publish_event(event)
4203 async def _validate_tool_url(self, url: str) -> None:
4204 """Validate tool URL is accessible.
4206 Args:
4207 url: URL to validate.
4209 Raises:
4210 ToolValidationError: If URL validation fails.
4211 """
4212 try:
4213 response = await self._http_client.get(url)
4214 response.raise_for_status()
4215 except Exception as e:
4216 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}")
4218 async def _check_tool_health(self, tool: DbTool) -> bool:
4219 """Check if tool endpoint is healthy.
4221 Args:
4222 tool: Tool to check.
4224 Returns:
4225 True if tool is healthy.
4226 """
4227 try:
4228 response = await self._http_client.get(tool.url)
4229 return response.is_success
4230 except Exception:
4231 return False
4233 # async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]:
4234 # """Generate tool events for SSE.
4236 # Yields:
4237 # Tool events.
4238 # """
4239 # queue: asyncio.Queue = asyncio.Queue()
4240 # self._event_subscribers.append(queue)
4241 # try:
4242 # while True:
4243 # event = await queue.get()
4244 # yield event
4245 # finally:
4246 # self._event_subscribers.remove(queue)
4248 # --- Metrics ---
4249 async def aggregate_metrics(self, db: Session) -> Dict[str, Any]:
4250 """
4251 Aggregate metrics for all tool invocations across all tools.
4253 Combines recent raw metrics (within retention period) with historical
4254 hourly rollups for complete historical coverage. Uses in-memory caching
4255 (10s TTL) to reduce database load under high request rates.
4257 Args:
4258 db: Database session
4260 Returns:
4261 Aggregated metrics computed from raw ToolMetric + ToolMetricsHourly.
4263 Examples:
4264 >>> from mcpgateway.services.tool_service import ToolService
4265 >>> service = ToolService()
4266 >>> # Method exists and is callable
4267 >>> callable(service.aggregate_metrics)
4268 True
4269 """
4270 # Check cache first (if enabled)
4271 # First-Party
4272 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
4274 if is_cache_enabled():
4275 cached = metrics_cache.get("tools")
4276 if cached is not None:
4277 return cached
4279 # Use combined raw + rollup query for full historical coverage
4280 # First-Party
4281 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel
4283 result = aggregate_metrics_combined(db, "tool")
4284 metrics = result.to_dict()
4286 # Cache the result (if enabled)
4287 if is_cache_enabled():
4288 metrics_cache.set("tools", metrics)
4290 return metrics
4292 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None:
4293 """
4294 Reset all tool metrics by deleting raw and hourly rollup records.
4296 Args:
4297 db: Database session
4298 tool_id: Optional tool ID to reset metrics for a specific tool
4300 Examples:
4301 >>> from mcpgateway.services.tool_service import ToolService
4302 >>> from unittest.mock import MagicMock
4303 >>> service = ToolService()
4304 >>> db = MagicMock()
4305 >>> db.execute = MagicMock()
4306 >>> db.commit = MagicMock()
4307 >>> import asyncio
4308 >>> asyncio.run(service.reset_metrics(db))
4309 """
4311 if tool_id:
4312 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id))
4313 db.execute(delete(ToolMetricsHourly).where(ToolMetricsHourly.tool_id == tool_id))
4314 else:
4315 db.execute(delete(ToolMetric))
4316 db.execute(delete(ToolMetricsHourly))
4317 db.commit()
4319 # Invalidate metrics cache
4320 # First-Party
4321 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
4323 metrics_cache.invalidate("tools")
4324 metrics_cache.invalidate_prefix("top_tools:")
4326 async def create_tool_from_a2a_agent(
4327 self,
4328 db: Session,
4329 agent: DbA2AAgent,
4330 created_by: Optional[str] = None,
4331 created_from_ip: Optional[str] = None,
4332 created_via: Optional[str] = None,
4333 created_user_agent: Optional[str] = None,
4334 ) -> DbTool:
4335 """Create a tool entry from an A2A agent for virtual server integration.
4337 Args:
4338 db: Database session.
4339 agent: A2A agent to create tool from.
4340 created_by: Username who created this tool.
4341 created_from_ip: IP address of creator.
4342 created_via: Creation method.
4343 created_user_agent: User agent of creation request.
4345 Returns:
4346 The created tool database object.
4348 Raises:
4349 ToolNameConflictError: If a tool with the same name already exists.
4350 """
4351 # Check if tool already exists for this agent
4352 tool_name = f"a2a_{agent.slug}"
4353 existing_query = select(DbTool).where(DbTool.original_name == tool_name)
4354 existing_tool = db.execute(existing_query).scalar_one_or_none()
4356 if existing_tool:
4357 # Tool already exists, return it
4358 return existing_tool
4360 # Create tool entry for the A2A agent
4361 logger.debug(f"agent.tags: {agent.tags} for agent: {agent.name} (ID: {agent.id})")
4363 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..},
4364 # extract the human-friendly label. If tags are already strings, keep them.
4365 normalized_tags: list[str] = []
4366 for t in agent.tags or []:
4367 if isinstance(t, dict):
4368 # Prefer 'label', fall back to 'id' or stringified dict
4369 normalized_tags.append(t.get("label") or t.get("id") or str(t))
4370 elif hasattr(t, "label"):
4371 normalized_tags.append(getattr(t, "label"))
4372 else:
4373 normalized_tags.append(str(t))
4375 # Ensure we include identifying A2A tags
4376 normalized_tags = normalized_tags + ["a2a", "agent"]
4378 tool_data = ToolCreate(
4379 name=tool_name,
4380 displayName=generate_display_name(agent.name),
4381 url=agent.endpoint_url,
4382 description=f"A2A Agent: {agent.description or agent.name}",
4383 integration_type="A2A", # Special integration type for A2A agents
4384 request_type="POST",
4385 input_schema={
4386 "type": "object",
4387 "properties": {
4388 "query": {"type": "string", "description": "User query", "default": "Hello from MCP Gateway Admin UI test!"},
4389 },
4390 "required": ["query"],
4391 },
4392 allow_auto=True,
4393 annotations={
4394 "title": f"A2A Agent: {agent.name}",
4395 "a2a_agent_id": agent.id,
4396 "a2a_agent_type": agent.agent_type,
4397 },
4398 auth_type=agent.auth_type,
4399 auth_value=agent.auth_value,
4400 tags=normalized_tags,
4401 )
4403 # Default to "public" visibility if agent visibility is not set
4404 # This ensures A2A tools are visible in the Global Tools Tab
4405 tool_visibility = agent.visibility or "public"
4407 tool_read = await self.register_tool(
4408 db,
4409 tool_data,
4410 created_by=created_by,
4411 created_from_ip=created_from_ip,
4412 created_via=created_via or "a2a_integration",
4413 created_user_agent=created_user_agent,
4414 team_id=agent.team_id,
4415 owner_email=agent.owner_email,
4416 visibility=tool_visibility,
4417 )
4419 # Return the DbTool object for relationship assignment
4420 tool_db = db.get(DbTool, tool_read.id)
4421 return tool_db
4423 async def update_tool_from_a2a_agent(
4424 self,
4425 db: Session,
4426 agent: DbA2AAgent,
4427 modified_by: Optional[str] = None,
4428 modified_from_ip: Optional[str] = None,
4429 modified_via: Optional[str] = None,
4430 modified_user_agent: Optional[str] = None,
4431 ) -> Optional[ToolRead]:
4432 """Update the tool associated with an A2A agent when the agent is updated.
4434 Args:
4435 db: Database session.
4436 agent: Updated A2A agent.
4437 modified_by: Username who modified this tool.
4438 modified_from_ip: IP address of modifier.
4439 modified_via: Modification method.
4440 modified_user_agent: User agent of modification request.
4442 Returns:
4443 The updated tool, or None if no associated tool exists.
4444 """
4445 # Use the tool_id from the agent for efficient lookup
4446 if not agent.tool_id:
4447 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool update")
4448 return None
4450 tool = db.get(DbTool, agent.tool_id)
4451 if not tool:
4452 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}, resetting tool_id")
4453 agent.tool_id = None
4454 db.commit()
4455 return None
4457 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..},
4458 # extract the human-friendly label. If tags are already strings, keep them.
4459 normalized_tags: list[str] = []
4460 for t in agent.tags or []:
4461 if isinstance(t, dict):
4462 # Prefer 'label', fall back to 'id' or stringified dict
4463 normalized_tags.append(t.get("label") or t.get("id") or str(t))
4464 elif hasattr(t, "label"):
4465 normalized_tags.append(getattr(t, "label"))
4466 else:
4467 normalized_tags.append(str(t))
4469 # Ensure we include identifying A2A tags
4470 normalized_tags = normalized_tags + ["a2a", "agent"]
4472 # Prepare update data matching the agent's current state
4473 # IMPORTANT: Preserve the existing tool's visibility to avoid unintentionally
4474 # making private/team tools public (ToolUpdate defaults to "public")
4475 # Note: team_id is not a field on ToolUpdate schema, so team assignment is preserved
4476 # implicitly by not changing visibility (team tools stay team-scoped)
4477 new_tool_name = f"a2a_{agent.slug}"
4478 tool_update = ToolUpdate(
4479 name=new_tool_name,
4480 custom_name=new_tool_name, # Also set custom_name to ensure name update works
4481 displayName=generate_display_name(agent.name),
4482 url=agent.endpoint_url,
4483 description=f"A2A Agent: {agent.description or agent.name}",
4484 auth=AuthenticationValues(auth_type=agent.auth_type, auth_value=agent.auth_value) if agent.auth_type else None,
4485 tags=normalized_tags,
4486 visibility=tool.visibility, # Preserve existing visibility
4487 )
4489 # Update the tool
4490 return await self.update_tool(
4491 db=db,
4492 tool_id=tool.id,
4493 tool_update=tool_update,
4494 modified_by=modified_by,
4495 modified_from_ip=modified_from_ip,
4496 modified_via=modified_via or "a2a_sync",
4497 modified_user_agent=modified_user_agent,
4498 )
4500 async def delete_tool_from_a2a_agent(self, db: Session, agent: DbA2AAgent, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
4501 """Delete the tool associated with an A2A agent when the agent is deleted.
4503 Args:
4504 db: Database session.
4505 agent: The A2A agent being deleted.
4506 user_email: Email of user performing delete (for ownership check).
4507 purge_metrics: If True, delete raw + rollup metrics for this tool.
4508 """
4509 # Use the tool_id from the agent for efficient lookup
4510 if not agent.tool_id:
4511 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool deletion")
4512 return
4514 tool = db.get(DbTool, agent.tool_id)
4515 if not tool:
4516 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}")
4517 return
4519 # Delete the tool
4520 await self.delete_tool(db=db, tool_id=tool.id, user_email=user_email, purge_metrics=purge_metrics)
4521 logger.info(f"Deleted tool {tool.id} associated with A2A agent {agent.id}")
4523 async def _invoke_a2a_tool(self, db: Session, tool: DbTool, arguments: Dict[str, Any]) -> ToolResult:
4524 """Invoke an A2A agent through its corresponding tool.
4526 Args:
4527 db: Database session.
4528 tool: The tool record that represents the A2A agent.
4529 arguments: Tool arguments.
4531 Returns:
4532 Tool result from A2A agent invocation.
4534 Raises:
4535 ToolNotFoundError: If the A2A agent is not found.
4536 """
4538 # Extract A2A agent ID from tool annotations
4539 agent_id = tool.annotations.get("a2a_agent_id")
4540 if not agent_id:
4541 raise ToolNotFoundError(f"A2A tool '{tool.name}' missing agent ID in annotations")
4543 # Get the A2A agent
4544 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
4545 agent = db.execute(agent_query).scalar_one_or_none()
4547 if not agent:
4548 raise ToolNotFoundError(f"A2A agent not found for tool '{tool.name}' (agent ID: {agent_id})")
4550 if not agent.enabled:
4551 raise ToolNotFoundError(f"A2A agent '{agent.name}' is disabled")
4553 # Force-load all attributes needed by _call_a2a_agent before detaching
4554 # (accessing them ensures they're loaded into the object's __dict__)
4555 _ = (agent.name, agent.endpoint_url, agent.agent_type, agent.protocol_version, agent.auth_type, agent.auth_value, agent.auth_query_params)
4557 # Detach agent from session so its loaded data remains accessible after close
4558 db.expunge(agent)
4560 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
4561 # This prevents "idle in transaction" connection pool exhaustion under load
4562 db.commit()
4563 db.close()
4565 # Prepare parameters for A2A invocation
4566 try:
4567 # Make the A2A agent call (agent is now detached but data is loaded)
4568 response_data = await self._call_a2a_agent(agent, arguments)
4570 # Convert A2A response to MCP ToolResult format
4571 if isinstance(response_data, dict) and "response" in response_data:
4572 content = [TextContent(type="text", text=str(response_data["response"]))]
4573 else:
4574 content = [TextContent(type="text", text=str(response_data))]
4576 result = ToolResult(content=content, is_error=False)
4578 except Exception as e:
4579 error_message = str(e)
4580 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")]
4581 result = ToolResult(content=content, is_error=True)
4583 # Note: Metrics are recorded by the calling invoke_tool method, not here
4584 return result
4586 async def _call_a2a_agent(self, agent: DbA2AAgent, parameters: Dict[str, Any]):
4587 """Call an A2A agent directly.
4589 Args:
4590 agent: The A2A agent to call.
4591 parameters: Parameters for the interaction.
4593 Returns:
4594 Response from the A2A agent.
4596 Raises:
4597 Exception: If the call fails.
4598 """
4599 logger.info(f"Calling A2A agent '{agent.name}' at {agent.endpoint_url} with arguments: {parameters}")
4601 # Build request data based on agent type
4602 if agent.agent_type in ["generic", "jsonrpc"] or agent.endpoint_url.endswith("/"):
4603 # JSONRPC agents: Convert flat query to nested message structure
4604 params = None
4605 if isinstance(parameters, dict) and "query" in parameters and isinstance(parameters["query"], str):
4606 # Build the nested message object for JSONRPC protocol
4607 message_id = f"admin-test-{int(time.time())}"
4608 # A2A v0.3.x: message.parts use "kind" (not "type").
4609 params = {
4610 "message": {
4611 "kind": "message",
4612 "messageId": message_id,
4613 "role": "user",
4614 "parts": [{"kind": "text", "text": parameters["query"]}],
4615 }
4616 }
4617 method = parameters.get("method", "message/send")
4618 else:
4619 # Already in correct format or unknown, pass through
4620 params = parameters.get("params", parameters)
4621 method = parameters.get("method", "message/send")
4623 try:
4624 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1}
4625 logger.info(f"invoke tool JSONRPC request_data prepared: {request_data}")
4626 except Exception as e:
4627 logger.error(f"Error preparing JSONRPC request data: {e}")
4628 raise
4629 else:
4630 # Custom agents: Pass parameters directly without JSONRPC message conversion
4631 # Custom agents expect flat fields like {"query": "...", "message": "..."}
4632 params = parameters if isinstance(parameters, dict) else {}
4633 logger.info(f"invoke tool Using custom A2A format for A2A agent '{params}'")
4634 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": agent.protocol_version}
4635 logger.info(f"invoke tool request_data prepared: {request_data}")
4636 # Make HTTP request to the agent endpoint using shared HTTP client
4637 # First-Party
4638 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
4640 client = await get_http_client()
4641 headers = {"Content-Type": "application/json"}
4643 # Determine the endpoint URL (may be modified for query_param auth)
4644 endpoint_url = agent.endpoint_url
4646 # Add authentication if configured
4647 if agent.auth_type == "api_key" and agent.auth_value:
4648 headers["Authorization"] = f"Bearer {agent.auth_value}"
4649 elif agent.auth_type == "bearer" and agent.auth_value:
4650 headers["Authorization"] = f"Bearer {agent.auth_value}"
4651 elif agent.auth_type == "query_param" and agent.auth_query_params:
4652 # Handle query parameter authentication (imports at top: decode_auth, apply_query_param_auth, sanitize_url_for_logging)
4653 auth_query_params_decrypted: dict[str, str] = {}
4654 for param_key, encrypted_value in agent.auth_query_params.items():
4655 if encrypted_value:
4656 try:
4657 decrypted = decode_auth(encrypted_value)
4658 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
4659 except Exception:
4660 logger.debug(f"Failed to decrypt query param for key '{param_key}'")
4661 if auth_query_params_decrypted:
4662 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted)
4663 # Log sanitized URL to avoid credential leakage
4664 sanitized_url = sanitize_url_for_logging(endpoint_url, auth_query_params_decrypted)
4665 logger.debug(f"Applied query param auth to A2A agent endpoint: {sanitized_url}")
4667 http_response = await client.post(endpoint_url, json=request_data, headers=headers)
4669 if http_response.status_code == 200:
4670 return http_response.json()
4672 raise Exception(f"HTTP {http_response.status_code}: {http_response.text}")
4675# Lazy singleton - created on first access, not at module import time.
4676# This avoids instantiation when only exception classes are imported.
4677_tool_service_instance = None # pylint: disable=invalid-name
4680def __getattr__(name: str):
4681 """Module-level __getattr__ for lazy singleton creation.
4683 Args:
4684 name: The attribute name being accessed.
4686 Returns:
4687 The tool_service singleton instance if name is "tool_service".
4689 Raises:
4690 AttributeError: If the attribute name is not "tool_service".
4691 """
4692 global _tool_service_instance # pylint: disable=global-statement
4693 if name == "tool_service":
4694 if _tool_service_instance is None:
4695 _tool_service_instance = ToolService()
4696 return _tool_service_instance
4697 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")