Coverage for mcpgateway / plugins / framework / external / mcp / client.py: 100%

286 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/external/mcp/client.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor, Fred Araujo 

6 

7External plugin client which connects to a remote server through MCP. 

8Module that contains plugin MCP client code to serve external plugins. 

9""" 

10 

11# Standard 

12import asyncio 

13from contextlib import AsyncExitStack 

14from functools import partial 

15import logging 

16import os 

17from pathlib import Path 

18import sys 

19from typing import Any, Awaitable, Callable, Optional 

20 

21# Third-Party 

22import httpx 

23from mcp import ClientSession, StdioServerParameters 

24from mcp.client.stdio import stdio_client 

25from mcp.client.streamable_http import streamablehttp_client 

26from mcp.types import TextContent 

27import orjson 

28 

29# First-Party 

30from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef 

31from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, HOOK_TYPE, IGNORE_CONFIG_EXTERNAL, INVOKE_HOOK, NAME, PAYLOAD, PLUGIN_NAME, PYTHON_SUFFIX, RESULT 

32from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError 

33from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context 

34from mcpgateway.plugins.framework.hooks.registry import get_hook_registry 

35from mcpgateway.plugins.framework.models import MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, PluginPayload, PluginResult, TransportType 

36from mcpgateway.plugins.framework.settings import get_http_client_settings 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41class ExternalPlugin(Plugin): 

42 """External plugin object for pre/post processing of inputs and outputs at various locations throughout the gateway. 

43 

44 The External Plugin connects to a remote MCP server that contains plugins. 

45 """ 

46 

47 def __init__(self, config: PluginConfig) -> None: 

48 """Initialize a plugin with a configuration and context. 

49 

50 Args: 

51 config: The plugin configuration 

52 """ 

53 super().__init__(config) 

54 self._session: Optional[ClientSession] = None 

55 self._exit_stack = AsyncExitStack() 

56 self._http: Optional[Any] 

57 self._stdio: Optional[Any] 

58 self._write: Optional[Any] 

59 self._current_task = asyncio.current_task() 

60 self._stdio_exit_stack: Optional[AsyncExitStack] = None 

61 self._stdio_task: Optional[asyncio.Task[None]] = None 

62 self._stdio_ready: Optional[asyncio.Event] = None 

63 self._stdio_stop: Optional[asyncio.Event] = None 

64 self._stdio_error: Optional[BaseException] = None 

65 self._get_session_id: Optional[Callable[[], str | None]] = None 

66 self._session_id: Optional[str] = None 

67 self._http_client_factory: Optional[Callable[..., httpx.AsyncClient]] = None 

68 

69 async def initialize(self) -> None: 

70 """Initialize the plugin's connection to the MCP server. 

71 

72 Raises: 

73 PluginError: if unable to retrieve plugin configuration of external plugin. 

74 """ 

75 

76 if not self._config.mcp: 

77 raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name)) 

78 if self._config.mcp.proto == TransportType.STDIO: 

79 if not (self._config.mcp.script or self._config.mcp.cmd): 

80 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name)) 

81 await self.__connect_to_stdio_server(self._config.mcp.script, self._config.mcp.cmd, self._config.mcp.env, self._config.mcp.cwd) 

82 elif self._config.mcp.proto == TransportType.STREAMABLEHTTP: 

83 if not self._config.mcp.url: 

84 raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name)) 

85 await self.__connect_to_http_server(self._config.mcp.url) 

86 

87 try: 

88 config = await self.__get_plugin_config() 

89 

90 if not config: 

91 raise PluginError(error=PluginErrorModel(message="Unable to retrieve configuration for external plugin", plugin_name=self.name)) 

92 

93 current_config = self._config.model_dump(exclude_unset=True) 

94 remote_config = config.model_dump(exclude_unset=True) 

95 remote_config.update(current_config) 

96 

97 context = {IGNORE_CONFIG_EXTERNAL: True} 

98 

99 self._config = PluginConfig.model_validate(remote_config, context=context) 

100 except PluginError as pe: 

101 try: 

102 await self.shutdown() 

103 except Exception as shutdown_error: 

104 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error) 

105 logger.exception(pe) 

106 raise 

107 except Exception as e: 

108 try: 

109 await self.shutdown() 

110 except Exception as shutdown_error: 

111 logger.error("Error during external plugin shutdown after init failure: %s", shutdown_error) 

112 logger.exception(e) 

113 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

114 

115 def __resolve_stdio_command(self, script_path: str | None, cmd: list[str] | None, cwd: str | None) -> tuple[str, list[str]]: 

116 """Resolve the stdio command + args from config. 

117 

118 Args: 

119 script_path: Path to a server script or executable. 

120 cmd: Command list to execute (command + args). 

121 cwd: Working directory for resolving relative script paths. 

122 

123 Returns: 

124 Tuple of (command, args). 

125 

126 Raises: 

127 PluginError: if the script is invalid or cmd is malformed. 

128 """ 

129 if cmd: 

130 if not isinstance(cmd, list) or not cmd or not all(isinstance(part, str) and part.strip() for part in cmd): 

131 raise PluginError(error=PluginErrorModel(message="STDIO cmd must be a non-empty list of strings", plugin_name=self.name)) 

132 return cmd[0], cmd[1:] 

133 

134 if not script_path: 

135 raise PluginError(error=PluginErrorModel(message="STDIO transport requires script or cmd", plugin_name=self.name)) 

136 

137 server_path = Path(script_path).expanduser() 

138 if not server_path.is_absolute() and cwd: 

139 server_path = Path(cwd).expanduser() / server_path 

140 resolved_script_path = str(server_path) 

141 if not server_path.is_file(): 

142 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} does not exist.", plugin_name=self.name)) 

143 

144 if server_path.suffix == PYTHON_SUFFIX: 

145 return sys.executable, [resolved_script_path] 

146 if server_path.suffix == ".sh": 

147 return "sh", [resolved_script_path] 

148 if not os.access(server_path, os.X_OK): 

149 raise PluginError(error=PluginErrorModel(message=f"Server script {resolved_script_path} must be executable.", plugin_name=self.name)) 

150 return resolved_script_path, [] 

151 

152 def __build_stdio_env(self, extra_env: dict[str, str] | None) -> dict[str, str]: 

153 """Build environment for the stdio server process. 

154 

155 Args: 

156 extra_env: Environment overrides to merge into the current process env. 

157 

158 Returns: 

159 Combined environment dictionary for the plugin process. 

160 """ 

161 current_env = os.environ.copy() 

162 if extra_env: 

163 current_env.update(extra_env) 

164 return current_env 

165 

166 async def __run_stdio_session(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None: 

167 """Run a stdio session in a dedicated task for consistent setup/teardown. 

168 

169 Args: 

170 server_script_path: Path to the server script or executable. 

171 cmd: Command list to start the server (command + args). 

172 env: Environment overrides for the server process. 

173 cwd: Working directory for the server process. 

174 """ 

175 try: 

176 command, args = self.__resolve_stdio_command(server_script_path, cmd, cwd) 

177 server_env = self.__build_stdio_env(env) 

178 server_params = StdioServerParameters(command=command, args=args, env=server_env, cwd=cwd) 

179 

180 self._stdio_exit_stack = AsyncExitStack() 

181 stdio_transport = await self._stdio_exit_stack.enter_async_context(stdio_client(server_params)) 

182 self._stdio, self._write = stdio_transport 

183 self._session = await self._stdio_exit_stack.enter_async_context(ClientSession(self._stdio, self._write)) 

184 

185 await self._session.initialize() 

186 

187 response = await self._session.list_tools() 

188 tools = response.tools 

189 logger.info("\nConnected to plugin MCP server (stdio) with tools: %s", " ".join([tool.name for tool in tools])) 

190 except Exception as e: 

191 self._stdio_error = e 

192 logger.exception(e) 

193 finally: 

194 if self._stdio_ready and not self._stdio_ready.is_set(): 

195 self._stdio_ready.set() 

196 

197 if self._stdio_error: 

198 if self._stdio_exit_stack: 

199 await self._stdio_exit_stack.aclose() 

200 return 

201 

202 if self._stdio_stop: 

203 await self._stdio_stop.wait() 

204 

205 if self._stdio_exit_stack: 

206 await self._stdio_exit_stack.aclose() 

207 

208 async def __connect_to_stdio_server(self, server_script_path: str | None, cmd: list[str] | None, env: dict[str, str] | None, cwd: str | None) -> None: 

209 """Connect to an MCP plugin server via stdio. 

210 

211 Args: 

212 server_script_path: Path to the server script or executable. 

213 cmd: Command list to start the server (command + args). 

214 env: Environment overrides for the server process. 

215 cwd: Working directory for the server process. 

216 

217 Raises: 

218 PluginError: if stdio script/cmd is invalid or if there is a connection error. 

219 """ 

220 try: 

221 if not self._stdio_ready: 

222 self._stdio_ready = asyncio.Event() 

223 if not self._stdio_stop: 

224 self._stdio_stop = asyncio.Event() 

225 self._stdio_error = None 

226 

227 self._stdio_task = asyncio.create_task( 

228 self.__run_stdio_session(server_script_path, cmd, env, cwd), 

229 name=f"external-plugin-stdio-{self.name}", 

230 ) 

231 

232 await self._stdio_ready.wait() 

233 if self._stdio_error: 

234 raise PluginError(error=convert_exception_to_error(self._stdio_error, plugin_name=self.name)) 

235 except PluginError: 

236 raise 

237 except Exception as e: 

238 logger.exception(e) 

239 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

240 

241 async def __connect_to_http_server(self, uri: str) -> None: 

242 """Connect to an MCP plugin server via streamable http with retry logic. 

243 

244 Args: 

245 uri: the URI of the mcp plugin server. 

246 

247 Raises: 

248 PluginError: if there is an external connection error after all retries. 

249 """ 

250 plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None 

251 uds_path = self._config.mcp.uds if self._config and self._config.mcp else None 

252 if uds_path and plugin_tls: 

253 logger.warning("TLS configuration is ignored for Unix domain socket connections.") 

254 tls_config = None if uds_path else (plugin_tls or MCPClientTLSConfig.from_env()) 

255 

256 def _tls_httpx_client_factory( 

257 headers: Optional[dict[str, str]] = None, 

258 timeout: Optional[httpx.Timeout] = None, 

259 auth: Optional[httpx.Auth] = None, 

260 ) -> httpx.AsyncClient: 

261 """Build an httpx client with TLS configuration for external MCP servers. 

262 

263 Args: 

264 headers: Optional HTTP headers to include in requests. 

265 timeout: Optional timeout configuration for HTTP requests. 

266 auth: Optional authentication handler for HTTP requests. 

267 

268 Returns: 

269 Configured httpx AsyncClient with TLS settings applied. 

270 

271 Raises: 

272 PluginError: If TLS configuration fails. 

273 """ 

274 

275 kwargs: dict[str, Any] = {"follow_redirects": True} 

276 if uds_path: 

277 kwargs["transport"] = httpx.AsyncHTTPTransport(uds=uds_path) 

278 if headers: 

279 kwargs["headers"] = headers 

280 http_settings = get_http_client_settings() 

281 kwargs["timeout"] = ( 

282 timeout 

283 if timeout 

284 else httpx.Timeout( 

285 connect=http_settings.httpx_connect_timeout, 

286 read=http_settings.httpx_read_timeout, 

287 write=http_settings.httpx_write_timeout, 

288 pool=http_settings.httpx_pool_timeout, 

289 ) 

290 ) 

291 if auth is not None: 

292 kwargs["auth"] = auth 

293 

294 # Add connection pool limits 

295 kwargs["limits"] = httpx.Limits( 

296 max_connections=http_settings.httpx_max_connections, 

297 max_keepalive_connections=http_settings.httpx_max_keepalive_connections, 

298 keepalive_expiry=http_settings.httpx_keepalive_expiry, 

299 ) 

300 

301 if not tls_config: 

302 # Use skip_ssl_verify setting when no custom TLS config 

303 kwargs["verify"] = not http_settings.skip_ssl_verify 

304 return httpx.AsyncClient(**kwargs) 

305 

306 # Create SSL context using the utility function 

307 # This implements certificate validation per test_client_certificate_validation.py 

308 ssl_context = create_ssl_context(tls_config, self.name) 

309 kwargs["verify"] = ssl_context 

310 

311 return httpx.AsyncClient(**kwargs) 

312 

313 self._http_client_factory = _tls_httpx_client_factory 

314 max_retries = 3 

315 base_delay = 1.0 

316 

317 for attempt in range(max_retries): 

318 

319 try: 

320 client_factory = _tls_httpx_client_factory 

321 streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory, terminate_on_close=False) 

322 http_transport = await self._exit_stack.enter_async_context(streamable_client) 

323 self._http, self._write, get_session_id = http_transport 

324 self._get_session_id = get_session_id 

325 self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write)) 

326 

327 await self._session.initialize() 

328 self._session_id = self._get_session_id() if self._get_session_id else None 

329 response = await self._session.list_tools() 

330 tools = response.tools 

331 logger.info( 

332 "Successfully connected to plugin MCP server with tools: %s", 

333 " ".join([tool.name for tool in tools]), 

334 ) 

335 return 

336 except Exception as e: 

337 logger.warning("Connection attempt %d/%d failed: %s", attempt + 1, max_retries, e) 

338 if attempt == max_retries - 1: 

339 # Final attempt failed 

340 target = f"{uri} (uds={uds_path})" if uds_path else uri 

341 error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {target} is not reachable. Please ensure the MCP server is running." 

342 logger.error(error_msg) 

343 raise PluginError(error=PluginErrorModel(message=error_msg, plugin_name=self.name)) 

344 await self.shutdown() 

345 self._exit_stack = AsyncExitStack() 

346 # Wait before retry 

347 delay = base_delay * (2**attempt) 

348 logger.info("Retrying in %ss...", delay) 

349 await asyncio.sleep(delay) 

350 

351 async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult: 

352 """Invoke an external plugin hook using the MCP protocol. 

353 

354 Args: 

355 hook_type: The type of hook invoked (i.e., prompt_pre_fetch) 

356 payload: The payload to be passed to the hook. 

357 context: The plugin context passed to the run. 

358 

359 Raises: 

360 PluginError: error passed from external plugin server. 

361 

362 Returns: 

363 The resulting payload from the plugin. 

364 """ 

365 # Get the result type from the global registry 

366 registry = get_hook_registry() 

367 result_type = registry.get_result_type(hook_type) 

368 if not result_type: 

369 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name)) 

370 

371 if not self._session: 

372 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) 

373 

374 try: 

375 result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context}) 

376 for content in result.content: 

377 if not isinstance(content, TextContent): 

378 continue 

379 try: 

380 res = orjson.loads(content.text) 

381 except orjson.JSONDecodeError: 

382 raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name)) 

383 if CONTEXT in res: 

384 cxt = PluginContext.model_validate(res[CONTEXT]) 

385 context.state = cxt.state 

386 context.metadata = cxt.metadata 

387 context.global_context.state = cxt.global_context.state 

388 if RESULT in res: 

389 return result_type.model_validate(res[RESULT]) 

390 if ERROR in res: 

391 error = PluginErrorModel.model_validate(res[ERROR]) 

392 raise PluginError(error) 

393 except PluginError as pe: 

394 logger.exception(pe) 

395 raise 

396 except Exception as e: 

397 logger.exception(e) 

398 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

399 raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name)) 

400 

401 async def __get_plugin_config(self) -> PluginConfig | None: 

402 """Retrieve plugin configuration for the current plugin on the remote MCP server. 

403 

404 Raises: 

405 PluginError: if there is a connection issue or validation issue. 

406 

407 Returns: 

408 A plugin configuration for the current plugin from a remote MCP server. 

409 """ 

410 if not self._session: 

411 raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) 

412 try: 

413 configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name}) 

414 for content in configs.content: 

415 if not isinstance(content, TextContent): 

416 continue 

417 conf = orjson.loads(content.text) 

418 if not conf: 

419 return None 

420 return PluginConfig.model_validate(conf) 

421 except Exception as e: 

422 logger.exception(e) 

423 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

424 

425 return None 

426 

427 async def shutdown(self) -> None: 

428 """Plugin cleanup code.""" 

429 if self._stdio_task: 

430 if self._stdio_stop: 

431 self._stdio_stop.set() 

432 try: 

433 await self._stdio_task 

434 except Exception as e: 

435 logger.error("Error shutting down stdio session for plugin %s: %s", self.name, e) 

436 self._stdio_task = None 

437 self._stdio_ready = None 

438 self._stdio_stop = None 

439 self._stdio_exit_stack = None 

440 self._stdio_error = None 

441 self._stdio = None 

442 self._write = None 

443 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STDIO: 

444 self._session = None 

445 

446 if self._exit_stack: 

447 await self._exit_stack.aclose() 

448 if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STREAMABLEHTTP: 

449 await self.__terminate_http_session() 

450 self._get_session_id = None 

451 self._session_id = None 

452 self._http_client_factory = None 

453 

454 async def __terminate_http_session(self) -> None: 

455 """Terminate streamable HTTP session explicitly to avoid lingering server state.""" 

456 if not self._session_id or not self._config or not self._config.mcp or not self._config.mcp.url: 

457 return 

458 # Third-Party 

459 from mcp.server.streamable_http import MCP_SESSION_ID_HEADER # pylint: disable=import-outside-toplevel 

460 

461 client_factory = self._http_client_factory 

462 try: 

463 if client_factory: 

464 client = client_factory() 

465 else: 

466 client = httpx.AsyncClient(follow_redirects=True) 

467 async with client: 

468 headers = {MCP_SESSION_ID_HEADER: self._session_id} 

469 await client.delete(self._config.mcp.url, headers=headers) 

470 except Exception as exc: 

471 logger.debug("Failed to terminate streamable HTTP session: %s", exc) 

472 

473 

474class ExternalHookRef(HookRef): 

475 """A Hook reference point for external plugins.""" 

476 

477 def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called 

478 """Initialize a hook reference point for an external plugin. 

479 

480 Note: We intentionally don't call super().__init__() because external plugins 

481 use invoke_hook() rather than direct method attributes. 

482 

483 Args: 

484 hook: name of the hook point. 

485 plugin_ref: The reference to the plugin to hook. 

486 

487 Raises: 

488 PluginError: If the plugin is not an external plugin. 

489 """ 

490 self._plugin_ref = plugin_ref 

491 self._hook = hook 

492 if hasattr(plugin_ref.plugin, INVOKE_HOOK): 

493 self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined] 

494 else: 

495 raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name))