Coverage for mcpgateway / plugins / framework / external / mcp / client.py: 100%

286 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/external/mcp/client.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor 

6 

7External plugin client which connects to a remote server through MCP. 

8Module that contains plugin MCP client code to serve external plugins. 

9""" 

10 

11# Standard 

12import asyncio 

13from contextlib import AsyncExitStack 

14from functools import partial 

15import logging 

16import os 

17from pathlib import Path 

18import sys 

19from typing import Any, Awaitable, Callable, Optional 

20 

21# Third-Party 

22import httpx 

23from mcp import ClientSession, StdioServerParameters 

24from mcp.client.stdio import stdio_client 

25from mcp.client.streamable_http import streamablehttp_client 

26from mcp.types import TextContent 

27import orjson 

28 

29# First-Party 

30from mcpgateway.common.models import TransportType 

31from mcpgateway.config import settings 

32from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef 

33from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, HOOK_TYPE, IGNORE_CONFIG_EXTERNAL, INVOKE_HOOK, NAME, PAYLOAD, PLUGIN_NAME, PYTHON_SUFFIX, RESULT 

34from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError 

35from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context 

36from mcpgateway.plugins.framework.hooks.registry import get_hook_registry 

37from mcpgateway.plugins.framework.models import MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, PluginPayload, PluginResult 

38 

39logger = logging.getLogger(__name__) 

40 

41 

42class ExternalPlugin(Plugin): 

43 """External plugin object for pre/post processing of inputs and outputs at various locations throughout the gateway. 

44 

45 The External Plugin connects to a remote MCP server that contains plugins. 

46 """ 

47 

48 def __init__(self, config: PluginConfig) -> None: 

49 """Initialize a plugin with a configuration and context. 

50 

51 Args: 

52 config: The plugin configuration 

53 """ 

54 super().__init__(config) 

55 self._session: Optional[ClientSession] = None 

56 self._exit_stack = AsyncExitStack() 

57 self._http: Optional[Any] 

58 self._stdio: Optional[Any] 

59 self._write: Optional[Any] 

60 self._current_task = asyncio.current_task() 

61 self._stdio_exit_stack: Optional[AsyncExitStack] = None 

62 self._stdio_task: Optional[asyncio.Task[None]] = None 

63 self._stdio_ready: Optional[asyncio.Event] = None 

64 self._stdio_stop: Optional[asyncio.Event] = None 

65 self._stdio_error: Optional[BaseException] = None 

66 self._get_session_id: Optional[Callable[[], str | None]] = None 

67 self._session_id: Optional[str] = None 

68 self._http_client_factory: Optional[Callable[..., httpx.AsyncClient]] = None 

69 

70 async def initialize(self) -> None: 

71 """Initialize the plugin's connection to the MCP server. 

72 

73 Raises: 

74 PluginError: if unable to retrieve plugin configuration of external plugin. 

75 """ 

76 

77 if not self._config.mcp: 

78 raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name)) 

79 if self._config.mcp.proto == TransportType.STDIO: 

80 if not (self._config.mcp.script or self._config.mcp.cmd): 

81 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name)) 

82 await self.__connect_to_stdio_server(self._config.mcp.script, self._config.mcp.cmd, self._config.mcp.env, self._config.mcp.cwd) 

83 elif self._config.mcp.proto == TransportType.STREAMABLEHTTP: 

84 if not self._config.mcp.url: 

85 raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name)) 

86 await self.__connect_to_http_server(self._config.mcp.url) 

87 

88 try: 

89 config = await self.__get_plugin_config() 

90 

91 if not config: 

92 raise PluginError(error=PluginErrorModel(message="Unable to retrieve configuration for external plugin", plugin_name=self.name)) 

93 

94 current_config = self._config.model_dump(exclude_unset=True) 

95 remote_config = config.model_dump(exclude_unset=True) 

96 remote_config.update(current_config) 

97 

98 context = {IGNORE_CONFIG_EXTERNAL: True} 

99 

100 self._config = PluginConfig.model_validate(remote_config, context=context) 

101 except PluginError as pe: 

102 try: 

103 await self.shutdown() 

104 except Exception as shutdown_error: 

105 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error) 

106 logger.exception(pe) 

107 raise 

108 except Exception as e: 

109 try: 

110 await self.shutdown() 

111 except Exception as shutdown_error: 

112 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error) 

113 logger.exception(e) 

114 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

115 

116 def __resolve_stdio_command(self, script_path: str | None, cmd: list[str] | None, cwd: str | None) -> tuple[str, list[str]]: 

117 """Resolve the stdio command + args from config. 

118 

119 Args: 

120 script_path: Path to a server script or executable. 

121 cmd: Command list to execute (command + args). 

122 cwd: Working directory for resolving relative script paths. 

123 

124 Returns: 

125 Tuple of (command, args). 

126 

127 Raises: 

128 PluginError: if the script is invalid or cmd is malformed. 

129 """ 

130 if cmd: 

131 if not isinstance(cmd, list) or not cmd or not all(isinstance(part, str) and part.strip() for part in cmd): 

132 raise PluginError(error=PluginErrorModel(message="STDIO cmd must be a non-empty list of strings", plugin_name=self.name)) 

133 return cmd[0], cmd[1:] 

134 

135 if not script_path: 

136 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name)) 

137 

138 server_path = Path(script_path).expanduser() 

139 if not server_path.is_absolute() and cwd: 

140 server_path = Path(cwd).expanduser() / server_path 

141 resolved_script_path = str(server_path) 

142 if not server_path.is_file(): 

143 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} does not exist.", plugin_name=self.name)) 

144 

145 if server_path.suffix == PYTHON_SUFFIX: 

146 return sys.executable, [resolved_script_path] 

147 if server_path.suffix == ".sh": 

148 return "sh", [resolved_script_path] 

149 if not os.access(server_path, os.X_OK): 

150 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} must be executable.", plugin_name=self.name)) 

151 return resolved_script_path, [] 

152 

153 def __build_stdio_env(self, extra_env: dict[str, str] | None) -> dict[str, str]: 

154 """Build environment for the stdio server process. 

155 

156 Args: 

157 extra_env: Environment overrides to merge into the current process env. 

158 

159 Returns: 

160 Combined environment dictionary for the plugin process. 

161 """ 

162 current_env = os.environ.copy() 

163 if extra_env: 

164 current_env.update(extra_env) 

165 return current_env 

166 

167 async def __run_stdio_session(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None: 

168 """Run a stdio session in a dedicated task for consistent setup/teardown. 

169 

170 Args: 

171 server_script_path: Path to the server script or executable. 

172 cmd: Command list to start the server (command + args). 

173 env: Environment overrides for the server process. 

174 cwd: Working directory for the server process. 

175 """ 

176 try: 

177 command, args = self.__resolve_stdio_command(server_script_path, cmd, cwd) 

178 server_env = self.__build_stdio_env(env) 

179 server_params = StdioServerParameters(command=command, args=args, env=server_env, cwd=cwd) 

180 

181 self._stdio_exit_stack = AsyncExitStack() 

182 stdio_transport = await self._stdio_exit_stack.enter_async_context(stdio_client(server_params)) 

183 self._stdio, self._write = stdio_transport 

184 self._session = await self._stdio_exit_stack.enter_async_context(ClientSession(self._stdio, self._write)) 

185 

186 await self._session.initialize() 

187 

188 response = await self._session.list_tools() 

189 tools = response.tools 

190 logger.info("\nConnected to plugin MCP server (stdio) with tools: %s", " ".join([tool.name for tool in tools])) 

191 except Exception as e: 

192 self._stdio_error = e 

193 logger.exception(e) 

194 finally: 

195 if self._stdio_ready and not self._stdio_ready.is_set(): 

196 self._stdio_ready.set() 

197 

198 if self._stdio_error: 

199 if self._stdio_exit_stack: 

200 await self._stdio_exit_stack.aclose() 

201 return 

202 

203 if self._stdio_stop: 

204 await self._stdio_stop.wait() 

205 

206 if self._stdio_exit_stack: 

207 await self._stdio_exit_stack.aclose() 

208 

209 async def __connect_to_stdio_server(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None: 

210 """Connect to an MCP plugin server via stdio. 

211 

212 Args: 

213 server_script_path: Path to the server script or executable. 

214 cmd: Command list to start the server (command + args). 

215 env: Environment overrides for the server process. 

216 cwd: Working directory for the server process. 

217 

218 Raises: 

219 PluginError: if stdio script/cmd is invalid or if there is a connection error. 

220 """ 

221 try: 

222 if not self._stdio_ready: 

223 self._stdio_ready = asyncio.Event() 

224 if not self._stdio_stop: 

225 self._stdio_stop = asyncio.Event() 

226 self._stdio_error = None 

227 

228 self._stdio_task = asyncio.create_task( 

229 self.__run_stdio_session(server_script_path, cmd, env, cwd), 

230 name=f"external-plugin-stdio-{self.name}", 

231 ) 

232 

233 await self._stdio_ready.wait() 

234 if self._stdio_error: 

235 raise PluginError(error=convert_exception_to_error(self._stdio_error, plugin_name=self.name)) 

236 except PluginError: 

237 raise 

238 except Exception as e: 

239 logger.exception(e) 

240 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

241 

242 async def __connect_to_http_server(self, uri: str) -> None: 

243 """Connect to an MCP plugin server via streamable http with retry logic. 

244 

245 Args: 

246 uri: the URI of the mcp plugin server. 

247 

248 Raises: 

249 PluginError: if there is an external connection error after all retries. 

250 """ 

251 plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None 

252 uds_path = self._config.mcp.uds if self._config and self._config.mcp else None 

253 if uds_path and plugin_tls: 

254 logger.warning("TLS configuration is ignored for Unix domain socket connections.") 

255 tls_config = None if uds_path else (plugin_tls or MCPClientTLSConfig.from_env()) 

256 

257 def _tls_httpx_client_factory( 

258 headers: Optional[dict[str, str]] = None, 

259 timeout: Optional[httpx.Timeout] = None, 

260 auth: Optional[httpx.Auth] = None, 

261 ) -> httpx.AsyncClient: 

262 """Build an httpx client with TLS configuration for external MCP servers. 

263 

264 Args: 

265 headers: Optional HTTP headers to include in requests. 

266 timeout: Optional timeout configuration for HTTP requests. 

267 auth: Optional authentication handler for HTTP requests. 

268 

269 Returns: 

270 Configured httpx AsyncClient with TLS settings applied. 

271 

272 Raises: 

273 PluginError: If TLS configuration fails. 

274 """ 

275 

276 # First-Party 

277 from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout # pylint: disable=import-outside-toplevel 

278 

279 kwargs: dict[str, Any] = {"follow_redirects": True} 

280 if uds_path: 

281 kwargs["transport"] = httpx.AsyncHTTPTransport(uds=uds_path) 

282 if headers: 

283 kwargs["headers"] = headers 

284 kwargs["timeout"] = timeout if timeout else get_http_timeout() 

285 if auth is not None: 

286 kwargs["auth"] = auth 

287 

288 # Add connection pool limits 

289 kwargs["limits"] = httpx.Limits( 

290 max_connections=settings.httpx_max_connections, 

291 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

292 keepalive_expiry=settings.httpx_keepalive_expiry, 

293 ) 

294 

295 if not tls_config: 

296 # Use skip_ssl_verify setting when no custom TLS config 

297 kwargs["verify"] = get_default_verify() 

298 return httpx.AsyncClient(**kwargs) 

299 

300 # Create SSL context using the utility function 

301 # This implements certificate validation per test_client_certificate_validation.py 

302 ssl_context = create_ssl_context(tls_config, self.name) 

303 kwargs["verify"] = ssl_context 

304 

305 return httpx.AsyncClient(**kwargs) 

306 

307 self._http_client_factory = _tls_httpx_client_factory 

308 max_retries = 3 

309 base_delay = 1.0 

310 

311 for attempt in range(max_retries): 

312 

313 try: 

314 client_factory = _tls_httpx_client_factory 

315 streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory, terminate_on_close=False) 

316 http_transport = await self._exit_stack.enter_async_context(streamable_client) 

317 self._http, self._write, get_session_id = http_transport 

318 self._get_session_id = get_session_id 

319 self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write)) 

320 

321 await self._session.initialize() 

322 self._session_id = self._get_session_id() if self._get_session_id else None 

323 response = await self._session.list_tools() 

324 tools = response.tools 

325 logger.info( 

326 "Successfully connected to plugin MCP server with tools: %s", 

327 " ".join([tool.name for tool in tools]), 

328 ) 

329 return 

330 except Exception as e: 

331 logger.warning(f"Connection attempt {attempt + 1}/{max_retries} failed: {e}") 

332 if attempt == max_retries - 1: 

333 # Final attempt failed 

334 target = f"{uri} (uds={uds_path})" if uds_path else uri 

335 error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {target} is not reachable. Please ensure the MCP server is running." 

336 logger.error(error_msg) 

337 raise PluginError(error=PluginErrorModel(message=error_msg, plugin_name=self.name)) 

338 await self.shutdown() 

339 # Wait before retry 

340 delay = base_delay * (2**attempt) 

341 logger.info(f"Retrying in {delay}s...") 

342 await asyncio.sleep(delay) 

343 

344 async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult: 

345 """Invoke an external plugin hook using the MCP protocol. 

346 

347 Args: 

348 hook_type: The type of hook invoked (i.e., prompt_pre_fetch) 

349 payload: The payload to be passed to the hook. 

350 context: The plugin context passed to the run. 

351 

352 Raises: 

353 PluginError: error passed from external plugin server. 

354 

355 Returns: 

356 The resulting payload from the plugin. 

357 """ 

358 # Get the result type from the global registry 

359 registry = get_hook_registry() 

360 result_type = registry.get_result_type(hook_type) 

361 if not result_type: 

362 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name)) 

363 

364 if not self._session: 

365 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) 

366 

367 try: 

368 result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context}) 

369 for content in result.content: 

370 if not isinstance(content, TextContent): 

371 continue 

372 try: 

373 res = orjson.loads(content.text) 

374 except orjson.JSONDecodeError: 

375 raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name)) 

376 if CONTEXT in res: 

377 cxt = PluginContext.model_validate(res[CONTEXT]) 

378 context.state = cxt.state 

379 context.metadata = cxt.metadata 

380 context.global_context.state = cxt.global_context.state 

381 if RESULT in res: 

382 return result_type.model_validate(res[RESULT]) 

383 if ERROR in res: 

384 error = PluginErrorModel.model_validate(res[ERROR]) 

385 raise PluginError(error) 

386 except PluginError as pe: 

387 logger.exception(pe) 

388 raise 

389 except Exception as e: 

390 logger.exception(e) 

391 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

392 raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name)) 

393 

394 async def __get_plugin_config(self) -> PluginConfig | None: 

395 """Retrieve plugin configuration for the current plugin on the remote MCP server. 

396 

397 Raises: 

398 PluginError: if there is a connection issue or validation issue. 

399 

400 Returns: 

401 A plugin configuration for the current plugin from a remote MCP server. 

402 """ 

403 if not self._session: 

404 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) 

405 try: 

406 configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name}) 

407 for content in configs.content: 

408 if not isinstance(content, TextContent): 

409 continue 

410 conf = orjson.loads(content.text) 

411 if not conf: 

412 return None 

413 return PluginConfig.model_validate(conf) 

414 except Exception as e: 

415 logger.exception(e) 

416 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

417 

418 return None 

419 

420 async def shutdown(self) -> None: 

421 """Plugin cleanup code.""" 

422 if self._stdio_task: 

423 if self._stdio_stop: 

424 self._stdio_stop.set() 

425 try: 

426 await self._stdio_task 

427 except Exception as e: 

428 logger.error("Error shutting down stdio session for plugin %s: %s", self.name, e) 

429 self._stdio_task = None 

430 self._stdio_ready = None 

431 self._stdio_stop = None 

432 self._stdio_exit_stack = None 

433 self._stdio_error = None 

434 self._stdio = None 

435 self._write = None 

436 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STDIO: 

437 self._session = None 

438 

439 if self._exit_stack: 

440 await self._exit_stack.aclose() 

441 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STREAMABLEHTTP: 

442 await self.__terminate_http_session() 

443 self._get_session_id = None 

444 self._session_id = None 

445 self._http_client_factory = None 

446 

447 async def __terminate_http_session(self) -> None: 

448 """Terminate streamable HTTP session explicitly to avoid lingering server state.""" 

449 if not self._session_id or not self._config or not self._config.mcp or not self._config.mcp.url: 

450 return 

451 # Third-Party 

452 from mcp.server.streamable_http import MCP_SESSION_ID_HEADER # pylint: disable=import-outside-toplevel 

453 

454 client_factory = self._http_client_factory 

455 try: 

456 if client_factory: 

457 client = client_factory() 

458 else: 

459 client = httpx.AsyncClient(follow_redirects=True) 

460 async with client: 

461 headers = {MCP_SESSION_ID_HEADER: self._session_id} 

462 await client.delete(self._config.mcp.url, headers=headers) 

463 except Exception as exc: 

464 logger.debug("Failed to terminate streamable HTTP session: %s", exc) 

465 

466 

467class ExternalHookRef(HookRef): 

468 """A Hook reference point for external plugins.""" 

469 

470 def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called 

471 """Initialize a hook reference point for an external plugin. 

472 

473 Note: We intentionally don't call super().__init__() because external plugins 

474 use invoke_hook() rather than direct method attributes. 

475 

476 Args: 

477 hook: name of the hook point. 

478 plugin_ref: The reference to the plugin to hook. 

479 

480 Raises: 

481 PluginError: If the plugin is not an external plugin. 

482 """ 

483 self._plugin_ref = plugin_ref 

484 self._hook = hook 

485 if hasattr(plugin_ref.plugin, INVOKE_HOOK): 

486 self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined] 

487 else: 

488 raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name))