Coverage for mcpgateway / plugins / framework / external / mcp / client.py: 100%
286 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/external/mcp/client.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor
7External plugin client which connects to a remote server through MCP.
8Module that contains plugin MCP client code to serve external plugins.
9"""
11# Standard
12import asyncio
13from contextlib import AsyncExitStack
14from functools import partial
15import logging
16import os
17from pathlib import Path
18import sys
19from typing import Any, Awaitable, Callable, Optional
21# Third-Party
22import httpx
23from mcp import ClientSession, StdioServerParameters
24from mcp.client.stdio import stdio_client
25from mcp.client.streamable_http import streamablehttp_client
26from mcp.types import TextContent
27import orjson
29# First-Party
30from mcpgateway.common.models import TransportType
31from mcpgateway.config import settings
32from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef
33from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, HOOK_TYPE, IGNORE_CONFIG_EXTERNAL, INVOKE_HOOK, NAME, PAYLOAD, PLUGIN_NAME, PYTHON_SUFFIX, RESULT
34from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError
35from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context
36from mcpgateway.plugins.framework.hooks.registry import get_hook_registry
37from mcpgateway.plugins.framework.models import MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, PluginPayload, PluginResult
39logger = logging.getLogger(__name__)
42class ExternalPlugin(Plugin):
43 """External plugin object for pre/post processing of inputs and outputs at various locations throughout the gateway.
45 The External Plugin connects to a remote MCP server that contains plugins.
46 """
48 def __init__(self, config: PluginConfig) -> None:
49 """Initialize a plugin with a configuration and context.
51 Args:
52 config: The plugin configuration
53 """
54 super().__init__(config)
55 self._session: Optional[ClientSession] = None
56 self._exit_stack = AsyncExitStack()
57 self._http: Optional[Any]
58 self._stdio: Optional[Any]
59 self._write: Optional[Any]
60 self._current_task = asyncio.current_task()
61 self._stdio_exit_stack: Optional[AsyncExitStack] = None
62 self._stdio_task: Optional[asyncio.Task[None]] = None
63 self._stdio_ready: Optional[asyncio.Event] = None
64 self._stdio_stop: Optional[asyncio.Event] = None
65 self._stdio_error: Optional[BaseException] = None
66 self._get_session_id: Optional[Callable[[], str | None]] = None
67 self._session_id: Optional[str] = None
68 self._http_client_factory: Optional[Callable[..., httpx.AsyncClient]] = None
70 async def initialize(self) -> None:
71 """Initialize the plugin's connection to the MCP server.
73 Raises:
74 PluginError: if unable to retrieve plugin configuration of external plugin.
75 """
77 if not self._config.mcp:
78 raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name))
79 if self._config.mcp.proto == TransportType.STDIO:
80 if not (self._config.mcp.script or self._config.mcp.cmd):
81 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name))
82 await self.__connect_to_stdio_server(self._config.mcp.script, self._config.mcp.cmd, self._config.mcp.env, self._config.mcp.cwd)
83 elif self._config.mcp.proto == TransportType.STREAMABLEHTTP:
84 if not self._config.mcp.url:
85 raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name))
86 await self.__connect_to_http_server(self._config.mcp.url)
88 try:
89 config = await self.__get_plugin_config()
91 if not config:
92 raise PluginError(error=PluginErrorModel(message="Unable to retrieve configuration for external plugin", plugin_name=self.name))
94 current_config = self._config.model_dump(exclude_unset=True)
95 remote_config = config.model_dump(exclude_unset=True)
96 remote_config.update(current_config)
98 context = {IGNORE_CONFIG_EXTERNAL: True}
100 self._config = PluginConfig.model_validate(remote_config, context=context)
101 except PluginError as pe:
102 try:
103 await self.shutdown()
104 except Exception as shutdown_error:
105 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error)
106 logger.exception(pe)
107 raise
108 except Exception as e:
109 try:
110 await self.shutdown()
111 except Exception as shutdown_error:
112 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error)
113 logger.exception(e)
114 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
116 def __resolve_stdio_command(self, script_path: str | None, cmd: list[str] | None, cwd: str | None) -> tuple[str, list[str]]:
117 """Resolve the stdio command + args from config.
119 Args:
120 script_path: Path to a server script or executable.
121 cmd: Command list to execute (command + args).
122 cwd: Working directory for resolving relative script paths.
124 Returns:
125 Tuple of (command, args).
127 Raises:
128 PluginError: if the script is invalid or cmd is malformed.
129 """
130 if cmd:
131 if not isinstance(cmd, list) or not cmd or not all(isinstance(part, str) and part.strip() for part in cmd):
132 raise PluginError(error=PluginErrorModel(message="STDIO cmd must be a non-empty list of strings", plugin_name=self.name))
133 return cmd[0], cmd[1:]
135 if not script_path:
136 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name))
138 server_path = Path(script_path).expanduser()
139 if not server_path.is_absolute() and cwd:
140 server_path = Path(cwd).expanduser() / server_path
141 resolved_script_path = str(server_path)
142 if not server_path.is_file():
143 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} does not exist.", plugin_name=self.name))
145 if server_path.suffix == PYTHON_SUFFIX:
146 return sys.executable, [resolved_script_path]
147 if server_path.suffix == ".sh":
148 return "sh", [resolved_script_path]
149 if not os.access(server_path, os.X_OK):
150 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} must be executable.", plugin_name=self.name))
151 return resolved_script_path, []
153 def __build_stdio_env(self, extra_env: dict[str, str] | None) -> dict[str, str]:
154 """Build environment for the stdio server process.
156 Args:
157 extra_env: Environment overrides to merge into the current process env.
159 Returns:
160 Combined environment dictionary for the plugin process.
161 """
162 current_env = os.environ.copy()
163 if extra_env:
164 current_env.update(extra_env)
165 return current_env
167 async def __run_stdio_session(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None:
168 """Run a stdio session in a dedicated task for consistent setup/teardown.
170 Args:
171 server_script_path: Path to the server script or executable.
172 cmd: Command list to start the server (command + args).
173 env: Environment overrides for the server process.
174 cwd: Working directory for the server process.
175 """
176 try:
177 command, args = self.__resolve_stdio_command(server_script_path, cmd, cwd)
178 server_env = self.__build_stdio_env(env)
179 server_params = StdioServerParameters(command=command, args=args, env=server_env, cwd=cwd)
181 self._stdio_exit_stack = AsyncExitStack()
182 stdio_transport = await self._stdio_exit_stack.enter_async_context(stdio_client(server_params))
183 self._stdio, self._write = stdio_transport
184 self._session = await self._stdio_exit_stack.enter_async_context(ClientSession(self._stdio, self._write))
186 await self._session.initialize()
188 response = await self._session.list_tools()
189 tools = response.tools
190 logger.info("\nConnected to plugin MCP server (stdio) with tools: %s", " ".join([tool.name for tool in tools]))
191 except Exception as e:
192 self._stdio_error = e
193 logger.exception(e)
194 finally:
195 if self._stdio_ready and not self._stdio_ready.is_set():
196 self._stdio_ready.set()
198 if self._stdio_error:
199 if self._stdio_exit_stack:
200 await self._stdio_exit_stack.aclose()
201 return
203 if self._stdio_stop:
204 await self._stdio_stop.wait()
206 if self._stdio_exit_stack:
207 await self._stdio_exit_stack.aclose()
209 async def __connect_to_stdio_server(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None:
210 """Connect to an MCP plugin server via stdio.
212 Args:
213 server_script_path: Path to the server script or executable.
214 cmd: Command list to start the server (command + args).
215 env: Environment overrides for the server process.
216 cwd: Working directory for the server process.
218 Raises:
219 PluginError: if stdio script/cmd is invalid or if there is a connection error.
220 """
221 try:
222 if not self._stdio_ready:
223 self._stdio_ready = asyncio.Event()
224 if not self._stdio_stop:
225 self._stdio_stop = asyncio.Event()
226 self._stdio_error = None
228 self._stdio_task = asyncio.create_task(
229 self.__run_stdio_session(server_script_path, cmd, env, cwd),
230 name=f"external-plugin-stdio-{self.name}",
231 )
233 await self._stdio_ready.wait()
234 if self._stdio_error:
235 raise PluginError(error=convert_exception_to_error(self._stdio_error, plugin_name=self.name))
236 except PluginError:
237 raise
238 except Exception as e:
239 logger.exception(e)
240 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
242 async def __connect_to_http_server(self, uri: str) -> None:
243 """Connect to an MCP plugin server via streamable http with retry logic.
245 Args:
246 uri: the URI of the mcp plugin server.
248 Raises:
249 PluginError: if there is an external connection error after all retries.
250 """
251 plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None
252 uds_path = self._config.mcp.uds if self._config and self._config.mcp else None
253 if uds_path and plugin_tls:
254 logger.warning("TLS configuration is ignored for Unix domain socket connections.")
255 tls_config = None if uds_path else (plugin_tls or MCPClientTLSConfig.from_env())
257 def _tls_httpx_client_factory(
258 headers: Optional[dict[str, str]] = None,
259 timeout: Optional[httpx.Timeout] = None,
260 auth: Optional[httpx.Auth] = None,
261 ) -> httpx.AsyncClient:
262 """Build an httpx client with TLS configuration for external MCP servers.
264 Args:
265 headers: Optional HTTP headers to include in requests.
266 timeout: Optional timeout configuration for HTTP requests.
267 auth: Optional authentication handler for HTTP requests.
269 Returns:
270 Configured httpx AsyncClient with TLS settings applied.
272 Raises:
273 PluginError: If TLS configuration fails.
274 """
276 # First-Party
277 from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout # pylint: disable=import-outside-toplevel
279 kwargs: dict[str, Any] = {"follow_redirects": True}
280 if uds_path:
281 kwargs["transport"] = httpx.AsyncHTTPTransport(uds=uds_path)
282 if headers:
283 kwargs["headers"] = headers
284 kwargs["timeout"] = timeout if timeout else get_http_timeout()
285 if auth is not None:
286 kwargs["auth"] = auth
288 # Add connection pool limits
289 kwargs["limits"] = httpx.Limits(
290 max_connections=settings.httpx_max_connections,
291 max_keepalive_connections=settings.httpx_max_keepalive_connections,
292 keepalive_expiry=settings.httpx_keepalive_expiry,
293 )
295 if not tls_config:
296 # Use skip_ssl_verify setting when no custom TLS config
297 kwargs["verify"] = get_default_verify()
298 return httpx.AsyncClient(**kwargs)
300 # Create SSL context using the utility function
301 # This implements certificate validation per test_client_certificate_validation.py
302 ssl_context = create_ssl_context(tls_config, self.name)
303 kwargs["verify"] = ssl_context
305 return httpx.AsyncClient(**kwargs)
307 self._http_client_factory = _tls_httpx_client_factory
308 max_retries = 3
309 base_delay = 1.0
311 for attempt in range(max_retries):
313 try:
314 client_factory = _tls_httpx_client_factory
315 streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory, terminate_on_close=False)
316 http_transport = await self._exit_stack.enter_async_context(streamable_client)
317 self._http, self._write, get_session_id = http_transport
318 self._get_session_id = get_session_id
319 self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write))
321 await self._session.initialize()
322 self._session_id = self._get_session_id() if self._get_session_id else None
323 response = await self._session.list_tools()
324 tools = response.tools
325 logger.info(
326 "Successfully connected to plugin MCP server with tools: %s",
327 " ".join([tool.name for tool in tools]),
328 )
329 return
330 except Exception as e:
331 logger.warning(f"Connection attempt {attempt + 1}/{max_retries} failed: {e}")
332 if attempt == max_retries - 1:
333 # Final attempt failed
334 target = f"{uri} (uds={uds_path})" if uds_path else uri
335 error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {target} is not reachable. Please ensure the MCP server is running."
336 logger.error(error_msg)
337 raise PluginError(error=PluginErrorModel(message=error_msg, plugin_name=self.name))
338 await self.shutdown()
339 # Wait before retry
340 delay = base_delay * (2**attempt)
341 logger.info(f"Retrying in {delay}s...")
342 await asyncio.sleep(delay)
344 async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult:
345 """Invoke an external plugin hook using the MCP protocol.
347 Args:
348 hook_type: The type of hook invoked (i.e., prompt_pre_fetch)
349 payload: The payload to be passed to the hook.
350 context: The plugin context passed to the run.
352 Raises:
353 PluginError: error passed from external plugin server.
355 Returns:
356 The resulting payload from the plugin.
357 """
358 # Get the result type from the global registry
359 registry = get_hook_registry()
360 result_type = registry.get_result_type(hook_type)
361 if not result_type:
362 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name))
364 if not self._session:
365 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name))
367 try:
368 result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context})
369 for content in result.content:
370 if not isinstance(content, TextContent):
371 continue
372 try:
373 res = orjson.loads(content.text)
374 except orjson.JSONDecodeError:
375 raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name))
376 if CONTEXT in res:
377 cxt = PluginContext.model_validate(res[CONTEXT])
378 context.state = cxt.state
379 context.metadata = cxt.metadata
380 context.global_context.state = cxt.global_context.state
381 if RESULT in res:
382 return result_type.model_validate(res[RESULT])
383 if ERROR in res:
384 error = PluginErrorModel.model_validate(res[ERROR])
385 raise PluginError(error)
386 except PluginError as pe:
387 logger.exception(pe)
388 raise
389 except Exception as e:
390 logger.exception(e)
391 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
392 raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name))
394 async def __get_plugin_config(self) -> PluginConfig | None:
395 """Retrieve plugin configuration for the current plugin on the remote MCP server.
397 Raises:
398 PluginError: if there is a connection issue or validation issue.
400 Returns:
401 A plugin configuration for the current plugin from a remote MCP server.
402 """
403 if not self._session:
404 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name))
405 try:
406 configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name})
407 for content in configs.content:
408 if not isinstance(content, TextContent):
409 continue
410 conf = orjson.loads(content.text)
411 if not conf:
412 return None
413 return PluginConfig.model_validate(conf)
414 except Exception as e:
415 logger.exception(e)
416 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
418 return None
420 async def shutdown(self) -> None:
421 """Plugin cleanup code."""
422 if self._stdio_task:
423 if self._stdio_stop:
424 self._stdio_stop.set()
425 try:
426 await self._stdio_task
427 except Exception as e:
428 logger.error("Error shutting down stdio session for plugin %s: %s", self.name, e)
429 self._stdio_task = None
430 self._stdio_ready = None
431 self._stdio_stop = None
432 self._stdio_exit_stack = None
433 self._stdio_error = None
434 self._stdio = None
435 self._write = None
436 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STDIO:
437 self._session = None
439 if self._exit_stack:
440 await self._exit_stack.aclose()
441 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STREAMABLEHTTP:
442 await self.__terminate_http_session()
443 self._get_session_id = None
444 self._session_id = None
445 self._http_client_factory = None
447 async def __terminate_http_session(self) -> None:
448 """Terminate streamable HTTP session explicitly to avoid lingering server state."""
449 if not self._session_id or not self._config or not self._config.mcp or not self._config.mcp.url:
450 return
451 # Third-Party
452 from mcp.server.streamable_http import MCP_SESSION_ID_HEADER # pylint: disable=import-outside-toplevel
454 client_factory = self._http_client_factory
455 try:
456 if client_factory:
457 client = client_factory()
458 else:
459 client = httpx.AsyncClient(follow_redirects=True)
460 async with client:
461 headers = {MCP_SESSION_ID_HEADER: self._session_id}
462 await client.delete(self._config.mcp.url, headers=headers)
463 except Exception as exc:
464 logger.debug("Failed to terminate streamable HTTP session: %s", exc)
467class ExternalHookRef(HookRef):
468 """A Hook reference point for external plugins."""
470 def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called
471 """Initialize a hook reference point for an external plugin.
473 Note: We intentionally don't call super().__init__() because external plugins
474 use invoke_hook() rather than direct method attributes.
476 Args:
477 hook: name of the hook point.
478 plugin_ref: The reference to the plugin to hook.
480 Raises:
481 PluginError: If the plugin is not an external plugin.
482 """
483 self._plugin_ref = plugin_ref
484 self._hook = hook
485 if hasattr(plugin_ref.plugin, INVOKE_HOOK):
486 self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined]
487 else:
488 raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name))