Coverage for mcpgateway / services / tool_service.py: 98%

2225 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/tool_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Tool Service Implementation. 

8This module implements tool management and invocation according to the MCP specification. 

9It handles: 

10- Tool registration and validation 

11- Tool invocation with schema validation 

12- Tool federation across gateways 

13- Event notifications for tool changes 

14- Active/inactive tool management 

15""" 

16 

17# Standard 

18import asyncio 

19import base64 

20import binascii 

21from datetime import datetime, timezone 

22from functools import lru_cache 

23import json # NOTE: httpx uses stdlib json, not orjson, so response.json() raises json.JSONDecodeError 

24import os 

25import re 

26import ssl 

27import time 

28from types import SimpleNamespace 

29from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union 

30from urllib.parse import parse_qs, urlparse 

31import uuid 

32 

33# Third-Party 

34import anyio 

35import httpx 

36import jq 

37import jsonschema 

38from jsonschema import Draft4Validator, Draft6Validator, Draft7Validator, validators 

39from mcp import ClientSession, types 

40from mcp.client.sse import sse_client 

41from mcp.client.streamable_http import streamablehttp_client 

42import orjson 

43from pydantic import ValidationError 

44from sqlalchemy import and_, delete, desc, or_, select 

45from sqlalchemy.exc import IntegrityError, OperationalError 

46from sqlalchemy.orm import joinedload, selectinload, Session 

47 

48# First-Party 

49from mcpgateway.cache.global_config_cache import global_config_cache 

50from mcpgateway.common.models import Gateway as PydanticGateway 

51from mcpgateway.common.models import TextContent 

52from mcpgateway.common.models import Tool as PydanticTool 

53from mcpgateway.common.models import ToolResult 

54from mcpgateway.common.validators import SecurityValidator 

55from mcpgateway.config import settings 

56from mcpgateway.db import A2AAgent as DbA2AAgent 

57from mcpgateway.db import fresh_db_session 

58from mcpgateway.db import Gateway as DbGateway 

59from mcpgateway.db import get_for_update, server_tool_association 

60from mcpgateway.db import Tool as DbTool 

61from mcpgateway.db import ToolMetric, ToolMetricsHourly 

62from mcpgateway.observability import create_child_span, create_span, inject_trace_context_headers, set_span_attribute, set_span_error 

63from mcpgateway.plugins.framework import ( 

64 get_plugin_manager, 

65 GlobalContext, 

66 HttpHeaderPayload, 

67 PluginContextTable, 

68 PluginError, 

69 PluginManager, 

70 PluginViolationError, 

71 ToolHookType, 

72 ToolPostInvokePayload, 

73 ToolPreInvokePayload, 

74) 

75from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA 

76from mcpgateway.schemas import AuthenticationValues, ToolCreate, ToolMetrics, ToolRead, ToolUpdate, TopPerformer 

77from mcpgateway.services.audit_trail_service import get_audit_trail_service 

78from mcpgateway.services.base_service import BaseService 

79from mcpgateway.services.event_service import EventService 

80from mcpgateway.services.logging_service import LoggingService 

81from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType 

82from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service 

83from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge 

84from mcpgateway.services.metrics_query_service import get_top_performers_combined 

85from mcpgateway.services.oauth_manager import OAuthManager 

86from mcpgateway.services.observability_service import current_trace_id, ObservabilityService 

87from mcpgateway.services.performance_tracker import get_performance_tracker 

88from mcpgateway.services.structured_logger import get_structured_logger 

89from mcpgateway.services.team_management_service import TeamManagementService 

90from mcpgateway.utils.correlation_id import get_correlation_id 

91from mcpgateway.utils.create_slug import slugify 

92from mcpgateway.utils.display_name import generate_display_name 

93from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers 

94from mcpgateway.utils.metrics_common import build_top_performers 

95from mcpgateway.utils.pagination import decode_cursor, encode_cursor, unified_paginate 

96from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached 

97from mcpgateway.utils.retry_manager import ResilientHttpClient 

98from mcpgateway.utils.services_auth import decode_auth, encode_auth 

99from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr 

100from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context 

101from mcpgateway.utils.trace_context import format_trace_team_scope 

102from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload 

103from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging 

104from mcpgateway.utils.validate_signature import validate_signature 

105 

106# Cache import (lazy to avoid circular dependencies) 

107_REGISTRY_CACHE = None 

108_TOOL_LOOKUP_CACHE = None 

109 

110 

111def _get_registry_cache(): 

112 """Get registry cache singleton lazily. 

113 

114 Returns: 

115 RegistryCache instance. 

116 """ 

117 global _REGISTRY_CACHE # pylint: disable=global-statement 

118 if _REGISTRY_CACHE is None: 

119 # First-Party 

120 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel 

121 

122 _REGISTRY_CACHE = registry_cache 

123 return _REGISTRY_CACHE 

124 

125 

126def _get_tool_lookup_cache(): 

127 """Get tool lookup cache singleton lazily. 

128 

129 Returns: 

130 ToolLookupCache instance. 

131 """ 

132 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement 

133 if _TOOL_LOOKUP_CACHE is None: 

134 # First-Party 

135 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

136 

137 _TOOL_LOOKUP_CACHE = tool_lookup_cache 

138 return _TOOL_LOOKUP_CACHE 

139 

140 

141# Initialize logging service first 

142logging_service = LoggingService() 

143logger = logging_service.get_logger(__name__) 

144 

145# Initialize performance tracker, structured logger, audit trail, and metrics buffer for tool operations 

146perf_tracker = get_performance_tracker() 

147structured_logger = get_structured_logger("tool_service") 

148audit_trail = get_audit_trail_service() 

149metrics_buffer = get_metrics_buffer_service() 

150 

151_ENCRYPTED_TOOL_HEADER_VALUE_KEY = "_mcpgateway_encrypted_header_value_v1" 

152_TOOL_HEADER_DATA_KEY = "data" 

153_TOOL_HEADER_LEGACY_VALUE_KEY = "value" 

154_SENSITIVE_TOOL_HEADER_PATTERNS = ( 

155 re.compile(r"^authorization$", re.IGNORECASE), 

156 re.compile(r"^proxy-authorization$", re.IGNORECASE), 

157 re.compile(r"^x-api-key$", re.IGNORECASE), 

158 re.compile(r"^api-key$", re.IGNORECASE), 

159 re.compile(r"^apikey$", re.IGNORECASE), 

160 # Keep broad-enough auth matching while avoiding operational noise from 

161 # non-secret tracing/idempotency headers (e.g. X-Correlation-Token). 

162 re.compile(r"^x-(?:auth|api|access|refresh|client|bearer|session|security)[-_]?(?:token|secret|key)$", re.IGNORECASE), 

163 re.compile(r"^(?:auth|api|access|refresh|client|bearer|session|security)[-_]?(?:token|secret|key)$", re.IGNORECASE), 

164) 

165 

166 

167def _is_sensitive_tool_header_name(name: str) -> bool: 

168 """Return whether a tool header name should be treated as sensitive. 

169 

170 Args: 

171 name: Header name to evaluate. 

172 

173 Returns: 

174 ``True`` when header value should be protected. 

175 """ 

176 normalized_name = str(name).strip().lower() 

177 return any(pattern.match(normalized_name) for pattern in _SENSITIVE_TOOL_HEADER_PATTERNS) 

178 

179 

180def _is_encrypted_tool_header_value(value: Any) -> bool: 

181 """Return whether a header value uses encrypted envelope format. 

182 

183 Args: 

184 value: Header value candidate. 

185 

186 Returns: 

187 ``True`` when value is an encrypted envelope mapping. 

188 """ 

189 return isinstance(value, dict) and isinstance(value.get(_ENCRYPTED_TOOL_HEADER_VALUE_KEY), str) 

190 

191 

192def _encrypt_tool_header_value(value: Any, existing_value: Any = None) -> Any: 

193 """Encrypt a single sensitive tool header value. 

194 

195 Args: 

196 value: Incoming header value from payload. 

197 existing_value: Existing stored value used for masked-value merges. 

198 

199 Returns: 

200 Encrypted envelope, preserved existing value, or ``None`` when cleared. 

201 """ 

202 if value is None or value == "": 

203 return value 

204 

205 if value == settings.masked_auth_value: 

206 if _is_encrypted_tool_header_value(existing_value): 

207 return existing_value 

208 if existing_value in (None, ""): 

209 return None 

210 return _encrypt_tool_header_value(existing_value, None) 

211 

212 if _is_encrypted_tool_header_value(value): 

213 return value 

214 

215 encrypted = encode_auth({_TOOL_HEADER_DATA_KEY: str(value)}) 

216 return {_ENCRYPTED_TOOL_HEADER_VALUE_KEY: encrypted} 

217 

218 

219def _protect_tool_headers_for_storage(headers: Optional[Dict[str, Any]], existing_headers: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: 

220 """Encrypt sensitive tool header values before persistence. 

221 

222 Args: 

223 headers: Incoming tool headers payload. 

224 existing_headers: Existing stored headers used for masked-value merges. 

225 

226 Returns: 

227 Header mapping with sensitive values protected for storage, or ``None``. 

228 """ 

229 if headers is None: 

230 return None 

231 if not isinstance(headers, dict): 

232 return None 

233 

234 existing_by_lower: Dict[str, Any] = {} 

235 if isinstance(existing_headers, dict): 

236 for key, existing_value in existing_headers.items(): 

237 existing_by_lower[str(key).strip().lower()] = existing_value 

238 

239 protected: Dict[str, Any] = {} 

240 for key, value in headers.items(): 

241 if _is_sensitive_tool_header_name(key): 

242 existing_value = existing_by_lower.get(str(key).strip().lower()) 

243 protected[key] = _encrypt_tool_header_value(value, existing_value) 

244 else: 

245 protected[key] = value 

246 return protected 

247 

248 

249def _decrypt_tool_header_value(value: Any) -> Any: 

250 """Decrypt a single tool header envelope when possible. 

251 

252 Args: 

253 value: Stored header value, possibly encrypted. 

254 

255 Returns: 

256 Decrypted plain value when envelope is valid, else original value. 

257 """ 

258 if not _is_encrypted_tool_header_value(value): 

259 return value 

260 

261 encrypted_payload = value.get(_ENCRYPTED_TOOL_HEADER_VALUE_KEY) 

262 if not encrypted_payload: 

263 return value 

264 

265 try: 

266 decoded = decode_auth(encrypted_payload) 

267 if isinstance(decoded, dict): 

268 if _TOOL_HEADER_DATA_KEY in decoded: 

269 return decoded[_TOOL_HEADER_DATA_KEY] 

270 if _TOOL_HEADER_LEGACY_VALUE_KEY in decoded: 

271 return decoded[_TOOL_HEADER_LEGACY_VALUE_KEY] 

272 except Exception as exc: 

273 logger.warning("Failed to decrypt tool header value: %s", exc) 

274 return value 

275 

276 

277def _decrypt_tool_headers_for_runtime(headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: 

278 """Decrypt tool header map for runtime outbound requests. 

279 

280 Args: 

281 headers: Stored header mapping. 

282 

283 Returns: 

284 Header mapping with encrypted values decrypted where possible. 

285 """ 

286 if not isinstance(headers, dict): 

287 return {} 

288 return {key: _decrypt_tool_header_value(value) for key, value in headers.items()} 

289 

290 

291def _handle_json_parse_error(response, error, is_error_response: bool = False) -> dict: 

292 """Handle JSON parsing failures with graceful fallback to raw text. 

293 

294 Args: 

295 response: The HTTP response object with .text attribute 

296 error: The exception that was raised during JSON parsing 

297 is_error_response: If True, logs as "error response", else "response" 

298 

299 Returns: 

300 Dictionary with response_text key containing the raw response text 

301 (truncated to REST_RESPONSE_TEXT_MAX_LENGTH if longer to avoid exposing sensitive data), 

302 or error details if response body is empty/None 

303 """ 

304 msg = "error response" if is_error_response else "response" 

305 if not response.text: 

306 logger.warning(f"Failed to parse JSON {msg}: {error}. Response body was empty.") 

307 return {"error": "Empty response body"} 

308 

309 max_length = settings.rest_response_text_max_length 

310 text = response.text[:max_length] if len(response.text) > max_length else response.text 

311 if len(response.text) > max_length: 

312 logger.warning(f"Failed to parse JSON {msg}: {error}. Response truncated from {len(response.text)} to {max_length} characters.") 

313 else: 

314 logger.warning(f"Failed to parse JSON {msg}: {error}") 

315 return {"response_text": text} 

316 

317 

318@lru_cache(maxsize=256) 

319def _compile_jq_filter(jq_filter: str): 

320 """Cache compiled jq filter program. 

321 

322 Args: 

323 jq_filter: The jq filter string to compile. 

324 

325 Returns: 

326 Compiled jq program object. 

327 

328 Raises: 

329 ValueError: If the jq filter is invalid. 

330 """ 

331 # pylint: disable=c-extension-no-member 

332 return jq.compile(jq_filter) 

333 

334 

335@lru_cache(maxsize=128) 

336def _get_validator_class_and_check(schema_json: str) -> Tuple[type, dict]: 

337 """Cache schema validation and validator class selection. 

338 

339 This caches the expensive operations: 

340 1. Deserializing the schema 

341 2. Selecting the appropriate validator class based on $schema 

342 3. Checking the schema is valid 

343 

344 Supports multiple JSON Schema drafts by using fallback validators when the 

345 auto-detected validator fails. This handles schemas using older draft features 

346 (e.g., Draft 4 style exclusiveMinimum: true) that are invalid in newer drafts. 

347 

348 Args: 

349 schema_json: Canonical JSON string of the schema (used as cache key). 

350 

351 Returns: 

352 Tuple of (validator_class, schema_dict) ready for instantiation. 

353 """ 

354 schema = orjson.loads(schema_json) 

355 

356 # First try auto-detection based on $schema 

357 validator_cls = validators.validator_for(schema) 

358 try: 

359 validator_cls.check_schema(schema) 

360 return validator_cls, schema 

361 except jsonschema.exceptions.SchemaError: 

362 pass 

363 

364 # Fallback: try older drafts that may accept schemas with legacy features 

365 # (e.g., Draft 4/6 style boolean exclusiveMinimum/exclusiveMaximum) 

366 for fallback_cls in [Draft7Validator, Draft6Validator, Draft4Validator]: 

367 try: 

368 fallback_cls.check_schema(schema) 

369 return fallback_cls, schema 

370 except jsonschema.exceptions.SchemaError: 

371 continue 

372 

373 # If no validator accepts the schema, use the original and let it fail 

374 # with a clear error message during validation 

375 validator_cls.check_schema(schema) 

376 return validator_cls, schema 

377 

378 

379def _canonicalize_schema(schema: dict) -> str: 

380 """Create a canonical JSON string of a schema for use as a cache key. 

381 

382 Args: 

383 schema: The JSON Schema dictionary. 

384 

385 Returns: 

386 Canonical JSON string with sorted keys. 

387 """ 

388 return orjson.dumps(schema, option=orjson.OPT_SORT_KEYS).decode() 

389 

390 

391def _validate_with_cached_schema(instance: Any, schema: dict) -> None: 

392 """Validate instance against schema using cached validator class. 

393 

394 Creates a fresh validator instance for thread safety, but reuses 

395 the cached validator class and schema check. Uses best_match to 

396 preserve jsonschema.validate() error selection semantics. 

397 

398 Args: 

399 instance: The data to validate. 

400 schema: The JSON Schema to validate against. 

401 

402 Raises: 

403 error: The best matching ValidationError from jsonschema validation. 

404 jsonschema.exceptions.ValidationError: If validation fails. 

405 jsonschema.exceptions.SchemaError: If the schema itself is invalid. 

406 """ 

407 schema_json = _canonicalize_schema(schema) 

408 validator_cls, checked_schema = _get_validator_class_and_check(schema_json) 

409 # Create fresh validator instance for thread safety 

410 validator = validator_cls(checked_schema) 

411 # Use best_match to match jsonschema.validate() error selection behavior 

412 error = jsonschema.exceptions.best_match(validator.iter_errors(instance)) 

413 if error is not None: 

414 raise error 

415 

416 

417def extract_using_jq(data, jq_filter=""): 

418 """ 

419 Extracts data from a given input (string, dict, or list) using a jq filter string. 

420 

421 Uses cached compiled jq programs for performance. 

422 

423 Args: 

424 data (str, dict, list): The input JSON data. Can be a string, dict, or list. 

425 jq_filter (str): The jq filter string to extract the desired data. 

426 

427 Returns: 

428 The result of applying the jq filter to the input data. 

429 

430 Examples: 

431 >>> extract_using_jq('{"a": 1, "b": 2}', '.a') 

432 [1] 

433 >>> extract_using_jq({'a': 1, 'b': 2}, '.b') 

434 [2] 

435 >>> extract_using_jq('[{"a": 1}, {"a": 2}]', '.[].a') 

436 [1, 2] 

437 >>> extract_using_jq('not a json', '.a') 

438 ['Invalid JSON string provided.'] 

439 >>> extract_using_jq({'a': 1}, '') 

440 {'a': 1} 

441 """ 

442 if not jq_filter or jq_filter == "": 

443 return data 

444 

445 # Validate that jq_filter looks like a valid jq expression 

446 jq_filter_str = str(jq_filter).strip() 

447 if not jq_filter_str: 

448 return data 

449 

450 # Check if it looks like an email address (common mistake when jsonpath_filter 

451 # field contains corrupted data). Intentionally simple regex to avoid false 

452 # positives with valid jq expressions like .foo|.bar 

453 if re.match(r"^[^.\[\]|]+@[^.\[\]|]+\.[^.\[\]|]+$", jq_filter_str): 

454 logger.warning(f"Invalid jq filter (email address): {jq_filter_str}. Treating as empty filter.") 

455 return data 

456 

457 # Track if input was originally a string (for error handling) 

458 was_string = isinstance(data, str) 

459 

460 if was_string: 

461 # If the input is a string, parse it as JSON 

462 try: 

463 data = orjson.loads(data) 

464 except orjson.JSONDecodeError: 

465 return ["Invalid JSON string provided."] 

466 elif not isinstance(data, (dict, list)): 

467 # If the input is not a string, dict, or list, raise an error 

468 return ["Input data must be a JSON string, dictionary, or list."] 

469 

470 # Apply the jq filter to the data using cached compiled program 

471 try: 

472 program = _compile_jq_filter(jq_filter) 

473 result = program.input(data).all() 

474 if result == [None]: 

475 return [TextContent(type="text", text="Error applying jsonpath filter")] 

476 except Exception as e: 

477 message = "Error applying jsonpath filter: " + str(e) 

478 return [TextContent(type="text", text=message)] 

479 

480 return result 

481 

482 

483class ToolError(Exception): 

484 """Base class for tool-related errors. 

485 

486 Examples: 

487 >>> from mcpgateway.services.tool_service import ToolError 

488 >>> err = ToolError("Something went wrong") 

489 >>> str(err) 

490 'Something went wrong' 

491 """ 

492 

493 

494class ToolNotFoundError(ToolError): 

495 """Raised when a requested tool is not found. 

496 

497 Examples: 

498 >>> from mcpgateway.services.tool_service import ToolNotFoundError 

499 >>> err = ToolNotFoundError("Tool xyz not found") 

500 >>> str(err) 

501 'Tool xyz not found' 

502 >>> isinstance(err, ToolError) 

503 True 

504 """ 

505 

506 

507class ToolNameConflictError(ToolError): 

508 """Raised when a tool name conflicts with existing (active or inactive) tool.""" 

509 

510 def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = None, visibility: str = "public"): 

511 """Initialize the error with tool information. 

512 

513 Args: 

514 name: The conflicting tool name. 

515 enabled: Whether the existing tool is enabled or not. 

516 tool_id: ID of the existing tool if available. 

517 visibility: The visibility of the tool ("public" or "team"). 

518 

519 Examples: 

520 >>> from mcpgateway.services.tool_service import ToolNameConflictError 

521 >>> err = ToolNameConflictError('test_tool', enabled=False, tool_id=123) 

522 >>> str(err) 

523 'Public Tool already exists with name: test_tool (currently inactive, ID: 123)' 

524 >>> err.name 

525 'test_tool' 

526 >>> err.enabled 

527 False 

528 >>> err.tool_id 

529 123 

530 """ 

531 self.name = name 

532 self.enabled = enabled 

533 self.tool_id = tool_id 

534 if visibility == "team": 

535 vis_label = "Team-level" 

536 elif visibility == "private": 

537 vis_label = "Private" 

538 else: 

539 vis_label = "Public" 

540 message = f"{vis_label} Tool already exists with name: {name}" 

541 if not enabled: 

542 message += f" (currently inactive, ID: {tool_id})" 

543 super().__init__(message) 

544 

545 

546class ToolLockConflictError(ToolError): 

547 """Raised when a tool row is locked by another transaction.""" 

548 

549 

550class ToolValidationError(ToolError): 

551 """Raised when tool validation fails. 

552 

553 Examples: 

554 >>> from mcpgateway.services.tool_service import ToolValidationError 

555 >>> err = ToolValidationError("Invalid tool configuration") 

556 >>> str(err) 

557 'Invalid tool configuration' 

558 >>> isinstance(err, ToolError) 

559 True 

560 """ 

561 

562 

563class ToolInvocationError(ToolError): 

564 """Raised when tool invocation fails. 

565 

566 Examples: 

567 >>> from mcpgateway.services.tool_service import ToolInvocationError 

568 >>> err = ToolInvocationError("Tool execution failed") 

569 >>> str(err) 

570 'Tool execution failed' 

571 >>> isinstance(err, ToolError) 

572 True 

573 >>> # Test with detailed error 

574 >>> detailed_err = ToolInvocationError("Network timeout after 30 seconds") 

575 >>> "timeout" in str(detailed_err) 

576 True 

577 >>> isinstance(err, ToolError) 

578 True 

579 """ 

580 

581 

582class ToolTimeoutError(ToolInvocationError): 

583 """Raised when tool invocation times out. 

584 

585 This subclass is used to distinguish timeout errors from other invocation errors. 

586 Timeout handlers call tool_post_invoke before raising this, so the generic exception 

587 handler should skip calling post_invoke again to avoid double-counting failures. 

588 

589 Attributes: 

590 retry_delay_ms: Delay in milliseconds requested by the retry plugin. 

591 0 (default) means no retry. Set by the timeout handler after 

592 invoking the post-invoke hook so the outer catch block can honour 

593 the signal without calling post_invoke a second time. 

594 """ 

595 

596 def __init__(self, message: str, retry_delay_ms: int = 0) -> None: 

597 """Initialise with an optional retry delay from the post-invoke hook. 

598 

599 Args: 

600 message: Human-readable error description. 

601 retry_delay_ms: Milliseconds the gateway should wait before retrying. 

602 """ 

603 super().__init__(message) 

604 self.retry_delay_ms = retry_delay_ms 

605 

606 

607class ToolService(BaseService): 

608 """Service for managing and invoking tools. 

609 

610 Handles: 

611 - Tool registration and deregistration. 

612 - Tool invocation and validation. 

613 - Tool federation. 

614 - Event notifications. 

615 - Active/inactive tool management. 

616 """ 

617 

618 _visibility_model_cls = DbTool 

619 

620 def __init__(self) -> None: 

621 """Initialize the tool service. 

622 

623 Examples: 

624 >>> from mcpgateway.services.tool_service import ToolService 

625 >>> service = ToolService() 

626 >>> isinstance(service._event_service, EventService) 

627 True 

628 >>> hasattr(service, '_http_client') 

629 True 

630 """ 

631 self._event_service = EventService(channel_name="mcpgateway:tool_events") 

632 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) 

633 self._plugin_manager: PluginManager | None = get_plugin_manager() 

634 self.oauth_manager = OAuthManager( 

635 request_timeout=int(settings.oauth_request_timeout if hasattr(settings, "oauth_request_timeout") else 30), 

636 max_retries=int(settings.oauth_max_retries if hasattr(settings, "oauth_max_retries") else 3), 

637 ) 

638 

639 async def initialize(self) -> None: 

640 """Initialize the service. 

641 

642 Examples: 

643 >>> from mcpgateway.services.tool_service import ToolService 

644 >>> service = ToolService() 

645 >>> import asyncio 

646 >>> asyncio.run(service.initialize()) # Should log "Initializing tool service" 

647 """ 

648 logger.info("Initializing tool service") 

649 await self._event_service.initialize() 

650 

651 async def shutdown(self) -> None: 

652 """Shutdown the service. 

653 

654 Examples: 

655 >>> from mcpgateway.services.tool_service import ToolService 

656 >>> service = ToolService() 

657 >>> import asyncio 

658 >>> asyncio.run(service.shutdown()) # Should log "Tool service shutdown complete" 

659 """ 

660 await self._http_client.aclose() 

661 await self._event_service.shutdown() 

662 logger.info("Tool service shutdown complete") 

663 

664 async def get_top_tools(self, db: Session, limit: Optional[int] = 5, include_deleted: bool = False) -> List[TopPerformer]: 

665 """Retrieve the top-performing tools based on execution count. 

666 

667 Queries the database to get tools with their metrics, ordered by the number of executions 

668 in descending order. Returns a list of TopPerformer objects containing tool details and 

669 performance metrics. Results are cached for performance. 

670 

671 Args: 

672 db (Session): Database session for querying tool metrics. 

673 limit (Optional[int]): Maximum number of tools to return. Defaults to 5. 

674 include_deleted (bool): Whether to include deleted tools from rollups. 

675 

676 Returns: 

677 List[TopPerformer]: A list of TopPerformer objects, each containing: 

678 - id: Tool ID. 

679 - name: Tool name. 

680 - execution_count: Total number of executions. 

681 - avg_response_time: Average response time in seconds, or None if no metrics. 

682 - success_rate: Success rate percentage, or None if no metrics. 

683 - last_execution: Timestamp of the last execution, or None if no metrics. 

684 """ 

685 # Check cache first (if enabled) 

686 # First-Party 

687 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel 

688 

689 effective_limit = limit or 5 

690 cache_key = f"top_tools:{effective_limit}:include_deleted={include_deleted}" 

691 

692 if is_cache_enabled(): 

693 cached = metrics_cache.get(cache_key) 

694 if cached is not None: 

695 return cached 

696 

697 # Use combined query that includes both raw metrics and rollup data 

698 results = get_top_performers_combined( 

699 db=db, 

700 metric_type="tool", 

701 entity_model=DbTool, 

702 limit=effective_limit, 

703 include_deleted=include_deleted, 

704 ) 

705 top_performers = build_top_performers(results) 

706 

707 # Cache the result (if enabled) 

708 if is_cache_enabled(): 

709 metrics_cache.set(cache_key, top_performers) 

710 

711 return top_performers 

712 

713 def _build_tool_cache_payload(self, tool: DbTool, gateway: Optional[DbGateway]) -> Dict[str, Any]: 

714 """Build cache payload for tool lookup by name. 

715 

716 Args: 

717 tool: Tool ORM instance. 

718 gateway: Optional gateway ORM instance. 

719 

720 Returns: 

721 Cache payload dict for tool lookup. 

722 """ 

723 tool_payload = { 

724 "id": str(tool.id), 

725 "name": tool.name, 

726 "original_name": tool.original_name, 

727 "url": tool.url, 

728 "description": tool.description, 

729 "original_description": tool.original_description, 

730 "integration_type": tool.integration_type, 

731 "request_type": tool.request_type, 

732 "headers": tool.headers or {}, 

733 "input_schema": tool.input_schema or {"type": "object", "properties": {}}, 

734 "output_schema": tool.output_schema, 

735 "annotations": tool.annotations or {}, 

736 "auth_type": tool.auth_type, 

737 "jsonpath_filter": tool.jsonpath_filter, 

738 "custom_name": tool.custom_name, 

739 "custom_name_slug": tool.custom_name_slug, 

740 "display_name": tool.display_name, 

741 "gateway_id": str(tool.gateway_id) if tool.gateway_id else None, 

742 "enabled": bool(tool.enabled), 

743 "reachable": bool(tool.reachable), 

744 "tags": tool.tags or [], 

745 "team_id": tool.team_id, 

746 "owner_email": tool.owner_email, 

747 "visibility": tool.visibility, 

748 } 

749 

750 gateway_payload = None 

751 if gateway: 

752 gateway_payload = { 

753 "id": str(gateway.id), 

754 "name": gateway.name, 

755 "url": gateway.url, 

756 "description": gateway.description, 

757 "slug": gateway.slug, 

758 "transport": gateway.transport, 

759 "capabilities": gateway.capabilities or {}, 

760 "passthrough_headers": gateway.passthrough_headers or [], 

761 "auth_type": gateway.auth_type, 

762 "ca_certificate": getattr(gateway, "ca_certificate", None), 

763 "ca_certificate_sig": getattr(gateway, "ca_certificate_sig", None), 

764 "enabled": bool(gateway.enabled), 

765 "reachable": bool(gateway.reachable), 

766 "team_id": gateway.team_id, 

767 "owner_email": gateway.owner_email, 

768 "visibility": gateway.visibility, 

769 "tags": gateway.tags or [], 

770 "gateway_mode": getattr(gateway, "gateway_mode", "cache"), # Gateway mode for direct proxy support 

771 "client_cert": getattr(gateway, "client_cert", None), 

772 "client_key": getattr(gateway, "client_key", None), 

773 } 

774 

775 return {"status": "active", "tool": tool_payload, "gateway": gateway_payload} 

776 

777 def _pydantic_tool_from_payload(self, tool_payload: Dict[str, Any]) -> Optional[PydanticTool]: 

778 """Build Pydantic tool metadata from cache payload. 

779 

780 Args: 

781 tool_payload: Cached tool payload dict. 

782 

783 Returns: 

784 Pydantic tool metadata or None if validation fails. 

785 """ 

786 try: 

787 return PydanticTool.model_validate(tool_payload) 

788 except Exception as exc: 

789 logger.debug("Failed to build PydanticTool from cache payload: %s", exc) 

790 return None 

791 

792 def _pydantic_gateway_from_payload(self, gateway_payload: Dict[str, Any]) -> Optional[PydanticGateway]: 

793 """Build Pydantic gateway metadata from cache payload. 

794 

795 Args: 

796 gateway_payload: Cached gateway payload dict. 

797 

798 Returns: 

799 Pydantic gateway metadata or None if validation fails. 

800 """ 

801 try: 

802 return PydanticGateway.model_validate(gateway_payload) 

803 except Exception as exc: 

804 logger.debug("Failed to build PydanticGateway from cache payload: %s", exc) 

805 return None 

806 

807 async def _check_tool_access( 

808 self, 

809 db: Session, 

810 tool_payload: Dict[str, Any], 

811 user_email: Optional[str], 

812 token_teams: Optional[List[str]], 

813 ) -> bool: 

814 """Check if user has access to a tool based on visibility rules. 

815 

816 Implements the same access control logic as list_tools() for consistency. 

817 

818 Access Rules: 

819 - Public tools: Accessible by all authenticated users 

820 - Team tools: Accessible by team members (team_id in user's teams) 

821 - Private tools: Accessible only by owner (owner_email matches) 

822 

823 Args: 

824 db: Database session for team membership lookup if needed. 

825 tool_payload: Tool data dict with visibility, team_id, owner_email. 

826 user_email: Email of the requesting user (None = unauthenticated). 

827 token_teams: List of team IDs from token. 

828 - None = unrestricted admin access 

829 - [] = public-only token 

830 - [...] = team-scoped token 

831 

832 Returns: 

833 True if access is allowed, False otherwise. 

834 """ 

835 visibility = tool_payload.get("visibility", "public") 

836 tool_team_id = tool_payload.get("team_id") 

837 tool_owner_email = tool_payload.get("owner_email") 

838 

839 # Public tools are accessible by everyone 

840 if visibility == "public": 

841 return True 

842 

843 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin 

844 # This happens when is_admin=True and no team scoping in token 

845 if token_teams is None and user_email is None: 

846 return True 

847 

848 # No user context (but not admin) = deny access to non-public tools 

849 if not user_email: 

850 return False 

851 

852 # Public-only tokens (empty teams array) can ONLY access public tools 

853 is_public_only_token = token_teams is not None and len(token_teams) == 0 

854 if is_public_only_token: 

855 return False # Already checked public above 

856 

857 # Owner can access their own private tools 

858 if visibility == "private" and tool_owner_email and tool_owner_email == user_email: 

859 return True 

860 

861 # Team tools: check team membership (matches list_tools behavior) 

862 if tool_team_id: 

863 # Use token_teams if provided, otherwise look up from DB 

864 if token_teams is not None: 

865 team_ids = token_teams 

866 else: 

867 team_service = TeamManagementService(db) 

868 user_teams = await team_service.get_user_teams(user_email) 

869 team_ids = [team.id for team in user_teams] 

870 

871 # Team/public visibility allows access if user is in the team 

872 if visibility in ["team", "public"] and tool_team_id in team_ids: 

873 return True 

874 

875 return False 

876 

877 def convert_tool_to_read( 

878 self, 

879 tool: DbTool, 

880 include_metrics: bool = False, 

881 include_auth: bool = True, 

882 requesting_user_email: Optional[str] = None, 

883 requesting_user_is_admin: bool = False, 

884 requesting_user_team_roles: Optional[Dict[str, str]] = None, 

885 ) -> ToolRead: 

886 """Converts a DbTool instance into a ToolRead model, including aggregated metrics and 

887 new API gateway fields: request_type and authentication credentials (masked). 

888 

889 Args: 

890 tool (DbTool): The ORM instance of the tool. 

891 include_metrics (bool): Whether to include metrics in the result. Defaults to False. 

892 include_auth (bool): Whether to decode and include auth details. Defaults to True. 

893 When False, skips expensive AES-GCM decryption and returns minimal auth info. 

894 requesting_user_email (Optional[str]): Email of the requesting user for header masking. 

895 requesting_user_is_admin (bool): Whether the requester is an admin. 

896 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester. 

897 

898 Returns: 

899 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields. 

900 """ 

901 # NOTE: This serves two purposes: 

902 # 1. It determines whether to decode auth (used later) 

903 # 2. It forces the tool object to lazily evaluate (required before copy) 

904 has_encrypted_auth = tool.auth_type and tool.auth_value 

905 

906 # Copy the dict from the tool 

907 tool_dict = tool.__dict__.copy() 

908 tool_dict.pop("_sa_instance_state", None) 

909 

910 # Compute metrics in a single pass (matches server/resource/prompt service pattern) 

911 if include_metrics: 

912 metrics = tool.metrics_summary # Single-pass computation 

913 tool_dict["metrics"] = metrics 

914 tool_dict["execution_count"] = metrics["total_executions"] 

915 else: 

916 tool_dict["metrics"] = None 

917 tool_dict["execution_count"] = None 

918 

919 tool_dict["request_type"] = tool.request_type 

920 tool_dict["annotations"] = tool.annotations or {} 

921 

922 # Only decode auth if include_auth=True AND we have encrypted credentials 

923 if include_auth and has_encrypted_auth: 

924 decoded_auth_value = decode_auth(tool.auth_value) 

925 if tool.auth_type == "basic": 

926 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1]) 

927 username, password = decoded_bytes.decode("utf-8").split(":") 

928 tool_dict["auth"] = { 

929 "auth_type": "basic", 

930 "username": username, 

931 "password": settings.masked_auth_value if password else None, 

932 } 

933 elif tool.auth_type == "bearer": 

934 tool_dict["auth"] = { 

935 "auth_type": "bearer", 

936 "token": settings.masked_auth_value if decoded_auth_value["Authorization"] else None, 

937 } 

938 elif tool.auth_type == "authheaders": 

939 # Support multi-header format (list of {key, value} dicts) 

940 if decoded_auth_value: 

941 # Convert decoded dict to list format for frontend 

942 auth_headers = [ 

943 { 

944 "key": key, 

945 "value": settings.masked_auth_value if value else None, 

946 } 

947 for key, value in decoded_auth_value.items() 

948 ] 

949 # Also include legacy single-header fields for backward compatibility 

950 first_key = next(iter(decoded_auth_value)) 

951 tool_dict["auth"] = { 

952 "auth_type": "authheaders", 

953 "authHeaders": auth_headers, # Multi-header format (masked) 

954 "auth_header_key": first_key, # Legacy format 

955 "auth_header_value": settings.masked_auth_value if decoded_auth_value[first_key] else None, # Legacy format 

956 } 

957 else: 

958 tool_dict["auth"] = None 

959 else: 

960 tool_dict["auth"] = None 

961 elif not include_auth and has_encrypted_auth: 

962 # LIST VIEW: Minimal auth info without decryption 

963 # Only show auth_type for tools that have encrypted credentials 

964 tool_dict["auth"] = {"auth_type": tool.auth_type} 

965 else: 

966 # No encrypted auth (includes OAuth tools where auth_value=None) 

967 # Behavior unchanged from current implementation 

968 tool_dict["auth"] = None 

969 

970 tool_dict["name"] = tool.name 

971 # Handle displayName with fallback and None checks 

972 display_name = getattr(tool, "display_name", None) 

973 custom_name = getattr(tool, "custom_name", tool.original_name) 

974 tool_dict["displayName"] = display_name or custom_name 

975 tool_dict["custom_name"] = custom_name 

976 tool_dict["gateway_slug"] = getattr(tool, "gateway_slug", "") or "" 

977 tool_dict["custom_name_slug"] = getattr(tool, "custom_name_slug", "") or "" 

978 tool_dict["tags"] = getattr(tool, "tags", []) or [] 

979 tool_dict["team"] = getattr(tool, "team", None) 

980 

981 # Mask custom headers unless the requester is allowed to modify this tool. 

982 # Safe default: if no requester context is provided, mask everything. 

983 headers = tool_dict.get("headers") 

984 if headers: 

985 tool_dict["headers"] = _decrypt_tool_headers_for_runtime(headers) 

986 headers = tool_dict["headers"] 

987 can_view = requesting_user_is_admin 

988 if not can_view and getattr(tool, "owner_email", None) == requesting_user_email: 

989 can_view = True 

990 if ( 

991 not can_view 

992 and getattr(tool, "visibility", None) == "team" 

993 and getattr(tool, "team_id", None) is not None 

994 and requesting_user_team_roles 

995 and requesting_user_team_roles.get(str(tool.team_id)) == "owner" 

996 ): 

997 can_view = True 

998 if not can_view: 

999 tool_dict["headers"] = {k: settings.masked_auth_value for k in headers} 

1000 

1001 return ToolRead.model_validate(tool_dict) 

1002 

1003 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None: 

1004 """ 

1005 Records a metric for a tool invocation. 

1006 

1007 This function calculates the response time using the provided start time and records 

1008 the metric details (including whether the invocation was successful and any error message) 

1009 into the database. The metric is then committed to the database. 

1010 

1011 Args: 

1012 db (Session): The SQLAlchemy database session. 

1013 tool (DbTool): The tool that was invoked. 

1014 start_time (float): The monotonic start time of the invocation. 

1015 success (bool): True if the invocation succeeded; otherwise, False. 

1016 error_message (Optional[str]): The error message if the invocation failed, otherwise None. 

1017 """ 

1018 end_time = time.monotonic() 

1019 response_time = end_time - start_time 

1020 metric = ToolMetric( 

1021 tool_id=tool.id, 

1022 response_time=response_time, 

1023 is_success=success, 

1024 error_message=error_message, 

1025 ) 

1026 db.add(metric) 

1027 db.commit() 

1028 

1029 def _record_tool_metric_by_id( 

1030 self, 

1031 db: Session, 

1032 tool_id: str, 

1033 start_time: float, 

1034 success: bool, 

1035 error_message: Optional[str], 

1036 ) -> None: 

1037 """Record tool metric using tool ID instead of ORM object. 

1038 

1039 This method is designed to be used with a fresh database session after the main 

1040 request session has been released. It avoids requiring the ORM tool object, 

1041 which may have been detached from the session. 

1042 

1043 Args: 

1044 db: A fresh database session (not the request session). 

1045 tool_id: The UUID string of the tool. 

1046 start_time: The monotonic start time of the invocation. 

1047 success: True if the invocation succeeded; otherwise, False. 

1048 error_message: The error message if the invocation failed, otherwise None. 

1049 """ 

1050 end_time = time.monotonic() 

1051 response_time = end_time - start_time 

1052 metric = ToolMetric( 

1053 tool_id=tool_id, 

1054 response_time=response_time, 

1055 is_success=success, 

1056 error_message=error_message, 

1057 ) 

1058 db.add(metric) 

1059 db.commit() 

1060 

1061 def _record_tool_metric_sync( 

1062 self, 

1063 tool_id: str, 

1064 start_time: float, 

1065 success: bool, 

1066 error_message: Optional[str], 

1067 ) -> None: 

1068 """Synchronous helper to record tool metrics with its own session. 

1069 

1070 This method creates a fresh database session, records the metric, and closes 

1071 the session. Designed to be called via asyncio.to_thread() to avoid blocking 

1072 the event loop. 

1073 

1074 Args: 

1075 tool_id: The UUID string of the tool. 

1076 start_time: The monotonic start time of the invocation. 

1077 success: True if the invocation succeeded; otherwise, False. 

1078 error_message: The error message if the invocation failed, otherwise None. 

1079 """ 

1080 with fresh_db_session() as db_metrics: 

1081 self._record_tool_metric_by_id( 

1082 db_metrics, 

1083 tool_id=tool_id, 

1084 start_time=start_time, 

1085 success=success, 

1086 error_message=error_message, 

1087 ) 

1088 

1089 def _extract_and_validate_structured_content(self, tool: DbTool, tool_result: "ToolResult", candidate: Optional[Any] = None) -> bool: 

1090 """ 

1091 Extract structured content (if any) and validate it against ``tool.output_schema``. 

1092 

1093 Args: 

1094 tool: The tool with an optional output schema to validate against. 

1095 tool_result: The tool result containing content to validate. 

1096 candidate: Optional structured payload to validate. If not provided, will attempt 

1097 to parse the first TextContent item as JSON. 

1098 

1099 Behavior: 

1100 - If ``candidate`` is provided it is used as the structured payload to validate. 

1101 - Otherwise the method will try to parse the first ``TextContent`` item in 

1102 ``tool_result.content`` as JSON and use that as the candidate. 

1103 - If no output schema is declared on the tool the method returns True (nothing to validate). 

1104 - On successful validation the parsed value is attached to ``tool_result.structured_content``. 

1105 When structured content is present and valid callers may drop textual ``content`` in favour 

1106 of the structured payload. 

1107 - On validation failure the method sets ``tool_result.content`` to a single ``TextContent`` 

1108 containing a compact JSON object describing the validation error, sets 

1109 ``tool_result.is_error = True`` and returns False. 

1110 

1111 Returns: 

1112 True when the structured content is valid or when no schema is declared. 

1113 False when validation fails. 

1114 

1115 Examples: 

1116 >>> from mcpgateway.services.tool_service import ToolService 

1117 >>> from mcpgateway.common.models import TextContent, ToolResult 

1118 >>> import json 

1119 >>> service = ToolService() 

1120 >>> # No schema declared -> nothing to validate 

1121 >>> tool = type("T", (object,), {"output_schema": None})() 

1122 >>> r = ToolResult(content=[TextContent(type="text", text='{"a":1}')]) 

1123 >>> service._extract_and_validate_structured_content(tool, r) 

1124 True 

1125 

1126 >>> # Valid candidate provided -> attaches structured_content and returns True 

1127 >>> tool = type( 

1128 ... "T", 

1129 ... (object,), 

1130 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}}, 

1131 ... )() 

1132 >>> r = ToolResult(content=[]) 

1133 >>> service._extract_and_validate_structured_content(tool, r, candidate={"foo": "bar"}) 

1134 True 

1135 >>> r.structured_content == {"foo": "bar"} 

1136 True 

1137 

1138 >>> # Invalid candidate -> returns False, marks result as error and emits details 

1139 >>> tool = type( 

1140 ... "T", 

1141 ... (object,), 

1142 ... {"output_schema": {"type": "object", "properties": {"foo": {"type": "string"}}, "required": ["foo"]}}, 

1143 ... )() 

1144 >>> r = ToolResult(content=[]) 

1145 >>> ok = service._extract_and_validate_structured_content(tool, r, candidate={"foo": 123}) 

1146 >>> ok 

1147 False 

1148 >>> r.is_error 

1149 True 

1150 >>> details = orjson.loads(r.content[0].text) 

1151 >>> "received" in details 

1152 True 

1153 """ 

1154 try: 

1155 output_schema = getattr(tool, "output_schema", None) 

1156 # Nothing to do if the tool doesn't declare a schema 

1157 if not output_schema: 

1158 return True 

1159 

1160 structured: Optional[Any] = None 

1161 # Prefer explicit candidate 

1162 if candidate is not None: 

1163 structured = candidate 

1164 else: 

1165 # Try to parse first TextContent text payload as JSON 

1166 for c in getattr(tool_result, "content", []) or []: 

1167 try: 

1168 if isinstance(c, dict) and "type" in c and c.get("type") == "text" and "text" in c: 

1169 structured = orjson.loads(c.get("text") or "null") 

1170 break 

1171 except (orjson.JSONDecodeError, TypeError, ValueError): 

1172 # ignore JSON parse errors and continue 

1173 continue 

1174 

1175 # If no structured data found, treat as valid (nothing to validate) 

1176 if structured is None: 

1177 return True 

1178 

1179 # Try to normalize common wrapper shapes to match schema expectations 

1180 schema_type = None 

1181 try: 

1182 if isinstance(output_schema, dict): 

1183 schema_type = output_schema.get("type") 

1184 except Exception: 

1185 schema_type = None 

1186 

1187 # Unwrap single-element list wrappers when schema expects object 

1188 if isinstance(structured, list) and len(structured) == 1 and schema_type == "object": 

1189 inner = structured[0] 

1190 # If inner is a TextContent-like dict with 'text' JSON string, parse it 

1191 if isinstance(inner, dict) and "text" in inner and "type" in inner and inner.get("type") == "text": 

1192 try: 

1193 structured = orjson.loads(inner.get("text") or "null") 

1194 except Exception: 

1195 # leave as-is if parsing fails 

1196 structured = inner 

1197 else: 

1198 structured = inner 

1199 

1200 # Attach structured content 

1201 try: 

1202 setattr(tool_result, "structured_content", structured) 

1203 except Exception: 

1204 logger.debug("Failed to set structured_content on ToolResult") 

1205 

1206 # Validate using cached schema validator 

1207 try: 

1208 _validate_with_cached_schema(structured, output_schema) 

1209 return True 

1210 except jsonschema.exceptions.ValidationError as e: 

1211 details = { 

1212 "code": getattr(e, "validator", "validation_error"), 

1213 "expected": e.schema.get("type") if isinstance(e.schema, dict) and "type" in e.schema else None, 

1214 "received": type(e.instance).__name__.lower() if e.instance is not None else None, 

1215 "path": list(e.absolute_path) if hasattr(e, "absolute_path") else list(e.path or []), 

1216 "message": e.message, 

1217 } 

1218 try: 

1219 tool_result.content = [TextContent(type="text", text=orjson.dumps(details).decode())] 

1220 except Exception: 

1221 tool_result.content = [TextContent(type="text", text=str(details))] 

1222 tool_result.is_error = True 

1223 logger.debug(f"structured_content validation failed for tool {getattr(tool, 'name', '<unknown>')}: {details}") 

1224 return False 

1225 except Exception as exc: # pragma: no cover - defensive 

1226 logger.error(f"Error extracting/validating structured_content: {exc}") 

1227 return False 

1228 

1229 async def register_tool( 

1230 self, 

1231 db: Session, 

1232 tool: ToolCreate, 

1233 created_by: Optional[str] = None, 

1234 created_from_ip: Optional[str] = None, 

1235 created_via: Optional[str] = None, 

1236 created_user_agent: Optional[str] = None, 

1237 import_batch_id: Optional[str] = None, 

1238 federation_source: Optional[str] = None, 

1239 team_id: Optional[str] = None, 

1240 owner_email: Optional[str] = None, 

1241 visibility: str = None, 

1242 ) -> ToolRead: 

1243 """Register a new tool with team support. 

1244 

1245 Args: 

1246 db: Database session. 

1247 tool: Tool creation schema. 

1248 created_by: Username who created this tool. 

1249 created_from_ip: IP address of creator. 

1250 created_via: Creation method (ui, api, import, federation). 

1251 created_user_agent: User agent of creation request. 

1252 import_batch_id: UUID for bulk import operations. 

1253 federation_source: Source gateway for federated tools. 

1254 team_id: Optional team ID to assign tool to. 

1255 owner_email: Optional owner email for tool ownership. 

1256 visibility: Tool visibility (private, team, public). 

1257 

1258 Returns: 

1259 Created tool information. 

1260 

1261 Raises: 

1262 IntegrityError: If there is a database integrity error. 

1263 ToolNameConflictError: If a tool with the same name and visibility public exists. 

1264 ToolError: For other tool registration errors. 

1265 

1266 Examples: 

1267 >>> from mcpgateway.services.tool_service import ToolService 

1268 >>> from unittest.mock import MagicMock, AsyncMock 

1269 >>> from mcpgateway.schemas import ToolRead 

1270 >>> service = ToolService() 

1271 >>> db = MagicMock() 

1272 >>> tool = MagicMock() 

1273 >>> tool.name = 'test' 

1274 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

1275 >>> mock_gateway = MagicMock() 

1276 >>> mock_gateway.name = 'test_gateway' 

1277 >>> db.add = MagicMock() 

1278 >>> db.commit = MagicMock() 

1279 >>> def mock_refresh(obj): 

1280 ... obj.gateway = mock_gateway 

1281 >>> db.refresh = MagicMock(side_effect=mock_refresh) 

1282 >>> service._notify_tool_added = AsyncMock() 

1283 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read') 

1284 >>> ToolRead.model_validate = MagicMock(return_value='tool_read') 

1285 >>> import asyncio 

1286 >>> asyncio.run(service.register_tool(db, tool)) 

1287 'tool_read' 

1288 """ 

1289 try: 

1290 if tool.auth is None: 

1291 auth_type = None 

1292 auth_value = None 

1293 else: 

1294 auth_type = tool.auth.auth_type 

1295 auth_value = tool.auth.auth_value 

1296 

1297 if team_id is None: 

1298 team_id = tool.team_id 

1299 

1300 if owner_email is None: 

1301 owner_email = tool.owner_email 

1302 

1303 if visibility is None: 

1304 visibility = tool.visibility or "public" 

1305 # Check for existing tool with the same name and visibility 

1306 if visibility.lower() == "public": 

1307 # Check for existing public tool with the same name 

1308 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "public")).scalar_one_or_none() # pylint: disable=comparison-with-callable 

1309 if existing_tool: 

1310 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) 

1311 elif visibility.lower() == "team" and team_id: 

1312 # Check for existing team tool with the same name, team_id 

1313 existing_tool = db.execute( 

1314 select(DbTool).where(DbTool.name == tool.name, DbTool.visibility == "team", DbTool.team_id == team_id) # pylint: disable=comparison-with-callable 

1315 ).scalar_one_or_none() 

1316 if existing_tool: 

1317 raise ToolNameConflictError(existing_tool.name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) 

1318 

1319 db_tool = DbTool( 

1320 original_name=tool.name, 

1321 custom_name=tool.name, 

1322 custom_name_slug=slugify(tool.name), 

1323 display_name=tool.displayName or tool.name, 

1324 title=tool.title, 

1325 url=str(tool.url), 

1326 description=tool.description, 

1327 original_description=tool.description, 

1328 integration_type=tool.integration_type, 

1329 request_type=tool.request_type, 

1330 headers=_protect_tool_headers_for_storage(tool.headers), 

1331 input_schema=tool.input_schema, 

1332 output_schema=tool.output_schema, 

1333 annotations=tool.annotations, 

1334 jsonpath_filter=tool.jsonpath_filter, 

1335 auth_type=auth_type, 

1336 auth_value=auth_value, 

1337 gateway_id=tool.gateway_id, 

1338 tags=tool.tags or [], 

1339 # Metadata fields 

1340 created_by=created_by, 

1341 created_from_ip=created_from_ip, 

1342 created_via=created_via, 

1343 created_user_agent=created_user_agent, 

1344 import_batch_id=import_batch_id, 

1345 federation_source=federation_source, 

1346 version=1, 

1347 # Team scoping fields 

1348 team_id=team_id, 

1349 owner_email=owner_email or created_by, 

1350 visibility=visibility, 

1351 # passthrough REST tools fields 

1352 base_url=tool.base_url if tool.integration_type == "REST" else None, 

1353 path_template=tool.path_template if tool.integration_type == "REST" else None, 

1354 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None, 

1355 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None, 

1356 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None, 

1357 expose_passthrough=(tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None, 

1358 allowlist=tool.allowlist if tool.integration_type == "REST" else None, 

1359 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None, 

1360 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None, 

1361 ) 

1362 db.add(db_tool) 

1363 db.commit() 

1364 db.refresh(db_tool) 

1365 await self._notify_tool_added(db_tool) 

1366 

1367 # Structured logging: Audit trail for tool creation 

1368 audit_trail.log_action( 

1369 user_id=created_by or "system", 

1370 action="create_tool", 

1371 resource_type="tool", 

1372 resource_id=db_tool.id, 

1373 resource_name=db_tool.name, 

1374 user_email=owner_email, 

1375 team_id=team_id, 

1376 client_ip=created_from_ip, 

1377 user_agent=created_user_agent, 

1378 new_values={ 

1379 "name": db_tool.name, 

1380 "display_name": db_tool.display_name, 

1381 "visibility": visibility, 

1382 "integration_type": db_tool.integration_type, 

1383 }, 

1384 context={ 

1385 "created_via": created_via, 

1386 "import_batch_id": import_batch_id, 

1387 "federation_source": federation_source, 

1388 }, 

1389 db=db, 

1390 ) 

1391 

1392 # Structured logging: Log successful tool creation 

1393 structured_logger.log( 

1394 level="INFO", 

1395 message="Tool created successfully", 

1396 event_type="tool_created", 

1397 component="tool_service", 

1398 user_id=created_by, 

1399 user_email=owner_email, 

1400 team_id=team_id, 

1401 resource_type="tool", 

1402 resource_id=db_tool.id, 

1403 custom_fields={ 

1404 "tool_name": db_tool.name, 

1405 "visibility": visibility, 

1406 "integration_type": db_tool.integration_type, 

1407 }, 

1408 ) 

1409 

1410 # Refresh db_tool after logging commits (they expire the session objects) 

1411 db.refresh(db_tool) 

1412 

1413 # Invalidate cache after successful creation 

1414 cache = _get_registry_cache() 

1415 await cache.invalidate_tools() 

1416 tool_lookup_cache = _get_tool_lookup_cache() 

1417 await tool_lookup_cache.invalidate(db_tool.name, gateway_id=str(db_tool.gateway_id) if db_tool.gateway_id else None) 

1418 # Also invalidate tags cache since tool tags may have changed 

1419 # First-Party 

1420 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

1421 

1422 await admin_stats_cache.invalidate_tags() 

1423 

1424 return self.convert_tool_to_read(db_tool, requesting_user_email=getattr(db_tool, "owner_email", None)) 

1425 except IntegrityError as ie: 

1426 db.rollback() 

1427 logger.error(f"IntegrityError during tool registration: {ie}") 

1428 

1429 # Structured logging: Log database integrity error 

1430 structured_logger.log( 

1431 level="ERROR", 

1432 message="Tool creation failed due to database integrity error", 

1433 event_type="tool_creation_failed", 

1434 component="tool_service", 

1435 user_id=created_by, 

1436 user_email=owner_email, 

1437 error=ie, 

1438 custom_fields={ 

1439 "tool_name": tool.name, 

1440 }, 

1441 ) 

1442 raise ie 

1443 except ToolNameConflictError as tnce: 

1444 db.rollback() 

1445 logger.error(f"ToolNameConflictError during tool registration: {tnce}") 

1446 

1447 # Structured logging: Log name conflict error 

1448 structured_logger.log( 

1449 level="WARNING", 

1450 message="Tool creation failed due to name conflict", 

1451 event_type="tool_name_conflict", 

1452 component="tool_service", 

1453 user_id=created_by, 

1454 user_email=owner_email, 

1455 custom_fields={ 

1456 "tool_name": tool.name, 

1457 "visibility": visibility, 

1458 }, 

1459 ) 

1460 raise tnce 

1461 except Exception as e: 

1462 db.rollback() 

1463 

1464 # Structured logging: Log generic tool creation failure 

1465 structured_logger.log( 

1466 level="ERROR", 

1467 message="Tool creation failed", 

1468 event_type="tool_creation_failed", 

1469 component="tool_service", 

1470 user_id=created_by, 

1471 user_email=owner_email, 

1472 error=e, 

1473 custom_fields={ 

1474 "tool_name": tool.name, 

1475 }, 

1476 ) 

1477 raise ToolError(f"Failed to register tool: {str(e)}") 

1478 

1479 async def register_tools_bulk( 

1480 self, 

1481 db: Session, 

1482 tools: List[ToolCreate], 

1483 created_by: Optional[str] = None, 

1484 created_from_ip: Optional[str] = None, 

1485 created_via: Optional[str] = None, 

1486 created_user_agent: Optional[str] = None, 

1487 import_batch_id: Optional[str] = None, 

1488 federation_source: Optional[str] = None, 

1489 team_id: Optional[str] = None, 

1490 owner_email: Optional[str] = None, 

1491 visibility: Optional[str] = "public", 

1492 conflict_strategy: str = "skip", 

1493 ) -> Dict[str, Any]: 

1494 """Register multiple tools in bulk with a single commit. 

1495 

1496 This method provides significant performance improvements over individual 

1497 tool registration by: 

1498 - Using db.add_all() instead of individual db.add() calls 

1499 - Performing a single commit for all tools 

1500 - Batch conflict detection 

1501 - Chunking for very large imports (>500 items) 

1502 

1503 Args: 

1504 db: Database session 

1505 tools: List of tool creation schemas 

1506 created_by: Username who created these tools 

1507 created_from_ip: IP address of creator 

1508 created_via: Creation method (ui, api, import, federation) 

1509 created_user_agent: User agent of creation request 

1510 import_batch_id: UUID for bulk import operations 

1511 federation_source: Source gateway for federated tools 

1512 team_id: Team ID to assign the tools to 

1513 owner_email: Email of the user who owns these tools 

1514 visibility: Tool visibility level (private, team, public) 

1515 conflict_strategy: How to handle conflicts (skip, update, rename, fail) 

1516 

1517 Returns: 

1518 Dict with statistics: 

1519 - created: Number of tools created 

1520 - updated: Number of tools updated 

1521 - skipped: Number of tools skipped 

1522 - failed: Number of tools that failed 

1523 - errors: List of error messages 

1524 

1525 Raises: 

1526 ToolError: If bulk registration fails critically 

1527 

1528 Examples: 

1529 >>> from mcpgateway.services.tool_service import ToolService 

1530 >>> from unittest.mock import MagicMock 

1531 >>> service = ToolService() 

1532 >>> db = MagicMock() 

1533 >>> tools = [MagicMock(), MagicMock()] 

1534 >>> import asyncio 

1535 >>> try: 

1536 ... result = asyncio.run(service.register_tools_bulk(db, tools)) 

1537 ... except Exception: 

1538 ... pass 

1539 """ 

1540 if not tools: 

1541 return {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []} 

1542 

1543 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []} 

1544 

1545 # Process in chunks to avoid memory issues and SQLite parameter limits 

1546 chunk_size = 500 

1547 

1548 for chunk_start in range(0, len(tools), chunk_size): 

1549 chunk = tools[chunk_start : chunk_start + chunk_size] 

1550 chunk_stats = self._process_tool_chunk( 

1551 db=db, 

1552 chunk=chunk, 

1553 conflict_strategy=conflict_strategy, 

1554 visibility=visibility, 

1555 team_id=team_id, 

1556 owner_email=owner_email, 

1557 created_by=created_by, 

1558 created_from_ip=created_from_ip, 

1559 created_via=created_via, 

1560 created_user_agent=created_user_agent, 

1561 import_batch_id=import_batch_id, 

1562 federation_source=federation_source, 

1563 ) 

1564 

1565 # Aggregate stats 

1566 for key, value in chunk_stats.items(): 

1567 if key == "errors": 

1568 stats[key].extend(value) 

1569 else: 

1570 stats[key] += value 

1571 

1572 if chunk_stats["created"] or chunk_stats["updated"]: 

1573 cache = _get_registry_cache() 

1574 await cache.invalidate_tools() 

1575 tool_lookup_cache = _get_tool_lookup_cache() 

1576 tool_name_map: Dict[str, Optional[str]] = {} 

1577 for tool in chunk: 

1578 name = getattr(tool, "name", None) 

1579 if not name: 

1580 continue 

1581 gateway_id = getattr(tool, "gateway_id", None) 

1582 tool_name_map[name] = str(gateway_id) if gateway_id else tool_name_map.get(name) 

1583 for tool_name, gateway_id in tool_name_map.items(): 

1584 await tool_lookup_cache.invalidate(tool_name, gateway_id=gateway_id) 

1585 # Also invalidate tags cache since tool tags may have changed 

1586 # First-Party 

1587 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

1588 

1589 await admin_stats_cache.invalidate_tags() 

1590 

1591 return stats 

1592 

1593 def _process_tool_chunk( 

1594 self, 

1595 db: Session, 

1596 chunk: List[ToolCreate], 

1597 conflict_strategy: str, 

1598 visibility: str, 

1599 team_id: Optional[int], 

1600 owner_email: Optional[str], 

1601 created_by: str, 

1602 created_from_ip: Optional[str], 

1603 created_via: Optional[str], 

1604 created_user_agent: Optional[str], 

1605 import_batch_id: Optional[str], 

1606 federation_source: Optional[str], 

1607 ) -> dict: 

1608 """Process a chunk of tools for bulk import. 

1609 

1610 Args: 

1611 db: The SQLAlchemy database session. 

1612 chunk: List of ToolCreate objects to process. 

1613 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail"). 

1614 visibility: Tool visibility level ("public", "team", or "private"). 

1615 team_id: Team ID for team-scoped tools. 

1616 owner_email: Email of the tool owner. 

1617 created_by: Email of the user creating the tools. 

1618 created_from_ip: IP address of the request origin. 

1619 created_via: Source of the creation (e.g., "api", "ui"). 

1620 created_user_agent: User agent string from the request. 

1621 import_batch_id: Batch identifier for bulk imports. 

1622 federation_source: Source identifier for federated tools. 

1623 

1624 Returns: 

1625 dict: Statistics dictionary with keys "created", "updated", "skipped", "failed", and "errors". 

1626 """ 

1627 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []} 

1628 

1629 try: 

1630 # Batch check for existing tools to detect conflicts 

1631 tool_names = [tool.name for tool in chunk] 

1632 

1633 if visibility.lower() == "public": 

1634 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "public") 

1635 elif visibility.lower() == "team" and team_id: 

1636 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "team", DbTool.team_id == team_id) 

1637 else: 

1638 # Private tools - check by owner 

1639 existing_tools_query = select(DbTool).where(DbTool.name.in_(tool_names), DbTool.visibility == "private", DbTool.owner_email == (owner_email or created_by)) 

1640 

1641 existing_tools = db.execute(existing_tools_query).scalars().all() 

1642 existing_tools_map = {tool.name: tool for tool in existing_tools} 

1643 

1644 tools_to_add = [] 

1645 tools_to_update = [] 

1646 

1647 for tool in chunk: 

1648 result = self._process_single_tool_for_bulk( 

1649 tool=tool, 

1650 existing_tools_map=existing_tools_map, 

1651 conflict_strategy=conflict_strategy, 

1652 visibility=visibility, 

1653 team_id=team_id, 

1654 owner_email=owner_email, 

1655 created_by=created_by, 

1656 created_from_ip=created_from_ip, 

1657 created_via=created_via, 

1658 created_user_agent=created_user_agent, 

1659 import_batch_id=import_batch_id, 

1660 federation_source=federation_source, 

1661 ) 

1662 

1663 if result["status"] == "add": 

1664 tools_to_add.append(result["tool"]) 

1665 stats["created"] += 1 

1666 elif result["status"] == "update": 

1667 tools_to_update.append(result["tool"]) 

1668 stats["updated"] += 1 

1669 elif result["status"] == "skip": 

1670 stats["skipped"] += 1 

1671 elif result["status"] == "fail": 

1672 stats["failed"] += 1 

1673 stats["errors"].append(result["error"]) 

1674 

1675 # Bulk add new tools 

1676 if tools_to_add: 

1677 db.add_all(tools_to_add) 

1678 

1679 # Commit the chunk 

1680 db.commit() 

1681 

1682 # Refresh tools for notifications and audit trail 

1683 for db_tool in tools_to_add: 

1684 db.refresh(db_tool) 

1685 # Notify subscribers (sync call in async context handled by caller) 

1686 

1687 # Log bulk audit trail entry 

1688 if tools_to_add or tools_to_update: 

1689 audit_trail.log_action( 

1690 user_id=created_by or "system", 

1691 action="bulk_create_tools" if tools_to_add else "bulk_update_tools", 

1692 resource_type="tool", 

1693 resource_id=None, 

1694 details={"count": len(tools_to_add) + len(tools_to_update), "import_batch_id": import_batch_id}, 

1695 db=db, 

1696 ) 

1697 

1698 except Exception as e: 

1699 db.rollback() 

1700 logger.error(f"Failed to process tool chunk: {str(e)}") 

1701 stats["failed"] += len(chunk) 

1702 stats["errors"].append(f"Chunk processing failed: {str(e)}") 

1703 

1704 return stats 

1705 

1706 def _process_single_tool_for_bulk( 

1707 self, 

1708 tool: ToolCreate, 

1709 existing_tools_map: dict, 

1710 conflict_strategy: str, 

1711 visibility: str, 

1712 team_id: Optional[int], 

1713 owner_email: Optional[str], 

1714 created_by: str, 

1715 created_from_ip: Optional[str], 

1716 created_via: Optional[str], 

1717 created_user_agent: Optional[str], 

1718 import_batch_id: Optional[str], 

1719 federation_source: Optional[str], 

1720 ) -> dict: 

1721 """Process a single tool for bulk import. 

1722 

1723 Args: 

1724 tool: ToolCreate object to process. 

1725 existing_tools_map: Dictionary mapping tool names to existing DbTool objects. 

1726 conflict_strategy: Strategy for handling conflicts ("skip", "update", or "fail"). 

1727 visibility: Tool visibility level ("public", "team", or "private"). 

1728 team_id: Team ID for team-scoped tools. 

1729 owner_email: Email of the tool owner. 

1730 created_by: Email of the user creating the tool. 

1731 created_from_ip: IP address of the request origin. 

1732 created_via: Source of the creation (e.g., "api", "ui"). 

1733 created_user_agent: User agent string from the request. 

1734 import_batch_id: Batch identifier for bulk imports. 

1735 federation_source: Source identifier for federated tools. 

1736 

1737 Returns: 

1738 dict: Result dictionary with "status" key ("add", "update", "skip", or "fail") 

1739 and either "tool" (DbTool object) or "error" (error message). 

1740 """ 

1741 try: 

1742 # Extract auth information 

1743 if tool.auth is None: 

1744 auth_type = None 

1745 auth_value = None 

1746 else: 

1747 auth_type = tool.auth.auth_type 

1748 auth_value = tool.auth.auth_value 

1749 

1750 # Use provided parameters or schema values 

1751 tool_team_id = team_id if team_id is not None else getattr(tool, "team_id", None) 

1752 tool_owner_email = owner_email or getattr(tool, "owner_email", None) or created_by 

1753 tool_visibility = visibility if visibility is not None else (getattr(tool, "visibility", None) or "public") 

1754 

1755 existing_tool = existing_tools_map.get(tool.name) 

1756 

1757 if existing_tool: 

1758 # Handle conflict based on strategy 

1759 if conflict_strategy == "skip": 

1760 return {"status": "skip"} 

1761 if conflict_strategy == "update": 

1762 # Update existing tool 

1763 existing_tool.display_name = tool.displayName or tool.name 

1764 existing_tool.title = tool.title 

1765 existing_tool.url = str(tool.url) 

1766 existing_tool.description = tool.description 

1767 if getattr(existing_tool, "original_description", None) is None: 

1768 existing_tool.original_description = tool.description 

1769 existing_tool.integration_type = tool.integration_type 

1770 existing_tool.request_type = tool.request_type 

1771 existing_tool.headers = _protect_tool_headers_for_storage(tool.headers, existing_headers=existing_tool.headers) 

1772 existing_tool.input_schema = tool.input_schema 

1773 existing_tool.output_schema = tool.output_schema 

1774 existing_tool.annotations = tool.annotations 

1775 existing_tool.jsonpath_filter = tool.jsonpath_filter 

1776 existing_tool.auth_type = auth_type 

1777 existing_tool.auth_value = auth_value 

1778 existing_tool.tags = tool.tags or [] 

1779 existing_tool.modified_by = created_by 

1780 existing_tool.modified_from_ip = created_from_ip 

1781 existing_tool.modified_via = created_via 

1782 existing_tool.modified_user_agent = created_user_agent 

1783 existing_tool.updated_at = datetime.now(timezone.utc) 

1784 existing_tool.version = (existing_tool.version or 1) + 1 

1785 

1786 # Update REST-specific fields if applicable 

1787 if tool.integration_type == "REST": 

1788 existing_tool.base_url = tool.base_url 

1789 existing_tool.path_template = tool.path_template 

1790 existing_tool.query_mapping = tool.query_mapping 

1791 existing_tool.header_mapping = tool.header_mapping 

1792 existing_tool.timeout_ms = tool.timeout_ms 

1793 existing_tool.expose_passthrough = tool.expose_passthrough if tool.expose_passthrough is not None else True 

1794 existing_tool.allowlist = tool.allowlist 

1795 existing_tool.plugin_chain_pre = tool.plugin_chain_pre 

1796 existing_tool.plugin_chain_post = tool.plugin_chain_post 

1797 

1798 return {"status": "update", "tool": existing_tool} 

1799 

1800 if conflict_strategy == "rename": 

1801 # Create with renamed tool 

1802 new_name = f"{tool.name}_imported_{int(datetime.now().timestamp())}" 

1803 db_tool = self._create_tool_object( 

1804 tool, 

1805 new_name, 

1806 auth_type, 

1807 auth_value, 

1808 tool_team_id, 

1809 tool_owner_email, 

1810 tool_visibility, 

1811 created_by, 

1812 created_from_ip, 

1813 created_via, 

1814 created_user_agent, 

1815 import_batch_id, 

1816 federation_source, 

1817 ) 

1818 return {"status": "add", "tool": db_tool} 

1819 

1820 if conflict_strategy == "fail": 

1821 return {"status": "fail", "error": f"Tool name conflict: {tool.name}"} 

1822 

1823 # Create new tool 

1824 db_tool = self._create_tool_object( 

1825 tool, 

1826 tool.name, 

1827 auth_type, 

1828 auth_value, 

1829 tool_team_id, 

1830 tool_owner_email, 

1831 tool_visibility, 

1832 created_by, 

1833 created_from_ip, 

1834 created_via, 

1835 created_user_agent, 

1836 import_batch_id, 

1837 federation_source, 

1838 ) 

1839 return {"status": "add", "tool": db_tool} 

1840 

1841 except Exception as e: 

1842 logger.warning(f"Failed to process tool {tool.name} in bulk operation: {str(e)}") 

1843 return {"status": "fail", "error": f"Failed to process tool {tool.name}: {str(e)}"} 

1844 

1845 def _create_tool_object( 

1846 self, 

1847 tool: ToolCreate, 

1848 name: str, 

1849 auth_type: Optional[str], 

1850 auth_value: Optional[str], 

1851 tool_team_id: Optional[int], 

1852 tool_owner_email: Optional[str], 

1853 tool_visibility: str, 

1854 created_by: str, 

1855 created_from_ip: Optional[str], 

1856 created_via: Optional[str], 

1857 created_user_agent: Optional[str], 

1858 import_batch_id: Optional[str], 

1859 federation_source: Optional[str], 

1860 ) -> DbTool: 

1861 """Create a DbTool object from ToolCreate schema. 

1862 

1863 Args: 

1864 tool: ToolCreate schema object containing tool data. 

1865 name: Name of the tool. 

1866 auth_type: Authentication type for the tool. 

1867 auth_value: Authentication value/credentials for the tool. 

1868 tool_team_id: Team ID for team-scoped tools. 

1869 tool_owner_email: Email of the tool owner. 

1870 tool_visibility: Tool visibility level ("public", "team", or "private"). 

1871 created_by: Email of the user creating the tool. 

1872 created_from_ip: IP address of the request origin. 

1873 created_via: Source of the creation (e.g., "api", "ui"). 

1874 created_user_agent: User agent string from the request. 

1875 import_batch_id: Batch identifier for bulk imports. 

1876 federation_source: Source identifier for federated tools. 

1877 

1878 Returns: 

1879 DbTool: Database model instance ready to be added to the session. 

1880 """ 

1881 return DbTool( 

1882 original_name=name, 

1883 custom_name=name, 

1884 custom_name_slug=slugify(name), 

1885 display_name=tool.displayName or name, 

1886 title=tool.title, 

1887 url=str(tool.url), 

1888 description=tool.description, 

1889 original_description=tool.description, 

1890 integration_type=tool.integration_type, 

1891 request_type=tool.request_type, 

1892 headers=_protect_tool_headers_for_storage(tool.headers), 

1893 input_schema=tool.input_schema, 

1894 output_schema=tool.output_schema, 

1895 annotations=tool.annotations, 

1896 jsonpath_filter=tool.jsonpath_filter, 

1897 auth_type=auth_type, 

1898 auth_value=auth_value, 

1899 gateway_id=tool.gateway_id, 

1900 tags=tool.tags or [], 

1901 created_by=created_by, 

1902 created_from_ip=created_from_ip, 

1903 created_via=created_via, 

1904 created_user_agent=created_user_agent, 

1905 import_batch_id=import_batch_id, 

1906 federation_source=federation_source, 

1907 version=1, 

1908 team_id=tool_team_id, 

1909 owner_email=tool_owner_email, 

1910 visibility=tool_visibility, 

1911 base_url=tool.base_url if tool.integration_type == "REST" else None, 

1912 path_template=tool.path_template if tool.integration_type == "REST" else None, 

1913 query_mapping=tool.query_mapping if tool.integration_type == "REST" else None, 

1914 header_mapping=tool.header_mapping if tool.integration_type == "REST" else None, 

1915 timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None, 

1916 expose_passthrough=((tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None), 

1917 allowlist=tool.allowlist if tool.integration_type == "REST" else None, 

1918 plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None, 

1919 plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None, 

1920 ) 

1921 

1922 async def list_tools( 

1923 self, 

1924 db: Session, 

1925 include_inactive: bool = False, 

1926 cursor: Optional[str] = None, 

1927 tags: Optional[List[str]] = None, 

1928 gateway_id: Optional[str] = None, 

1929 limit: Optional[int] = None, 

1930 page: Optional[int] = None, 

1931 per_page: Optional[int] = None, 

1932 user_email: Optional[str] = None, 

1933 team_id: Optional[str] = None, 

1934 visibility: Optional[str] = None, 

1935 token_teams: Optional[List[str]] = None, 

1936 _request_headers: Optional[Dict[str, str]] = None, 

1937 requesting_user_email: Optional[str] = None, 

1938 requesting_user_is_admin: bool = False, 

1939 requesting_user_team_roles: Optional[Dict[str, str]] = None, 

1940 ) -> Union[tuple[List[ToolRead], Optional[str]], Dict[str, Any]]: 

1941 """ 

1942 Retrieve a list of registered tools from the database with pagination support. 

1943 

1944 Args: 

1945 db (Session): The SQLAlchemy database session. 

1946 include_inactive (bool): If True, include inactive tools in the result. 

1947 Defaults to False. 

1948 cursor (Optional[str], optional): An opaque cursor token for pagination. 

1949 Opaque base64-encoded string containing last item's ID. 

1950 tags (Optional[List[str]]): Filter tools by tags. If provided, only tools with at least one matching tag will be returned. 

1951 gateway_id (Optional[str]): Filter tools by gateway ID. Accepts the literal value 'null' to match NULL gateway_id. 

1952 limit (Optional[int]): Maximum number of tools to return. Use 0 for all tools (no limit). 

1953 If not specified, uses pagination_default_page_size. 

1954 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor. 

1955 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size. 

1956 user_email (Optional[str]): User email for team-based access control. If None, no access control is applied. 

1957 team_id (Optional[str]): Filter by specific team ID. Requires user_email for access validation. 

1958 visibility (Optional[str]): Filter by visibility (private, team, public). 

1959 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API token access 

1960 where the token scope should be respected instead of the user's full team memberships. 

1961 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. 

1962 Currently unused but kept for API consistency. Defaults to None. 

1963 requesting_user_email (Optional[str]): Email of the requesting user for header masking. 

1964 requesting_user_is_admin (bool): Whether the requester is an admin. 

1965 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester. 

1966 

1967 Returns: 

1968 tuple[List[ToolRead], Optional[str]]: Tuple containing: 

1969 - List of tools for current page 

1970 - Next cursor token if more results exist, None otherwise 

1971 

1972 Examples: 

1973 >>> from mcpgateway.services.tool_service import ToolService 

1974 >>> from unittest.mock import MagicMock 

1975 >>> service = ToolService() 

1976 >>> db = MagicMock() 

1977 >>> tool_read = MagicMock() 

1978 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read) 

1979 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()] 

1980 >>> import asyncio 

1981 >>> tools, next_cursor = asyncio.run(service.list_tools(db)) 

1982 >>> isinstance(tools, list) 

1983 True 

1984 """ 

1985 with create_span( 

1986 "tool.list", 

1987 { 

1988 "include_inactive": include_inactive, 

1989 "tags.count": len(tags) if tags else 0, 

1990 "gateway_id": gateway_id, 

1991 "limit": limit, 

1992 "page": page, 

1993 "per_page": per_page, 

1994 "user.email": user_email, 

1995 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None, 

1996 "team.filter": team_id, 

1997 "visibility": visibility, 

1998 }, 

1999 ): 

2000 # Check cache for first page only (cursor=None) 

2001 # Skip caching when: 

2002 # - user_email is provided (team-filtered results are user-specific) 

2003 # - token_teams is set (scoped access, e.g., public-only or team-scoped tokens) 

2004 # - page-based pagination is used 

2005 # This prevents cache poisoning where admin results could leak to public-only requests 

2006 cache = _get_registry_cache() 

2007 filters_hash = None 

2008 # Only use the cache when using the real converter. In unit tests we often patch 

2009 # convert_tool_to_read() to exercise error handling, and a warm cache would bypass it. 

2010 try: 

2011 converter_is_default = self.convert_tool_to_read.__func__ is ToolService.convert_tool_to_read # type: ignore[attr-defined] 

2012 except Exception: 

2013 converter_is_default = False 

2014 

2015 if cursor is None and user_email is None and token_teams is None and page is None and converter_is_default: 

2016 # Include visibility in the cache hash so admin requests that include 

2017 # an explicit visibility filter don't get served stale results from 

2018 # a previously cached unfiltered admin request. 

2019 filters_hash = cache.hash_filters( 

2020 include_inactive=include_inactive, 

2021 tags=sorted(tags) if tags else None, 

2022 gateway_id=gateway_id, 

2023 limit=limit, 

2024 visibility=visibility, 

2025 ) 

2026 cached = await cache.get("tools", filters_hash) 

2027 if cached is not None: 

2028 # Reconstruct ToolRead objects from cached dicts 

2029 cached_tools = [ToolRead.model_validate(t) for t in cached["tools"]] 

2030 return (cached_tools, cached.get("next_cursor")) 

2031 

2032 # Build base query with ordering and eager load gateway + email_team to avoid N+1 

2033 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)).order_by(desc(DbTool.created_at), desc(DbTool.id)) 

2034 

2035 # Apply active/inactive filter 

2036 if not include_inactive: 

2037 query = query.where(DbTool.enabled) 

2038 query = await self._apply_access_control(query, db, user_email, token_teams, team_id) 

2039 

2040 if visibility: 

2041 query = query.where(DbTool.visibility == visibility) 

2042 

2043 # Add gateway_id filtering if provided 

2044 if gateway_id: 

2045 if gateway_id.lower() == "null": 

2046 query = query.where(DbTool.gateway_id.is_(None)) 

2047 else: 

2048 query = query.where(DbTool.gateway_id == gateway_id) 

2049 

2050 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats) 

2051 if tags: 

2052 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True)) 

2053 

2054 # Use unified pagination helper - handles both page and cursor pagination 

2055 pag_result = await unified_paginate( 

2056 db=db, 

2057 query=query, 

2058 page=page, 

2059 per_page=per_page, 

2060 cursor=cursor, 

2061 limit=limit, 

2062 base_url="/admin/tools", # Used for page-based links 

2063 query_params={"include_inactive": include_inactive} if include_inactive else {}, 

2064 ) 

2065 

2066 next_cursor = None 

2067 # Extract servers based on pagination type 

2068 if page is not None: 

2069 # Page-based: pag_result is a dict 

2070 tools_db = pag_result["data"] 

2071 else: 

2072 # Cursor-based: pag_result is a tuple 

2073 tools_db, next_cursor = pag_result 

2074 

2075 db.commit() # Release transaction to avoid idle-in-transaction 

2076 

2077 # Convert to ToolRead (common for both pagination types) 

2078 # Team names are loaded via joinedload(DbTool.email_team) 

2079 result = [] 

2080 for s in tools_db: 

2081 try: 

2082 result.append( 

2083 self.convert_tool_to_read( 

2084 s, 

2085 include_metrics=False, 

2086 include_auth=False, 

2087 requesting_user_email=requesting_user_email, 

2088 requesting_user_is_admin=requesting_user_is_admin, 

2089 requesting_user_team_roles=requesting_user_team_roles, 

2090 ) 

2091 ) 

2092 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e: 

2093 logger.exception(f"Failed to convert tool {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}") 

2094 # Continue with remaining tools instead of failing completely 

2095 

2096 # Return appropriate format based on pagination type 

2097 if page is not None: 

2098 # Page-based format 

2099 return { 

2100 "data": result, 

2101 "pagination": pag_result["pagination"], 

2102 "links": pag_result["links"], 

2103 } 

2104 

2105 # Cursor-based format 

2106 

2107 # Cache first page results - only for non-user-specific/non-scoped queries 

2108 # Must match the same conditions as cache lookup to prevent cache poisoning 

2109 if filters_hash is not None and cursor is None and user_email is None and token_teams is None and page is None and converter_is_default: 

2110 try: 

2111 cache_data = {"tools": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor} 

2112 await cache.set("tools", cache_data, filters_hash) 

2113 except AttributeError: 

2114 pass # Skip caching if result objects don't support model_dump (e.g., in doctests) 

2115 

2116 return (result, next_cursor) 

2117 

2118 async def list_server_tools( 

2119 self, 

2120 db: Session, 

2121 server_id: str, 

2122 include_inactive: bool = False, 

2123 include_metrics: bool = False, 

2124 cursor: Optional[str] = None, 

2125 user_email: Optional[str] = None, 

2126 token_teams: Optional[List[str]] = None, 

2127 _request_headers: Optional[Dict[str, str]] = None, 

2128 requesting_user_email: Optional[str] = None, 

2129 requesting_user_is_admin: bool = False, 

2130 requesting_user_team_roles: Optional[Dict[str, str]] = None, 

2131 ) -> List[ToolRead]: 

2132 """ 

2133 Retrieve a list of registered tools from the database. 

2134 

2135 Args: 

2136 db (Session): The SQLAlchemy database session. 

2137 server_id (str): Server ID 

2138 include_inactive (bool): If True, include inactive tools in the result. 

2139 Defaults to False. 

2140 include_metrics (bool): If True, all tool metrics included in result otherwise null. 

2141 Defaults to False. 

2142 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, 

2143 this parameter is ignored. Defaults to None. 

2144 user_email (Optional[str]): User email for visibility filtering. If None, no filtering applied. 

2145 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API 

2146 token access where the token scope should be respected. 

2147 _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. 

2148 Currently unused but kept for API consistency. Defaults to None. 

2149 requesting_user_email (Optional[str]): Email of the requesting user for header masking. 

2150 requesting_user_is_admin (bool): Whether the requester is an admin. 

2151 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester. 

2152 

2153 Returns: 

2154 List[ToolRead]: A list of registered tools represented as ToolRead objects. 

2155 

2156 Examples: 

2157 >>> from mcpgateway.services.tool_service import ToolService 

2158 >>> from unittest.mock import MagicMock 

2159 >>> service = ToolService() 

2160 >>> db = MagicMock() 

2161 >>> tool_read = MagicMock() 

2162 >>> service.convert_tool_to_read = MagicMock(return_value=tool_read) 

2163 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()] 

2164 >>> import asyncio 

2165 >>> result = asyncio.run(service.list_server_tools(db, 'server1')) 

2166 >>> isinstance(result, list) 

2167 True 

2168 """ 

2169 

2170 with create_span( 

2171 "tool.list", 

2172 { 

2173 "server_id": server_id, 

2174 "include_inactive": include_inactive, 

2175 "include_metrics": include_metrics, 

2176 "user.email": user_email, 

2177 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None, 

2178 }, 

2179 ): 

2180 if include_metrics: 

2181 query = ( 

2182 select(DbTool) 

2183 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)) 

2184 .options(selectinload(DbTool.metrics)) 

2185 .options(selectinload(DbTool.metrics_hourly)) 

2186 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id) 

2187 .where(server_tool_association.c.server_id == server_id) 

2188 ) 

2189 else: 

2190 query = ( 

2191 select(DbTool) 

2192 .options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)) 

2193 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id) 

2194 .where(server_tool_association.c.server_id == server_id) 

2195 ) 

2196 

2197 cursor = None # Placeholder for pagination; ignore for now 

2198 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}") 

2199 

2200 if not include_inactive: 

2201 query = query.where(DbTool.enabled) 

2202 

2203 # Add visibility filtering if user context OR token_teams provided 

2204 # This ensures unauthenticated requests with token_teams=[] only see public tools 

2205 if user_email is not None or token_teams is not None: # empty-string user_email -> public-only filtering (secure default) 

2206 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB 

2207 if token_teams is not None: 

2208 team_ids = token_teams 

2209 elif user_email: 

2210 team_service = TeamManagementService(db) 

2211 user_teams = await team_service.get_user_teams(user_email) 

2212 team_ids = [team.id for team in user_teams] 

2213 else: 

2214 team_ids = [] 

2215 

2216 # Check if this is a public-only token (empty teams array) 

2217 # Public-only tokens can ONLY see public resources - no owner access 

2218 is_public_only_token = token_teams is not None and len(token_teams) == 0 

2219 

2220 access_conditions = [ 

2221 DbTool.visibility == "public", 

2222 ] 

2223 # Only include owner access for non-public-only tokens with user_email 

2224 if not is_public_only_token and user_email: 

2225 access_conditions.append(DbTool.owner_email == user_email) 

2226 if team_ids: 

2227 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"]))) 

2228 query = query.where(or_(*access_conditions)) 

2229 

2230 # Execute the query - team names are loaded via joinedload(DbTool.email_team) 

2231 tools = db.execute(query).scalars().all() 

2232 

2233 db.commit() # Release transaction to avoid idle-in-transaction 

2234 

2235 result = [] 

2236 for tool in tools: 

2237 try: 

2238 result.append( 

2239 self.convert_tool_to_read( 

2240 tool, 

2241 include_metrics=include_metrics, 

2242 include_auth=False, 

2243 requesting_user_email=requesting_user_email, 

2244 requesting_user_is_admin=requesting_user_is_admin, 

2245 requesting_user_team_roles=requesting_user_team_roles, 

2246 ) 

2247 ) 

2248 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e: 

2249 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}") 

2250 # Continue with remaining tools instead of failing completely 

2251 

2252 return result 

2253 

2254 async def list_server_mcp_tool_definitions( 

2255 self, 

2256 db: Session, 

2257 server_id: str, 

2258 *, 

2259 include_inactive: bool = False, 

2260 user_email: Optional[str] = None, 

2261 token_teams: Optional[List[str]] = None, 

2262 ) -> List[Dict[str, Any]]: 

2263 """Return server-scoped MCP tool definitions without building full ToolRead models. 

2264 

2265 This is a hot-path helper for the internal Rust -> Python seam. It keeps 

2266 auth and visibility semantics aligned with ``list_server_tools`` while 

2267 avoiding the heavier ``ToolRead`` conversion that is only needed for the 

2268 admin/API surfaces. 

2269 

2270 Args: 

2271 db: Active database session. 

2272 server_id: Virtual server identifier used to scope the tool listing. 

2273 include_inactive: Whether disabled tools should be included. 

2274 user_email: Requester email for owner-scoped visibility checks. 

2275 token_teams: Normalized team scope from the caller token. 

2276 

2277 Returns: 

2278 A list of MCP-compatible tool definition dictionaries. 

2279 """ 

2280 with create_span( 

2281 "tool.list", 

2282 { 

2283 "server_id": server_id, 

2284 "include_inactive": include_inactive, 

2285 "user.email": user_email, 

2286 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None, 

2287 "mcp.definition_mode": True, 

2288 }, 

2289 ): 

2290 name_column = DbTool.__table__.c.name 

2291 query = ( 

2292 select( 

2293 name_column.label("name"), 

2294 DbTool.description.label("description"), 

2295 DbTool.input_schema.label("input_schema"), 

2296 DbTool.output_schema.label("output_schema"), 

2297 DbTool.annotations.label("annotations"), 

2298 DbTool.owner_email.label("owner_email"), 

2299 DbTool.team_id.label("team_id"), 

2300 DbTool.visibility.label("visibility"), 

2301 ) 

2302 .join(server_tool_association, DbTool.id == server_tool_association.c.tool_id) 

2303 .where(server_tool_association.c.server_id == server_id) 

2304 ) 

2305 

2306 if not include_inactive: 

2307 query = query.where(DbTool.enabled) 

2308 

2309 if user_email is not None or token_teams is not None: 

2310 team_ids = token_teams if token_teams is not None else [] 

2311 is_public_only_token = token_teams is not None and len(token_teams) == 0 

2312 

2313 access_conditions = [DbTool.visibility == "public"] 

2314 if not is_public_only_token and user_email: 

2315 access_conditions.append(DbTool.owner_email == user_email) 

2316 if team_ids: 

2317 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"]))) 

2318 query = query.where(or_(*access_conditions)) 

2319 

2320 rows = db.execute(query).mappings().all() 

2321 db.commit() 

2322 

2323 result: List[Dict[str, Any]] = [] 

2324 for row in rows: 

2325 payload: Dict[str, Any] = { 

2326 "name": row["name"], 

2327 "description": row["description"], 

2328 "inputSchema": row["input_schema"] or {"type": "object", "properties": {}}, 

2329 "annotations": row["annotations"] or {}, 

2330 } 

2331 if row["output_schema"] is not None: 

2332 payload["outputSchema"] = row["output_schema"] 

2333 result.append(payload) 

2334 

2335 return result 

2336 

2337 async def list_tools_for_user( 

2338 self, 

2339 db: Session, 

2340 user_email: str, 

2341 team_id: Optional[str] = None, 

2342 visibility: Optional[str] = None, 

2343 include_inactive: bool = False, 

2344 _skip: int = 0, 

2345 _limit: int = 100, 

2346 *, 

2347 cursor: Optional[str] = None, 

2348 gateway_id: Optional[str] = None, 

2349 tags: Optional[List[str]] = None, 

2350 limit: Optional[int] = None, 

2351 ) -> tuple[List[ToolRead], Optional[str]]: 

2352 """ 

2353 DEPRECATED: Use list_tools() with user_email parameter instead. 

2354 

2355 List tools user has access to with team filtering and cursor pagination. 

2356 

2357 This method is maintained for backward compatibility but is no longer used. 

2358 New code should call list_tools() with user_email, team_id, and visibility parameters. 

2359 

2360 Args: 

2361 db: Database session 

2362 user_email: Email of the user requesting tools 

2363 team_id: Optional team ID to filter by specific team 

2364 visibility: Optional visibility filter (private, team, public) 

2365 include_inactive: Whether to include inactive tools 

2366 _skip: Number of tools to skip for pagination (deprecated) 

2367 _limit: Maximum number of tools to return (deprecated) 

2368 cursor: Opaque cursor token for pagination 

2369 gateway_id: Filter tools by gateway ID. Accepts literal 'null' for NULL gateway_id. 

2370 tags: Filter tools by tags (match any) 

2371 limit: Maximum number of tools to return. Use 0 for all tools (no limit). 

2372 If not specified, uses pagination_default_page_size. 

2373 

2374 Returns: 

2375 tuple[List[ToolRead], Optional[str]]: Tools the user has access to and optional next_cursor 

2376 """ 

2377 # Determine page size based on limit parameter 

2378 # limit=None: use default, limit=0: no limit (all), limit>0: use specified (capped) 

2379 if limit is None: 

2380 page_size = settings.pagination_default_page_size 

2381 elif limit == 0: 

2382 page_size = None # No limit - fetch all 

2383 else: 

2384 page_size = min(limit, settings.pagination_max_page_size) 

2385 

2386 # Decode cursor to get last_id if provided 

2387 last_id = None 

2388 if cursor: 

2389 try: 

2390 cursor_data = decode_cursor(cursor) 

2391 last_id = cursor_data.get("id") 

2392 logger.debug(f"Decoded cursor: last_id={last_id}") 

2393 except ValueError as e: 

2394 logger.warning(f"Invalid cursor, ignoring: {e}") 

2395 

2396 # Build query following existing patterns from list_tools() 

2397 team_service = TeamManagementService(db) 

2398 user_teams = await team_service.get_user_teams(user_email) 

2399 team_ids = [team.id for team in user_teams] 

2400 

2401 # Eager load gateway and email_team to avoid N+1 when accessing gateway_slug and team name 

2402 query = select(DbTool).options(joinedload(DbTool.gateway), joinedload(DbTool.email_team)) 

2403 

2404 # Apply active/inactive filter 

2405 if not include_inactive: 

2406 query = query.where(DbTool.enabled.is_(True)) 

2407 

2408 if team_id: 

2409 if team_id not in team_ids: 

2410 return ([], None) # No access to team 

2411 

2412 access_conditions = [ 

2413 and_(DbTool.team_id == team_id, DbTool.visibility.in_(["team", "public"])), 

2414 and_(DbTool.team_id == team_id, DbTool.owner_email == user_email), 

2415 ] 

2416 query = query.where(or_(*access_conditions)) 

2417 else: 

2418 access_conditions = [ 

2419 DbTool.owner_email == user_email, 

2420 DbTool.visibility == "public", 

2421 ] 

2422 if team_ids: 

2423 access_conditions.append(and_(DbTool.team_id.in_(team_ids), DbTool.visibility.in_(["team", "public"]))) 

2424 

2425 query = query.where(or_(*access_conditions)) 

2426 

2427 # Apply visibility filter if specified 

2428 if visibility: 

2429 query = query.where(DbTool.visibility == visibility) 

2430 

2431 if gateway_id: 

2432 if gateway_id.lower() == "null": 

2433 query = query.where(DbTool.gateway_id.is_(None)) 

2434 else: 

2435 query = query.where(DbTool.gateway_id == gateway_id) 

2436 

2437 if tags: 

2438 query = query.where(json_contains_tag_expr(db, DbTool.tags, tags, match_any=True)) 

2439 

2440 # Apply cursor filter (WHERE id > last_id) 

2441 if last_id: 

2442 query = query.where(DbTool.id > last_id) 

2443 

2444 # Execute query - team names are loaded via joinedload(DbTool.email_team) 

2445 if page_size is not None: 

2446 tools = db.execute(query.limit(page_size + 1)).scalars().all() 

2447 else: 

2448 tools = db.execute(query).scalars().all() 

2449 

2450 db.commit() # Release transaction to avoid idle-in-transaction 

2451 

2452 # Check if there are more results (only when paginating) 

2453 has_more = page_size is not None and len(tools) > page_size 

2454 if has_more: 

2455 tools = tools[:page_size] 

2456 

2457 # Convert to ToolRead objects 

2458 result = [] 

2459 for tool in tools: 

2460 try: 

2461 result.append(self.convert_tool_to_read(tool, include_metrics=False, include_auth=False, requesting_user_email=user_email, requesting_user_is_admin=False)) 

2462 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e: 

2463 logger.exception(f"Failed to convert tool {getattr(tool, 'id', 'unknown')} ({getattr(tool, 'name', 'unknown')}): {e}") 

2464 # Continue with remaining tools instead of failing completely 

2465 

2466 next_cursor = None 

2467 # Generate cursor if there are more results (cursor-based pagination) 

2468 if has_more and tools: 

2469 last_tool = tools[-1] 

2470 next_cursor = encode_cursor({"created_at": last_tool.created_at.isoformat(), "id": last_tool.id}) 

2471 

2472 return (result, next_cursor) 

2473 

2474 async def get_tool( 

2475 self, 

2476 db: Session, 

2477 tool_id: str, 

2478 requesting_user_email: Optional[str] = None, 

2479 requesting_user_is_admin: bool = False, 

2480 requesting_user_team_roles: Optional[Dict[str, str]] = None, 

2481 ) -> ToolRead: 

2482 """ 

2483 Retrieve a tool by its ID. 

2484 

2485 Args: 

2486 db (Session): The SQLAlchemy database session. 

2487 tool_id (str): The unique identifier of the tool. 

2488 requesting_user_email (Optional[str]): Email of the requesting user for header masking. 

2489 requesting_user_is_admin (bool): Whether the requester is an admin. 

2490 requesting_user_team_roles (Optional[Dict[str, str]]): {team_id: role} for the requester. 

2491 

2492 Returns: 

2493 ToolRead: The tool object. 

2494 

2495 Raises: 

2496 ToolNotFoundError: If the tool is not found. 

2497 

2498 Examples: 

2499 >>> from mcpgateway.services.tool_service import ToolService 

2500 >>> from unittest.mock import MagicMock 

2501 >>> service = ToolService() 

2502 >>> db = MagicMock() 

2503 >>> tool = MagicMock() 

2504 >>> db.get.return_value = tool 

2505 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read') 

2506 >>> import asyncio 

2507 >>> asyncio.run(service.get_tool(db, 'tool_id')) 

2508 'tool_read' 

2509 """ 

2510 tool = db.get(DbTool, tool_id) 

2511 if not tool: 

2512 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

2513 

2514 tool_read = self.convert_tool_to_read( 

2515 tool, 

2516 requesting_user_email=requesting_user_email, 

2517 requesting_user_is_admin=requesting_user_is_admin, 

2518 requesting_user_team_roles=requesting_user_team_roles, 

2519 ) 

2520 

2521 structured_logger.log( 

2522 level="INFO", 

2523 message="Tool retrieved successfully", 

2524 event_type="tool_viewed", 

2525 component="tool_service", 

2526 team_id=getattr(tool, "team_id", None), 

2527 resource_type="tool", 

2528 resource_id=str(tool.id), 

2529 custom_fields={ 

2530 "tool_name": tool.name, 

2531 "include_metrics": bool(getattr(tool_read, "metrics", {})), 

2532 }, 

2533 ) 

2534 

2535 return tool_read 

2536 

2537 async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] = None, purge_metrics: bool = False) -> None: 

2538 """ 

2539 Delete a tool by its ID. 

2540 

2541 Args: 

2542 db (Session): The SQLAlchemy database session. 

2543 tool_id (str): The unique identifier of the tool. 

2544 user_email (Optional[str]): Email of user performing delete (for ownership check). 

2545 purge_metrics (bool): If True, delete raw + rollup metrics for this tool. 

2546 

2547 Raises: 

2548 ToolNotFoundError: If the tool is not found. 

2549 PermissionError: If user doesn't own the tool. 

2550 ToolError: For other deletion errors. 

2551 

2552 Examples: 

2553 >>> from mcpgateway.services.tool_service import ToolService 

2554 >>> from unittest.mock import MagicMock, AsyncMock 

2555 >>> service = ToolService() 

2556 >>> db = MagicMock() 

2557 >>> tool = MagicMock() 

2558 >>> db.get.return_value = tool 

2559 >>> db.delete = MagicMock() 

2560 >>> db.commit = MagicMock() 

2561 >>> service._notify_tool_deleted = AsyncMock() 

2562 >>> import asyncio 

2563 >>> asyncio.run(service.delete_tool(db, 'tool_id')) 

2564 """ 

2565 try: 

2566 tool = db.get(DbTool, tool_id) 

2567 if not tool: 

2568 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

2569 

2570 # Check ownership if user_email provided 

2571 if user_email: 

2572 # First-Party 

2573 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2574 

2575 permission_service = PermissionService(db) 

2576 if not await permission_service.check_resource_ownership(user_email, tool): 

2577 raise PermissionError("Only the owner can delete this tool") 

2578 

2579 tool_info = {"id": tool.id, "name": tool.name} 

2580 tool_name = tool.name 

2581 tool_team_id = tool.team_id 

2582 

2583 if purge_metrics: 

2584 with pause_rollup_during_purge(reason=f"purge_tool:{tool_id}"): 

2585 delete_metrics_in_batches(db, ToolMetric, ToolMetric.tool_id, tool_id) 

2586 delete_metrics_in_batches(db, ToolMetricsHourly, ToolMetricsHourly.tool_id, tool_id) 

2587 

2588 # Use DELETE with rowcount check for database-agnostic atomic delete 

2589 stmt = delete(DbTool).where(DbTool.id == tool_id) 

2590 result = db.execute(stmt) 

2591 if result.rowcount == 0: 

2592 # Tool was already deleted by another concurrent request 

2593 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

2594 

2595 db.commit() 

2596 await self._notify_tool_deleted(tool_info) 

2597 logger.info(f"Permanently deleted tool: {tool_info['name']}") 

2598 

2599 # Structured logging: Audit trail for tool deletion 

2600 audit_trail.log_action( 

2601 user_id=user_email or "system", 

2602 action="delete_tool", 

2603 resource_type="tool", 

2604 resource_id=tool_info["id"], 

2605 resource_name=tool_name, 

2606 user_email=user_email, 

2607 team_id=tool_team_id, 

2608 old_values={ 

2609 "name": tool_name, 

2610 }, 

2611 db=db, 

2612 ) 

2613 

2614 # Structured logging: Log successful tool deletion 

2615 structured_logger.log( 

2616 level="INFO", 

2617 message="Tool deleted successfully", 

2618 event_type="tool_deleted", 

2619 component="tool_service", 

2620 user_email=user_email, 

2621 team_id=tool_team_id, 

2622 resource_type="tool", 

2623 resource_id=tool_info["id"], 

2624 custom_fields={ 

2625 "tool_name": tool_name, 

2626 "purge_metrics": purge_metrics, 

2627 }, 

2628 ) 

2629 

2630 # Invalidate cache after successful deletion 

2631 cache = _get_registry_cache() 

2632 await cache.invalidate_tools() 

2633 tool_lookup_cache = _get_tool_lookup_cache() 

2634 await tool_lookup_cache.invalidate(tool_name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None) 

2635 # Also invalidate tags cache since tool tags may have changed 

2636 # First-Party 

2637 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

2638 

2639 await admin_stats_cache.invalidate_tags() 

2640 # Invalidate top performers cache 

2641 # First-Party 

2642 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel 

2643 

2644 metrics_cache.invalidate_prefix("top_tools:") 

2645 metrics_cache.invalidate("tools") 

2646 except PermissionError as pe: 

2647 db.rollback() 

2648 

2649 # Structured logging: Log permission error 

2650 structured_logger.log( 

2651 level="WARNING", 

2652 message="Tool deletion failed due to permission error", 

2653 event_type="tool_delete_permission_denied", 

2654 component="tool_service", 

2655 user_email=user_email, 

2656 resource_type="tool", 

2657 resource_id=tool_id, 

2658 error=pe, 

2659 ) 

2660 raise 

2661 except Exception as e: 

2662 db.rollback() 

2663 

2664 # Structured logging: Log generic tool deletion failure 

2665 structured_logger.log( 

2666 level="ERROR", 

2667 message="Tool deletion failed", 

2668 event_type="tool_deletion_failed", 

2669 component="tool_service", 

2670 user_email=user_email, 

2671 resource_type="tool", 

2672 resource_id=tool_id, 

2673 error=e, 

2674 ) 

2675 raise ToolError(f"Failed to delete tool: {str(e)}") 

2676 

2677 async def set_tool_state(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead: 

2678 """ 

2679 Set the activation status of a tool. 

2680 

2681 Args: 

2682 db (Session): The SQLAlchemy database session. 

2683 tool_id (str): The unique identifier of the tool. 

2684 activate (bool): True to activate, False to deactivate. 

2685 reachable (bool): True if the tool is reachable. 

2686 user_email: Optional[str] The email of the user to check if the user has permission to modify. 

2687 skip_cache_invalidation: If True, skip cache invalidation (used for batch operations). 

2688 

2689 Returns: 

2690 ToolRead: The updated tool object. 

2691 

2692 Raises: 

2693 ToolNotFoundError: If the tool is not found. 

2694 ToolLockConflictError: If the tool row is locked by another transaction. 

2695 ToolError: For other errors. 

2696 PermissionError: If user doesn't own the agent. 

2697 

2698 Examples: 

2699 >>> from mcpgateway.services.tool_service import ToolService 

2700 >>> from unittest.mock import MagicMock, AsyncMock 

2701 >>> from mcpgateway.schemas import ToolRead 

2702 >>> service = ToolService() 

2703 >>> db = MagicMock() 

2704 >>> tool = MagicMock() 

2705 >>> db.get.return_value = tool 

2706 >>> db.commit = MagicMock() 

2707 >>> db.refresh = MagicMock() 

2708 >>> service._notify_tool_activated = AsyncMock() 

2709 >>> service._notify_tool_deactivated = AsyncMock() 

2710 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read') 

2711 >>> ToolRead.model_validate = MagicMock(return_value='tool_read') 

2712 >>> import asyncio 

2713 >>> asyncio.run(service.set_tool_state(db, 'tool_id', True, True)) 

2714 'tool_read' 

2715 """ 

2716 try: 

2717 # Use nowait=True to fail fast if row is locked, preventing lock contention under high load 

2718 try: 

2719 tool = get_for_update(db, DbTool, tool_id, nowait=True) 

2720 except OperationalError as lock_err: 

2721 # Row is locked by another transaction - fail fast with 409 

2722 db.rollback() 

2723 raise ToolLockConflictError(f"Tool {tool_id} is currently being modified by another request") from lock_err 

2724 if not tool: 

2725 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

2726 

2727 if user_email: 

2728 # First-Party 

2729 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2730 

2731 permission_service = PermissionService(db) 

2732 if not await permission_service.check_resource_ownership(user_email, tool): 

2733 raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool") 

2734 

2735 is_activated = is_reachable = False 

2736 if tool.enabled != activate: 

2737 tool.enabled = activate 

2738 is_activated = True 

2739 

2740 if tool.reachable != reachable: 

2741 tool.reachable = reachable 

2742 is_reachable = True 

2743 

2744 if is_activated or is_reachable: 

2745 tool.updated_at = datetime.now(timezone.utc) 

2746 

2747 db.commit() 

2748 db.refresh(tool) 

2749 

2750 # Invalidate cache after status change (skip for batch operations) 

2751 if not skip_cache_invalidation: 

2752 cache = _get_registry_cache() 

2753 await cache.invalidate_tools() 

2754 tool_lookup_cache = _get_tool_lookup_cache() 

2755 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None) 

2756 

2757 if not tool.enabled: 

2758 # Inactive 

2759 await self._notify_tool_deactivated(tool) 

2760 elif tool.enabled and not tool.reachable: 

2761 # Offline 

2762 await self._notify_tool_offline(tool) 

2763 else: 

2764 # Active 

2765 await self._notify_tool_activated(tool) 

2766 

2767 logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}") 

2768 

2769 # Structured logging: Audit trail for tool state change 

2770 audit_trail.log_action( 

2771 user_id=user_email or "system", 

2772 action="set_tool_state", 

2773 resource_type="tool", 

2774 resource_id=tool.id, 

2775 resource_name=tool.name, 

2776 user_email=user_email, 

2777 team_id=tool.team_id, 

2778 new_values={ 

2779 "enabled": tool.enabled, 

2780 "reachable": tool.reachable, 

2781 }, 

2782 context={ 

2783 "action": "activate" if activate else "deactivate", 

2784 }, 

2785 db=db, 

2786 ) 

2787 

2788 # Structured logging: Log successful tool state change 

2789 structured_logger.log( 

2790 level="INFO", 

2791 message=f"Tool {'activated' if activate else 'deactivated'} successfully", 

2792 event_type="tool_state_changed", 

2793 component="tool_service", 

2794 user_email=user_email, 

2795 team_id=tool.team_id, 

2796 resource_type="tool", 

2797 resource_id=tool.id, 

2798 custom_fields={ 

2799 "tool_name": tool.name, 

2800 "enabled": tool.enabled, 

2801 "reachable": tool.reachable, 

2802 }, 

2803 ) 

2804 

2805 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None)) 

2806 except PermissionError as e: 

2807 # Structured logging: Log permission error 

2808 structured_logger.log( 

2809 level="WARNING", 

2810 message="Tool state change failed due to permission error", 

2811 event_type="tool_state_change_permission_denied", 

2812 component="tool_service", 

2813 user_email=user_email, 

2814 resource_type="tool", 

2815 resource_id=tool_id, 

2816 error=e, 

2817 ) 

2818 raise e 

2819 except ToolLockConflictError: 

2820 # Re-raise lock conflicts without wrapping - allows 409 response 

2821 raise 

2822 except ToolNotFoundError: 

2823 # Re-raise not found without wrapping - allows 404 response 

2824 raise 

2825 except Exception as e: 

2826 db.rollback() 

2827 

2828 # Structured logging: Log generic tool state change failure 

2829 structured_logger.log( 

2830 level="ERROR", 

2831 message="Tool state change failed", 

2832 event_type="tool_state_change_failed", 

2833 component="tool_service", 

2834 user_email=user_email, 

2835 resource_type="tool", 

2836 resource_id=tool_id, 

2837 error=e, 

2838 ) 

2839 raise ToolError(f"Failed to set tool state: {str(e)}") 

2840 

2841 async def invoke_tool_direct( 

2842 self, 

2843 gateway_id: str, 

2844 name: str, 

2845 arguments: Dict[str, Any], 

2846 request_headers: Optional[Dict[str, str]] = None, 

2847 meta_data: Optional[Dict[str, Any]] = None, 

2848 user_email: Optional[str] = None, 

2849 token_teams: Optional[List[str]] = None, 

2850 ) -> types.CallToolResult: 

2851 """ 

2852 Invoke a tool directly on a remote MCP gateway in direct_proxy mode. 

2853 

2854 This bypasses all gateway processing (caching, plugins, validation) and forwards 

2855 the tool call directly to the remote MCP server, returning the raw result. 

2856 

2857 Args: 

2858 gateway_id: Gateway ID to invoke the tool on. 

2859 name: Name of tool to invoke. 

2860 arguments: Tool arguments. 

2861 request_headers: Headers from the request to pass through. 

2862 meta_data: Optional metadata dictionary for additional context (e.g., request ID). 

2863 user_email: Email of the requesting user for access control. 

2864 token_teams: Team IDs from the user's token for access control. 

2865 

2866 Returns: 

2867 CallToolResult from the remote MCP server (as-is, no normalization). 

2868 

2869 Raises: 

2870 ToolNotFoundError: If gateway not found or access denied. 

2871 ToolInvocationError: If invocation fails. 

2872 """ 

2873 logger.info(f"Direct proxy tool invocation: {name} via gateway {SecurityValidator.sanitize_log_message(gateway_id)}") 

2874 # Look up gateway 

2875 # Use a fresh session for this lookup 

2876 with fresh_db_session() as db: 

2877 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

2878 if not gateway: 

2879 raise ToolNotFoundError(f"Gateway {gateway_id} not found") 

2880 

2881 if getattr(gateway, "gateway_mode", "cache") != "direct_proxy" or not settings.mcpgateway_direct_proxy_enabled: 

2882 raise ToolInvocationError(f"Gateway {gateway_id} is not in direct_proxy mode") 

2883 

2884 # SECURITY: Defensive access check — callers should also check, 

2885 # but enforce here to prevent RBAC bypass if called from a new context. 

2886 if not await check_gateway_access(db, gateway, user_email, token_teams): 

2887 raise ToolNotFoundError(f"Tool not found: {name}") 

2888 

2889 # Prepare headers with gateway auth 

2890 headers = build_gateway_auth_headers(gateway) 

2891 

2892 # Forward passthrough headers if configured 

2893 if gateway.passthrough_headers and request_headers: 

2894 for header_name in gateway.passthrough_headers: 

2895 header_value = request_headers.get(header_name.lower()) or request_headers.get(header_name) 

2896 if header_value: 

2897 headers[header_name] = header_value 

2898 

2899 gateway_url = gateway.url 

2900 

2901 # Resolve the original (unprefixed) tool name for the remote server. 

2902 # Tools registered via gateways are stored as "{gateway_slug}{separator}{slugified_name}", 

2903 # but the remote server only knows the original name (e.g. "get_system_time" not "get-system-time"). 

2904 # Look up the tool's original_name from the DB; fall back to the prefixed name if not found 

2905 # (e.g. when calling a tool that exists on the remote but hasn't been cached locally). 

2906 remote_name = name 

2907 tool_row = db.execute(select(DbTool).where(DbTool.name == name, DbTool.gateway_id == gateway_id)).scalar_one_or_none() 

2908 if tool_row and tool_row.original_name: 

2909 remote_name = tool_row.original_name 

2910 else: 

2911 # Fallback: strip the slug prefix (best-effort for tools not yet in DB) 

2912 gateway_slug = getattr(gateway, "slug", None) or "" 

2913 if gateway_slug: 

2914 prefix = f"{gateway_slug}{settings.gateway_tool_name_separator}" 

2915 if name.startswith(prefix): 

2916 remote_name = name[len(prefix) :] 

2917 

2918 # Use MCP SDK to connect and call tool 

2919 try: 

2920 with create_span( 

2921 "mcp.client.call", 

2922 { 

2923 "mcp.tool.name": remote_name, 

2924 "contextforge.gateway_id": str(gateway.id), 

2925 "contextforge.runtime": "python", 

2926 "contextforge.transport": "streamablehttp", 

2927 "network.protocol.name": "mcp", 

2928 "server.address": urlparse(gateway_url).hostname, 

2929 "server.port": urlparse(gateway_url).port, 

2930 "url.path": urlparse(gateway_url).path or "/", 

2931 "url.full": sanitize_url_for_logging(gateway_url, {}), 

2932 }, 

2933 ): 

2934 traced_headers = inject_trace_context_headers(headers) 

2935 async with streamablehttp_client(url=gateway_url, headers=traced_headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id): 

2936 async with ClientSession(read_stream, write_stream) as session: 

2937 with create_span("mcp.client.initialize", {"contextforge.transport": "streamablehttp", "contextforge.runtime": "python"}): 

2938 await session.initialize() 

2939 

2940 with create_span( 

2941 "mcp.client.request", 

2942 { 

2943 "mcp.tool.name": remote_name, 

2944 "contextforge.gateway_id": str(gateway.id), 

2945 "contextforge.runtime": "python", 

2946 }, 

2947 ): 

2948 # Call tool with meta if provided 

2949 if meta_data: 

2950 logger.debug(f"Forwarding _meta to remote gateway: {meta_data}") 

2951 tool_result = await session.call_tool(name=remote_name, arguments=arguments, meta=meta_data) 

2952 else: 

2953 tool_result = await session.call_tool(name=remote_name, arguments=arguments) 

2954 with create_span( 

2955 "mcp.client.response", 

2956 { 

2957 "mcp.tool.name": remote_name, 

2958 "contextforge.gateway_id": str(gateway.id), 

2959 "contextforge.runtime": "python", 

2960 "upstream.response.success": not getattr(tool_result, "is_error", False) and not getattr(tool_result, "isError", False), 

2961 }, 

2962 ): 

2963 pass 

2964 

2965 logger.info( 

2966 f"[INVOKE TOOL] Using direct_proxy mode for gateway {SecurityValidator.sanitize_log_message(gateway.id)} (from X-Context-Forge-Gateway-Id header). Meta Attached: {meta_data is not None}" 

2967 ) 

2968 return tool_result 

2969 except Exception as e: 

2970 logger.exception(f"Direct proxy tool invocation failed for {name}: {e}") 

2971 raise ToolInvocationError(f"Direct proxy tool invocation failed: {str(e)}") 

2972 

2973 async def prepare_rust_mcp_tool_execution( 

2974 self, 

2975 db: Session, 

2976 name: str, 

2977 arguments: Optional[Dict[str, Any]] = None, 

2978 request_headers: Optional[Dict[str, str]] = None, 

2979 app_user_email: Optional[str] = None, 

2980 user_email: Optional[str] = None, 

2981 token_teams: Optional[List[str]] = None, 

2982 server_id: Optional[str] = None, 

2983 plugin_global_context: Optional[GlobalContext] = None, 

2984 plugin_context_table: Optional[PluginContextTable] = None, 

2985 ) -> Dict[str, Any]: 

2986 """Build a narrow MCP execution plan for the Rust runtime hot path. 

2987 

2988 This reuses Python's existing auth, scoping, and secret-handling logic, 

2989 but stops before the actual upstream MCP call. The Rust runtime can then 

2990 execute the call directly for the simple streamable HTTP MCP cases that 

2991 dominate load tests, while Python remains the authority for policy. 

2992 

2993 When tool_pre_invoke hooks are registered, they are executed during plan 

2994 resolution and their modifications (cleaned args, injected headers) are 

2995 returned in the plan for the Rust runtime to apply. 

2996 

2997 Args: 

2998 db: Active database session. 

2999 name: Tool name requested by the caller. 

3000 arguments: Tool call arguments from the JSON-RPC params (passed to pre-invoke hooks). 

3001 request_headers: Incoming request headers used for passthrough/auth decisions. 

3002 app_user_email: OAuth application user email, when present. 

3003 user_email: Effective requester email after auth normalization. 

3004 token_teams: Normalized team scope from the caller token. 

3005 server_id: Optional virtual server identifier restricting tool access. 

3006 plugin_global_context: Optional global context from middleware for hook continuity. 

3007 plugin_context_table: Optional context table from prior hooks for state sharing. 

3008 

3009 Returns: 

3010 A Rust execution plan dictionary, or a fallback descriptor when direct 

3011 Rust execution is not eligible. 

3012 

3013 Raises: 

3014 ToolNotFoundError: If the requested tool is not visible or invocable. 

3015 ToolInvocationError: If gateway auth preparation fails or the tool name is ambiguous. 

3016 """ 

3017 has_pre_invoke = self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) 

3018 has_post_invoke = self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE) 

3019 

3020 gateway_id_from_header = extract_gateway_id_from_headers(request_headers) 

3021 is_direct_proxy = False 

3022 tool = None 

3023 gateway = None 

3024 tool_selected_from_server_scope = False 

3025 tool_payload: Dict[str, Any] = {} 

3026 gateway_payload: Optional[Dict[str, Any]] = None 

3027 if gateway_id_from_header: 

3028 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none() 

3029 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled: 

3030 if not await check_gateway_access(db, gateway, user_email, token_teams): 

3031 raise ToolNotFoundError(f"Tool not found: {name}") 

3032 is_direct_proxy = True 

3033 gateway_payload = { 

3034 "id": str(gateway.id), 

3035 "name": gateway.name, 

3036 "url": gateway.url, 

3037 "auth_type": gateway.auth_type, 

3038 "auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value, 

3039 "auth_query_params": gateway.auth_query_params, 

3040 "oauth_config": gateway.oauth_config, 

3041 "ca_certificate": gateway.ca_certificate, 

3042 "ca_certificate_sig": gateway.ca_certificate_sig, 

3043 "passthrough_headers": gateway.passthrough_headers, 

3044 "gateway_mode": gateway.gateway_mode, 

3045 } 

3046 tool_payload = { 

3047 "id": None, 

3048 "name": name, 

3049 "original_name": name, 

3050 "enabled": True, 

3051 "reachable": True, 

3052 "integration_type": "MCP", 

3053 "request_type": "streamablehttp", 

3054 "gateway_id": str(gateway.id), 

3055 } 

3056 

3057 if not is_direct_proxy: 

3058 tool_lookup_cache = _get_tool_lookup_cache() 

3059 cached_payload = await tool_lookup_cache.get(name) if tool_lookup_cache.enabled else None 

3060 

3061 if cached_payload: 

3062 status = cached_payload.get("status", "active") 

3063 if status == "missing": 

3064 raise ToolNotFoundError(f"Tool not found: {name}") 

3065 if status == "inactive": 

3066 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

3067 if status == "offline": 

3068 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

3069 tool_payload = cached_payload.get("tool") or {} 

3070 gateway_payload = cached_payload.get("gateway") 

3071 

3072 if not tool_payload: 

3073 tools = self._load_invocable_tools(db, name, server_id=server_id) 

3074 tool_selected_from_server_scope = bool(server_id) 

3075 

3076 if not tools: 

3077 raise ToolNotFoundError(f"Tool not found: {name}") 

3078 

3079 multiple_found = len(tools) > 1 

3080 if not multiple_found: 

3081 tool = tools[0] 

3082 else: 

3083 visibility_priority = {"team": 0, "private": 1, "public": 2} 

3084 accessible_tools: list[tuple[int, Any]] = [] 

3085 for candidate in tools: 

3086 tool_dict = {"visibility": candidate.visibility, "team_id": candidate.team_id, "owner_email": candidate.owner_email} 

3087 if await self._check_tool_access(db, tool_dict, user_email, token_teams): 

3088 priority = visibility_priority.get(candidate.visibility, 99) 

3089 accessible_tools.append((priority, candidate)) 

3090 

3091 if not accessible_tools: 

3092 raise ToolNotFoundError(f"Tool not found: {name}") 

3093 

3094 accessible_tools.sort(key=lambda item: item[0]) 

3095 best_priority = accessible_tools[0][0] 

3096 best_tools = [candidate for priority, candidate in accessible_tools if priority == best_priority] 

3097 if len(best_tools) > 1: 

3098 raise ToolInvocationError(f"Multiple tools found with name '{name}' at same priority level. Tool name is ambiguous.") 

3099 tool = best_tools[0] 

3100 

3101 if not tool.enabled: 

3102 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

3103 

3104 if not tool.reachable: 

3105 await tool_lookup_cache.set_negative(name, "offline") 

3106 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

3107 

3108 gateway = tool.gateway 

3109 cache_payload = self._build_tool_cache_payload(tool, gateway) 

3110 tool_payload = cache_payload.get("tool") or {} 

3111 gateway_payload = cache_payload.get("gateway") 

3112 if not multiple_found: 

3113 await tool_lookup_cache.set(name, cache_payload, gateway_id=tool_payload.get("gateway_id")) 

3114 

3115 if tool_payload.get("enabled") is False: 

3116 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

3117 if tool_payload.get("reachable") is False: 

3118 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

3119 

3120 if is_direct_proxy: 

3121 return {"eligible": False, "fallbackReason": "direct-proxy"} 

3122 

3123 if not await self._check_tool_access(db, tool_payload, user_email, token_teams): 

3124 raise ToolNotFoundError(f"Tool not found: {name}") 

3125 

3126 if server_id and not tool_selected_from_server_scope: 

3127 tool_id_for_check = tool_payload.get("id") 

3128 if not tool_id_for_check: 

3129 raise ToolNotFoundError(f"Tool not found: {name}") 

3130 server_match = db.execute( 

3131 select(server_tool_association.c.tool_id).where( 

3132 server_tool_association.c.server_id == server_id, 

3133 server_tool_association.c.tool_id == tool_id_for_check, 

3134 ) 

3135 ).first() 

3136 if not server_match: 

3137 raise ToolNotFoundError(f"Tool not found: {name}") 

3138 

3139 tool_integration_type = tool_payload.get("integration_type") 

3140 if tool_integration_type != "MCP": 

3141 return {"eligible": False, "fallbackReason": f"unsupported-integration:{tool_integration_type or 'unknown'}"} 

3142 

3143 tool_request_type = tool_payload.get("request_type") 

3144 transport = tool_request_type.lower() if tool_request_type else "sse" 

3145 if transport not in {"streamablehttp", "sse"}: 

3146 return {"eligible": False, "fallbackReason": f"unsupported-transport:{transport}"} 

3147 

3148 tool_jsonpath_filter = tool_payload.get("jsonpath_filter") 

3149 if tool_jsonpath_filter: 

3150 return {"eligible": False, "fallbackReason": "jsonpath-filter-configured"} 

3151 

3152 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers) 

3153 

3154 if tool is not None: 

3155 gateway = tool.gateway 

3156 

3157 tool_name_original = tool_payload.get("original_name") or tool_payload.get("name") or name 

3158 tool_id = tool_payload.get("id") 

3159 tool_gateway_id = tool_payload.get("gateway_id") 

3160 tool_timeout_ms = tool_payload.get("timeout_ms") 

3161 effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout 

3162 

3163 has_gateway = gateway_payload is not None 

3164 gateway_url = gateway_payload.get("url") if has_gateway else None 

3165 gateway_name = gateway_payload.get("name") if has_gateway else None 

3166 gateway_auth_type = gateway_payload.get("auth_type") if has_gateway else None 

3167 gateway_auth_value = gateway_payload.get("auth_value") if has_gateway and isinstance(gateway_payload.get("auth_value"), str) else None 

3168 gateway_auth_query_params = gateway_payload.get("auth_query_params") if has_gateway and isinstance(gateway_payload.get("auth_query_params"), dict) else None 

3169 gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None 

3170 if has_gateway and gateway is not None: 

3171 runtime_gateway_auth_value = getattr(gateway, "auth_value", None) 

3172 if isinstance(runtime_gateway_auth_value, dict): 

3173 gateway_auth_value = encode_auth(runtime_gateway_auth_value) 

3174 elif isinstance(runtime_gateway_auth_value, str): 

3175 gateway_auth_value = runtime_gateway_auth_value 

3176 runtime_gateway_query_params = getattr(gateway, "auth_query_params", None) 

3177 if isinstance(runtime_gateway_query_params, dict): 

3178 gateway_auth_query_params = runtime_gateway_query_params 

3179 runtime_gateway_oauth_config = getattr(gateway, "oauth_config", None) 

3180 if isinstance(runtime_gateway_oauth_config, dict): 

3181 gateway_oauth_config = runtime_gateway_oauth_config 

3182 gateway_ca_cert = gateway_payload.get("ca_certificate") if has_gateway else None 

3183 gateway_id_str = gateway_payload.get("id") if has_gateway else None 

3184 

3185 if tool is None and has_gateway: 

3186 requires_gateway_auth_hydration = gateway_auth_type in {"basic", "bearer", "authheaders", "oauth", "query_param"} 

3187 if requires_gateway_auth_hydration: 

3188 tool_id_for_hydration = tool_payload.get("id") 

3189 if tool_id_for_hydration: 

3190 tool_auth_row = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.id == tool_id_for_hydration)).scalar_one_or_none() 

3191 if tool_auth_row and tool_auth_row.gateway: 

3192 hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None) 

3193 if isinstance(hydrated_gateway_auth_value, dict): 

3194 gateway_auth_value = encode_auth(hydrated_gateway_auth_value) 

3195 elif isinstance(hydrated_gateway_auth_value, str): 

3196 gateway_auth_value = hydrated_gateway_auth_value 

3197 hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None) 

3198 if isinstance(hydrated_gateway_query_params, dict): 

3199 gateway_auth_query_params = hydrated_gateway_query_params 

3200 hydrated_gateway_oauth_config = getattr(tool_auth_row.gateway, "oauth_config", None) 

3201 if isinstance(hydrated_gateway_oauth_config, dict): 

3202 gateway_oauth_config = hydrated_gateway_oauth_config 

3203 

3204 gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None 

3205 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

3206 gateway_auth_query_params_decrypted = {} 

3207 for param_key, encrypted_value in gateway_auth_query_params.items(): 

3208 if encrypted_value: 

3209 try: 

3210 decrypted = decode_auth(encrypted_value) 

3211 gateway_auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

3212 except Exception: # noqa: S110 

3213 logger.debug(f"Failed to decrypt query param '{param_key}' for Rust MCP tool execution plan") 

3214 if gateway_auth_query_params_decrypted and gateway_url: 

3215 gateway_url = apply_query_param_auth(gateway_url, gateway_auth_query_params_decrypted) 

3216 

3217 if gateway_ca_cert: 

3218 return {"eligible": False, "fallbackReason": "custom-ca-certificate"} 

3219 

3220 if not gateway_url: 

3221 return {"eligible": False, "fallbackReason": "missing-gateway-url"} 

3222 

3223 if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config: 

3224 grant_type = gateway_oauth_config.get("grant_type", "client_credentials") 

3225 if grant_type == "authorization_code": 

3226 try: 

3227 # First-Party 

3228 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

3229 

3230 with fresh_db_session() as token_db: 

3231 token_storage = TokenStorageService(token_db) 

3232 if not app_user_email: 

3233 raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.") 

3234 access_token = await token_storage.get_user_token(gateway_id_str, app_user_email) 

3235 

3236 if access_token: 

3237 headers = {"Authorization": f"Bearer {access_token}"} 

3238 else: 

3239 raise ToolInvocationError(f"Please authorize {gateway_name} first. Visit /oauth/authorize/{gateway_id_str} to complete OAuth flow.") 

3240 except Exception as e: 

3241 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}") 

3242 raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}") 

3243 else: 

3244 try: 

3245 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config) 

3246 headers = {"Authorization": f"Bearer {access_token}"} 

3247 except Exception as e: 

3248 logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}") 

3249 raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") 

3250 else: 

3251 headers = decode_auth(gateway_auth_value) if gateway_auth_value else {} 

3252 

3253 if request_headers: 

3254 headers = compute_passthrough_headers_cached( 

3255 request_headers, 

3256 headers, 

3257 passthrough_allowed, 

3258 gateway_auth_type=gateway_auth_type, 

3259 gateway_passthrough_headers=gateway_payload.get("passthrough_headers") if has_gateway else None, 

3260 ) 

3261 

3262 runtime_headers = {str(header_name): str(header_value) for header_name, header_value in headers.items() if header_name and header_value} 

3263 

3264 hook_global_context = None 

3265 if has_pre_invoke or has_post_invoke: 

3266 hook_global_context = self._build_rust_tool_hook_global_context( 

3267 app_user_email=app_user_email, 

3268 server_id=server_id, 

3269 tool_gateway_id=tool_gateway_id, 

3270 plugin_global_context=plugin_global_context, 

3271 tool_payload=tool_payload, 

3272 gateway_payload=gateway_payload, 

3273 ) 

3274 

3275 native_post_invoke_retry_policy = None 

3276 if has_post_invoke: 

3277 native_post_invoke_retry_policy, requires_python_fallback = self._build_rust_native_tool_post_invoke_retry_policy(name, hook_global_context) 

3278 if requires_python_fallback: 

3279 return {"eligible": False, "fallbackReason": "post-invoke-hooks-configured"} 

3280 

3281 # Run tool_pre_invoke hooks so that plugins (e.g. wxo_connections) can 

3282 # inject credentials and clean arguments before the Rust direct call. 

3283 modified_args = arguments 

3284 if has_pre_invoke and arguments is not None: 

3285 pre_result, _ = await self._plugin_manager.invoke_hook( 

3286 ToolHookType.TOOL_PRE_INVOKE, 

3287 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=dict(runtime_headers))), 

3288 global_context=hook_global_context, 

3289 local_contexts=plugin_context_table, 

3290 violations_as_exceptions=True, 

3291 ) 

3292 if pre_result.modified_payload: 

3293 modified_args = pre_result.modified_payload.args 

3294 if pre_result.modified_payload.name and pre_result.modified_payload.name != name: 

3295 tool_name_original = pre_result.modified_payload.name 

3296 if pre_result.modified_payload.headers is not None: 

3297 plugin_headers = pre_result.modified_payload.headers.root if hasattr(pre_result.modified_payload.headers, "root") else {} 

3298 for hk, hv in plugin_headers.items(): 

3299 if hk and hv: 

3300 runtime_headers[str(hk).lower()] = str(hv) 

3301 

3302 runtime_headers = inject_trace_context_headers(runtime_headers) 

3303 

3304 plan: Dict[str, Any] = { 

3305 "eligible": True, 

3306 "transport": transport, 

3307 "serverUrl": gateway_url, 

3308 "remoteToolName": tool_name_original, 

3309 "headers": runtime_headers, 

3310 "timeoutMs": int(effective_timeout * 1000), 

3311 "gatewayId": tool_gateway_id, 

3312 "toolName": name, 

3313 "toolId": tool_id or None, 

3314 "serverId": server_id, 

3315 } 

3316 if native_post_invoke_retry_policy is not None: 

3317 plan["postInvokeRetryPolicy"] = native_post_invoke_retry_policy 

3318 if has_pre_invoke: 

3319 plan["hasPreInvokeHooks"] = True 

3320 if modified_args is not None: 

3321 plan["modifiedArgs"] = modified_args 

3322 return plan 

3323 

3324 def _build_rust_tool_hook_global_context( 

3325 self, 

3326 *, 

3327 app_user_email: Optional[str], 

3328 server_id: Optional[str], 

3329 tool_gateway_id: Optional[str], 

3330 plugin_global_context: Optional[GlobalContext], 

3331 tool_payload: Optional[Dict[str, Any]], 

3332 gateway_payload: Optional[Dict[str, Any]], 

3333 ) -> GlobalContext: 

3334 """Build plugin global context for Rust-direct tool plan resolution. 

3335 

3336 Args: 

3337 app_user_email: Effective authenticated user for plugin context. 

3338 server_id: Explicit virtual server scope from the request. 

3339 tool_gateway_id: Resolved tool gateway id. 

3340 plugin_global_context: Existing middleware context if available. 

3341 tool_payload: Resolved tool payload. 

3342 gateway_payload: Resolved gateway payload. 

3343 

3344 Returns: 

3345 GlobalContext primed with the same metadata the Python invoke path exposes. 

3346 """ 

3347 if plugin_global_context: 

3348 hook_global_context = plugin_global_context 

3349 if tool_gateway_id and isinstance(tool_gateway_id, str): 

3350 hook_global_context.server_id = tool_gateway_id 

3351 if not hook_global_context.user and app_user_email and isinstance(app_user_email, str): 

3352 hook_global_context.user = app_user_email 

3353 else: 

3354 request_id = get_correlation_id() or uuid.uuid4().hex 

3355 context_server_id = tool_gateway_id if tool_gateway_id and isinstance(tool_gateway_id, str) else server_id 

3356 hook_global_context = GlobalContext(request_id=request_id, server_id=context_server_id, tenant_id=None, user=app_user_email) 

3357 

3358 tool_metadata: Optional[PydanticTool] = self._pydantic_tool_from_payload(tool_payload) if tool_payload else None 

3359 gateway_metadata: Optional[PydanticGateway] = self._pydantic_gateway_from_payload(gateway_payload) if gateway_payload else None 

3360 if tool_metadata: 

3361 hook_global_context.metadata[TOOL_METADATA] = tool_metadata 

3362 if gateway_metadata: 

3363 hook_global_context.metadata[GATEWAY_METADATA] = gateway_metadata 

3364 return hook_global_context 

3365 

3366 def _build_rust_native_tool_post_invoke_retry_policy( 

3367 self, 

3368 tool_name: str, 

3369 hook_global_context: Optional[GlobalContext], 

3370 ) -> Tuple[Optional[Dict[str, Any]], bool]: 

3371 """Return a native Rust retry policy when the active post-invoke hooks allow it. 

3372 

3373 The Rust runtime only supports native post-invoke execution for the 

3374 default retry-with-backoff plugin. Any other active `tool_post_invoke` 

3375 hook must still force the call back to Python to preserve plugin semantics. 

3376 

3377 Args: 

3378 tool_name: Requested tool name. 

3379 hook_global_context: Resolved plugin context for condition matching. 

3380 

3381 Returns: 

3382 Tuple of `(policy, requires_python_fallback)`. 

3383 """ 

3384 if not self._plugin_manager or not self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 

3385 return (None, False) 

3386 

3387 # First-Party 

3388 from mcpgateway.plugins.framework import PluginMode # pylint: disable=import-outside-toplevel 

3389 from mcpgateway.plugins.framework.utils import payload_matches # pylint: disable=import-outside-toplevel 

3390 

3391 # Third-Party/Local 

3392 from plugins.retry_with_backoff.retry_with_backoff import RetryConfig # pylint: disable=import-outside-toplevel 

3393 

3394 global_context = hook_global_context or GlobalContext(request_id=get_correlation_id() or uuid.uuid4().hex) 

3395 payload = ToolPostInvokePayload(name=tool_name, result={}) 

3396 hook_refs = self._plugin_manager._registry.get_hook_refs_for_hook(hook_type=ToolHookType.TOOL_POST_INVOKE) # pylint: disable=protected-access 

3397 

3398 active_hook_refs = [] 

3399 for hook_ref in hook_refs: 

3400 if hook_ref.plugin_ref.mode == PluginMode.DISABLED: 

3401 continue 

3402 if hook_ref.plugin_ref.conditions and not payload_matches(payload, ToolHookType.TOOL_POST_INVOKE, hook_ref.plugin_ref.conditions, global_context): 

3403 continue 

3404 active_hook_refs.append(hook_ref) 

3405 

3406 if not active_hook_refs: 

3407 return (None, False) 

3408 

3409 if len(active_hook_refs) != 1 or active_hook_refs[0].plugin_ref.name != "RetryWithBackoffPlugin": 

3410 return (None, True) 

3411 

3412 retry_hook = active_hook_refs[0] 

3413 effective_cfg = RetryConfig(**(retry_hook.plugin_ref.plugin.config.config or {})) 

3414 ceiling = settings.max_tool_retries 

3415 if effective_cfg.max_retries > ceiling: 

3416 effective_cfg = effective_cfg.model_copy(update={"max_retries": ceiling}) 

3417 

3418 overrides = effective_cfg.tool_overrides.get(tool_name) 

3419 if overrides: 

3420 merged_cfg = effective_cfg.model_dump() 

3421 merged_cfg.update(overrides) 

3422 merged_cfg.pop("tool_overrides", None) 

3423 effective_cfg = RetryConfig(**merged_cfg) 

3424 if effective_cfg.max_retries > ceiling: 

3425 effective_cfg = effective_cfg.model_copy(update={"max_retries": ceiling}) 

3426 

3427 if effective_cfg.check_text_content: 

3428 return (None, True) 

3429 

3430 return ( 

3431 { 

3432 "kind": "retry_with_backoff", 

3433 "maxRetries": int(effective_cfg.max_retries), 

3434 "backoffBaseMs": int(effective_cfg.backoff_base_ms), 

3435 "maxBackoffMs": int(effective_cfg.max_backoff_ms), 

3436 "retryOnStatus": list(effective_cfg.retry_on_status), 

3437 "jitter": bool(effective_cfg.jitter), 

3438 }, 

3439 False, 

3440 ) 

3441 

3442 def _load_invocable_tools(self, db: Session, name: str, server_id: Optional[str] = None) -> List[DbTool]: 

3443 """Load candidate tools for invocation, narrowing to a virtual server when possible. 

3444 

3445 Args: 

3446 db: Active database session. 

3447 name: Tool name to resolve. 

3448 server_id: Optional virtual server identifier used to constrain results. 

3449 

3450 Returns: 

3451 A list of candidate tool ORM rows matching the request. 

3452 """ 

3453 query = select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.name == name) 

3454 if server_id: 

3455 query = query.join(server_tool_association, DbTool.id == server_tool_association.c.tool_id).where(server_tool_association.c.server_id == server_id) 

3456 return db.execute(query).scalars().all() 

3457 

3458 # ------------------------------------------------------------------ 

3459 # Retry helpers (used by invoke_tool) 

3460 # ------------------------------------------------------------------ 

3461 

3462 async def _run_timeout_post_invoke( 

3463 self, 

3464 name: str, 

3465 effective_timeout: float, 

3466 global_context: Any, 

3467 context_table: Any, 

3468 ) -> None: 

3469 """Invoke post-invoke plugins after a timeout and raise with retry signal if requested. 

3470 

3471 Called from each transport-specific timeout handler so the retry plugin 

3472 can record the failure and (optionally) request a retry. If the plugin 

3473 sets ``retry_delay_ms > 0``, a ``ToolTimeoutError`` carrying the delay 

3474 is raised immediately; otherwise control returns to the caller which 

3475 raises a plain ``ToolTimeoutError``. 

3476 

3477 Args: 

3478 name: Tool name. 

3479 effective_timeout: Timeout duration in seconds. 

3480 global_context: Plugin global context for cross-hook state. 

3481 context_table: Plugin local context table for per-plugin state. 

3482 

3483 Raises: 

3484 ToolTimeoutError: When the retry plugin requests a delayed retry. 

3485 """ 

3486 if context_table: 

3487 for ctx in context_table.values(): 

3488 ctx.set_state("cb_timeout_failure", True) 

3489 

3490 if not self._plugin_manager: 

3491 return 

3492 

3493 if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 

3494 timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True) 

3495 timeout_post_result, _ = await self._plugin_manager.invoke_hook( 

3496 ToolHookType.TOOL_POST_INVOKE, 

3497 payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)), 

3498 global_context=global_context, 

3499 local_contexts=context_table, 

3500 violations_as_exceptions=False, 

3501 ) 

3502 if timeout_post_result and timeout_post_result.retry_delay_ms > 0: 

3503 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s", retry_delay_ms=timeout_post_result.retry_delay_ms) 

3504 

3505 async def _retry_tool_invocation( 

3506 self, 

3507 delay_ms: int, 

3508 retry_attempt: int, 

3509 name: str, 

3510 arguments: Dict[str, Any], 

3511 request_headers: Any, 

3512 app_user_email: Optional[str], 

3513 user_email: Optional[str], 

3514 token_teams: Optional[List[str]], 

3515 server_id: Optional[str], 

3516 context_table: Any, 

3517 global_context: Any, 

3518 meta_data: Optional[Dict[str, Any]], 

3519 skip_pre_invoke: bool, 

3520 path_label: str, 

3521 ) -> "ToolResult": 

3522 """Sleep for the plugin-requested delay, then recursively re-invoke the tool. 

3523 

3524 The sleep is cancellation-aware: if the calling task is cancelled (e.g. 

3525 client disconnect) the ``CancelledError`` propagates immediately instead 

3526 of wasting time on a retry that nobody will consume. 

3527 

3528 Args: 

3529 delay_ms: Backoff delay in milliseconds before retrying. 

3530 retry_attempt: Current zero-based retry counter. 

3531 name: Tool name to re-invoke. 

3532 arguments: Tool arguments to forward. 

3533 request_headers: Original request headers. 

3534 app_user_email: ContextForge user email for OAuth. 

3535 user_email: User email for authorization. 

3536 token_teams: Team IDs from JWT token. 

3537 server_id: Virtual server ID for scoping. 

3538 context_table: Plugin local context table. 

3539 global_context: Plugin global context. 

3540 meta_data: Optional metadata dictionary. 

3541 skip_pre_invoke: Whether to skip pre-invoke hooks. 

3542 path_label: Label for log messages (success/timeout/exception). 

3543 

3544 Returns: 

3545 ToolResult from the retried invocation. 

3546 """ 

3547 logger.debug( 

3548 "tool_service: retry requested (%s) for tool=%s attempt=%d/%d delay_ms=%d", 

3549 path_label, 

3550 name, 

3551 retry_attempt + 1, 

3552 settings.max_tool_retries, 

3553 delay_ms, 

3554 ) 

3555 await asyncio.sleep(delay_ms / 1000) 

3556 with fresh_db_session() as retry_db: 

3557 return await self.invoke_tool( 

3558 db=retry_db, 

3559 name=name, 

3560 arguments=arguments, 

3561 request_headers=request_headers, 

3562 app_user_email=app_user_email, 

3563 user_email=user_email, 

3564 token_teams=token_teams, 

3565 server_id=server_id, 

3566 plugin_context_table=context_table, 

3567 plugin_global_context=global_context, 

3568 meta_data=meta_data, 

3569 skip_pre_invoke=skip_pre_invoke, 

3570 retry_attempt=retry_attempt + 1, 

3571 ) 

3572 

3573 async def invoke_tool( 

3574 self, 

3575 db: Session, 

3576 name: str, 

3577 arguments: Dict[str, Any], 

3578 request_headers: Optional[Dict[str, str]] = None, 

3579 app_user_email: Optional[str] = None, 

3580 user_email: Optional[str] = None, 

3581 token_teams: Optional[List[str]] = None, 

3582 server_id: Optional[str] = None, 

3583 plugin_context_table: Optional[PluginContextTable] = None, 

3584 plugin_global_context: Optional[GlobalContext] = None, 

3585 meta_data: Optional[Dict[str, Any]] = None, 

3586 skip_pre_invoke: bool = False, 

3587 retry_attempt: int = 0, 

3588 ) -> ToolResult: 

3589 """ 

3590 Invoke a registered tool and record execution metrics. 

3591 

3592 Args: 

3593 db: Database session. 

3594 name: Name of tool to invoke. 

3595 arguments: Tool arguments. 

3596 request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. 

3597 Defaults to None. 

3598 app_user_email (Optional[str], optional): ContextForge user email for OAuth token retrieval. 

3599 Required for OAuth-protected gateways. 

3600 user_email (Optional[str], optional): User email for authorization checks. 

3601 None = unauthenticated request. 

3602 token_teams (Optional[List[str]], optional): Team IDs from JWT token for authorization. 

3603 None = unrestricted admin, [] = public-only, [...] = team-scoped. 

3604 server_id (Optional[str], optional): Virtual server ID for server scoping enforcement. 

3605 If provided, tool must be attached to this server. 

3606 plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing. 

3607 plugin_global_context: Optional global context from middleware for consistency across hooks. 

3608 meta_data: Optional metadata dictionary for additional context (e.g., request ID). 

3609 skip_pre_invoke: When True, skip TOOL_PRE_INVOKE hooks (used by trusted Rust fallback path). 

3610 retry_attempt: Zero-based retry counter; 0 = original call. Incremented by the retry 

3611 loop and compared against ``settings.max_tool_retries``. 

3612 

3613 Returns: 

3614 Tool invocation result. 

3615 

3616 Raises: 

3617 ToolNotFoundError: If tool not found or access denied. 

3618 ToolInvocationError: If invocation fails or A2A authentication decryption fails. 

3619 ToolTimeoutError: If tool invocation times out. 

3620 PluginViolationError: If plugin blocks tool invocation. 

3621 PluginError: If encounters issue with plugin. 

3622 

3623 Examples: 

3624 >>> # Note: This method requires extensive mocking of SQLAlchemy models, 

3625 >>> # database relationships, and caching infrastructure, which is not 

3626 >>> # suitable for doctests. See tests/unit/mcpgateway/services/test_tool_service.py 

3627 >>> pass # doctest: +SKIP 

3628 """ 

3629 # pylint: disable=comparison-with-callable 

3630 logger.info(f"Invoking tool: {name} with arguments: {arguments.keys() if arguments else None} and headers: {request_headers.keys() if request_headers else None}, server_id={server_id}") 

3631 # ═══════════════════════════════════════════════════════════════════════════ 

3632 # PHASE 1: Check for X-Context-Forge-Gateway-Id header for direct_proxy mode (no DB lookup) 

3633 # ═══════════════════════════════════════════════════════════════════════════ 

3634 gateway_id_from_header = extract_gateway_id_from_headers(request_headers) 

3635 

3636 # If X-Context-Forge-Gateway-Id header is present, check if gateway is in direct_proxy mode 

3637 is_direct_proxy = False 

3638 tool = None 

3639 gateway = None 

3640 tool_payload: Dict[str, Any] = {} 

3641 gateway_payload: Optional[Dict[str, Any]] = None 

3642 

3643 if gateway_id_from_header: 

3644 # Look up gateway to check if it's in direct_proxy mode 

3645 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none() 

3646 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled: 

3647 # SECURITY: Check gateway access before allowing direct proxy 

3648 # This prevents RBAC bypass where any authenticated user could invoke tools 

3649 # on any gateway just by knowing the gateway ID 

3650 if not await check_gateway_access(db, gateway, user_email, token_teams): 

3651 logger.warning(f"Access denied to gateway {gateway_id_from_header} in direct_proxy mode for user {SecurityValidator.sanitize_log_message(user_email)}") 

3652 raise ToolNotFoundError(f"Tool not found: {name}") 

3653 

3654 is_direct_proxy = True 

3655 # Build minimal gateway payload for direct proxy (no tool lookup needed) 

3656 gateway_payload = { 

3657 "id": str(gateway.id), 

3658 "name": gateway.name, 

3659 "url": gateway.url, 

3660 "auth_type": gateway.auth_type, 

3661 # DbGateway.auth_value is JSON (dict); downstream code expects an encoded str. 

3662 "auth_value": encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value, 

3663 "auth_query_params": gateway.auth_query_params, 

3664 "oauth_config": gateway.oauth_config, 

3665 "ca_certificate": gateway.ca_certificate, 

3666 "ca_certificate_sig": gateway.ca_certificate_sig, 

3667 "passthrough_headers": gateway.passthrough_headers, 

3668 "gateway_mode": gateway.gateway_mode, 

3669 } 

3670 # Create minimal tool payload for direct proxy (no DB tool needed) 

3671 tool_payload = { 

3672 "id": None, # No tool ID in direct proxy mode 

3673 "name": name, 

3674 "original_name": name, 

3675 "enabled": True, 

3676 "reachable": True, 

3677 "integration_type": "MCP", 

3678 "request_type": "streamablehttp", # Default to streamablehttp 

3679 "gateway_id": str(gateway.id), 

3680 } 

3681 logger.info(f"Direct proxy mode via X-Context-Forge-Gateway-Id header: passing tool '{name}' directly to remote MCP server at {gateway.url}") 

3682 elif gateway: 

3683 logger.debug(f"Gateway {gateway_id_from_header} found but not in direct_proxy mode (mode: {gateway.gateway_mode}), using normal lookup") 

3684 else: 

3685 logger.warning(f"Gateway {gateway_id_from_header} specified in X-Context-Forge-Gateway-Id header not found") 

3686 

3687 # Normal mode: look up tool in database/cache 

3688 if not is_direct_proxy: 

3689 tool_lookup_cache = _get_tool_lookup_cache() 

3690 cached_payload = await tool_lookup_cache.get(name) if tool_lookup_cache.enabled else None 

3691 

3692 if cached_payload: 

3693 status = cached_payload.get("status", "active") 

3694 if status == "missing": 

3695 raise ToolNotFoundError(f"Tool not found: {name}") 

3696 if status == "inactive": 

3697 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

3698 if status == "offline": 

3699 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

3700 tool_payload = cached_payload.get("tool") or {} 

3701 gateway_payload = cached_payload.get("gateway") 

3702 

3703 if not tool_payload: 

3704 # Eager load tool WITH gateway in single query to prevent lazy load N+1 

3705 # Use a single query to avoid a race between separate enabled/inactive lookups. 

3706 # Use scalars().all() instead of scalar_one_or_none() to handle duplicate 

3707 # tool names across teams without crashing on MultipleResultsFound. 

3708 tools = self._load_invocable_tools(db, name, server_id=server_id) 

3709 

3710 if not tools: 

3711 raise ToolNotFoundError(f"Tool not found: {name}") 

3712 

3713 multiple_found = len(tools) > 1 

3714 if not multiple_found: 

3715 tool = tools[0] 

3716 else: 

3717 # Multiple tools found with same name — filter by access using 

3718 # _check_tool_access (same rules as list_tools) and prioritize. 

3719 # Priority (lower is better): team (0) > private (1) > public (2) 

3720 visibility_priority = {"team": 0, "private": 1, "public": 2} 

3721 accessible_tools: list[tuple[int, Any]] = [] 

3722 for t in tools: 

3723 tool_dict = {"visibility": t.visibility, "team_id": t.team_id, "owner_email": t.owner_email} 

3724 if await self._check_tool_access(db, tool_dict, user_email, token_teams): 

3725 priority = visibility_priority.get(t.visibility, 99) 

3726 accessible_tools.append((priority, t)) 

3727 

3728 if not accessible_tools: 

3729 raise ToolNotFoundError(f"Tool not found: {name}") 

3730 

3731 accessible_tools.sort(key=lambda x: x[0]) 

3732 

3733 # Check for ambiguity at the highest priority level 

3734 best_priority = accessible_tools[0][0] 

3735 best_tools = [t for p, t in accessible_tools if p == best_priority] 

3736 

3737 if len(best_tools) > 1: 

3738 raise ToolInvocationError(f"Multiple tools found with name '{name}' at same priority level. Tool name is ambiguous.") 

3739 

3740 tool = best_tools[0] 

3741 

3742 if not tool.enabled: 

3743 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

3744 

3745 if not tool.reachable: 

3746 await tool_lookup_cache.set_negative(name, "offline") 

3747 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

3748 

3749 gateway = tool.gateway 

3750 cache_payload = self._build_tool_cache_payload(tool, gateway) 

3751 tool_payload = cache_payload.get("tool") or {} 

3752 gateway_payload = cache_payload.get("gateway") 

3753 # Skip caching when multiple tools share a name — resolution is 

3754 # user-dependent, so a cached result could be wrong for other users. 

3755 if not multiple_found: 

3756 await tool_lookup_cache.set(name, cache_payload, gateway_id=tool_payload.get("gateway_id")) 

3757 

3758 if tool_payload.get("enabled") is False: 

3759 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

3760 if tool_payload.get("reachable") is False: 

3761 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

3762 

3763 # ═══════════════════════════════════════════════════════════════════════════ 

3764 # SECURITY: Check tool access based on visibility and team membership 

3765 # Skip these checks for direct_proxy mode (no tool in database) 

3766 # ═══════════════════════════════════════════════════════════════════════════ 

3767 if not is_direct_proxy: 

3768 if not await self._check_tool_access(db, tool_payload, user_email, token_teams): 

3769 # Don't reveal tool existence - return generic "not found" 

3770 raise ToolNotFoundError(f"Tool not found: {name}") 

3771 

3772 # ═══════════════════════════════════════════════════════════════════════════ 

3773 # SECURITY: Enforce server scoping if server_id is provided 

3774 # Tool must be attached to the specified virtual server 

3775 # ═══════════════════════════════════════════════════════════════════════════ 

3776 if server_id: 

3777 tool_id_for_check = tool_payload.get("id") 

3778 if not tool_id_for_check: 

3779 # Cannot verify server membership without tool ID - deny access 

3780 # This should not happen with properly cached tools, but fail safe 

3781 logger.warning(f"Tool '{name}' has no ID in payload, cannot verify server membership") 

3782 raise ToolNotFoundError(f"Tool not found: {name}") 

3783 

3784 server_match = db.execute( 

3785 select(server_tool_association.c.tool_id).where( 

3786 server_tool_association.c.server_id == server_id, 

3787 server_tool_association.c.tool_id == tool_id_for_check, 

3788 ) 

3789 ).first() 

3790 if not server_match: 

3791 raise ToolNotFoundError(f"Tool not found: {name}") 

3792 

3793 # Extract A2A-related data from annotations (will be used after db.close() if A2A tool) 

3794 tool_annotations = tool_payload.get("annotations") or {} 

3795 tool_integration_type = tool_payload.get("integration_type") 

3796 

3797 # Get passthrough headers from in-memory cache (Issue #1715) 

3798 # This eliminates 42,000+ redundant DB queries under load 

3799 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers) 

3800 

3801 # Access gateway now (already eager-loaded) to prevent later lazy load 

3802 if tool is not None: 

3803 gateway = tool.gateway 

3804 

3805 # ═══════════════════════════════════════════════════════════════════════════ 

3806 # PHASE 2: Extract all needed data to local variables before network I/O 

3807 # This allows us to release the DB session before making HTTP calls 

3808 # ═══════════════════════════════════════════════════════════════════════════ 

3809 tool_id = tool_payload.get("id") or (str(tool.id) if tool else "") 

3810 tool_name_original = tool_payload.get("original_name") or tool_payload.get("name") or name 

3811 tool_name_computed = tool_payload.get("name") or name 

3812 tool_url = tool_payload.get("url") 

3813 tool_integration_type = tool_payload.get("integration_type") 

3814 tool_request_type = tool_payload.get("request_type") 

3815 tool_headers = _decrypt_tool_headers_for_runtime(tool_payload.get("headers") or {}) 

3816 tool_auth_type = tool_payload.get("auth_type") 

3817 tool_auth_value = tool_payload.get("auth_value") 

3818 if tool is not None: 

3819 runtime_tool_auth_value = getattr(tool, "auth_value", None) 

3820 if isinstance(runtime_tool_auth_value, str): 

3821 tool_auth_value = runtime_tool_auth_value 

3822 if not isinstance(tool_auth_value, str): 

3823 tool_auth_value = None 

3824 tool_jsonpath_filter = tool_payload.get("jsonpath_filter") 

3825 tool_output_schema = tool_payload.get("output_schema") 

3826 tool_oauth_config = tool_payload.get("oauth_config") if isinstance(tool_payload.get("oauth_config"), dict) else None 

3827 if tool is not None: 

3828 runtime_tool_oauth_config = getattr(tool, "oauth_config", None) 

3829 if isinstance(runtime_tool_oauth_config, dict): 

3830 tool_oauth_config = runtime_tool_oauth_config 

3831 tool_gateway_id = tool_payload.get("gateway_id") 

3832 

3833 # Get effective timeout: per-tool timeout_ms (in seconds) or global fallback 

3834 # timeout_ms is stored in milliseconds, convert to seconds 

3835 tool_timeout_ms = tool_payload.get("timeout_ms") 

3836 effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout 

3837 

3838 # Save gateway existence as local boolean BEFORE db.close() 

3839 # to avoid checking ORM object truthiness after session is closed 

3840 has_gateway = gateway_payload is not None 

3841 gateway_url = gateway_payload.get("url") if has_gateway else None 

3842 gateway_name = gateway_payload.get("name") if has_gateway else None 

3843 gateway_auth_type = gateway_payload.get("auth_type") if has_gateway else None 

3844 gateway_auth_value = gateway_payload.get("auth_value") if has_gateway and isinstance(gateway_payload.get("auth_value"), str) else None 

3845 gateway_auth_query_params = gateway_payload.get("auth_query_params") if has_gateway and isinstance(gateway_payload.get("auth_query_params"), dict) else None 

3846 gateway_oauth_config = gateway_payload.get("oauth_config") if has_gateway and isinstance(gateway_payload.get("oauth_config"), dict) else None 

3847 if has_gateway and gateway is not None: 

3848 runtime_gateway_auth_value = getattr(gateway, "auth_value", None) 

3849 if isinstance(runtime_gateway_auth_value, dict): 

3850 gateway_auth_value = encode_auth(runtime_gateway_auth_value) 

3851 elif isinstance(runtime_gateway_auth_value, str): 

3852 gateway_auth_value = runtime_gateway_auth_value 

3853 runtime_gateway_query_params = getattr(gateway, "auth_query_params", None) 

3854 if isinstance(runtime_gateway_query_params, dict): 

3855 gateway_auth_query_params = runtime_gateway_query_params 

3856 runtime_gateway_oauth_config = getattr(gateway, "oauth_config", None) 

3857 if isinstance(runtime_gateway_oauth_config, dict): 

3858 gateway_oauth_config = runtime_gateway_oauth_config 

3859 gateway_ca_cert = gateway_payload.get("ca_certificate") if has_gateway else None 

3860 gateway_ca_cert_sig = gateway_payload.get("ca_certificate_sig") if has_gateway else None 

3861 gateway_passthrough = gateway_payload.get("passthrough_headers") if has_gateway else None 

3862 gateway_id_str = gateway_payload.get("id") if has_gateway else None 

3863 

3864 # Cache payload intentionally excludes sensitive auth material. For cache hits 

3865 # (tool is None), hydrate auth-related fields from DB only when needed. 

3866 if tool is None: 

3867 requires_tool_auth_hydration = tool_auth_type in {"basic", "bearer", "authheaders", "oauth"} 

3868 requires_gateway_auth_hydration = has_gateway and gateway_auth_type in {"basic", "bearer", "authheaders", "oauth", "query_param"} 

3869 if requires_tool_auth_hydration or requires_gateway_auth_hydration: 

3870 tool_id_for_hydration = tool_payload.get("id") 

3871 if tool_id_for_hydration: 

3872 tool_auth_row = db.execute(select(DbTool).options(joinedload(DbTool.gateway)).where(DbTool.id == tool_id_for_hydration)).scalar_one_or_none() 

3873 if tool_auth_row: 

3874 hydrated_tool_auth_value = getattr(tool_auth_row, "auth_value", None) 

3875 if isinstance(hydrated_tool_auth_value, str): 

3876 tool_auth_value = hydrated_tool_auth_value 

3877 hydrated_tool_oauth_config = getattr(tool_auth_row, "oauth_config", None) 

3878 if isinstance(hydrated_tool_oauth_config, dict): 

3879 tool_oauth_config = hydrated_tool_oauth_config 

3880 if has_gateway and tool_auth_row.gateway: 

3881 hydrated_gateway_auth_value = getattr(tool_auth_row.gateway, "auth_value", None) 

3882 if isinstance(hydrated_gateway_auth_value, dict): 

3883 gateway_auth_value = encode_auth(hydrated_gateway_auth_value) 

3884 elif isinstance(hydrated_gateway_auth_value, str): 

3885 gateway_auth_value = hydrated_gateway_auth_value 

3886 hydrated_gateway_query_params = getattr(tool_auth_row.gateway, "auth_query_params", None) 

3887 if isinstance(hydrated_gateway_query_params, dict): 

3888 gateway_auth_query_params = hydrated_gateway_query_params 

3889 hydrated_gateway_oauth_config = getattr(tool_auth_row.gateway, "oauth_config", None) 

3890 if isinstance(hydrated_gateway_oauth_config, dict): 

3891 gateway_oauth_config = hydrated_gateway_oauth_config 

3892 

3893 # Decrypt and apply query param auth to URL if applicable 

3894 gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None 

3895 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

3896 # Decrypt the query param values 

3897 gateway_auth_query_params_decrypted = {} 

3898 for param_key, encrypted_value in gateway_auth_query_params.items(): 

3899 if encrypted_value: 

3900 try: 

3901 decrypted = decode_auth(encrypted_value) 

3902 gateway_auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

3903 except Exception: # noqa: S110 - intentionally skip failed decryptions 

3904 # Silently skip params that fail decryption (may be corrupted or use old key) 

3905 logger.debug(f"Failed to decrypt query param '{param_key}' for tool invocation") 

3906 # Apply query params to gateway URL 

3907 if gateway_auth_query_params_decrypted and gateway_url: 

3908 gateway_url = apply_query_param_auth(gateway_url, gateway_auth_query_params_decrypted) 

3909 

3910 # Create Pydantic models for plugins BEFORE HTTP calls (use ORM objects while still valid) 

3911 # This prevents lazy loading during HTTP calls 

3912 tool_metadata: Optional[PydanticTool] = None 

3913 gateway_metadata: Optional[PydanticGateway] = None 

3914 if self._plugin_manager: 

3915 if tool is not None: 

3916 tool_metadata = PydanticTool.model_validate(tool) 

3917 if has_gateway and gateway is not None: 

3918 gateway_metadata = PydanticGateway.model_validate(gateway) 

3919 else: 

3920 tool_metadata = self._pydantic_tool_from_payload(tool_payload) 

3921 if has_gateway and gateway_payload: 

3922 gateway_metadata = self._pydantic_gateway_from_payload(gateway_payload) 

3923 

3924 tool_for_validation = tool if tool is not None else SimpleNamespace(output_schema=tool_output_schema, name=tool_name_computed) 

3925 

3926 # ═══════════════════════════════════════════════════════════════════════════ 

3927 # A2A Agent Data Extraction (must happen before db.close()) 

3928 # Extract all A2A agent data to local variables so HTTP call can happen after db.close() 

3929 # ═══════════════════════════════════════════════════════════════════════════ 

3930 a2a_agent_name: Optional[str] = None 

3931 a2a_agent_endpoint_url: Optional[str] = None 

3932 a2a_agent_type: Optional[str] = None 

3933 a2a_agent_protocol_version: Optional[str] = None 

3934 a2a_agent_auth_type: Optional[str] = None 

3935 a2a_agent_auth_value: Optional[str] = None 

3936 a2a_agent_auth_query_params: Optional[Dict[str, str]] = None 

3937 

3938 if tool_integration_type == "A2A" and "a2a_agent_id" in tool_annotations: 

3939 a2a_agent_id = tool_annotations.get("a2a_agent_id") 

3940 if not a2a_agent_id: 

3941 raise ToolNotFoundError(f"A2A tool '{name}' missing agent ID in annotations") 

3942 

3943 # Query for the A2A agent 

3944 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == a2a_agent_id) 

3945 a2a_agent = db.execute(agent_query).scalar_one_or_none() 

3946 

3947 if not a2a_agent: 

3948 raise ToolNotFoundError(f"A2A agent not found for tool '{name}' (agent ID: {a2a_agent_id})") 

3949 

3950 if not a2a_agent.enabled: 

3951 raise ToolNotFoundError(f"A2A agent '{a2a_agent.name}' is disabled") 

3952 

3953 # Extract all needed data to local variables before db.close() 

3954 a2a_agent_name = a2a_agent.name 

3955 a2a_agent_endpoint_url = a2a_agent.endpoint_url 

3956 a2a_agent_type = a2a_agent.agent_type 

3957 a2a_agent_protocol_version = a2a_agent.protocol_version 

3958 a2a_agent_auth_type = a2a_agent.auth_type 

3959 a2a_agent_auth_value = a2a_agent.auth_value 

3960 a2a_agent_auth_query_params = a2a_agent.auth_query_params 

3961 

3962 # ═══════════════════════════════════════════════════════════════════════════ 

3963 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls 

3964 # This prevents connection pool exhaustion during slow upstream requests. 

3965 # All needed data has been extracted to local variables above. 

3966 # The session will be closed again by FastAPI's get_db() finally block (safe no-op). 

3967 # ═══════════════════════════════════════════════════════════════════════════ 

3968 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats) 

3969 db.close() 

3970 

3971 # Plugin hook: tool pre-invoke 

3972 # Use existing context_table from previous hooks if available 

3973 context_table = plugin_context_table 

3974 

3975 # Reuse existing global_context from middleware or create new one 

3976 # IMPORTANT: Use local variables (tool_gateway_id) instead of ORM object access 

3977 if plugin_global_context: 

3978 global_context = plugin_global_context 

3979 # Update server_id using local variable (not ORM access) 

3980 if tool_gateway_id and isinstance(tool_gateway_id, str): 

3981 global_context.server_id = tool_gateway_id 

3982 # Propagate user email to global context for plugin access 

3983 if not plugin_global_context.user and app_user_email and isinstance(app_user_email, str): 

3984 global_context.user = app_user_email 

3985 else: 

3986 # Create new context (fallback when middleware didn't run) 

3987 # Use correlation ID from context if available, otherwise generate new one 

3988 request_id = get_correlation_id() or uuid.uuid4().hex 

3989 context_server_id = tool_gateway_id if tool_gateway_id and isinstance(tool_gateway_id, str) else "unknown" 

3990 global_context = GlobalContext(request_id=request_id, server_id=context_server_id, tenant_id=None, user=app_user_email) 

3991 

3992 start_time = time.monotonic() 

3993 success = False 

3994 error_message = None 

3995 tool_result: Optional[ToolResult] = None 

3996 tool_team_scope = format_trace_team_scope(token_teams) 

3997 

3998 # Get trace_id from context for database span creation 

3999 trace_id = current_trace_id.get() 

4000 db_span_id = None 

4001 db_span_ended = False 

4002 observability_service = ObservabilityService() if trace_id else None 

4003 

4004 # Create database span for observability_spans table 

4005 if trace_id and observability_service: 

4006 try: 

4007 # Re-open database session for span creation (original was closed at line 2285) 

4008 # Use commit=False since fresh_db_session() handles commits on exit 

4009 with fresh_db_session() as span_db: 

4010 db_span_id = observability_service.start_span( 

4011 db=span_db, 

4012 trace_id=trace_id, 

4013 name="tool.invoke", 

4014 kind="client", 

4015 resource_type="tool", 

4016 resource_name=name, 

4017 resource_id=tool_id, 

4018 attributes={ 

4019 "tool.name": name, 

4020 "tool.id": tool_id, 

4021 "tool.integration_type": tool_integration_type, 

4022 "tool.gateway_id": tool_gateway_id, 

4023 "arguments_count": len(arguments) if arguments else 0, 

4024 "has_headers": bool(request_headers), 

4025 }, 

4026 commit=False, 

4027 ) 

4028 logger.debug(f"✓ Created tool.invoke span: {db_span_id} for tool: {name}") 

4029 except Exception as e: 

4030 logger.warning(f"Failed to start observability span for tool invocation: {e}") 

4031 db_span_id = None 

4032 

4033 # Create a trace span for OpenTelemetry export (Jaeger, Zipkin, etc.) 

4034 span_attributes = { 

4035 "tool.name": name, 

4036 "tool.id": tool_id, 

4037 "tool.integration_type": tool_integration_type, 

4038 "tool.gateway_id": tool_gateway_id, 

4039 "arguments_count": len(arguments) if arguments else 0, 

4040 "has_headers": bool(request_headers), 

4041 "user.email": user_email or app_user_email or "anonymous", 

4042 "team.scope": tool_team_scope, 

4043 "server_id": server_id, 

4044 } 

4045 if is_input_capture_enabled("tool.invoke"): 

4046 span_attributes["langfuse.observation.input"] = serialize_trace_payload(arguments or {}) 

4047 

4048 with create_span("tool.invoke", span_attributes) as span: 

4049 try: 

4050 # Create a lightweight lookup child span so Langfuse shows the invoke breakdown. 

4051 with create_child_span( 

4052 "tool.lookup", 

4053 { 

4054 "tool.name": name, 

4055 "tool.id": tool_id, 

4056 "tool.integration_type": tool_integration_type, 

4057 }, 

4058 ): 

4059 headers = tool_headers.copy() 

4060 if tool_integration_type == "REST": 

4061 # Handle OAuth authentication for REST tools 

4062 if tool_auth_type == "oauth" and isinstance(tool_oauth_config, dict) and tool_oauth_config: 

4063 try: 

4064 access_token = await self.oauth_manager.get_access_token(tool_oauth_config) 

4065 headers["Authorization"] = f"Bearer {access_token}" 

4066 except Exception as e: 

4067 logger.error(f"Failed to obtain OAuth access token for tool {tool_name_computed}: {e}") 

4068 raise ToolInvocationError(f"OAuth authentication failed: {str(e)}") 

4069 else: 

4070 credentials = decode_auth(tool_auth_value) if tool_auth_value else {} 

4071 # Filter out empty header names/values to avoid "Illegal header name" errors 

4072 filtered_credentials = {k: v for k, v in credentials.items() if k and v} 

4073 headers.update(filtered_credentials) 

4074 

4075 # Use cached passthrough headers (no DB query needed) 

4076 if request_headers: 

4077 headers = compute_passthrough_headers_cached( 

4078 request_headers, 

4079 headers, 

4080 passthrough_allowed, 

4081 gateway_auth_type=None, 

4082 gateway_passthrough_headers=None, # REST tools don't use gateway auth here 

4083 ) 

4084 # Read MCP-Session-Id from downstream client (MCP protocol header) 

4085 # and normalize to x-mcp-session-id for our internal session affinity logic 

4086 # The pool will strip this before sending to upstream 

4087 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests) 

4088 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} 

4089 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id") 

4090 if mcp_session_id: 

4091 headers["x-mcp-session-id"] = mcp_session_id 

4092 

4093 worker_id = str(os.getpid()) 

4094 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

4095 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity") 

4096 

4097 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke: 

4098 # Use pre-created Pydantic model from Phase 2 (no ORM access) 

4099 if tool_metadata: 

4100 global_context.metadata[TOOL_METADATA] = tool_metadata 

4101 pre_result, context_table = await self._plugin_manager.invoke_hook( 

4102 ToolHookType.TOOL_PRE_INVOKE, 

4103 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), 

4104 global_context=global_context, 

4105 local_contexts=context_table, # Pass context from previous hooks 

4106 violations_as_exceptions=True, 

4107 ) 

4108 if pre_result.modified_payload: 

4109 payload = pre_result.modified_payload 

4110 name = payload.name 

4111 arguments = payload.args 

4112 if payload.headers is not None: 

4113 headers = payload.headers.model_dump() 

4114 

4115 # Build the payload based on integration type 

4116 payload = arguments.copy() 

4117 

4118 # Handle URL path parameter substitution (using local variable) 

4119 final_url = tool_url 

4120 if "{" in tool_url and "}" in tool_url: 

4121 # Extract path parameters from URL template and arguments 

4122 url_params = re.findall(r"\{(\w+)\}", tool_url) 

4123 url_substitutions = {} 

4124 

4125 for param in url_params: 

4126 if param in payload: 

4127 url_substitutions[param] = payload.pop(param) # Remove from payload 

4128 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param])) 

4129 else: 

4130 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments") 

4131 

4132 # Use the tool's request_type rather than defaulting to POST (using local variable) 

4133 method = tool_request_type.upper() if tool_request_type else "POST" 

4134 with create_child_span("tool.gateway_call", {"tool.name": name, "tool.id": tool_id, "tool.integration_type": "REST"}): 

4135 rest_start_time = time.time() 

4136 try: 

4137 if method == "GET": 

4138 # For GET: Extract query params from URL and merge with input arguments 

4139 # URL query params take precedence over input arguments on conflicts 

4140 # to ensure stable behavior when tools have default params in their URL. 

4141 parsed = urlparse(final_url) 

4142 final_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}" 

4143 query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()} 

4144 

4145 conflicts = set(payload.keys()) & set(query_params.keys()) 

4146 if conflicts: 

4147 logger.warning( 

4148 f"REST tool GET request has conflicting parameters between URL and input arguments. " 

4149 f"URL query params will take precedence for: {', '.join(sorted(conflicts))}. " 

4150 f"Tool: {name}" 

4151 ) 

4152 

4153 payload.update(query_params) 

4154 response = await asyncio.wait_for(self._http_client.get(final_url, params=payload, headers=headers), timeout=effective_timeout) 

4155 else: 

4156 # For POST/PUT/PATCH/DELETE: Preserve query params in URL, only send input args in body. 

4157 # This is critical for signed URLs (Azure SAS, AWS presigned URLs, webhook signatures). 

4158 response = await asyncio.wait_for(self._http_client.request(method, final_url, json=payload, headers=headers), timeout=effective_timeout) 

4159 except (asyncio.TimeoutError, httpx.TimeoutException): 

4160 rest_elapsed_ms = (time.time() - rest_start_time) * 1000 

4161 structured_logger.log( 

4162 level="WARNING", 

4163 message=f"REST tool invocation timed out: {tool_name_computed}", 

4164 component="tool_service", 

4165 correlation_id=get_correlation_id(), 

4166 duration_ms=rest_elapsed_ms, 

4167 metadata={"event": "tool_timeout", "tool_name": tool_name_computed, "timeout_seconds": effective_timeout}, 

4168 ) 

4169 

4170 # Manually trigger circuit breaker (or other plugins) on timeout 

4171 try: 

4172 # First-Party 

4173 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel 

4174 

4175 tool_timeout_counter.labels(tool_name=name).inc() 

4176 except Exception as exc: 

4177 logger.debug( 

4178 "Failed to increment tool_timeout_counter for %s: %s", 

4179 name, 

4180 exc, 

4181 exc_info=True, 

4182 ) 

4183 if self._plugin_manager: 

4184 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table) 

4185 

4186 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") 

4187 try: 

4188 response.raise_for_status() 

4189 except httpx.HTTPStatusError: 

4190 # Non-2xx response — parse body (may be HTML, plain text, XML, etc.) 

4191 try: 

4192 result = response.json() 

4193 except (json.JSONDecodeError, orjson.JSONDecodeError, UnicodeDecodeError, AttributeError) as e: 

4194 result = _handle_json_parse_error(response, e, is_error_response=True) 

4195 if "error" in result: 

4196 error_val = result["error"] 

4197 elif "response_text" in result: 

4198 error_val = f"HTTP {response.status_code}: {result['response_text']}" 

4199 else: 

4200 error_val = f"HTTP {response.status_code}" 

4201 tool_result = ToolResult( 

4202 content=[TextContent(type="text", text=error_val if isinstance(error_val, str) else orjson.dumps(error_val).decode())], 

4203 is_error=True, 

4204 structured_content={"status_code": response.status_code}, 

4205 ) 

4206 # Don't mark as successful — success remains False 

4207 

4208 # Handle 204 No Content responses that have no body 

4209 if tool_result is not None and tool_result.is_error: 

4210 pass # Already handled by HTTPStatusError above 

4211 elif response.status_code == 204: 

4212 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")]) 

4213 success = True 

4214 elif response.status_code not in [200, 201, 202, 206]: 

4215 # Non-standard 2xx codes (203, 205, 207, etc.) treated as errors 

4216 try: 

4217 result = response.json() 

4218 except (json.JSONDecodeError, orjson.JSONDecodeError, UnicodeDecodeError, AttributeError) as e: 

4219 result = _handle_json_parse_error(response, e, is_error_response=True) 

4220 error_val = result["error"] if "error" in result else "Tool error encountered" 

4221 tool_result = ToolResult( 

4222 content=[TextContent(type="text", text=error_val if isinstance(error_val, str) else orjson.dumps(error_val).decode())], 

4223 is_error=True, 

4224 ) 

4225 # Don't mark as successful for error responses - success remains False 

4226 else: 

4227 try: 

4228 result = response.json() 

4229 except (json.JSONDecodeError, orjson.JSONDecodeError, UnicodeDecodeError, AttributeError) as e: 

4230 result = _handle_json_parse_error(response, e, is_error_response=False) 

4231 logger.debug(f"REST API tool response: {result}") 

4232 filtered_response = extract_using_jq(result, tool_jsonpath_filter) 

4233 # Check if extract_using_jq returned an error (list of TextContent objects) 

4234 if isinstance(filtered_response, list) and len(filtered_response) > 0 and isinstance(filtered_response[0], TextContent): 

4235 # Error case - use the TextContent directly 

4236 tool_result = ToolResult(content=filtered_response, is_error=True) 

4237 success = False 

4238 else: 

4239 # Success case - serialize the filtered response 

4240 serialized = orjson.dumps(filtered_response, option=orjson.OPT_INDENT_2) 

4241 tool_result = ToolResult(content=[TextContent(type="text", text=serialized.decode())]) 

4242 success = True 

4243 # If output schema is present, validate and attach structured content 

4244 if tool_output_schema: 

4245 valid = self._extract_and_validate_structured_content(tool_for_validation, tool_result, candidate=filtered_response) 

4246 success = bool(valid) 

4247 elif tool_integration_type == "MCP": 

4248 transport = tool_request_type.lower() if tool_request_type else "sse" 

4249 

4250 # Handle OAuth authentication for the gateway (using local variables) 

4251 # NOTE: Use has_gateway instead of gateway to avoid accessing detached ORM object 

4252 if has_gateway and gateway_auth_type == "oauth" and isinstance(gateway_oauth_config, dict) and gateway_oauth_config: 

4253 grant_type = gateway_oauth_config.get("grant_type", "client_credentials") 

4254 

4255 if grant_type == "authorization_code": 

4256 # For Authorization Code flow, try to get stored tokens 

4257 # NOTE: Use fresh_db_session() since the original db was closed 

4258 try: 

4259 # First-Party 

4260 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

4261 

4262 with fresh_db_session() as token_db: 

4263 token_storage = TokenStorageService(token_db) 

4264 

4265 # Get user-specific OAuth token 

4266 if not app_user_email: 

4267 raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.") 

4268 

4269 access_token = await token_storage.get_user_token(gateway_id_str, app_user_email) 

4270 

4271 if access_token: 

4272 headers = {"Authorization": f"Bearer {access_token}"} 

4273 else: 

4274 # User hasn't authorized this gateway yet 

4275 raise ToolInvocationError(f"Please authorize {gateway_name} first. Visit /oauth/authorize/{gateway_id_str} to complete OAuth flow.") 

4276 except Exception as e: 

4277 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}") 

4278 raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}") 

4279 else: 

4280 # For Client Credentials flow, get token directly (no DB needed) 

4281 try: 

4282 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config) 

4283 headers = {"Authorization": f"Bearer {access_token}"} 

4284 except Exception as e: 

4285 logger.error(f"Failed to obtain OAuth access token for gateway {gateway_name}: {e}") 

4286 raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") 

4287 else: 

4288 headers = decode_auth(gateway_auth_value) if gateway_auth_value else {} 

4289 

4290 # Use cached passthrough headers (no DB query needed) 

4291 if request_headers: 

4292 headers = compute_passthrough_headers_cached( 

4293 request_headers, headers, passthrough_allowed, gateway_auth_type=gateway_auth_type, gateway_passthrough_headers=gateway_passthrough 

4294 ) 

4295 # Read MCP-Session-Id from downstream client (MCP protocol header) 

4296 # and normalize to x-mcp-session-id for our internal session affinity logic 

4297 # The pool will strip this before sending to upstream 

4298 # Check both mcp-session-id (direct client) and x-mcp-session-id (forwarded requests) 

4299 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} 

4300 mcp_session_id = request_headers_lower.get("mcp-session-id") or request_headers_lower.get("x-mcp-session-id") 

4301 if mcp_session_id: 

4302 headers["x-mcp-session-id"] = mcp_session_id 

4303 

4304 worker_id = str(os.getpid()) 

4305 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

4306 logger.debug(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity (MCP transport)") 

4307 

4308 # mTLS client cert/key: resolve from payload, then override with runtime gateway if available 

4309 client_cert_from_payload = gateway_payload.get("client_cert") if has_gateway else None 

4310 client_key_from_payload = gateway_payload.get("client_key") if has_gateway else None 

4311 

4312 # Resolve client cert/key: payload values take precedence, runtime values override if present 

4313 gateway_client_cert = client_cert_from_payload 

4314 gateway_client_key = client_key_from_payload 

4315 if has_gateway and gateway is not None: 

4316 runtime_gateway_client_cert = getattr(gateway, "client_cert", None) 

4317 runtime_gateway_client_key = getattr(gateway, "client_key", None) 

4318 if runtime_gateway_client_cert: 

4319 gateway_client_cert = runtime_gateway_client_cert 

4320 if runtime_gateway_client_key: 

4321 gateway_client_key = runtime_gateway_client_key 

4322 

4323 # Decrypt client_key if stored encrypted 

4324 if gateway_client_key: 

4325 try: 

4326 # First-Party 

4327 from mcpgateway.services.encryption_service import get_encryption_service # pylint: disable=import-outside-toplevel 

4328 

4329 _enc = get_encryption_service(settings.auth_encryption_secret) 

4330 gateway_client_key = _enc.decrypt_secret_or_plaintext(gateway_client_key) 

4331 except Exception as _dec_exc: 

4332 logger.debug("client_key decryption skipped, using as-is: %s", _dec_exc) 

4333 

4334 def create_ssl_context( 

4335 ca_certificate: str, 

4336 client_cert: str | None = None, 

4337 client_key: str | None = None, 

4338 ) -> ssl.SSLContext: 

4339 """Create an SSL context with the provided CA certificate and optional mTLS credentials. 

4340 

4341 Uses caching to avoid repeated SSL context creation for the same certificate(s). 

4342 

4343 Args: 

4344 ca_certificate: CA certificate in PEM format 

4345 client_cert: Optional client cert path or PEM for mTLS 

4346 client_key: Optional client key path or PEM for mTLS 

4347 

4348 Returns: 

4349 ssl.SSLContext: Configured SSL context 

4350 """ 

4351 return get_cached_ssl_context(ca_certificate, client_cert=client_cert, client_key=client_key) 

4352 

4353 # Capture mTLS client cert/key values for passing to nested function 

4354 _client_cert_value = gateway_client_cert 

4355 _client_key_value = gateway_client_key 

4356 

4357 def get_httpx_client_factory( 

4358 headers: dict[str, str] | None = None, 

4359 timeout: httpx.Timeout | None = None, 

4360 auth: httpx.Auth | None = None, 

4361 ) -> httpx.AsyncClient: 

4362 """Factory function to create httpx.AsyncClient with optional CA certificate. 

4363 

4364 Args: 

4365 headers: Optional headers for the client 

4366 timeout: Optional timeout for the client 

4367 auth: Optional auth for the client 

4368 

4369 Returns: 

4370 httpx.AsyncClient: Configured HTTPX async client 

4371 

4372 Raises: 

4373 Exception: If CA certificate signature is invalid 

4374 """ 

4375 # Use captured client cert/key values from closure 

4376 client_cert_value = _client_cert_value 

4377 client_key_value = _client_key_value 

4378 # Use local variables instead of ORM objects (captured from outer scope) 

4379 valid = False 

4380 if gateway_ca_cert: 

4381 if settings.enable_ed25519_signing: 

4382 public_key_pem = settings.ed25519_public_key 

4383 valid = validate_signature(gateway_ca_cert.encode(), gateway_ca_cert_sig, public_key_pem) 

4384 else: 

4385 valid = True 

4386 # First-Party 

4387 from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout # pylint: disable=import-outside-toplevel 

4388 

4389 # For plain HTTP gateway URLs, skip SSL context entirely to avoid unnecessary SSL setup. 

4390 if gateway_url and gateway_url.lower().startswith("http://"): 

4391 ctx = None 

4392 elif valid and gateway_ca_cert: 

4393 ctx = create_ssl_context( 

4394 gateway_ca_cert, 

4395 client_cert=client_cert_value, 

4396 client_key=client_key_value, 

4397 ) 

4398 else: 

4399 ctx = None 

4400 

4401 # Use effective_timeout for read operations if not explicitly overridden by caller 

4402 # This ensures the underlying client waits at least as long as the tool configuration requires 

4403 factory_timeout = timeout if timeout else get_http_timeout(read_timeout=effective_timeout) 

4404 

4405 return httpx.AsyncClient( 

4406 verify=ctx if ctx else get_default_verify(), 

4407 follow_redirects=True, 

4408 headers=headers, 

4409 timeout=factory_timeout, 

4410 auth=auth, 

4411 limits=httpx.Limits( 

4412 max_connections=settings.httpx_max_connections, 

4413 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

4414 keepalive_expiry=settings.httpx_keepalive_expiry, 

4415 ), 

4416 ) 

4417 

4418 async def connect_to_sse_server(server_url: str, headers: dict = headers): 

4419 """Connect to an MCP server running with SSE transport. 

4420 

4421 Args: 

4422 server_url: MCP Server SSE URL 

4423 headers: HTTP headers to include in the request 

4424 

4425 Returns: 

4426 ToolResult: Result of tool call 

4427 

4428 Raises: 

4429 ToolInvocationError: If the tool invocation fails during execution. 

4430 ToolTimeoutError: If the tool invocation times out. 

4431 BaseException: On connection or communication errors 

4432 

4433 """ 

4434 # Get correlation ID for distributed tracing 

4435 correlation_id = get_correlation_id() 

4436 

4437 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions. 

4438 # MCP SDK pins headers at transport creation, so adding per-request headers 

4439 # would cause the first request's correlation ID to be reused for all 

4440 # subsequent requests on the same pooled session. Correlation IDs are 

4441 # still logged locally for tracing within the gateway. 

4442 

4443 # Log MCP call start (using local variables) 

4444 # Sanitize server_url to redact sensitive query params from logs 

4445 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted) 

4446 mcp_start_time = time.time() 

4447 structured_logger.log( 

4448 level="INFO", 

4449 message=f"MCP tool call started: {tool_name_original}", 

4450 component="tool_service", 

4451 correlation_id=correlation_id, 

4452 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "sse"}, 

4453 ) 

4454 

4455 try: 

4456 # Use session pool if enabled for 10-20x latency improvement 

4457 tool_call_result = None 

4458 use_pool = False 

4459 pool = None 

4460 if settings.mcp_session_pool_enabled: 

4461 try: 

4462 pool = get_mcp_session_pool() 

4463 use_pool = True 

4464 except RuntimeError: 

4465 # Pool not initialized (e.g., in tests), fall back to per-call sessions 

4466 pass 

4467 

4468 if use_pool and pool is not None: 

4469 # Pooled path: do NOT inject per-request trace headers 

4470 # The pool reuses transports with pinned headers, so injecting 

4471 # traceparent/X-Correlation-ID would cause the first request's 

4472 # trace context to be replayed on later unrelated requests, 

4473 # corrupting distributed traces and leaking correlation IDs. 

4474 # Trade-off: pooled sessions gain 10-20x latency improvement 

4475 # but lose distributed trace propagation to upstream servers. 

4476 async with pool.session( 

4477 url=server_url, 

4478 headers=headers, 

4479 transport_type=TransportType.SSE, 

4480 httpx_client_factory=get_httpx_client_factory, 

4481 user_identity=app_user_email, 

4482 gateway_id=gateway_id_str, 

4483 ) as pooled: 

4484 with anyio.fail_after(effective_timeout): 

4485 tool_call_result = await pooled.session.call_tool(tool_name_original, arguments, meta=meta_data) 

4486 else: 

4487 # Fallback to per-call sessions when pool disabled or not initialized 

4488 with create_span( 

4489 "mcp.client.call", 

4490 { 

4491 "mcp.tool.name": tool_name_original, 

4492 "contextforge.tool.id": tool_id, 

4493 "contextforge.gateway_id": tool_gateway_id, 

4494 "contextforge.runtime": "python", 

4495 "contextforge.transport": "sse", 

4496 "network.protocol.name": "mcp", 

4497 "server.address": urlparse(server_url).hostname, 

4498 "server.port": urlparse(server_url).port, 

4499 "url.path": urlparse(server_url).path or "/", 

4500 "url.full": server_url_sanitized, 

4501 }, 

4502 ): 

4503 # Non-pooled path: safe to add per-request headers. 

4504 # Inject within the active client span so an upstream service 

4505 # can attach beneath this span when it extracts traceparent. 

4506 request_headers = inject_trace_context_headers(headers) 

4507 if correlation_id and request_headers: 

4508 request_headers["X-Correlation-ID"] = correlation_id 

4509 async with sse_client(url=server_url, headers=request_headers, httpx_client_factory=get_httpx_client_factory) as streams: 

4510 async with ClientSession(*streams) as session: 

4511 with create_span("mcp.client.initialize", {"contextforge.transport": "sse", "contextforge.runtime": "python"}): 

4512 await session.initialize() 

4513 with create_span( 

4514 "mcp.client.request", 

4515 { 

4516 "mcp.tool.name": tool_name_original, 

4517 "contextforge.tool.id": tool_id, 

4518 "contextforge.gateway_id": tool_gateway_id, 

4519 "contextforge.runtime": "python", 

4520 }, 

4521 ): 

4522 with anyio.fail_after(effective_timeout): 

4523 tool_call_result = await session.call_tool(tool_name_original, arguments, meta=meta_data) 

4524 with create_span( 

4525 "mcp.client.response", 

4526 { 

4527 "mcp.tool.name": tool_name_original, 

4528 "contextforge.tool.id": tool_id, 

4529 "contextforge.gateway_id": tool_gateway_id, 

4530 "contextforge.runtime": "python", 

4531 "upstream.response.success": not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False), 

4532 }, 

4533 ): 

4534 pass 

4535 

4536 # Log successful MCP call 

4537 mcp_duration_ms = (time.time() - mcp_start_time) * 1000 

4538 structured_logger.log( 

4539 level="INFO", 

4540 message=f"MCP tool call completed: {tool_name_original}", 

4541 component="tool_service", 

4542 correlation_id=correlation_id, 

4543 duration_ms=mcp_duration_ms, 

4544 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "success": True}, 

4545 ) 

4546 

4547 return tool_call_result 

4548 except (asyncio.TimeoutError, httpx.TimeoutException): 

4549 # Handle timeout specifically - log and raise ToolInvocationError 

4550 mcp_duration_ms = (time.time() - mcp_start_time) * 1000 

4551 structured_logger.log( 

4552 level="WARNING", 

4553 message=f"MCP SSE tool invocation timed out: {tool_name_original}", 

4554 component="tool_service", 

4555 correlation_id=correlation_id, 

4556 duration_ms=mcp_duration_ms, 

4557 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "timeout_seconds": effective_timeout}, 

4558 ) 

4559 

4560 # Manually trigger circuit breaker (or other plugins) on timeout 

4561 try: 

4562 # First-Party 

4563 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel 

4564 

4565 tool_timeout_counter.labels(tool_name=name).inc() 

4566 except Exception as exc: 

4567 logger.debug( 

4568 "Failed to increment tool_timeout_counter for %s: %s", 

4569 name, 

4570 exc, 

4571 exc_info=True, 

4572 ) 

4573 

4574 if self._plugin_manager: 

4575 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table) 

4576 

4577 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") 

4578 except BaseException as e: 

4579 # Extract root cause from ExceptionGroup (Python 3.11+) 

4580 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup 

4581 root_cause = e 

4582 if isinstance(e, BaseExceptionGroup): 

4583 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions: 

4584 root_cause = root_cause.exceptions[0] 

4585 # Log failed MCP call (using local variables) 

4586 mcp_duration_ms = (time.time() - mcp_start_time) * 1000 

4587 # Sanitize error message to prevent URL secrets from leaking in logs 

4588 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted) 

4589 structured_logger.log( 

4590 level="ERROR", 

4591 message=f"MCP tool call failed: {tool_name_original}", 

4592 component="tool_service", 

4593 correlation_id=correlation_id, 

4594 duration_ms=mcp_duration_ms, 

4595 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error}, 

4596 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse"}, 

4597 ) 

4598 raise 

4599 

4600 async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers): 

4601 """Connect to an MCP server running with Streamable HTTP transport. 

4602 

4603 Args: 

4604 server_url: MCP Server URL 

4605 headers: HTTP headers to include in the request 

4606 

4607 Returns: 

4608 ToolResult: Result of tool call 

4609 

4610 Raises: 

4611 ToolInvocationError: If the tool invocation fails during execution. 

4612 ToolTimeoutError: If the tool invocation times out. 

4613 BaseException: On connection or communication errors 

4614 """ 

4615 # Get correlation ID for distributed tracing 

4616 correlation_id = get_correlation_id() 

4617 

4618 # NOTE: X-Correlation-ID is NOT added to headers for pooled sessions. 

4619 # MCP SDK pins headers at transport creation, so adding per-request headers 

4620 # would cause the first request's correlation ID to be reused for all 

4621 # subsequent requests on the same pooled session. Correlation IDs are 

4622 # still logged locally for tracing within the gateway. 

4623 

4624 # Log MCP call start (using local variables) 

4625 # Sanitize server_url to redact sensitive query params from logs 

4626 server_url_sanitized = sanitize_url_for_logging(server_url, gateway_auth_query_params_decrypted) 

4627 mcp_start_time = time.time() 

4628 structured_logger.log( 

4629 level="INFO", 

4630 message=f"MCP tool call started: {tool_name_original}", 

4631 component="tool_service", 

4632 correlation_id=correlation_id, 

4633 metadata={"event": "mcp_call_started", "tool_name": tool_name_original, "tool_id": tool_id, "server_url": server_url_sanitized, "transport": "streamablehttp"}, 

4634 ) 

4635 

4636 try: 

4637 # Use session pool if enabled for 10-20x latency improvement 

4638 tool_call_result = None 

4639 use_pool = False 

4640 pool = None 

4641 if settings.mcp_session_pool_enabled: 

4642 try: 

4643 pool = get_mcp_session_pool() 

4644 use_pool = True 

4645 except RuntimeError: 

4646 # Pool not initialized (e.g., in tests), fall back to per-call sessions 

4647 pass 

4648 

4649 if use_pool and pool is not None: 

4650 # Pooled path: do NOT inject per-request trace headers 

4651 # The pool reuses transports with pinned headers, so injecting 

4652 # traceparent/X-Correlation-ID would cause the first request's 

4653 # trace context to be replayed on later unrelated requests, 

4654 # corrupting distributed traces and leaking correlation IDs. 

4655 # Trade-off: pooled sessions gain 10-20x latency improvement 

4656 # but lose distributed trace propagation to upstream servers. 

4657 # Determine transport type based on current transport setting 

4658 pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP 

4659 async with pool.session( 

4660 url=server_url, 

4661 headers=headers, 

4662 transport_type=pool_transport_type, 

4663 httpx_client_factory=get_httpx_client_factory, 

4664 user_identity=app_user_email, 

4665 gateway_id=gateway_id_str, 

4666 ) as pooled: 

4667 with anyio.fail_after(effective_timeout): 

4668 tool_call_result = await pooled.session.call_tool(tool_name_original, arguments, meta=meta_data) 

4669 else: 

4670 # Fallback to per-call sessions when pool disabled or not initialized 

4671 with create_span( 

4672 "mcp.client.call", 

4673 { 

4674 "mcp.tool.name": tool_name_original, 

4675 "contextforge.tool.id": tool_id, 

4676 "contextforge.gateway_id": tool_gateway_id, 

4677 "contextforge.runtime": "python", 

4678 "contextforge.transport": "streamablehttp", 

4679 "network.protocol.name": "mcp", 

4680 "server.address": urlparse(server_url).hostname, 

4681 "server.port": urlparse(server_url).port, 

4682 "url.path": urlparse(server_url).path or "/", 

4683 "url.full": server_url_sanitized, 

4684 }, 

4685 ): 

4686 # Non-pooled path: safe to add per-request headers. 

4687 # Inject within the active client span so an upstream service 

4688 # can attach beneath this span when it extracts traceparent. 

4689 request_headers = inject_trace_context_headers(headers) 

4690 if correlation_id and request_headers: 

4691 request_headers["X-Correlation-ID"] = correlation_id 

4692 async with streamablehttp_client(url=server_url, headers=request_headers, httpx_client_factory=get_httpx_client_factory) as ( 

4693 read_stream, 

4694 write_stream, 

4695 _get_session_id, 

4696 ): 

4697 async with ClientSession(read_stream, write_stream) as session: 

4698 with create_span("mcp.client.initialize", {"contextforge.transport": "streamablehttp", "contextforge.runtime": "python"}): 

4699 await session.initialize() 

4700 with create_span( 

4701 "mcp.client.request", 

4702 { 

4703 "mcp.tool.name": tool_name_original, 

4704 "contextforge.tool.id": tool_id, 

4705 "contextforge.gateway_id": tool_gateway_id, 

4706 "contextforge.runtime": "python", 

4707 }, 

4708 ): 

4709 with anyio.fail_after(effective_timeout): 

4710 tool_call_result = await session.call_tool(tool_name_original, arguments, meta=meta_data) 

4711 with create_span( 

4712 "mcp.client.response", 

4713 { 

4714 "mcp.tool.name": tool_name_original, 

4715 "contextforge.tool.id": tool_id, 

4716 "contextforge.gateway_id": tool_gateway_id, 

4717 "contextforge.runtime": "python", 

4718 "upstream.response.success": not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False), 

4719 }, 

4720 ): 

4721 pass 

4722 

4723 # Log successful MCP call 

4724 mcp_duration_ms = (time.time() - mcp_start_time) * 1000 

4725 structured_logger.log( 

4726 level="INFO", 

4727 message=f"MCP tool call completed: {tool_name_original}", 

4728 component="tool_service", 

4729 correlation_id=correlation_id, 

4730 duration_ms=mcp_duration_ms, 

4731 metadata={"event": "mcp_call_completed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "success": True}, 

4732 ) 

4733 

4734 return tool_call_result 

4735 except (asyncio.TimeoutError, httpx.TimeoutException): 

4736 # Handle timeout specifically - log and raise ToolInvocationError 

4737 mcp_duration_ms = (time.time() - mcp_start_time) * 1000 

4738 structured_logger.log( 

4739 level="WARNING", 

4740 message=f"MCP StreamableHTTP tool invocation timed out: {tool_name_original}", 

4741 component="tool_service", 

4742 correlation_id=correlation_id, 

4743 duration_ms=mcp_duration_ms, 

4744 metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "timeout_seconds": effective_timeout}, 

4745 ) 

4746 

4747 # Manually trigger circuit breaker (or other plugins) on timeout 

4748 try: 

4749 # First-Party 

4750 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel 

4751 

4752 tool_timeout_counter.labels(tool_name=name).inc() 

4753 except Exception as exc: 

4754 logger.debug( 

4755 "Failed to increment tool_timeout_counter for %s: %s", 

4756 name, 

4757 exc, 

4758 exc_info=True, 

4759 ) 

4760 

4761 if self._plugin_manager: 

4762 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table) 

4763 

4764 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") 

4765 except BaseException as e: 

4766 # Extract root cause from ExceptionGroup (Python 3.11+) 

4767 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup 

4768 root_cause = e 

4769 if isinstance(e, BaseExceptionGroup): 

4770 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions: 

4771 root_cause = root_cause.exceptions[0] 

4772 # Log failed MCP call 

4773 mcp_duration_ms = (time.time() - mcp_start_time) * 1000 

4774 # Sanitize error message to prevent URL secrets from leaking in logs 

4775 sanitized_error = sanitize_exception_message(str(root_cause), gateway_auth_query_params_decrypted) 

4776 structured_logger.log( 

4777 level="ERROR", 

4778 message=f"MCP tool call failed: {tool_name_original}", 

4779 component="tool_service", 

4780 correlation_id=correlation_id, 

4781 duration_ms=mcp_duration_ms, 

4782 error_details={"error_type": type(root_cause).__name__, "error_message": sanitized_error}, 

4783 metadata={"event": "mcp_call_failed", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp"}, 

4784 ) 

4785 raise 

4786 

4787 # REMOVED: Redundant gateway query - gateway already eager-loaded via joinedload 

4788 # tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id)...) 

4789 

4790 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke: 

4791 # Use pre-created Pydantic models from Phase 2 (no ORM access) 

4792 if tool_metadata: 

4793 global_context.metadata[TOOL_METADATA] = tool_metadata 

4794 if gateway_metadata: 

4795 global_context.metadata[GATEWAY_METADATA] = gateway_metadata 

4796 pre_result, context_table = await self._plugin_manager.invoke_hook( 

4797 ToolHookType.TOOL_PRE_INVOKE, 

4798 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), 

4799 global_context=global_context, 

4800 local_contexts=None, 

4801 violations_as_exceptions=True, 

4802 ) 

4803 if pre_result.modified_payload: 

4804 payload = pre_result.modified_payload 

4805 name = payload.name 

4806 arguments = payload.args 

4807 if payload.headers is not None: 

4808 headers = payload.headers.model_dump() 

4809 

4810 with create_child_span("tool.gateway_call", {"tool.name": name, "tool.id": tool_id, "tool.integration_type": "MCP"}): 

4811 tool_call_result = ToolResult(content=[TextContent(text="", type="text")]) 

4812 if transport == "sse": 

4813 tool_call_result = await connect_to_sse_server(gateway_url, headers=headers) 

4814 elif transport == "streamablehttp": 

4815 tool_call_result = await connect_to_streamablehttp_server(gateway_url, headers=headers) 

4816 

4817 # In direct proxy mode, use the tool result as-is without splitting content 

4818 if is_direct_proxy: 

4819 tool_result = tool_call_result 

4820 success = not getattr(tool_call_result, "is_error", False) and not getattr(tool_call_result, "isError", False) 

4821 logger.debug(f"Direct proxy mode: using tool result as-is: {tool_result}") 

4822 else: 

4823 dump = tool_call_result.model_dump(by_alias=True, mode="json") 

4824 logger.debug(f"Tool call result dump: {dump}") 

4825 content = dump.get("content", []) 

4826 # Accept both alias and pythonic names for structured content 

4827 structured = dump.get("structuredContent") or dump.get("structured_content") 

4828 filtered_response = extract_using_jq(content, tool_jsonpath_filter) 

4829 

4830 is_err = getattr(tool_call_result, "is_error", None) 

4831 if is_err is None: 

4832 is_err = getattr(tool_call_result, "isError", False) 

4833 tool_result = ToolResult(content=filtered_response, structured_content=structured, is_error=is_err, meta=getattr(tool_call_result, "meta", None)) 

4834 success = not is_err 

4835 logger.debug(f"Final tool_result: {tool_result}") 

4836 

4837 elif tool_integration_type == "A2A" and a2a_agent_endpoint_url: 

4838 # A2A tool invocation using pre-extracted agent data (extracted in Phase 2 before db.close()) 

4839 headers = {"Content-Type": "application/json"} 

4840 

4841 # Plugin hook: tool pre-invoke for A2A 

4842 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE) and not skip_pre_invoke: 

4843 if tool_metadata: 

4844 global_context.metadata[TOOL_METADATA] = tool_metadata 

4845 pre_result, context_table = await self._plugin_manager.invoke_hook( 

4846 ToolHookType.TOOL_PRE_INVOKE, 

4847 payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), 

4848 global_context=global_context, 

4849 local_contexts=context_table, 

4850 violations_as_exceptions=True, 

4851 ) 

4852 if pre_result.modified_payload: 

4853 payload = pre_result.modified_payload 

4854 name = payload.name 

4855 arguments = payload.args 

4856 if payload.headers is not None: 

4857 headers = payload.headers.model_dump() 

4858 

4859 # Build request data based on agent type 

4860 endpoint_url = a2a_agent_endpoint_url 

4861 if a2a_agent_type in ["generic", "jsonrpc"] or endpoint_url.endswith("/"): 

4862 # JSONRPC agents: Convert flat query to nested message structure 

4863 params = None 

4864 if isinstance(arguments, dict) and "query" in arguments and isinstance(arguments["query"], str): 

4865 message_id = f"admin-test-{int(time.time())}" 

4866 # A2A v0.3.x: message.parts use "kind" (not "type"). 

4867 params = { 

4868 "message": { 

4869 "kind": "message", 

4870 "messageId": message_id, 

4871 "role": "user", 

4872 "parts": [{"kind": "text", "text": arguments["query"]}], 

4873 } 

4874 } 

4875 method = arguments.get("method", "message/send") 

4876 else: 

4877 params = arguments.get("params", arguments) if isinstance(arguments, dict) else arguments 

4878 method = arguments.get("method", "message/send") if isinstance(arguments, dict) else "message/send" 

4879 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1} 

4880 else: 

4881 # Custom agents: Pass parameters directly 

4882 params = arguments if isinstance(arguments, dict) else {} 

4883 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": a2a_agent_protocol_version} 

4884 

4885 # Add authentication 

4886 if a2a_agent_auth_type in ("api_key", "basic", "bearer", "authheaders") and a2a_agent_auth_value: 

4887 # Decrypt auth_value before using it 

4888 if isinstance(a2a_agent_auth_value, str): 

4889 try: 

4890 auth_headers = decode_auth(a2a_agent_auth_value) 

4891 headers.update(auth_headers) 

4892 except Exception as e: 

4893 logger.error(f"Failed to decrypt authentication for A2A agent '{a2a_agent_name}': {e}") 

4894 raise ToolInvocationError(f"Failed to decrypt authentication for A2A agent '{a2a_agent_name}'") 

4895 elif isinstance(a2a_agent_auth_value, dict): 

4896 auth_headers = {str(k): str(v) for k, v in a2a_agent_auth_value.items()} 

4897 headers.update(auth_headers) 

4898 elif a2a_agent_auth_type == "query_param" and a2a_agent_auth_query_params: 

4899 auth_query_params_decrypted: dict[str, str] = {} 

4900 for param_key, encrypted_value in a2a_agent_auth_query_params.items(): 

4901 if encrypted_value: 

4902 try: 

4903 decrypted = decode_auth(encrypted_value) 

4904 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

4905 except Exception: 

4906 logger.debug(f"Failed to decrypt query param for key '{param_key}'") 

4907 if auth_query_params_decrypted: 

4908 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted) 

4909 

4910 with create_child_span("tool.gateway_call", {"tool.name": name, "tool.id": tool_id, "tool.integration_type": "A2A"}): 

4911 # Make HTTP request with timeout enforcement 

4912 logger.info(f"Calling A2A agent '{a2a_agent_name}' at {endpoint_url}") 

4913 a2a_start_time = time.time() 

4914 try: 

4915 http_response = await asyncio.wait_for(self._http_client.post(endpoint_url, json=request_data, headers=headers), timeout=effective_timeout) 

4916 except (asyncio.TimeoutError, httpx.TimeoutException): 

4917 a2a_elapsed_ms = (time.time() - a2a_start_time) * 1000 

4918 structured_logger.log( 

4919 level="WARNING", 

4920 message=f"A2A tool invocation timed out: {name}", 

4921 component="tool_service", 

4922 correlation_id=get_correlation_id(), 

4923 duration_ms=a2a_elapsed_ms, 

4924 metadata={"event": "tool_timeout", "tool_name": name, "a2a_agent": a2a_agent_name, "timeout_seconds": effective_timeout}, 

4925 ) 

4926 

4927 # Increment timeout counter 

4928 try: 

4929 # First-Party 

4930 from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel 

4931 

4932 tool_timeout_counter.labels(tool_name=name).inc() 

4933 except Exception as exc: 

4934 logger.debug("Failed to increment tool_timeout_counter for %s: %s", name, exc, exc_info=True) 

4935 

4936 # Trigger circuit breaker on timeout 

4937 if self._plugin_manager: 

4938 await self._run_timeout_post_invoke(name, effective_timeout, global_context, context_table) 

4939 

4940 raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") 

4941 

4942 if http_response.status_code == 200: 

4943 response_data = http_response.json() 

4944 if isinstance(response_data, dict) and "response" in response_data: 

4945 val = response_data["response"] 

4946 content = [TextContent(type="text", text=val if isinstance(val, str) else orjson.dumps(val).decode())] 

4947 else: 

4948 content = [TextContent(type="text", text=response_data if isinstance(response_data, str) else orjson.dumps(response_data).decode())] 

4949 tool_result = ToolResult(content=content, is_error=False) 

4950 success = True 

4951 else: 

4952 error_message = f"HTTP {http_response.status_code}: {http_response.text}" 

4953 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")] 

4954 tool_result = ToolResult(content=content, is_error=True) 

4955 else: 

4956 tool_result = ToolResult(content=[TextContent(type="text", text="Invalid tool type")], is_error=True) 

4957 

4958 with create_child_span("tool.post_process", {"tool.name": name, "tool.id": tool_id}): 

4959 post_result = None 

4960 # Plugin hook: tool post-invoke 

4961 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 

4962 post_result, _ = await self._plugin_manager.invoke_hook( 

4963 ToolHookType.TOOL_POST_INVOKE, 

4964 payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), 

4965 global_context=global_context, 

4966 local_contexts=context_table, 

4967 violations_as_exceptions=True, 

4968 ) 

4969 # Use modified payload if provided 

4970 if post_result.modified_payload: 

4971 # Reconstruct ToolResult from modified result 

4972 modified_result = post_result.modified_payload.result 

4973 if isinstance(modified_result, dict) and "content" in modified_result: 

4974 # Safely obtain structured content using .get() to avoid KeyError when 

4975 # plugins provide only the content without structured content fields. 

4976 structured = modified_result.get("structuredContent") if "structuredContent" in modified_result else modified_result.get("structured_content") 

4977 

4978 tool_result = ToolResult(content=modified_result["content"], structured_content=structured) 

4979 else: 

4980 # If result is not in expected format, convert it to text content 

4981 try: 

4982 tool_result = ToolResult(content=[TextContent(type="text", text=modified_result if isinstance(modified_result, str) else orjson.dumps(modified_result).decode())]) 

4983 except Exception: 

4984 tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))]) 

4985 

4986 # Retry: if the plugin requested a delayed retry and we haven't hit the gateway ceiling. 

4987 # retry_attempt is 0-based (0 = original call). The condition allows retry_attempt 

4988 # values 0..max_tool_retries-1, meaning up to max_tool_retries *retry* attempts on 

4989 # top of the original call (total attempts = max_tool_retries + 1). 

4990 if post_result is not None and post_result.retry_delay_ms > 0 and retry_attempt < settings.max_tool_retries: 

4991 return await self._retry_tool_invocation( 

4992 post_result.retry_delay_ms, 

4993 retry_attempt, 

4994 name, 

4995 arguments, 

4996 request_headers, 

4997 app_user_email, 

4998 user_email, 

4999 token_teams, 

5000 server_id, 

5001 context_table, 

5002 global_context, 

5003 meta_data, 

5004 skip_pre_invoke, 

5005 "success", 

5006 ) 

5007 

5008 return tool_result 

5009 except (PluginError, PluginViolationError): 

5010 raise 

5011 except ToolTimeoutError as e: 

5012 # ToolTimeoutError is raised by timeout handlers which already called tool_post_invoke. 

5013 # Do NOT call post_invoke again — the retry_delay_ms signal is carried on the exception. 

5014 error_message = str(e) 

5015 if span: 

5016 set_span_error(span, error_message) 

5017 

5018 # Retry if the post-invoke hook (called by the timeout handler) requested it. 

5019 if e.retry_delay_ms > 0 and retry_attempt < settings.max_tool_retries: 

5020 return await self._retry_tool_invocation( 

5021 e.retry_delay_ms, 

5022 retry_attempt, 

5023 name, 

5024 arguments, 

5025 request_headers, 

5026 app_user_email, 

5027 user_email, 

5028 token_teams, 

5029 server_id, 

5030 context_table, 

5031 global_context, 

5032 meta_data, 

5033 skip_pre_invoke, 

5034 "timeout", 

5035 ) 

5036 raise 

5037 except BaseException as e: 

5038 # Extract root cause from ExceptionGroup (Python 3.11+) 

5039 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup 

5040 root_cause = e 

5041 if isinstance(e, BaseExceptionGroup): 

5042 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions: 

5043 root_cause = root_cause.exceptions[0] 

5044 error_message = str(root_cause) 

5045 # Set span error status 

5046 if span: 

5047 set_span_error(span, error_message) 

5048 

5049 # Notify plugins of the failure so circuit breaker / retry plugin can track it. 

5050 # Capture the result so we can honour a retry_delay_ms signal from the retry plugin. 

5051 # When the exception carries an HTTP status code (e.g. httpx.HTTPStatusError), 

5052 # include it in structuredContent so the retry plugin can honour retry_on_status 

5053 # instead of blindly retrying every exception. 

5054 exc_post_result = None 

5055 if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): 

5056 try: 

5057 exc_structured: Optional[Dict[str, Any]] = None 

5058 if isinstance(root_cause, httpx.HTTPStatusError): 

5059 exc_structured = {"status_code": root_cause.response.status_code} 

5060 exception_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation failed: {error_message}")], is_error=True, structured_content=exc_structured) 

5061 exc_post_result, _ = await self._plugin_manager.invoke_hook( 

5062 ToolHookType.TOOL_POST_INVOKE, 

5063 payload=ToolPostInvokePayload(name=name, result=exception_error_result.model_dump(by_alias=True)), 

5064 global_context=global_context, 

5065 local_contexts=context_table, 

5066 violations_as_exceptions=False, # Don't let plugin errors mask the original exception 

5067 ) 

5068 except Exception as plugin_exc: 

5069 logger.debug("Failed to invoke post-invoke plugins on exception: %s", plugin_exc) 

5070 

5071 # Retry if the plugin requested a delayed retry and we haven't hit the ceiling. 

5072 # Same counting convention as the success path: retry_attempt is 0-based, 

5073 # so this allows up to max_tool_retries retry attempts beyond the original call. 

5074 if exc_post_result is not None and exc_post_result.retry_delay_ms > 0 and retry_attempt < settings.max_tool_retries: 

5075 return await self._retry_tool_invocation( 

5076 exc_post_result.retry_delay_ms, 

5077 retry_attempt, 

5078 name, 

5079 arguments, 

5080 request_headers, 

5081 app_user_email, 

5082 user_email, 

5083 token_teams, 

5084 server_id, 

5085 context_table, 

5086 global_context, 

5087 meta_data, 

5088 skip_pre_invoke, 

5089 "exception", 

5090 ) 

5091 

5092 raise ToolInvocationError(f"Tool invocation failed: {error_message}") 

5093 finally: 

5094 # Calculate duration 

5095 duration_ms = (time.monotonic() - start_time) * 1000 

5096 

5097 # End database span for observability_spans table 

5098 # Use commit=False since fresh_db_session() handles commits on exit 

5099 if db_span_id and observability_service and not db_span_ended: 

5100 try: 

5101 with fresh_db_session() as span_db: 

5102 observability_service.end_span( 

5103 db=span_db, 

5104 span_id=db_span_id, 

5105 status="ok" if success else "error", 

5106 status_message=error_message if error_message else None, 

5107 attributes={ 

5108 "success": success, 

5109 "duration_ms": duration_ms, 

5110 }, 

5111 commit=False, 

5112 ) 

5113 db_span_ended = True 

5114 logger.debug(f"✓ Ended tool.invoke span: {db_span_id}") 

5115 except Exception as e: 

5116 logger.warning(f"Failed to end observability span for tool invocation: {e}") 

5117 

5118 # Add final span attributes for OpenTelemetry 

5119 if span: 

5120 set_span_attribute(span, "success", success) 

5121 set_span_attribute(span, "duration.ms", duration_ms) 

5122 if success and tool_result and is_output_capture_enabled("tool.invoke"): 

5123 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload(tool_result)) 

5124 

5125 # ═══════════════════════════════════════════════════════════════════════════ 

5126 # PHASE 4: Record metrics via buffered service (batches writes for performance) 

5127 # ═══════════════════════════════════════════════════════════════════════════ 

5128 # Only record metrics if tool_id is valid (skip for direct_proxy mode) 

5129 if tool_id: 

5130 try: 

5131 metrics_buffer.record_tool_metric( 

5132 tool_id=tool_id, 

5133 start_time=start_time, 

5134 success=success, 

5135 error_message=error_message, 

5136 ) 

5137 except Exception as metric_error: 

5138 logger.warning(f"Failed to record tool metric: {metric_error}") 

5139 

5140 # Record server metrics ONLY when invoked through a specific virtual server 

5141 # When server_id is provided, it means the tool was called via a virtual server endpoint 

5142 # Direct tool calls via /rpc should NOT populate server metrics 

5143 if tool_id and server_id: 

5144 try: 

5145 # Record server metric only for the specific virtual server being accessed 

5146 metrics_buffer.record_server_metric( 

5147 server_id=server_id, 

5148 start_time=start_time, 

5149 success=success, 

5150 error_message=error_message, 

5151 ) 

5152 except Exception as metric_error: 

5153 logger.warning(f"Failed to record server metric: {metric_error}") 

5154 

5155 # Log structured message with performance tracking (using local variables) 

5156 if success: 

5157 structured_logger.info( 

5158 f"Tool '{name}' invoked successfully", 

5159 user_id=app_user_email, 

5160 resource_type="tool", 

5161 resource_id=tool_id, 

5162 resource_action="invoke", 

5163 duration_ms=duration_ms, 

5164 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "arguments_count": len(arguments) if arguments else 0}, 

5165 ) 

5166 else: 

5167 structured_logger.error( 

5168 f"Tool '{name}' invocation failed", 

5169 error=Exception(error_message) if error_message else None, 

5170 user_id=app_user_email, 

5171 resource_type="tool", 

5172 resource_id=tool_id, 

5173 resource_action="invoke", 

5174 duration_ms=duration_ms, 

5175 custom_fields={"tool_name": name, "integration_type": tool_integration_type, "error_message": error_message}, 

5176 ) 

5177 

5178 # Track performance with threshold checking 

5179 with perf_tracker.track_operation("tool_invocation", name): 

5180 pass # Duration already captured above 

5181 

5182 @staticmethod 

5183 def _check_tool_name_conflict(db: Session, custom_name: str, visibility: str, tool_id: str, team_id: Optional[str] = None, owner_email: Optional[str] = None) -> None: 

5184 """Raise ToolNameConflictError if another tool with the same name exists in the target visibility scope. 

5185 

5186 Args: 

5187 db: The SQLAlchemy database session. 

5188 custom_name: The custom name to check for conflicts. 

5189 visibility: The target visibility scope (``public``, ``team``, or ``private``). 

5190 tool_id: The ID of the tool being updated (excluded from the conflict search). 

5191 team_id: Required when *visibility* is ``team``; scopes the uniqueness check to this team. 

5192 owner_email: Required when *visibility* is ``private``; scopes the uniqueness check to this owner. 

5193 

5194 Raises: 

5195 ToolNameConflictError: If a conflicting tool already exists in the target scope. 

5196 """ 

5197 if visibility == "public": 

5198 existing_tool = get_for_update( 

5199 db, 

5200 DbTool, 

5201 where=and_( 

5202 DbTool.custom_name == custom_name, 

5203 DbTool.visibility == "public", 

5204 DbTool.id != tool_id, 

5205 ), 

5206 ) 

5207 elif visibility == "team" and team_id: 

5208 existing_tool = get_for_update( 

5209 db, 

5210 DbTool, 

5211 where=and_( 

5212 DbTool.custom_name == custom_name, 

5213 DbTool.visibility == "team", 

5214 DbTool.team_id == team_id, 

5215 DbTool.id != tool_id, 

5216 ), 

5217 ) 

5218 elif visibility == "private" and owner_email: 

5219 existing_tool = get_for_update( 

5220 db, 

5221 DbTool, 

5222 where=and_( 

5223 DbTool.custom_name == custom_name, 

5224 DbTool.visibility == "private", 

5225 DbTool.owner_email == owner_email, 

5226 DbTool.id != tool_id, 

5227 ), 

5228 ) 

5229 else: 

5230 logger.warning("Skipping conflict check for tool %s: visibility=%r requires %s but none provided", tool_id, visibility, "team_id" if visibility == "team" else "owner_email") 

5231 return 

5232 if existing_tool: 

5233 raise ToolNameConflictError(existing_tool.custom_name, enabled=existing_tool.enabled, tool_id=existing_tool.id, visibility=existing_tool.visibility) 

5234 

5235 async def update_tool( 

5236 self, 

5237 db: Session, 

5238 tool_id: str, 

5239 tool_update: ToolUpdate, 

5240 modified_by: Optional[str] = None, 

5241 modified_from_ip: Optional[str] = None, 

5242 modified_via: Optional[str] = None, 

5243 modified_user_agent: Optional[str] = None, 

5244 user_email: Optional[str] = None, 

5245 ) -> ToolRead: 

5246 """ 

5247 Update an existing tool. 

5248 

5249 Args: 

5250 db (Session): The SQLAlchemy database session. 

5251 tool_id (str): The unique identifier of the tool. 

5252 tool_update (ToolUpdate): Tool update schema with new data. 

5253 modified_by (Optional[str]): Username who modified this tool. 

5254 modified_from_ip (Optional[str]): IP address of modifier. 

5255 modified_via (Optional[str]): Modification method (ui, api). 

5256 modified_user_agent (Optional[str]): User agent of modification request. 

5257 user_email (Optional[str]): Email of user performing update (for ownership check). 

5258 

5259 Returns: 

5260 The updated ToolRead object. 

5261 

5262 Raises: 

5263 ToolNotFoundError: If the tool is not found. 

5264 PermissionError: If user doesn't own the tool. 

5265 IntegrityError: If there is a database integrity error. 

5266 ToolNameConflictError: If a tool with the same name already exists. 

5267 ToolError: For other update errors. 

5268 

5269 Examples: 

5270 >>> from mcpgateway.services.tool_service import ToolService 

5271 >>> from unittest.mock import MagicMock, AsyncMock 

5272 >>> from mcpgateway.schemas import ToolRead 

5273 >>> service = ToolService() 

5274 >>> db = MagicMock() 

5275 >>> tool = MagicMock() 

5276 >>> db.get.return_value = tool 

5277 >>> db.commit = MagicMock() 

5278 >>> db.refresh = MagicMock() 

5279 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

5280 >>> service._notify_tool_updated = AsyncMock() 

5281 >>> service.convert_tool_to_read = MagicMock(return_value='tool_read') 

5282 >>> ToolRead.model_validate = MagicMock(return_value='tool_read') 

5283 >>> import asyncio 

5284 >>> asyncio.run(service.update_tool(db, 'tool_id', MagicMock())) 

5285 'tool_read' 

5286 """ 

5287 try: 

5288 tool = get_for_update(db, DbTool, tool_id) 

5289 

5290 if not tool: 

5291 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

5292 

5293 old_tool_name = tool.name 

5294 old_gateway_id = tool.gateway_id 

5295 

5296 # Check ownership if user_email provided 

5297 if user_email: 

5298 # First-Party 

5299 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

5300 

5301 permission_service = PermissionService(db) 

5302 if not await permission_service.check_resource_ownership(user_email, tool): 

5303 raise PermissionError("Only the owner can update this tool") 

5304 

5305 # Track whether a name change occurred (before tool.name is mutated) 

5306 name_is_changing = bool(tool_update.name and tool_update.name != tool.name) 

5307 

5308 # Check for name change and ensure uniqueness 

5309 if name_is_changing: 

5310 # Always derive ownership fields from the DB record — never trust client-provided team_id/owner_email 

5311 tool_visibility_ref = tool.visibility if tool_update.visibility is None else tool_update.visibility.lower() 

5312 if tool_update.custom_name is not None: 

5313 custom_name_ref = tool_update.custom_name 

5314 elif tool.name == tool.custom_name: 

5315 custom_name_ref = tool_update.name # custom_name will track the rename 

5316 else: 

5317 custom_name_ref = tool.custom_name # custom_name stays unchanged 

5318 self._check_tool_name_conflict(db, custom_name_ref, tool_visibility_ref, tool.id, team_id=tool.team_id, owner_email=tool.owner_email) 

5319 if tool_update.custom_name is None and tool.name == tool.custom_name: 

5320 tool.custom_name = tool_update.name 

5321 tool.name = tool_update.name 

5322 

5323 # Check for conflicts when visibility changes without a name change 

5324 if tool_update.visibility is not None and tool_update.visibility.lower() != tool.visibility and not name_is_changing: 

5325 new_visibility = tool_update.visibility.lower() 

5326 self._check_tool_name_conflict(db, tool.custom_name, new_visibility, tool.id, team_id=tool.team_id, owner_email=tool.owner_email) 

5327 

5328 if tool_update.custom_name is not None: 

5329 tool.custom_name = tool_update.custom_name 

5330 if tool_update.displayName is not None: 

5331 tool.display_name = tool_update.displayName 

5332 if tool_update.url is not None: 

5333 tool.url = str(tool_update.url) 

5334 if tool_update.description is not None: 

5335 tool.description = tool_update.description 

5336 if tool_update.title is not None: 

5337 tool.title = tool_update.title 

5338 if tool_update.integration_type is not None: 

5339 tool.integration_type = tool_update.integration_type 

5340 if tool_update.request_type is not None: 

5341 tool.request_type = tool_update.request_type 

5342 if tool_update.headers is not None: 

5343 tool.headers = _protect_tool_headers_for_storage(tool_update.headers, existing_headers=tool.headers) 

5344 if tool_update.input_schema is not None: 

5345 tool.input_schema = tool_update.input_schema 

5346 if tool_update.output_schema is not None: 

5347 tool.output_schema = tool_update.output_schema 

5348 if tool_update.annotations is not None: 

5349 tool.annotations = tool_update.annotations 

5350 if tool_update.jsonpath_filter is not None: 

5351 tool.jsonpath_filter = tool_update.jsonpath_filter 

5352 if tool_update.visibility is not None: 

5353 tool.visibility = tool_update.visibility 

5354 

5355 if tool_update.auth is not None: 

5356 if tool_update.auth.auth_type is not None: 

5357 tool.auth_type = tool_update.auth.auth_type 

5358 if tool_update.auth.auth_value is not None: 

5359 tool.auth_value = tool_update.auth.auth_value 

5360 

5361 # Update tags if provided 

5362 if tool_update.tags is not None: 

5363 tool.tags = tool_update.tags 

5364 

5365 # Update modification metadata 

5366 if modified_by is not None: 

5367 tool.modified_by = modified_by 

5368 if modified_from_ip is not None: 

5369 tool.modified_from_ip = modified_from_ip 

5370 if modified_via is not None: 

5371 tool.modified_via = modified_via 

5372 if modified_user_agent is not None: 

5373 tool.modified_user_agent = modified_user_agent 

5374 

5375 # Increment version 

5376 if hasattr(tool, "version") and tool.version is not None: 

5377 tool.version += 1 

5378 else: 

5379 tool.version = 1 

5380 logger.info(f"Update tool: {tool.name} (output_schema: {tool.output_schema})") 

5381 

5382 tool.updated_at = datetime.now(timezone.utc) 

5383 db.commit() 

5384 db.refresh(tool) 

5385 await self._notify_tool_updated(tool) 

5386 logger.info(f"Updated tool: {tool.name}") 

5387 

5388 # Structured logging: Audit trail for tool update 

5389 changes = [] 

5390 if tool_update.name: 

5391 changes.append(f"name: {tool_update.name}") 

5392 if tool_update.visibility: 

5393 changes.append(f"visibility: {tool_update.visibility}") 

5394 if tool_update.description: 

5395 changes.append("description updated") 

5396 

5397 audit_trail.log_action( 

5398 user_id=user_email or modified_by or "system", 

5399 action="update_tool", 

5400 resource_type="tool", 

5401 resource_id=tool.id, 

5402 resource_name=tool.name, 

5403 user_email=user_email, 

5404 team_id=tool.team_id, 

5405 client_ip=modified_from_ip, 

5406 user_agent=modified_user_agent, 

5407 new_values={ 

5408 "name": tool.name, 

5409 "display_name": tool.display_name, 

5410 "version": tool.version, 

5411 }, 

5412 context={ 

5413 "modified_via": modified_via, 

5414 "changes": ", ".join(changes) if changes else "metadata only", 

5415 }, 

5416 db=db, 

5417 ) 

5418 

5419 # Structured logging: Log successful tool update 

5420 structured_logger.log( 

5421 level="INFO", 

5422 message="Tool updated successfully", 

5423 event_type="tool_updated", 

5424 component="tool_service", 

5425 user_id=modified_by, 

5426 user_email=user_email, 

5427 team_id=tool.team_id, 

5428 resource_type="tool", 

5429 resource_id=tool.id, 

5430 custom_fields={ 

5431 "tool_name": tool.name, 

5432 "version": tool.version, 

5433 }, 

5434 ) 

5435 

5436 # Invalidate cache after successful update 

5437 cache = _get_registry_cache() 

5438 await cache.invalidate_tools() 

5439 tool_lookup_cache = _get_tool_lookup_cache() 

5440 await tool_lookup_cache.invalidate(old_tool_name, gateway_id=str(old_gateway_id) if old_gateway_id else None) 

5441 await tool_lookup_cache.invalidate(tool.name, gateway_id=str(tool.gateway_id) if tool.gateway_id else None) 

5442 # Also invalidate tags cache since tool tags may have changed 

5443 # First-Party 

5444 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

5445 

5446 await admin_stats_cache.invalidate_tags() 

5447 

5448 return self.convert_tool_to_read(tool, requesting_user_email=getattr(tool, "owner_email", None)) 

5449 except PermissionError as pe: 

5450 db.rollback() 

5451 

5452 # Structured logging: Log permission error 

5453 structured_logger.log( 

5454 level="WARNING", 

5455 message="Tool update failed due to permission error", 

5456 event_type="tool_update_permission_denied", 

5457 component="tool_service", 

5458 user_email=user_email, 

5459 resource_type="tool", 

5460 resource_id=tool_id, 

5461 error=pe, 

5462 ) 

5463 raise 

5464 except IntegrityError as ie: 

5465 db.rollback() 

5466 logger.error(f"IntegrityError during tool update: {ie}") 

5467 

5468 # Structured logging: Log database integrity error 

5469 structured_logger.log( 

5470 level="ERROR", 

5471 message="Tool update failed due to database integrity error", 

5472 event_type="tool_update_failed", 

5473 component="tool_service", 

5474 user_id=modified_by, 

5475 user_email=user_email, 

5476 resource_type="tool", 

5477 resource_id=tool_id, 

5478 error=ie, 

5479 ) 

5480 raise ie 

5481 except ToolNotFoundError as tnfe: 

5482 db.rollback() 

5483 logger.error(f"Tool not found during update: {tnfe}") 

5484 

5485 # Structured logging: Log not found error 

5486 structured_logger.log( 

5487 level="ERROR", 

5488 message="Tool update failed - tool not found", 

5489 event_type="tool_not_found", 

5490 component="tool_service", 

5491 user_email=user_email, 

5492 resource_type="tool", 

5493 resource_id=tool_id, 

5494 error=tnfe, 

5495 ) 

5496 raise tnfe 

5497 except ToolNameConflictError as tnce: 

5498 db.rollback() 

5499 logger.error(f"Tool name conflict during update: {tnce}") 

5500 

5501 # Structured logging: Log name conflict error 

5502 structured_logger.log( 

5503 level="WARNING", 

5504 message="Tool update failed due to name conflict", 

5505 event_type="tool_name_conflict", 

5506 component="tool_service", 

5507 user_id=modified_by, 

5508 user_email=user_email, 

5509 resource_type="tool", 

5510 resource_id=tool_id, 

5511 error=tnce, 

5512 ) 

5513 raise tnce 

5514 except Exception as ex: 

5515 db.rollback() 

5516 

5517 # Structured logging: Log generic tool update failure 

5518 structured_logger.log( 

5519 level="ERROR", 

5520 message="Tool update failed", 

5521 event_type="tool_update_failed", 

5522 component="tool_service", 

5523 user_id=modified_by, 

5524 user_email=user_email, 

5525 resource_type="tool", 

5526 resource_id=tool_id, 

5527 error=ex, 

5528 ) 

5529 raise ToolError(f"Failed to update tool: {str(ex)}") 

5530 

5531 async def _notify_tool_updated(self, tool: DbTool) -> None: 

5532 """ 

5533 Notify subscribers of tool update. 

5534 

5535 Args: 

5536 tool: Tool updated 

5537 """ 

5538 event = { 

5539 "type": "tool_updated", 

5540 "data": {"id": tool.id, "name": tool.name, "url": tool.url, "description": tool.description, "enabled": tool.enabled}, 

5541 "timestamp": datetime.now(timezone.utc).isoformat(), 

5542 } 

5543 await self._publish_event(event) 

5544 

5545 async def _notify_tool_activated(self, tool: DbTool) -> None: 

5546 """ 

5547 Notify subscribers of tool activation. 

5548 

5549 Args: 

5550 tool: Tool activated 

5551 """ 

5552 event = { 

5553 "type": "tool_activated", 

5554 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable}, 

5555 "timestamp": datetime.now(timezone.utc).isoformat(), 

5556 } 

5557 await self._publish_event(event) 

5558 

5559 async def _notify_tool_deactivated(self, tool: DbTool) -> None: 

5560 """ 

5561 Notify subscribers of tool deactivation. 

5562 

5563 Args: 

5564 tool: Tool deactivated 

5565 """ 

5566 event = { 

5567 "type": "tool_deactivated", 

5568 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled, "reachable": tool.reachable}, 

5569 "timestamp": datetime.now(timezone.utc).isoformat(), 

5570 } 

5571 await self._publish_event(event) 

5572 

5573 async def _notify_tool_offline(self, tool: DbTool) -> None: 

5574 """ 

5575 Notify subscribers that tool is offline. 

5576 

5577 Args: 

5578 tool: Tool database object 

5579 """ 

5580 event = { 

5581 "type": "tool_offline", 

5582 "data": { 

5583 "id": tool.id, 

5584 "name": tool.name, 

5585 "enabled": True, 

5586 "reachable": False, 

5587 }, 

5588 "timestamp": datetime.now(timezone.utc).isoformat(), 

5589 } 

5590 await self._publish_event(event) 

5591 

5592 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None: 

5593 """ 

5594 Notify subscribers of tool deletion. 

5595 

5596 Args: 

5597 tool_info: Dictionary on tool deleted 

5598 """ 

5599 event = { 

5600 "type": "tool_deleted", 

5601 "data": tool_info, 

5602 "timestamp": datetime.now(timezone.utc).isoformat(), 

5603 } 

5604 await self._publish_event(event) 

5605 

5606 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

5607 """Subscribe to tool events via the EventService. 

5608 

5609 Yields: 

5610 Tool event messages. 

5611 """ 

5612 async for event in self._event_service.subscribe_events(): 

5613 yield event 

5614 

5615 async def _notify_tool_added(self, tool: DbTool) -> None: 

5616 """ 

5617 Notify subscribers of tool addition. 

5618 

5619 Args: 

5620 tool: Tool added 

5621 """ 

5622 event = { 

5623 "type": "tool_added", 

5624 "data": { 

5625 "id": tool.id, 

5626 "name": tool.name, 

5627 "url": tool.url, 

5628 "description": tool.description, 

5629 "enabled": tool.enabled, 

5630 }, 

5631 "timestamp": datetime.now(timezone.utc).isoformat(), 

5632 } 

5633 await self._publish_event(event) 

5634 

5635 async def _notify_tool_removed(self, tool: DbTool) -> None: 

5636 """ 

5637 Notify subscribers of tool removal (soft delete/deactivation). 

5638 

5639 Args: 

5640 tool: Tool removed 

5641 """ 

5642 event = { 

5643 "type": "tool_removed", 

5644 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled}, 

5645 "timestamp": datetime.now(timezone.utc).isoformat(), 

5646 } 

5647 await self._publish_event(event) 

5648 

5649 async def _publish_event(self, event: Dict[str, Any]) -> None: 

5650 """ 

5651 Publish event to all subscribers via the EventService. 

5652 

5653 Args: 

5654 event: Event to publish 

5655 """ 

5656 await self._event_service.publish_event(event) 

5657 

5658 async def _validate_tool_url(self, url: str) -> None: 

5659 """Validate tool URL is accessible. 

5660 

5661 Args: 

5662 url: URL to validate. 

5663 

5664 Raises: 

5665 ToolValidationError: If URL validation fails. 

5666 """ 

5667 try: 

5668 response = await self._http_client.get(url) 

5669 response.raise_for_status() 

5670 except Exception as e: 

5671 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}") 

5672 

5673 async def _check_tool_health(self, tool: DbTool) -> bool: 

5674 """Check if tool endpoint is healthy. 

5675 

5676 Args: 

5677 tool: Tool to check. 

5678 

5679 Returns: 

5680 True if tool is healthy. 

5681 """ 

5682 try: 

5683 response = await self._http_client.get(tool.url) 

5684 return response.is_success 

5685 except Exception: 

5686 return False 

5687 

5688 # async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]: 

5689 # """Generate tool events for SSE. 

5690 

5691 # Yields: 

5692 # Tool events. 

5693 # """ 

5694 # queue: asyncio.Queue = asyncio.Queue() 

5695 # self._event_subscribers.append(queue) 

5696 # try: 

5697 # while True: 

5698 # event = await queue.get() 

5699 # yield event 

5700 # finally: 

5701 # self._event_subscribers.remove(queue) 

5702 

5703 # --- Metrics --- 

5704 async def aggregate_metrics(self, db: Session) -> ToolMetrics: 

5705 """ 

5706 Aggregate metrics for all tool invocations across all tools. 

5707 

5708 Combines recent raw metrics (within retention period) with historical 

5709 hourly rollups for complete historical coverage. Uses in-memory caching 

5710 (10s TTL) to reduce database load under high request rates. 

5711 

5712 Args: 

5713 db: Database session 

5714 

5715 Returns: 

5716 ToolMetrics: Aggregated metrics computed from raw ToolMetric + ToolMetricsHourly. 

5717 

5718 Examples: 

5719 >>> from mcpgateway.services.tool_service import ToolService 

5720 >>> service = ToolService() 

5721 >>> # Method exists and is callable 

5722 >>> callable(service.aggregate_metrics) 

5723 True 

5724 """ 

5725 # Check cache first (if enabled) 

5726 # First-Party 

5727 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel 

5728 

5729 if is_cache_enabled(): 

5730 cached = metrics_cache.get("tools") 

5731 if cached is not None: 

5732 return ToolMetrics(**cached) 

5733 

5734 # Use combined raw + rollup query for full historical coverage 

5735 # First-Party 

5736 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel 

5737 

5738 result = aggregate_metrics_combined(db, "tool") 

5739 metrics = ToolMetrics( 

5740 total_executions=result.total_executions, 

5741 successful_executions=result.successful_executions, 

5742 failed_executions=result.failed_executions, 

5743 failure_rate=result.failure_rate, 

5744 min_response_time=result.min_response_time, 

5745 max_response_time=result.max_response_time, 

5746 avg_response_time=result.avg_response_time, 

5747 last_execution_time=result.last_execution_time, 

5748 ) 

5749 

5750 # Cache the result as dict for serialization compatibility (if enabled) 

5751 if is_cache_enabled(): 

5752 metrics_cache.set("tools", metrics.model_dump()) 

5753 

5754 return metrics 

5755 

5756 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None: 

5757 """ 

5758 Reset all tool metrics by deleting raw and hourly rollup records. 

5759 

5760 Args: 

5761 db: Database session 

5762 tool_id: Optional tool ID to reset metrics for a specific tool 

5763 

5764 Examples: 

5765 >>> from mcpgateway.services.tool_service import ToolService 

5766 >>> from unittest.mock import MagicMock 

5767 >>> service = ToolService() 

5768 >>> db = MagicMock() 

5769 >>> db.execute = MagicMock() 

5770 >>> db.commit = MagicMock() 

5771 >>> import asyncio 

5772 >>> asyncio.run(service.reset_metrics(db)) 

5773 """ 

5774 

5775 if tool_id: 

5776 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id)) 

5777 db.execute(delete(ToolMetricsHourly).where(ToolMetricsHourly.tool_id == tool_id)) 

5778 else: 

5779 db.execute(delete(ToolMetric)) 

5780 db.execute(delete(ToolMetricsHourly)) 

5781 db.commit() 

5782 

5783 # Invalidate metrics cache 

5784 # First-Party 

5785 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel 

5786 

5787 metrics_cache.invalidate("tools") 

5788 metrics_cache.invalidate_prefix("top_tools:") 

5789 

5790 async def create_tool_from_a2a_agent( 

5791 self, 

5792 db: Session, 

5793 agent: DbA2AAgent, 

5794 created_by: Optional[str] = None, 

5795 created_from_ip: Optional[str] = None, 

5796 created_via: Optional[str] = None, 

5797 created_user_agent: Optional[str] = None, 

5798 ) -> DbTool: 

5799 """Create a tool entry from an A2A agent for virtual server integration. 

5800 

5801 Args: 

5802 db: Database session. 

5803 agent: A2A agent to create tool from. 

5804 created_by: Username who created this tool. 

5805 created_from_ip: IP address of creator. 

5806 created_via: Creation method. 

5807 created_user_agent: User agent of creation request. 

5808 

5809 Returns: 

5810 The created tool database object. 

5811 

5812 Raises: 

5813 ToolNameConflictError: If a tool with the same name already exists. 

5814 """ 

5815 # Check if tool already exists for this agent 

5816 tool_name = f"a2a_{agent.slug}" 

5817 existing_query = select(DbTool).where(DbTool.original_name == tool_name) 

5818 existing_tool = db.execute(existing_query).scalar_one_or_none() 

5819 

5820 if existing_tool: 

5821 # Tool already exists, return it 

5822 return existing_tool 

5823 

5824 # Create tool entry for the A2A agent 

5825 logger.debug(f"agent.tags: {agent.tags} for agent: {agent.name} (ID: {agent.id})") 

5826 

5827 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..}, 

5828 # extract the human-friendly label. If tags are already strings, keep them. 

5829 normalized_tags: list[str] = [] 

5830 for t in agent.tags or []: 

5831 if isinstance(t, dict): 

5832 # Prefer 'label', fall back to 'id' or stringified dict 

5833 normalized_tags.append(t.get("label") or t.get("id") or str(t)) 

5834 elif hasattr(t, "label"): 

5835 normalized_tags.append(getattr(t, "label")) 

5836 else: 

5837 normalized_tags.append(str(t)) 

5838 

5839 # Ensure we include identifying A2A tags 

5840 normalized_tags = normalized_tags + ["a2a", "agent"] 

5841 

5842 tool_data = ToolCreate( 

5843 name=tool_name, 

5844 displayName=generate_display_name(agent.name), 

5845 url=agent.endpoint_url, 

5846 description=f"A2A Agent: {agent.description or agent.name}", 

5847 integration_type="A2A", # Special integration type for A2A agents 

5848 request_type="POST", 

5849 input_schema={ 

5850 "type": "object", 

5851 "properties": { 

5852 "query": {"type": "string", "description": "User query", "default": "Hello from ContextForge Admin UI test!"}, 

5853 }, 

5854 "required": ["query"], 

5855 }, 

5856 allow_auto=True, 

5857 annotations={ 

5858 "title": f"A2A Agent: {agent.name}", 

5859 "a2a_agent_id": agent.id, 

5860 "a2a_agent_type": agent.agent_type, 

5861 }, 

5862 auth_type=agent.auth_type, 

5863 auth_value=agent.auth_value, 

5864 tags=normalized_tags, 

5865 ) 

5866 

5867 # Default to "public" visibility if agent visibility is not set 

5868 # This ensures A2A tools are visible in the Global Tools Tab 

5869 tool_visibility = agent.visibility or "public" 

5870 

5871 tool_read = await self.register_tool( 

5872 db, 

5873 tool_data, 

5874 created_by=created_by, 

5875 created_from_ip=created_from_ip, 

5876 created_via=created_via or "a2a_integration", 

5877 created_user_agent=created_user_agent, 

5878 team_id=agent.team_id, 

5879 owner_email=agent.owner_email, 

5880 visibility=tool_visibility, 

5881 ) 

5882 

5883 # Return the DbTool object for relationship assignment 

5884 tool_db = db.get(DbTool, tool_read.id) 

5885 return tool_db 

5886 

5887 async def update_tool_from_a2a_agent( 

5888 self, 

5889 db: Session, 

5890 agent: DbA2AAgent, 

5891 modified_by: Optional[str] = None, 

5892 modified_from_ip: Optional[str] = None, 

5893 modified_via: Optional[str] = None, 

5894 modified_user_agent: Optional[str] = None, 

5895 ) -> Optional[ToolRead]: 

5896 """Update the tool associated with an A2A agent when the agent is updated. 

5897 

5898 Args: 

5899 db: Database session. 

5900 agent: Updated A2A agent. 

5901 modified_by: Username who modified this tool. 

5902 modified_from_ip: IP address of modifier. 

5903 modified_via: Modification method. 

5904 modified_user_agent: User agent of modification request. 

5905 

5906 Returns: 

5907 The updated tool, or None if no associated tool exists. 

5908 """ 

5909 # Use the tool_id from the agent for efficient lookup 

5910 if not agent.tool_id: 

5911 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool update") 

5912 return None 

5913 

5914 tool = db.get(DbTool, agent.tool_id) 

5915 if not tool: 

5916 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}, resetting tool_id") 

5917 agent.tool_id = None 

5918 db.commit() 

5919 return None 

5920 

5921 # Normalize tags: if agent.tags contains dicts like {'id':..,'label':..}, 

5922 # extract the human-friendly label. If tags are already strings, keep them. 

5923 normalized_tags: list[str] = [] 

5924 for t in agent.tags or []: 

5925 if isinstance(t, dict): 

5926 # Prefer 'label', fall back to 'id' or stringified dict 

5927 normalized_tags.append(t.get("label") or t.get("id") or str(t)) 

5928 elif hasattr(t, "label"): 

5929 normalized_tags.append(getattr(t, "label")) 

5930 else: 

5931 normalized_tags.append(str(t)) 

5932 

5933 # Ensure we include identifying A2A tags 

5934 normalized_tags = normalized_tags + ["a2a", "agent"] 

5935 

5936 # Prepare update data matching the agent's current state 

5937 # IMPORTANT: Preserve the existing tool's visibility to avoid unintentionally 

5938 # making private/team tools public (ToolUpdate defaults to "public") 

5939 # Note: team_id is not a field on ToolUpdate schema, so team assignment is preserved 

5940 # implicitly by not changing visibility (team tools stay team-scoped) 

5941 new_tool_name = f"a2a_{agent.slug}" 

5942 tool_update = ToolUpdate( 

5943 name=new_tool_name, 

5944 custom_name=new_tool_name, # Also set custom_name to ensure name update works 

5945 displayName=generate_display_name(agent.name), 

5946 url=agent.endpoint_url, 

5947 description=f"A2A Agent: {agent.description or agent.name}", 

5948 auth=AuthenticationValues(auth_type=agent.auth_type, auth_value=agent.auth_value) if agent.auth_type else None, 

5949 tags=normalized_tags, 

5950 visibility=tool.visibility, # Preserve existing visibility 

5951 ) 

5952 

5953 # Update the tool 

5954 return await self.update_tool( 

5955 db=db, 

5956 tool_id=tool.id, 

5957 tool_update=tool_update, 

5958 modified_by=modified_by, 

5959 modified_from_ip=modified_from_ip, 

5960 modified_via=modified_via or "a2a_sync", 

5961 modified_user_agent=modified_user_agent, 

5962 ) 

5963 

5964 async def delete_tool_from_a2a_agent(self, db: Session, agent: DbA2AAgent, user_email: Optional[str] = None, purge_metrics: bool = False) -> None: 

5965 """Delete the tool associated with an A2A agent when the agent is deleted. 

5966 

5967 Args: 

5968 db: Database session. 

5969 agent: The A2A agent being deleted. 

5970 user_email: Email of user performing delete (for ownership check). 

5971 purge_metrics: If True, delete raw + rollup metrics for this tool. 

5972 """ 

5973 # Use the tool_id from the agent for efficient lookup 

5974 if not agent.tool_id: 

5975 logger.debug(f"No tool_id found for A2A agent {agent.id}, skipping tool deletion") 

5976 return 

5977 

5978 tool = db.get(DbTool, agent.tool_id) 

5979 if not tool: 

5980 logger.warning(f"Tool {agent.tool_id} not found for A2A agent {agent.id}") 

5981 return 

5982 

5983 # Delete the tool 

5984 await self.delete_tool(db=db, tool_id=tool.id, user_email=user_email, purge_metrics=purge_metrics) 

5985 logger.info(f"Deleted tool {tool.id} associated with A2A agent {agent.id}") 

5986 

5987 async def _invoke_a2a_tool(self, db: Session, tool: DbTool, arguments: Dict[str, Any]) -> ToolResult: 

5988 """Invoke an A2A agent through its corresponding tool. 

5989 

5990 Args: 

5991 db: Database session. 

5992 tool: The tool record that represents the A2A agent. 

5993 arguments: Tool arguments. 

5994 

5995 Returns: 

5996 Tool result from A2A agent invocation. 

5997 

5998 Raises: 

5999 ToolNotFoundError: If the A2A agent is not found. 

6000 """ 

6001 

6002 # Extract A2A agent ID from tool annotations 

6003 agent_id = tool.annotations.get("a2a_agent_id") 

6004 if not agent_id: 

6005 raise ToolNotFoundError(f"A2A tool '{tool.name}' missing agent ID in annotations") 

6006 

6007 # Get the A2A agent 

6008 agent_query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id) 

6009 agent = db.execute(agent_query).scalar_one_or_none() 

6010 

6011 if not agent: 

6012 raise ToolNotFoundError(f"A2A agent not found for tool '{tool.name}' (agent ID: {agent_id})") 

6013 

6014 if not agent.enabled: 

6015 raise ToolNotFoundError(f"A2A agent '{agent.name}' is disabled") 

6016 

6017 # Force-load all attributes needed by _call_a2a_agent before detaching 

6018 # (accessing them ensures they're loaded into the object's __dict__) 

6019 _ = (agent.name, agent.endpoint_url, agent.agent_type, agent.protocol_version, agent.auth_type, agent.auth_value, agent.auth_query_params) 

6020 

6021 # Detach agent from session so its loaded data remains accessible after close 

6022 db.expunge(agent) 

6023 

6024 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls 

6025 # This prevents "idle in transaction" connection pool exhaustion under load 

6026 db.commit() 

6027 db.close() 

6028 

6029 # Prepare parameters for A2A invocation 

6030 try: 

6031 # Make the A2A agent call (agent is now detached but data is loaded) 

6032 response_data = await self._call_a2a_agent(agent, arguments) 

6033 

6034 # Convert A2A response to MCP ToolResult format 

6035 if isinstance(response_data, dict) and "response" in response_data: 

6036 val = response_data["response"] 

6037 content = [TextContent(type="text", text=val if isinstance(val, str) else orjson.dumps(val).decode())] 

6038 else: 

6039 content = [TextContent(type="text", text=response_data if isinstance(response_data, str) else orjson.dumps(response_data).decode())] 

6040 

6041 result = ToolResult(content=content, is_error=False) 

6042 

6043 except Exception as e: 

6044 error_message = str(e) 

6045 content = [TextContent(type="text", text=f"A2A agent error: {error_message}")] 

6046 result = ToolResult(content=content, is_error=True) 

6047 

6048 # Note: Metrics are recorded by the calling invoke_tool method, not here 

6049 return result 

6050 

6051 async def _call_a2a_agent(self, agent: DbA2AAgent, parameters: Dict[str, Any]): 

6052 """Call an A2A agent directly. 

6053 

6054 Args: 

6055 agent: The A2A agent to call. 

6056 parameters: Parameters for the interaction. 

6057 

6058 Returns: 

6059 Response from the A2A agent. 

6060 

6061 Raises: 

6062 ToolInvocationError: If authentication decryption fails. 

6063 Exception: If the call fails. 

6064 """ 

6065 logger.info(f"Calling A2A agent '{agent.name}' at {agent.endpoint_url} with arguments: {parameters}") 

6066 

6067 # Build request data based on agent type 

6068 if agent.agent_type in ["generic", "jsonrpc"] or agent.endpoint_url.endswith("/"): 

6069 # JSONRPC agents: Convert flat query to nested message structure 

6070 params = None 

6071 if isinstance(parameters, dict) and "query" in parameters and isinstance(parameters["query"], str): 

6072 # Build the nested message object for JSONRPC protocol 

6073 message_id = f"admin-test-{int(time.time())}" 

6074 # A2A v0.3.x: message.parts use "kind" (not "type"). 

6075 params = { 

6076 "message": { 

6077 "kind": "message", 

6078 "messageId": message_id, 

6079 "role": "user", 

6080 "parts": [{"kind": "text", "text": parameters["query"]}], 

6081 } 

6082 } 

6083 method = parameters.get("method", "message/send") 

6084 else: 

6085 # Already in correct format or unknown, pass through 

6086 params = parameters.get("params", parameters) 

6087 method = parameters.get("method", "message/send") 

6088 

6089 try: 

6090 request_data = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1} 

6091 logger.info(f"invoke tool JSONRPC request_data prepared: {request_data}") 

6092 except Exception as e: 

6093 logger.error(f"Error preparing JSONRPC request data: {e}") 

6094 raise 

6095 else: 

6096 # Custom agents: Pass parameters directly without JSONRPC message conversion 

6097 # Custom agents expect flat fields like {"query": "...", "message": "..."} 

6098 params = parameters if isinstance(parameters, dict) else {} 

6099 logger.info(f"invoke tool Using custom A2A format for A2A agent '{params}'") 

6100 request_data = {"interaction_type": params.get("interaction_type", "query"), "parameters": params, "protocol_version": agent.protocol_version} 

6101 logger.info(f"invoke tool request_data prepared: {request_data}") 

6102 # Make HTTP request to the agent endpoint using shared HTTP client 

6103 # First-Party 

6104 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

6105 

6106 client = await get_http_client() 

6107 headers = {"Content-Type": "application/json"} 

6108 

6109 # Determine the endpoint URL (may be modified for query_param auth) 

6110 endpoint_url = agent.endpoint_url 

6111 

6112 # Add authentication if configured 

6113 if agent.auth_type in ("api_key", "basic", "bearer", "authheaders") and agent.auth_value: 

6114 # Decrypt auth_value and extract headers (matches a2a_service.py pattern) 

6115 if isinstance(agent.auth_value, str): 

6116 try: 

6117 auth_headers = decode_auth(agent.auth_value) 

6118 headers.update(auth_headers) 

6119 except Exception as e: 

6120 logger.error(f"Failed to decrypt authentication for A2A agent '{agent.name}': {e}") 

6121 raise ToolInvocationError(f"Failed to decrypt authentication for A2A agent '{agent.name}'") 

6122 elif isinstance(agent.auth_value, dict): 

6123 auth_headers = {str(k): str(v) for k, v in agent.auth_value.items()} 

6124 headers.update(auth_headers) 

6125 elif agent.auth_type == "query_param" and agent.auth_query_params: 

6126 # Handle query parameter authentication (imports at top: decode_auth, apply_query_param_auth, sanitize_url_for_logging) 

6127 auth_query_params_decrypted: dict[str, str] = {} 

6128 for param_key, encrypted_value in agent.auth_query_params.items(): 

6129 if encrypted_value: 

6130 try: 

6131 decrypted = decode_auth(encrypted_value) 

6132 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

6133 except Exception: 

6134 logger.debug(f"Failed to decrypt query param for key '{param_key}'") 

6135 if auth_query_params_decrypted: 

6136 endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted) 

6137 # Log sanitized URL to avoid credential leakage 

6138 sanitized_url = sanitize_url_for_logging(endpoint_url, auth_query_params_decrypted) 

6139 logger.debug(f"Applied query param auth to A2A agent endpoint: {sanitized_url}") 

6140 

6141 http_response = await client.post(endpoint_url, json=request_data, headers=headers) 

6142 

6143 if http_response.status_code == 200: 

6144 return http_response.json() 

6145 

6146 raise Exception(f"HTTP {http_response.status_code}: {http_response.text}") 

6147 

6148 

6149# Lazy singleton - created on first access, not at module import time. 

6150# This avoids instantiation when only exception classes are imported. 

6151_tool_service_instance = None # pylint: disable=invalid-name 

6152 

6153 

6154def __getattr__(name: str): 

6155 """Module-level __getattr__ for lazy singleton creation. 

6156 

6157 Args: 

6158 name: The attribute name being accessed. 

6159 

6160 Returns: 

6161 The tool_service singleton instance if name is "tool_service". 

6162 

6163 Raises: 

6164 AttributeError: If the attribute name is not "tool_service". 

6165 """ 

6166 global _tool_service_instance # pylint: disable=global-statement 

6167 if name == "tool_service": 

6168 if _tool_service_instance is None: 

6169 _tool_service_instance = ToolService() 

6170 return _tool_service_instance 

6171 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")