Coverage for mcpgateway / plugins / framework / external / mcp / client.py: 100%
352 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/external/mcp/client.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor, Fred Araujo
7External plugin client which connects to a remote server through MCP.
8Module that contains plugin MCP client code to serve external plugins.
9"""
11# Standard
12import asyncio
13from contextlib import AsyncExitStack
14from functools import partial
15import logging
16import os
17from pathlib import Path
18import sys
19from typing import Any, Awaitable, Callable, Optional
21# Third-Party
22import httpx
23from mcp import ClientSession, McpError, StdioServerParameters
24from mcp.client.stdio import stdio_client
25from mcp.client.streamable_http import streamablehttp_client
26from mcp.types import TextContent
27import orjson
29# First-Party
30from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef
31from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, HOOK_TYPE, IGNORE_CONFIG_EXTERNAL, INVOKE_HOOK, NAME, PAYLOAD, PLUGIN_NAME, PYTHON_SUFFIX, RESULT
32from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError
33from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context
34from mcpgateway.plugins.framework.hooks.registry import get_hook_registry
35from mcpgateway.plugins.framework.models import MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, PluginPayload, PluginResult, TransportType
36from mcpgateway.plugins.framework.settings import get_http_client_settings
38logger = logging.getLogger(__name__)
41class ExternalPlugin(Plugin):
42 """External plugin object for pre/post processing of inputs and outputs at various locations throughout the gateway.
44 The External Plugin connects to a remote MCP server that contains plugins.
45 """
47 def __init__(self, config: PluginConfig) -> None:
48 """Initialize a plugin with a configuration and context.
50 Args:
51 config: The plugin configuration
52 """
53 super().__init__(config)
54 self._session: Optional[ClientSession] = None
55 self._exit_stack = AsyncExitStack()
56 self._http: Optional[Any]
57 self._stdio: Optional[Any]
58 self._write: Optional[Any]
59 self._current_task = asyncio.current_task()
60 self._stdio_exit_stack: Optional[AsyncExitStack] = None
61 self._stdio_task: Optional[asyncio.Task[None]] = None
62 self._stdio_ready: Optional[asyncio.Event] = None
63 self._stdio_stop: Optional[asyncio.Event] = None
64 self._stdio_error: Optional[BaseException] = None
65 self._get_session_id: Optional[Callable[[], str | None]] = None
66 self._session_id: Optional[str] = None
67 self._http_client_factory: Optional[Callable[..., httpx.AsyncClient]] = None
68 self._reconnect_attempts: int = 3 # Will be loaded from config
69 self._reconnect_delay: float = 0.1 # Will be loaded from config
71 async def initialize(self) -> None:
72 """Initialize the plugin's connection to the MCP server.
74 Raises:
75 PluginError: if unable to retrieve plugin configuration of external plugin.
76 """
78 if not self._config.mcp:
79 raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name))
81 # Load reconnect configuration
82 self._reconnect_attempts = self._config.mcp.reconnect_attempts
83 self._reconnect_delay = self._config.mcp.reconnect_delay
85 if self._config.mcp.proto == TransportType.STDIO:
86 if not (self._config.mcp.script or self._config.mcp.cmd):
87 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name))
88 await self.__connect_to_stdio_server(self._config.mcp.script, self._config.mcp.cmd, self._config.mcp.env, self._config.mcp.cwd)
89 elif self._config.mcp.proto == TransportType.STREAMABLEHTTP:
90 if not self._config.mcp.url:
91 raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name))
92 await self.__connect_to_http_server(self._config.mcp.url)
94 try:
95 config = await self.__get_plugin_config()
97 if not config:
98 raise PluginError(error=PluginErrorModel(message="Unable to retrieve configuration for external plugin", plugin_name=self.name))
100 current_config = self._config.model_dump(exclude_unset=True)
101 remote_config = config.model_dump(exclude_unset=True)
102 remote_config.update(current_config)
104 context = {IGNORE_CONFIG_EXTERNAL: True}
106 self._config = PluginConfig.model_validate(remote_config, context=context)
107 except PluginError as pe:
108 try:
109 await self.shutdown()
110 except Exception as shutdown_error:
111 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error)
112 logger.exception(pe)
113 raise
114 except Exception as e:
115 try:
116 await self.shutdown()
117 except Exception as shutdown_error:
118 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error)
119 logger.exception(e)
120 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
122 def __resolve_stdio_command(self, script_path: str | None, cmd: list[str] | None, cwd: str | None) -> tuple[str, list[str]]:
123 """Resolve the stdio command + args from config.
125 Args:
126 script_path: Path to a server script or executable.
127 cmd: Command list to execute (command + args).
128 cwd: Working directory for resolving relative script paths.
130 Returns:
131 Tuple of (command, args).
133 Raises:
134 PluginError: if the script is invalid or cmd is malformed.
135 """
136 if cmd:
137 if not isinstance(cmd, list) or not cmd or not all(isinstance(part, str) and part.strip() for part in cmd):
138 raise PluginError(error=PluginErrorModel(message="STDIO cmd must be a non-empty list of strings", plugin_name=self.name))
139 return cmd[0], cmd[1:]
141 if not script_path:
142 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name))
144 server_path = Path(script_path).expanduser()
145 if not server_path.is_absolute() and cwd:
146 server_path = Path(cwd).expanduser() / server_path
147 resolved_script_path = str(server_path)
148 if not server_path.is_file():
149 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} does not exist.", plugin_name=self.name))
151 if server_path.suffix == PYTHON_SUFFIX:
152 return sys.executable, [resolved_script_path]
153 if server_path.suffix == ".sh":
154 return "sh", [resolved_script_path]
155 if not os.access(server_path, os.X_OK):
156 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} must be executable.", plugin_name=self.name))
157 return resolved_script_path, []
159 def __build_stdio_env(self, extra_env: dict[str, str] | None) -> dict[str, str]:
160 """Build environment for the stdio server process.
162 Args:
163 extra_env: Environment overrides to merge into the current process env.
165 Returns:
166 Combined environment dictionary for the plugin process.
167 """
168 current_env = os.environ.copy()
169 if extra_env:
170 current_env.update(extra_env)
171 return current_env
173 async def __run_stdio_session(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None:
174 """Run a stdio session in a dedicated task for consistent setup/teardown.
176 Args:
177 server_script_path: Path to the server script or executable.
178 cmd: Command list to start the server (command + args).
179 env: Environment overrides for the server process.
180 cwd: Working directory for the server process.
181 """
182 try:
183 command, args = self.__resolve_stdio_command(server_script_path, cmd, cwd)
184 server_env = self.__build_stdio_env(env)
185 server_params = StdioServerParameters(command=command, args=args, env=server_env, cwd=cwd)
187 self._stdio_exit_stack = AsyncExitStack()
188 stdio_transport = await self._stdio_exit_stack.enter_async_context(stdio_client(server_params))
189 self._stdio, self._write = stdio_transport
190 self._session = await self._stdio_exit_stack.enter_async_context(ClientSession(self._stdio, self._write))
192 await self._session.initialize()
194 response = await self._session.list_tools()
195 tools = response.tools
196 logger.info("\nConnected to plugin MCP server (stdio) with tools: %s", " ".join([tool.name for tool in tools]))
197 except Exception as e:
198 self._stdio_error = e
199 logger.exception(e)
200 finally:
201 if self._stdio_ready and not self._stdio_ready.is_set():
202 self._stdio_ready.set()
204 if self._stdio_error:
205 if self._stdio_exit_stack:
206 await self._stdio_exit_stack.aclose()
207 return
209 if self._stdio_stop:
210 await self._stdio_stop.wait()
212 if self._stdio_exit_stack:
213 await self._stdio_exit_stack.aclose()
215 async def __connect_to_stdio_server(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None:
216 """Connect to an MCP plugin server via stdio.
218 Args:
219 server_script_path: Path to the server script or executable.
220 cmd: Command list to start the server (command + args).
221 env: Environment overrides for the server process.
222 cwd: Working directory for the server process.
224 Raises:
225 PluginError: if stdio script/cmd is invalid or if there is a connection error.
226 """
227 try:
228 if not self._stdio_ready:
229 self._stdio_ready = asyncio.Event()
230 if not self._stdio_stop:
231 self._stdio_stop = asyncio.Event()
232 self._stdio_error = None
234 self._stdio_task = asyncio.create_task(
235 self.__run_stdio_session(server_script_path, cmd, env, cwd),
236 name=f"external-plugin-stdio-{self.name}",
237 )
239 await self._stdio_ready.wait()
240 if self._stdio_error:
241 raise PluginError(error=convert_exception_to_error(self._stdio_error, plugin_name=self.name))
242 except PluginError:
243 raise
244 except Exception as e:
245 logger.exception(e)
246 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
248 async def __connect_to_http_server(self, uri: str) -> None:
249 """Connect to an MCP plugin server via streamable http with retry logic.
251 Args:
252 uri: the URI of the mcp plugin server.
254 Raises:
255 PluginError: if there is an external connection error after all retries.
256 """
257 plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None
258 uds_path = self._config.mcp.uds if self._config and self._config.mcp else None
259 if uds_path and plugin_tls:
260 logger.warning("TLS configuration is ignored for Unix domain socket connections.")
261 tls_config = None if uds_path else (plugin_tls or MCPClientTLSConfig.from_env())
263 def _tls_httpx_client_factory(
264 headers: Optional[dict[str, str]] = None,
265 timeout: Optional[httpx.Timeout] = None,
266 auth: Optional[httpx.Auth] = None,
267 ) -> httpx.AsyncClient:
268 """Build an httpx client with TLS configuration for external MCP servers.
270 Args:
271 headers: Optional HTTP headers to include in requests.
272 timeout: Optional timeout configuration for HTTP requests.
273 auth: Optional authentication handler for HTTP requests.
275 Returns:
276 Configured httpx AsyncClient with TLS settings applied.
278 Raises:
279 PluginError: If TLS configuration fails.
280 """
282 kwargs: dict[str, Any] = {"follow_redirects": True}
283 if uds_path:
284 kwargs["transport"] = httpx.AsyncHTTPTransport(uds=uds_path)
285 if headers:
286 kwargs["headers"] = headers
287 http_settings = get_http_client_settings()
288 kwargs["timeout"] = (
289 timeout
290 if timeout
291 else httpx.Timeout(
292 connect=http_settings.httpx_connect_timeout,
293 read=http_settings.httpx_read_timeout,
294 write=http_settings.httpx_write_timeout,
295 pool=http_settings.httpx_pool_timeout,
296 )
297 )
298 if auth is not None:
299 kwargs["auth"] = auth
301 # Add connection pool limits
302 kwargs["limits"] = httpx.Limits(
303 max_connections=http_settings.httpx_max_connections,
304 max_keepalive_connections=http_settings.httpx_max_keepalive_connections,
305 keepalive_expiry=http_settings.httpx_keepalive_expiry,
306 )
308 if not tls_config:
309 # Use skip_ssl_verify setting when no custom TLS config
310 kwargs["verify"] = not http_settings.skip_ssl_verify
311 return httpx.AsyncClient(**kwargs)
313 # Create SSL context using the utility function
314 # This implements certificate validation per test_client_certificate_validation.py
315 ssl_context = create_ssl_context(tls_config, self.name)
316 kwargs["verify"] = ssl_context
318 return httpx.AsyncClient(**kwargs)
320 self._http_client_factory = _tls_httpx_client_factory
321 max_retries = 3
322 base_delay = 1.0
324 for attempt in range(max_retries):
326 try:
327 client_factory = _tls_httpx_client_factory
328 streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory, terminate_on_close=False)
329 http_transport = await self._exit_stack.enter_async_context(streamable_client)
330 self._http, self._write, get_session_id = http_transport
331 self._get_session_id = get_session_id
332 self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write))
334 await self._session.initialize()
335 self._session_id = self._get_session_id() if self._get_session_id else None
336 response = await self._session.list_tools()
337 tools = response.tools
338 logger.info(
339 "Successfully connected to plugin MCP server with tools: %s",
340 " ".join([tool.name for tool in tools]),
341 )
342 return
343 except Exception as e:
344 logger.warning("Connection attempt %d/%d failed: %s", attempt + 1, max_retries, e)
345 if attempt == max_retries - 1:
346 # Final attempt failed
347 target = f"{uri} (uds={uds_path})" if uds_path else uri
348 error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {target} is not reachable. Please ensure the MCP server is running."
349 logger.error(error_msg)
350 raise PluginError(error=PluginErrorModel(message=error_msg, plugin_name=self.name))
351 await self.shutdown()
352 self._exit_stack = AsyncExitStack()
353 # Wait before retry
354 delay = base_delay * (2**attempt)
355 logger.info("Retrying in %ss...", delay)
356 await asyncio.sleep(delay)
358 async def _cleanup_session(self) -> None:
359 """Clean up existing session without full shutdown.
361 Resets all transport and session state so that a subsequent
362 connection attempt starts from a clean slate. For STDIO
363 transports this includes stopping the background task and
364 resetting its synchronisation primitives so they are properly
365 re-created on the next connect call.
366 """
367 # Stop the stdio background task first (mirrors shutdown() logic)
368 if self._stdio_task:
369 if self._stdio_stop:
370 self._stdio_stop.set()
371 try:
372 await self._stdio_task
373 except Exception: # nosec B110 - cleanup code
374 pass
375 self._stdio_task = None
376 # Reset stdio synchronisation primitives so __connect_to_stdio_server
377 # creates fresh ones on the next connection attempt.
378 self._stdio_ready = None
379 self._stdio_stop = None
380 self._stdio_error = None
382 if self._exit_stack:
383 await self._exit_stack.aclose()
384 self._exit_stack = AsyncExitStack()
385 if self._stdio_exit_stack:
386 await self._stdio_exit_stack.aclose()
387 self._stdio_exit_stack = None
388 self._session = None
389 self._http = None
390 self._write = None
391 self._stdio = None
392 self._get_session_id = None
393 self._session_id = None
395 async def _reconnect_session(self) -> None:
396 """Tear down old session and reconnect to MCP server.
398 Implements retry logic with linear backoff.
400 Raises:
401 PluginError: If reconnection fails after all attempts.
402 """
403 logger.info("Attempting to reconnect to MCP server: %s", self.name)
405 # Clean up existing session
406 await self._cleanup_session()
408 last_error: Optional[Exception] = None
409 for attempt in range(1, self._reconnect_attempts + 1):
410 try:
411 logger.debug("Reconnection attempt %d/%d to %s", attempt, self._reconnect_attempts, self.name)
413 # Re-run connection based on transport type
414 if self._config.mcp.proto == TransportType.STREAMABLEHTTP:
415 await self.__connect_to_http_server(self._config.mcp.url)
416 elif self._config.mcp.proto == TransportType.STDIO:
417 await self.__connect_to_stdio_server(self._config.mcp.script, self._config.mcp.cmd, self._config.mcp.env, self._config.mcp.cwd)
419 logger.info("Reconnected to MCP server on attempt %d: %s", attempt, self.name)
420 return
421 except Exception as e:
422 last_error = e
423 if attempt < self._reconnect_attempts:
424 delay = self._reconnect_delay * attempt # Linear backoff
425 logger.warning("Reconnection attempt %d failed: %s. Retrying in %ss...", attempt, e, delay)
426 await asyncio.sleep(delay)
428 raise PluginError(error=PluginErrorModel(message=f"Failed to reconnect after {self._reconnect_attempts} attempts: {last_error}", plugin_name=self.name))
430 async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult:
431 """Invoke an external plugin hook using the MCP protocol.
433 Args:
434 hook_type: The type of hook invoked (i.e., prompt_pre_fetch)
435 payload: The payload to be passed to the hook.
436 context: The plugin context passed to the run.
438 Raises:
439 PluginError: Error passed from external plugin server, or if reconnection fails.
441 Returns:
442 The resulting payload from the plugin.
443 """
444 # Get the result type from the global registry
445 registry = get_hook_registry()
446 result_type = registry.get_result_type(hook_type)
447 if not result_type:
448 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name))
450 if not self._session:
451 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name))
453 async def _execute_call() -> PluginResult:
454 """Execute the MCP tool call.
456 Returns:
457 The plugin result from the tool call.
459 Raises:
460 PluginError: If the call fails or returns invalid response.
461 """
462 result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context})
463 for content in result.content:
464 if not isinstance(content, TextContent):
465 continue
466 try:
467 res = orjson.loads(content.text)
468 except orjson.JSONDecodeError:
469 raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name))
470 if CONTEXT in res:
471 cxt = PluginContext.model_validate(res[CONTEXT])
472 context.state = cxt.state
473 context.metadata = cxt.metadata
474 context.global_context.state = cxt.global_context.state
475 if RESULT in res:
476 return result_type.model_validate(res[RESULT])
477 if ERROR in res:
478 error = PluginErrorModel.model_validate(res[ERROR])
479 raise PluginError(error)
480 raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name))
482 try:
483 return await _execute_call()
484 except PluginError as pe:
485 # Check if it's a session terminated error
486 error_msg = str(pe.error.message).lower() if pe.error.message else ""
487 if "session" in error_msg and "terminated" in error_msg:
488 logger.warning("Session terminated for plugin %s, attempting reconnection...", self.name)
489 try:
490 await self._reconnect_session()
491 # Retry the request once after successful reconnection
492 return await _execute_call()
493 except Exception as reconn_err:
494 logger.exception("Reconnection failed for plugin %s: %s", self.name, reconn_err)
495 # Fall through to re-raise the original PluginError
496 # Log and re-raise the original PluginError
497 logger.exception(pe)
498 raise
499 except McpError as e:
500 logger.warning("McpError for plugin %s: %s", self.name, e)
501 try:
502 await self._reconnect_session()
503 return await _execute_call()
504 except Exception as reconn_err:
505 logger.exception("Reconnection failed for plugin %s: %s", self.name, reconn_err)
506 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
507 except Exception as e:
508 logger.exception(e)
509 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
511 async def __get_plugin_config(self) -> PluginConfig | None:
512 """Retrieve plugin configuration for the current plugin on the remote MCP server.
514 Raises:
515 PluginError: if there is a connection issue or validation issue.
517 Returns:
518 A plugin configuration for the current plugin from a remote MCP server.
519 """
520 if not self._session:
521 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name))
522 try:
523 configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name})
524 for content in configs.content:
525 if not isinstance(content, TextContent):
526 continue
527 conf = orjson.loads(content.text)
528 if not conf:
529 return None
530 return PluginConfig.model_validate(conf)
531 except Exception as e:
532 logger.exception(e)
533 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
535 return None
537 async def shutdown(self) -> None:
538 """Plugin cleanup code."""
539 if self._stdio_task:
540 if self._stdio_stop:
541 self._stdio_stop.set()
542 try:
543 await self._stdio_task
544 except Exception as e:
545 logger.error("Error shutting down stdio session for plugin %s: %s", self.name, e)
546 self._stdio_task = None
547 self._stdio_ready = None
548 self._stdio_stop = None
549 self._stdio_exit_stack = None
550 self._stdio_error = None
551 self._stdio = None
552 self._write = None
553 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STDIO:
554 self._session = None
556 if self._exit_stack:
557 await self._exit_stack.aclose()
558 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STREAMABLEHTTP:
559 await self.__terminate_http_session()
560 self._get_session_id = None
561 self._session_id = None
562 self._http_client_factory = None
564 async def __terminate_http_session(self) -> None:
565 """Terminate streamable HTTP session explicitly to avoid lingering server state."""
566 if not self._session_id or not self._config or not self._config.mcp or not self._config.mcp.url:
567 return
568 # Third-Party
569 from mcp.server.streamable_http import MCP_SESSION_ID_HEADER # pylint: disable=import-outside-toplevel
571 client_factory = self._http_client_factory
572 try:
573 if client_factory:
574 client = client_factory()
575 else:
576 client = httpx.AsyncClient(follow_redirects=True)
577 async with client:
578 headers = {MCP_SESSION_ID_HEADER: self._session_id}
579 await client.delete(self._config.mcp.url, headers=headers)
580 except Exception as exc:
581 logger.debug("Failed to terminate streamable HTTP session: %s", exc)
584class ExternalHookRef(HookRef):
585 """A Hook reference point for external plugins."""
587 def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called
588 """Initialize a hook reference point for an external plugin.
590 Note: We intentionally don't call super().__init__() because external plugins
591 use invoke_hook() rather than direct method attributes.
593 Args:
594 hook: name of the hook point.
595 plugin_ref: The reference to the plugin to hook.
597 Raises:
598 PluginError: If the plugin is not an external plugin.
599 """
600 self._plugin_ref = plugin_ref
601 self._hook = hook
602 if hasattr(plugin_ref.plugin, INVOKE_HOOK):
603 self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined]
604 else:
605 raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name))