Coverage for mcpgateway / plugins / framework / external / unix / client.py: 99%

138 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/external/unix/client.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor 

6 

7Unix socket client for external plugins. 

8 

9This module provides a high-performance client for communicating with 

10external plugins over Unix domain sockets using length-prefixed protobuf 

11messages. 

12 

13Examples: 

14 Create and use a Unix socket plugin client: 

15 

16 >>> from mcpgateway.plugins.framework.external.unix.client import UnixSocketExternalPlugin 

17 >>> from mcpgateway.plugins.framework.models import PluginConfig, UnixSocketClientConfig 

18 

19 >>> config = PluginConfig( 

20 ... name="MyPlugin", 

21 ... kind="external", 

22 ... hooks=["tool_pre_invoke"], 

23 ... unix_socket=UnixSocketClientConfig(path="/tmp/plugin.sock"), 

24 ... ) 

25 >>> plugin = UnixSocketExternalPlugin(config) 

26 >>> # await plugin.initialize() 

27 >>> # result = await plugin.invoke_hook(hook_type, payload, context) 

28""" 

29 

30# pylint: disable=no-member,no-name-in-module 

31 

32# Standard 

33import asyncio 

34import logging 

35from typing import Any, Optional 

36 

37# Third-Party 

38from google.protobuf import json_format 

39from google.protobuf.struct_pb2 import Struct 

40 

41# First-Party 

42from mcpgateway.plugins.framework.base import Plugin 

43from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError 

44from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2 

45from mcpgateway.plugins.framework.external.proto_convert import pydantic_context_to_proto, update_pydantic_context_from_proto 

46from mcpgateway.plugins.framework.external.unix.protocol import read_message, write_message_async 

47from mcpgateway.plugins.framework.hooks.registry import get_hook_registry 

48from mcpgateway.plugins.framework.models import PluginConfig, PluginContext, PluginErrorModel, PluginResult 

49 

50logger = logging.getLogger(__name__) 

51 

52 

53class UnixSocketExternalPlugin(Plugin): 

54 """External plugin client using raw Unix domain sockets. 

55 

56 This client provides high-performance IPC for local plugins using 

57 length-prefixed protobuf messages. It includes automatic reconnection 

58 with configurable retry logic. 

59 

60 Attributes: 

61 config: The plugin configuration. 

62 

63 Examples: 

64 >>> from mcpgateway.plugins.framework.models import PluginConfig, UnixSocketClientConfig 

65 >>> config = PluginConfig( 

66 ... name="TestPlugin", 

67 ... kind="external", 

68 ... hooks=["tool_pre_invoke"], 

69 ... unix_socket=UnixSocketClientConfig(path="/tmp/test.sock"), 

70 ... ) 

71 >>> plugin = UnixSocketExternalPlugin(config) 

72 >>> plugin.name 

73 'TestPlugin' 

74 """ 

75 

76 def __init__(self, config: PluginConfig) -> None: 

77 """Initialize the Unix socket plugin client. 

78 

79 Args: 

80 config: The plugin configuration with unix_socket settings. 

81 

82 Raises: 

83 PluginError: If unix_socket configuration is missing. 

84 """ 

85 super().__init__(config) 

86 

87 if not config.unix_socket: 

88 raise PluginError(error=PluginErrorModel(message="The unix_socket section must be defined for Unix socket plugin", plugin_name=config.name)) 

89 

90 self._socket_path = config.unix_socket.path 

91 self._reconnect_attempts = config.unix_socket.reconnect_attempts 

92 self._reconnect_delay = config.unix_socket.reconnect_delay 

93 self._timeout = config.unix_socket.timeout 

94 

95 self._reader: Optional[asyncio.StreamReader] = None 

96 self._writer: Optional[asyncio.StreamWriter] = None 

97 self._connected = False 

98 self._lock = asyncio.Lock() 

99 

100 @property 

101 def connected(self) -> bool: 

102 """Check if the client is connected. 

103 

104 Returns: 

105 bool: True if connected and writer is active, False otherwise. 

106 """ 

107 return self._connected and self._writer is not None and not self._writer.is_closing() 

108 

109 async def _connect(self) -> None: 

110 """Establish connection to the Unix socket server. 

111 

112 Raises: 

113 PluginError: If connection fails. 

114 """ 

115 try: 

116 self._reader, self._writer = await asyncio.open_unix_connection(self._socket_path) 

117 self._connected = True 

118 logger.debug("Connected to Unix socket: %s", self._socket_path) 

119 except OSError as e: 

120 self._connected = False 

121 raise PluginError(error=PluginErrorModel(message=f"Failed to connect to {self._socket_path}: {e}", plugin_name=self.name)) from e 

122 

123 async def _disconnect(self) -> None: 

124 """Close the connection.""" 

125 if self._writer: 

126 try: 

127 self._writer.close() 

128 await self._writer.wait_closed() 

129 except Exception: # nosec B110 - cleanup code, exceptions should not propagate 

130 pass 

131 self._writer = None 

132 self._reader = None 

133 self._connected = False 

134 

135 async def _reconnect(self) -> None: 

136 """Attempt to reconnect with retry logic. 

137 

138 Raises: 

139 PluginError: If all reconnection attempts fail. 

140 """ 

141 await self._disconnect() 

142 

143 last_error: Optional[Exception] = None 

144 for attempt in range(1, self._reconnect_attempts + 1): 

145 try: 

146 logger.debug("Reconnection attempt %d/%d to %s", attempt, self._reconnect_attempts, self._socket_path) 

147 await self._connect() 

148 logger.info("Reconnected to %s on attempt %d", self._socket_path, attempt) 

149 return 

150 except PluginError as e: 

151 last_error = e 

152 if attempt < self._reconnect_attempts: 

153 await asyncio.sleep(self._reconnect_delay * attempt) # Exponential backoff 

154 

155 raise PluginError(error=PluginErrorModel(message=f"Failed to reconnect after {self._reconnect_attempts} attempts: {last_error}", plugin_name=self.name)) 

156 

157 async def _send_request(self, request: plugin_service_pb2.InvokeHookRequest) -> plugin_service_pb2.InvokeHookResponse: 

158 """Send a request and receive response, with reconnection on failure. 

159 

160 Args: 

161 request: The protobuf request to send. 

162 

163 Returns: 

164 The protobuf response. 

165 

166 Raises: 

167 PluginError: If sending fails after reconnection attempts. 

168 """ 

169 request_bytes = request.SerializeToString() 

170 

171 async with self._lock: 

172 for attempt in range(self._reconnect_attempts + 1): 

173 try: 

174 if not self.connected: 

175 await self._reconnect() 

176 

177 # Send request 

178 await write_message_async(self._writer, request_bytes) 

179 

180 # Read response 

181 response_bytes = await read_message(self._reader, timeout=self._timeout) 

182 

183 # Parse response 

184 response = plugin_service_pb2.InvokeHookResponse() 

185 response.ParseFromString(response_bytes) 

186 return response 

187 

188 except asyncio.TimeoutError as e: 

189 logger.warning("Request timed out after %s seconds", self._timeout) 

190 raise PluginError(error=PluginErrorModel(message=f"Request timed out after {self._timeout}s", plugin_name=self.name)) from e 

191 

192 except (OSError, asyncio.IncompleteReadError, BrokenPipeError) as e: 

193 logger.warning("Connection error on attempt %d: %s", attempt + 1, e) 

194 self._connected = False 

195 

196 if attempt < self._reconnect_attempts: 

197 await asyncio.sleep(self._reconnect_delay * (attempt + 1)) 

198 continue 

199 raise PluginError(error=PluginErrorModel(message=f"Request failed after {self._reconnect_attempts + 1} attempts: {e}", plugin_name=self.name)) from e 

200 

201 # Should not reach here 

202 raise PluginError(error=PluginErrorModel(message="Unexpected state in _send_request", plugin_name=self.name)) 

203 

204 async def initialize(self) -> None: 

205 """Initialize the plugin client by connecting to the server. 

206 

207 This establishes the Unix socket connection and optionally 

208 fetches the remote plugin configuration. 

209 

210 Raises: 

211 PluginError: If initial connection fails. 

212 """ 

213 logger.info("Initializing Unix socket plugin: %s -> %s", self.name, self._socket_path) 

214 

215 try: 

216 await self._connect() 

217 except PluginError: 

218 raise 

219 except Exception as e: 

220 logger.exception(e) 

221 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

222 

223 # Optionally fetch remote config to verify connection 

224 try: 

225 request = plugin_service_pb2.GetPluginConfigRequest(name=self.name) 

226 request_bytes = request.SerializeToString() 

227 

228 await write_message_async(self._writer, request_bytes) 

229 response_bytes = await read_message(self._reader, timeout=self._timeout) 

230 

231 response = plugin_service_pb2.GetPluginConfigResponse() 

232 response.ParseFromString(response_bytes) 

233 

234 if response.found: 

235 logger.debug("Remote plugin config verified for %s", self.name) 

236 else: 

237 logger.warning("Plugin %s not found on remote server", self.name) 

238 

239 except Exception as e: 

240 logger.warning("Could not verify remote plugin config: %s", e) 

241 # Continue anyway - the plugin might still work 

242 

243 logger.info("Unix socket plugin initialized: %s", self.name) 

244 

245 async def shutdown(self) -> None: 

246 """Shutdown the plugin client and close the connection.""" 

247 logger.info("Shutting down Unix socket plugin: %s", self.name) 

248 await self._disconnect() 

249 

250 async def invoke_hook( 

251 self, 

252 hook_type: str, 

253 payload: Any, 

254 context: PluginContext, 

255 ) -> PluginResult: 

256 """Invoke a plugin hook over the Unix socket connection. 

257 

258 Args: 

259 hook_type: The type of hook to invoke (e.g., "tool_pre_invoke"). 

260 payload: The hook payload (will be serialized to protobuf Struct). 

261 context: The plugin context. 

262 

263 Returns: 

264 The plugin result. 

265 

266 Raises: 

267 PluginError: If the request fails after retries or hook type is invalid. 

268 """ 

269 # Get the result type from the global registry 

270 registry = get_hook_registry() 

271 result_type = registry.get_result_type(hook_type) 

272 if not result_type: 

273 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name)) 

274 

275 # Convert payload to Struct (still polymorphic) 

276 payload_struct = Struct() 

277 if hasattr(payload, "model_dump"): 

278 json_format.ParseDict(payload.model_dump(), payload_struct) 

279 else: 

280 json_format.ParseDict(payload, payload_struct) 

281 

282 # Convert context to explicit proto message (faster than Struct) 

283 context_proto = pydantic_context_to_proto(context) 

284 

285 # Build request 

286 request = plugin_service_pb2.InvokeHookRequest( 

287 hook_type=hook_type, 

288 plugin_name=self.name, 

289 payload=payload_struct, 

290 context=context_proto, 

291 ) 

292 

293 try: 

294 # Send request and get response 

295 response = await self._send_request(request) 

296 

297 # Handle error response 

298 if response.HasField("error") and response.error.message: 

299 error = PluginErrorModel( 

300 message=response.error.message, 

301 plugin_name=response.error.plugin_name or self.name, 

302 code=response.error.code, 

303 mcp_error_code=response.error.mcp_error_code, 

304 ) 

305 if response.error.HasField("details"): 

306 error.details = json_format.MessageToDict(response.error.details) 

307 raise PluginError(error=error) 

308 

309 # Update context if modified (using explicit proto message) 

310 if response.HasField("context"): 

311 update_pydantic_context_from_proto(context, response.context) 

312 

313 # Parse and return result 

314 if response.HasField("result"): 

315 result_dict = json_format.MessageToDict(response.result) 

316 return result_type.model_validate(result_dict) 

317 

318 raise PluginError( 

319 error=PluginErrorModel( 

320 message="Received invalid response from Unix socket plugin server", 

321 plugin_name=self.name, 

322 ) 

323 ) 

324 

325 except PluginError: 

326 raise 

327 except Exception as e: 

328 logger.exception(e) 

329 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))