Coverage for mcpgateway / plugins / framework / external / unix / client.py: 99%

138 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/external/unix/client.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor 

6 

7Unix socket client for external plugins. 

8 

9This module provides a high-performance client for communicating with 

10external plugins over Unix domain sockets using length-prefixed protobuf 

11messages. 

12 

13Examples: 

14 Create and use a Unix socket plugin client: 

15 

16 >>> from mcpgateway.plugins.framework.external.unix.client import UnixSocketExternalPlugin 

17 >>> from mcpgateway.plugins.framework.models import PluginConfig, UnixSocketClientConfig 

18 

19 >>> config = PluginConfig( 

20 ... name="MyPlugin", 

21 ... kind="external", 

22 ... hooks=["tool_pre_invoke"], 

23 ... unix_socket=UnixSocketClientConfig(path="/tmp/plugin.sock"), 

24 ... ) 

25 >>> plugin = UnixSocketExternalPlugin(config) 

26 >>> # await plugin.initialize() 

27 >>> # result = await plugin.invoke_hook(hook_type, payload, context) 

28""" 

29# pylint: disable=no-member,no-name-in-module 

30 

31# Standard 

32import asyncio 

33import logging 

34from typing import Any, Optional 

35 

36# Third-Party 

37from google.protobuf import json_format 

38from google.protobuf.struct_pb2 import Struct 

39 

40# First-Party 

41from mcpgateway.plugins.framework.base import Plugin 

42from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError 

43from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2 

44from mcpgateway.plugins.framework.external.proto_convert import pydantic_context_to_proto, update_pydantic_context_from_proto 

45from mcpgateway.plugins.framework.external.unix.protocol import read_message, write_message_async 

46from mcpgateway.plugins.framework.hooks.registry import get_hook_registry 

47from mcpgateway.plugins.framework.models import PluginConfig, PluginContext, PluginErrorModel, PluginResult 

48 

49logger = logging.getLogger(__name__) 

50 

51 

52class UnixSocketExternalPlugin(Plugin): 

53 """External plugin client using raw Unix domain sockets. 

54 

55 This client provides high-performance IPC for local plugins using 

56 length-prefixed protobuf messages. It includes automatic reconnection 

57 with configurable retry logic. 

58 

59 Attributes: 

60 config: The plugin configuration. 

61 

62 Examples: 

63 >>> from mcpgateway.plugins.framework.models import PluginConfig, UnixSocketClientConfig 

64 >>> config = PluginConfig( 

65 ... name="TestPlugin", 

66 ... kind="external", 

67 ... hooks=["tool_pre_invoke"], 

68 ... unix_socket=UnixSocketClientConfig(path="/tmp/test.sock"), 

69 ... ) 

70 >>> plugin = UnixSocketExternalPlugin(config) 

71 >>> plugin.name 

72 'TestPlugin' 

73 """ 

74 

75 def __init__(self, config: PluginConfig) -> None: 

76 """Initialize the Unix socket plugin client. 

77 

78 Args: 

79 config: The plugin configuration with unix_socket settings. 

80 

81 Raises: 

82 PluginError: If unix_socket configuration is missing. 

83 """ 

84 super().__init__(config) 

85 

86 if not config.unix_socket: 

87 raise PluginError(error=PluginErrorModel(message="The unix_socket section must be defined for Unix socket plugin", plugin_name=config.name)) 

88 

89 self._socket_path = config.unix_socket.path 

90 self._reconnect_attempts = config.unix_socket.reconnect_attempts 

91 self._reconnect_delay = config.unix_socket.reconnect_delay 

92 self._timeout = config.unix_socket.timeout 

93 

94 self._reader: Optional[asyncio.StreamReader] = None 

95 self._writer: Optional[asyncio.StreamWriter] = None 

96 self._connected = False 

97 self._lock = asyncio.Lock() 

98 

99 @property 

100 def connected(self) -> bool: 

101 """Check if the client is connected. 

102 

103 Returns: 

104 bool: True if connected and writer is active, False otherwise. 

105 """ 

106 return self._connected and self._writer is not None and not self._writer.is_closing() 

107 

108 async def _connect(self) -> None: 

109 """Establish connection to the Unix socket server. 

110 

111 Raises: 

112 PluginError: If connection fails. 

113 """ 

114 try: 

115 self._reader, self._writer = await asyncio.open_unix_connection(self._socket_path) 

116 self._connected = True 

117 logger.debug("Connected to Unix socket: %s", self._socket_path) 

118 except OSError as e: 

119 self._connected = False 

120 raise PluginError(error=PluginErrorModel(message=f"Failed to connect to {self._socket_path}: {e}", plugin_name=self.name)) from e 

121 

122 async def _disconnect(self) -> None: 

123 """Close the connection.""" 

124 if self._writer: 

125 try: 

126 self._writer.close() 

127 await self._writer.wait_closed() 

128 except Exception: # nosec B110 - cleanup code, exceptions should not propagate 

129 pass 

130 self._writer = None 

131 self._reader = None 

132 self._connected = False 

133 

134 async def _reconnect(self) -> None: 

135 """Attempt to reconnect with retry logic. 

136 

137 Raises: 

138 PluginError: If all reconnection attempts fail. 

139 """ 

140 await self._disconnect() 

141 

142 last_error: Optional[Exception] = None 

143 for attempt in range(1, self._reconnect_attempts + 1): 

144 try: 

145 logger.debug("Reconnection attempt %d/%d to %s", attempt, self._reconnect_attempts, self._socket_path) 

146 await self._connect() 

147 logger.info("Reconnected to %s on attempt %d", self._socket_path, attempt) 

148 return 

149 except PluginError as e: 

150 last_error = e 

151 if attempt < self._reconnect_attempts: 

152 await asyncio.sleep(self._reconnect_delay * attempt) # Exponential backoff 

153 

154 raise PluginError(error=PluginErrorModel(message=f"Failed to reconnect after {self._reconnect_attempts} attempts: {last_error}", plugin_name=self.name)) 

155 

156 async def _send_request(self, request: plugin_service_pb2.InvokeHookRequest) -> plugin_service_pb2.InvokeHookResponse: 

157 """Send a request and receive response, with reconnection on failure. 

158 

159 Args: 

160 request: The protobuf request to send. 

161 

162 Returns: 

163 The protobuf response. 

164 

165 Raises: 

166 PluginError: If sending fails after reconnection attempts. 

167 """ 

168 request_bytes = request.SerializeToString() 

169 

170 async with self._lock: 

171 for attempt in range(self._reconnect_attempts + 1): 171 ↛ 201line 171 didn't jump to line 201

172 try: 

173 if not self.connected: 

174 await self._reconnect() 

175 

176 # Send request 

177 await write_message_async(self._writer, request_bytes) 

178 

179 # Read response 

180 response_bytes = await read_message(self._reader, timeout=self._timeout) 

181 

182 # Parse response 

183 response = plugin_service_pb2.InvokeHookResponse() 

184 response.ParseFromString(response_bytes) 

185 return response 

186 

187 except asyncio.TimeoutError as e: 

188 logger.warning("Request timed out after %s seconds", self._timeout) 

189 raise PluginError(error=PluginErrorModel(message=f"Request timed out after {self._timeout}s", plugin_name=self.name)) from e 

190 

191 except (OSError, asyncio.IncompleteReadError, BrokenPipeError) as e: 

192 logger.warning("Connection error on attempt %d: %s", attempt + 1, e) 

193 self._connected = False 

194 

195 if attempt < self._reconnect_attempts: 

196 await asyncio.sleep(self._reconnect_delay * (attempt + 1)) 

197 continue 

198 raise PluginError(error=PluginErrorModel(message=f"Request failed after {self._reconnect_attempts + 1} attempts: {e}", plugin_name=self.name)) from e 

199 

200 # Should not reach here 

201 raise PluginError(error=PluginErrorModel(message="Unexpected state in _send_request", plugin_name=self.name)) 

202 

203 async def initialize(self) -> None: 

204 """Initialize the plugin client by connecting to the server. 

205 

206 This establishes the Unix socket connection and optionally 

207 fetches the remote plugin configuration. 

208 

209 Raises: 

210 PluginError: If initial connection fails. 

211 """ 

212 logger.info("Initializing Unix socket plugin: %s -> %s", self.name, self._socket_path) 

213 

214 try: 

215 await self._connect() 

216 except PluginError: 

217 raise 

218 except Exception as e: 

219 logger.exception(e) 

220 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) 

221 

222 # Optionally fetch remote config to verify connection 

223 try: 

224 request = plugin_service_pb2.GetPluginConfigRequest(name=self.name) 

225 request_bytes = request.SerializeToString() 

226 

227 await write_message_async(self._writer, request_bytes) 

228 response_bytes = await read_message(self._reader, timeout=self._timeout) 

229 

230 response = plugin_service_pb2.GetPluginConfigResponse() 

231 response.ParseFromString(response_bytes) 

232 

233 if response.found: 

234 logger.debug("Remote plugin config verified for %s", self.name) 

235 else: 

236 logger.warning("Plugin %s not found on remote server", self.name) 

237 

238 except Exception as e: 

239 logger.warning("Could not verify remote plugin config: %s", e) 

240 # Continue anyway - the plugin might still work 

241 

242 logger.info("Unix socket plugin initialized: %s", self.name) 

243 

244 async def shutdown(self) -> None: 

245 """Shutdown the plugin client and close the connection.""" 

246 logger.info("Shutting down Unix socket plugin: %s", self.name) 

247 await self._disconnect() 

248 

249 async def invoke_hook( 

250 self, 

251 hook_type: str, 

252 payload: Any, 

253 context: PluginContext, 

254 ) -> PluginResult: 

255 """Invoke a plugin hook over the Unix socket connection. 

256 

257 Args: 

258 hook_type: The type of hook to invoke (e.g., "tool_pre_invoke"). 

259 payload: The hook payload (will be serialized to protobuf Struct). 

260 context: The plugin context. 

261 

262 Returns: 

263 The plugin result. 

264 

265 Raises: 

266 PluginError: If the request fails after retries or hook type is invalid. 

267 """ 

268 # Get the result type from the global registry 

269 registry = get_hook_registry() 

270 result_type = registry.get_result_type(hook_type) 

271 if not result_type: 

272 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name)) 

273 

274 # Convert payload to Struct (still polymorphic) 

275 payload_struct = Struct() 

276 if hasattr(payload, "model_dump"): 

277 json_format.ParseDict(payload.model_dump(), payload_struct) 

278 else: 

279 json_format.ParseDict(payload, payload_struct) 

280 

281 # Convert context to explicit proto message (faster than Struct) 

282 context_proto = pydantic_context_to_proto(context) 

283 

284 # Build request 

285 request = plugin_service_pb2.InvokeHookRequest( 

286 hook_type=hook_type, 

287 plugin_name=self.name, 

288 payload=payload_struct, 

289 context=context_proto, 

290 ) 

291 

292 try: 

293 # Send request and get response 

294 response = await self._send_request(request) 

295 

296 # Handle error response 

297 if response.HasField("error") and response.error.message: 

298 error = PluginErrorModel( 

299 message=response.error.message, 

300 plugin_name=response.error.plugin_name or self.name, 

301 code=response.error.code, 

302 mcp_error_code=response.error.mcp_error_code, 

303 ) 

304 if response.error.HasField("details"): 

305 error.details = json_format.MessageToDict(response.error.details) 

306 raise PluginError(error=error) 

307 

308 # Update context if modified (using explicit proto message) 

309 if response.HasField("context"): 

310 update_pydantic_context_from_proto(context, response.context) 

311 

312 # Parse and return result 

313 if response.HasField("result"): 

314 result_dict = json_format.MessageToDict(response.result) 

315 return result_type.model_validate(result_dict) 

316 

317 raise PluginError( 

318 error=PluginErrorModel( 

319 message="Received invalid response from Unix socket plugin server", 

320 plugin_name=self.name, 

321 ) 

322 ) 

323 

324 except PluginError: 

325 raise 

326 except Exception as e: 

327 logger.exception(e) 

328 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))