Coverage for mcpgateway / plugins / framework / external / mcp / client.py: 100%
286 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/external/mcp/client.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor, Fred Araujo
7External plugin client which connects to a remote server through MCP.
8Module that contains plugin MCP client code to serve external plugins.
9"""
11# Standard
12import asyncio
13from contextlib import AsyncExitStack
14from functools import partial
15import logging
16import os
17from pathlib import Path
18import sys
19from typing import Any, Awaitable, Callable, Optional
21# Third-Party
22import httpx
23from mcp import ClientSession, StdioServerParameters
24from mcp.client.stdio import stdio_client
25from mcp.client.streamable_http import streamablehttp_client
26from mcp.types import TextContent
27import orjson
29# First-Party
30from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef
31from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, HOOK_TYPE, IGNORE_CONFIG_EXTERNAL, INVOKE_HOOK, NAME, PAYLOAD, PLUGIN_NAME, PYTHON_SUFFIX, RESULT
32from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError
33from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context
34from mcpgateway.plugins.framework.hooks.registry import get_hook_registry
35from mcpgateway.plugins.framework.models import MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, PluginPayload, PluginResult, TransportType
36from mcpgateway.plugins.framework.settings import get_http_client_settings
38logger = logging.getLogger(__name__)
41class ExternalPlugin(Plugin):
42 """External plugin object for pre/post processing of inputs and outputs at various locations throughout the gateway.
44 The External Plugin connects to a remote MCP server that contains plugins.
45 """
47 def __init__(self, config: PluginConfig) -> None:
48 """Initialize a plugin with a configuration and context.
50 Args:
51 config: The plugin configuration
52 """
53 super().__init__(config)
54 self._session: Optional[ClientSession] = None
55 self._exit_stack = AsyncExitStack()
56 self._http: Optional[Any]
57 self._stdio: Optional[Any]
58 self._write: Optional[Any]
59 self._current_task = asyncio.current_task()
60 self._stdio_exit_stack: Optional[AsyncExitStack] = None
61 self._stdio_task: Optional[asyncio.Task[None]] = None
62 self._stdio_ready: Optional[asyncio.Event] = None
63 self._stdio_stop: Optional[asyncio.Event] = None
64 self._stdio_error: Optional[BaseException] = None
65 self._get_session_id: Optional[Callable[[], str | None]] = None
66 self._session_id: Optional[str] = None
67 self._http_client_factory: Optional[Callable[..., httpx.AsyncClient]] = None
69 async def initialize(self) -> None:
70 """Initialize the plugin's connection to the MCP server.
72 Raises:
73 PluginError: if unable to retrieve plugin configuration of external plugin.
74 """
76 if not self._config.mcp:
77 raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name))
78 if self._config.mcp.proto == TransportType.STDIO:
79 if not (self._config.mcp.script or self._config.mcp.cmd):
80 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name))
81 await self.__connect_to_stdio_server(self._config.mcp.script, self._config.mcp.cmd, self._config.mcp.env, self._config.mcp.cwd)
82 elif self._config.mcp.proto == TransportType.STREAMABLEHTTP:
83 if not self._config.mcp.url:
84 raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name))
85 await self.__connect_to_http_server(self._config.mcp.url)
87 try:
88 config = await self.__get_plugin_config()
90 if not config:
91 raise PluginError(error=PluginErrorModel(message="Unable to retrieve configuration for external plugin", plugin_name=self.name))
93 current_config = self._config.model_dump(exclude_unset=True)
94 remote_config = config.model_dump(exclude_unset=True)
95 remote_config.update(current_config)
97 context = {IGNORE_CONFIG_EXTERNAL: True}
99 self._config = PluginConfig.model_validate(remote_config, context=context)
100 except PluginError as pe:
101 try:
102 await self.shutdown()
103 except Exception as shutdown_error:
104 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error)
105 logger.exception(pe)
106 raise
107 except Exception as e:
108 try:
109 await self.shutdown()
110 except Exception as shutdown_error:
111 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error)
112 logger.exception(e)
113 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
115 def __resolve_stdio_command(self, script_path: str | None, cmd: list[str] | None, cwd: str | None) -> tuple[str, list[str]]:
116 """Resolve the stdio command + args from config.
118 Args:
119 script_path: Path to a server script or executable.
120 cmd: Command list to execute (command + args).
121 cwd: Working directory for resolving relative script paths.
123 Returns:
124 Tuple of (command, args).
126 Raises:
127 PluginError: if the script is invalid or cmd is malformed.
128 """
129 if cmd:
130 if not isinstance(cmd, list) or not cmd or not all(isinstance(part, str) and part.strip() for part in cmd):
131 raise PluginError(error=PluginErrorModel(message="STDIO cmd must be a non-empty list of strings", plugin_name=self.name))
132 return cmd[0], cmd[1:]
134 if not script_path:
135 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name))
137 server_path = Path(script_path).expanduser()
138 if not server_path.is_absolute() and cwd:
139 server_path = Path(cwd).expanduser() / server_path
140 resolved_script_path = str(server_path)
141 if not server_path.is_file():
142 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} does not exist.", plugin_name=self.name))
144 if server_path.suffix == PYTHON_SUFFIX:
145 return sys.executable, [resolved_script_path]
146 if server_path.suffix == ".sh":
147 return "sh", [resolved_script_path]
148 if not os.access(server_path, os.X_OK):
149 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} must be executable.", plugin_name=self.name))
150 return resolved_script_path, []
152 def __build_stdio_env(self, extra_env: dict[str, str] | None) -> dict[str, str]:
153 """Build environment for the stdio server process.
155 Args:
156 extra_env: Environment overrides to merge into the current process env.
158 Returns:
159 Combined environment dictionary for the plugin process.
160 """
161 current_env = os.environ.copy()
162 if extra_env:
163 current_env.update(extra_env)
164 return current_env
166 async def __run_stdio_session(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None:
167 """Run a stdio session in a dedicated task for consistent setup/teardown.
169 Args:
170 server_script_path: Path to the server script or executable.
171 cmd: Command list to start the server (command + args).
172 env: Environment overrides for the server process.
173 cwd: Working directory for the server process.
174 """
175 try:
176 command, args = self.__resolve_stdio_command(server_script_path, cmd, cwd)
177 server_env = self.__build_stdio_env(env)
178 server_params = StdioServerParameters(command=command, args=args, env=server_env, cwd=cwd)
180 self._stdio_exit_stack = AsyncExitStack()
181 stdio_transport = await self._stdio_exit_stack.enter_async_context(stdio_client(server_params))
182 self._stdio, self._write = stdio_transport
183 self._session = await self._stdio_exit_stack.enter_async_context(ClientSession(self._stdio, self._write))
185 await self._session.initialize()
187 response = await self._session.list_tools()
188 tools = response.tools
189 logger.info("\nConnected to plugin MCP server (stdio) with tools: %s", " ".join([tool.name for tool in tools]))
190 except Exception as e:
191 self._stdio_error = e
192 logger.exception(e)
193 finally:
194 if self._stdio_ready and not self._stdio_ready.is_set():
195 self._stdio_ready.set()
197 if self._stdio_error:
198 if self._stdio_exit_stack:
199 await self._stdio_exit_stack.aclose()
200 return
202 if self._stdio_stop:
203 await self._stdio_stop.wait()
205 if self._stdio_exit_stack:
206 await self._stdio_exit_stack.aclose()
208 async def __connect_to_stdio_server(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None:
209 """Connect to an MCP plugin server via stdio.
211 Args:
212 server_script_path: Path to the server script or executable.
213 cmd: Command list to start the server (command + args).
214 env: Environment overrides for the server process.
215 cwd: Working directory for the server process.
217 Raises:
218 PluginError: if stdio script/cmd is invalid or if there is a connection error.
219 """
220 try:
221 if not self._stdio_ready:
222 self._stdio_ready = asyncio.Event()
223 if not self._stdio_stop:
224 self._stdio_stop = asyncio.Event()
225 self._stdio_error = None
227 self._stdio_task = asyncio.create_task(
228 self.__run_stdio_session(server_script_path, cmd, env, cwd),
229 name=f"external-plugin-stdio-{self.name}",
230 )
232 await self._stdio_ready.wait()
233 if self._stdio_error:
234 raise PluginError(error=convert_exception_to_error(self._stdio_error, plugin_name=self.name))
235 except PluginError:
236 raise
237 except Exception as e:
238 logger.exception(e)
239 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
241 async def __connect_to_http_server(self, uri: str) -> None:
242 """Connect to an MCP plugin server via streamable http with retry logic.
244 Args:
245 uri: the URI of the mcp plugin server.
247 Raises:
248 PluginError: if there is an external connection error after all retries.
249 """
250 plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None
251 uds_path = self._config.mcp.uds if self._config and self._config.mcp else None
252 if uds_path and plugin_tls:
253 logger.warning("TLS configuration is ignored for Unix domain socket connections.")
254 tls_config = None if uds_path else (plugin_tls or MCPClientTLSConfig.from_env())
256 def _tls_httpx_client_factory(
257 headers: Optional[dict[str, str]] = None,
258 timeout: Optional[httpx.Timeout] = None,
259 auth: Optional[httpx.Auth] = None,
260 ) -> httpx.AsyncClient:
261 """Build an httpx client with TLS configuration for external MCP servers.
263 Args:
264 headers: Optional HTTP headers to include in requests.
265 timeout: Optional timeout configuration for HTTP requests.
266 auth: Optional authentication handler for HTTP requests.
268 Returns:
269 Configured httpx AsyncClient with TLS settings applied.
271 Raises:
272 PluginError: If TLS configuration fails.
273 """
275 kwargs: dict[str, Any] = {"follow_redirects": True}
276 if uds_path:
277 kwargs["transport"] = httpx.AsyncHTTPTransport(uds=uds_path)
278 if headers:
279 kwargs["headers"] = headers
280 http_settings = get_http_client_settings()
281 kwargs["timeout"] = (
282 timeout
283 if timeout
284 else httpx.Timeout(
285 connect=http_settings.httpx_connect_timeout,
286 read=http_settings.httpx_read_timeout,
287 write=http_settings.httpx_write_timeout,
288 pool=http_settings.httpx_pool_timeout,
289 )
290 )
291 if auth is not None:
292 kwargs["auth"] = auth
294 # Add connection pool limits
295 kwargs["limits"] = httpx.Limits(
296 max_connections=http_settings.httpx_max_connections,
297 max_keepalive_connections=http_settings.httpx_max_keepalive_connections,
298 keepalive_expiry=http_settings.httpx_keepalive_expiry,
299 )
301 if not tls_config:
302 # Use skip_ssl_verify setting when no custom TLS config
303 kwargs["verify"] = not http_settings.skip_ssl_verify
304 return httpx.AsyncClient(**kwargs)
306 # Create SSL context using the utility function
307 # This implements certificate validation per test_client_certificate_validation.py
308 ssl_context = create_ssl_context(tls_config, self.name)
309 kwargs["verify"] = ssl_context
311 return httpx.AsyncClient(**kwargs)
313 self._http_client_factory = _tls_httpx_client_factory
314 max_retries = 3
315 base_delay = 1.0
317 for attempt in range(max_retries):
319 try:
320 client_factory = _tls_httpx_client_factory
321 streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory, terminate_on_close=False)
322 http_transport = await self._exit_stack.enter_async_context(streamable_client)
323 self._http, self._write, get_session_id = http_transport
324 self._get_session_id = get_session_id
325 self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write))
327 await self._session.initialize()
328 self._session_id = self._get_session_id() if self._get_session_id else None
329 response = await self._session.list_tools()
330 tools = response.tools
331 logger.info(
332 "Successfully connected to plugin MCP server with tools: %s",
333 " ".join([tool.name for tool in tools]),
334 )
335 return
336 except Exception as e:
337 logger.warning("Connection attempt %d/%d failed: %s", attempt + 1, max_retries, e)
338 if attempt == max_retries - 1:
339 # Final attempt failed
340 target = f"{uri} (uds={uds_path})" if uds_path else uri
341 error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {target} is not reachable. Please ensure the MCP server is running."
342 logger.error(error_msg)
343 raise PluginError(error=PluginErrorModel(message=error_msg, plugin_name=self.name))
344 await self.shutdown()
345 self._exit_stack = AsyncExitStack()
346 # Wait before retry
347 delay = base_delay * (2**attempt)
348 logger.info("Retrying in %ss...", delay)
349 await asyncio.sleep(delay)
351 async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult:
352 """Invoke an external plugin hook using the MCP protocol.
354 Args:
355 hook_type: The type of hook invoked (i.e., prompt_pre_fetch)
356 payload: The payload to be passed to the hook.
357 context: The plugin context passed to the run.
359 Raises:
360 PluginError: error passed from external plugin server.
362 Returns:
363 The resulting payload from the plugin.
364 """
365 # Get the result type from the global registry
366 registry = get_hook_registry()
367 result_type = registry.get_result_type(hook_type)
368 if not result_type:
369 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name))
371 if not self._session:
372 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name))
374 try:
375 result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context})
376 for content in result.content:
377 if not isinstance(content, TextContent):
378 continue
379 try:
380 res = orjson.loads(content.text)
381 except orjson.JSONDecodeError:
382 raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name))
383 if CONTEXT in res:
384 cxt = PluginContext.model_validate(res[CONTEXT])
385 context.state = cxt.state
386 context.metadata = cxt.metadata
387 context.global_context.state = cxt.global_context.state
388 if RESULT in res:
389 return result_type.model_validate(res[RESULT])
390 if ERROR in res:
391 error = PluginErrorModel.model_validate(res[ERROR])
392 raise PluginError(error)
393 except PluginError as pe:
394 logger.exception(pe)
395 raise
396 except Exception as e:
397 logger.exception(e)
398 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
399 raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name))
401 async def __get_plugin_config(self) -> PluginConfig | None:
402 """Retrieve plugin configuration for the current plugin on the remote MCP server.
404 Raises:
405 PluginError: if there is a connection issue or validation issue.
407 Returns:
408 A plugin configuration for the current plugin from a remote MCP server.
409 """
410 if not self._session:
411 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name))
412 try:
413 configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name})
414 for content in configs.content:
415 if not isinstance(content, TextContent):
416 continue
417 conf = orjson.loads(content.text)
418 if not conf:
419 return None
420 return PluginConfig.model_validate(conf)
421 except Exception as e:
422 logger.exception(e)
423 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
425 return None
427 async def shutdown(self) -> None:
428 """Plugin cleanup code."""
429 if self._stdio_task:
430 if self._stdio_stop:
431 self._stdio_stop.set()
432 try:
433 await self._stdio_task
434 except Exception as e:
435 logger.error("Error shutting down stdio session for plugin %s: %s", self.name, e)
436 self._stdio_task = None
437 self._stdio_ready = None
438 self._stdio_stop = None
439 self._stdio_exit_stack = None
440 self._stdio_error = None
441 self._stdio = None
442 self._write = None
443 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STDIO:
444 self._session = None
446 if self._exit_stack:
447 await self._exit_stack.aclose()
448 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STREAMABLEHTTP:
449 await self.__terminate_http_session()
450 self._get_session_id = None
451 self._session_id = None
452 self._http_client_factory = None
454 async def __terminate_http_session(self) -> None:
455 """Terminate streamable HTTP session explicitly to avoid lingering server state."""
456 if not self._session_id or not self._config or not self._config.mcp or not self._config.mcp.url:
457 return
458 # Third-Party
459 from mcp.server.streamable_http import MCP_SESSION_ID_HEADER # pylint: disable=import-outside-toplevel
461 client_factory = self._http_client_factory
462 try:
463 if client_factory:
464 client = client_factory()
465 else:
466 client = httpx.AsyncClient(follow_redirects=True)
467 async with client:
468 headers = {MCP_SESSION_ID_HEADER: self._session_id}
469 await client.delete(self._config.mcp.url, headers=headers)
470 except Exception as exc:
471 logger.debug("Failed to terminate streamable HTTP session: %s", exc)
474class ExternalHookRef(HookRef):
475 """A Hook reference point for external plugins."""
477 def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called
478 """Initialize a hook reference point for an external plugin.
480 Note: We intentionally don't call super().__init__() because external plugins
481 use invoke_hook() rather than direct method attributes.
483 Args:
484 hook: name of the hook point.
485 plugin_ref: The reference to the plugin to hook.
487 Raises:
488 PluginError: If the plugin is not an external plugin.
489 """
490 self._plugin_ref = plugin_ref
491 self._hook = hook
492 if hasattr(plugin_ref.plugin, INVOKE_HOOK):
493 self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined]
494 else:
495 raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name))