Coverage for mcpgateway / plugins / framework / external / unix / client.py: 99%
138 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/external/unix/client.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor
7Unix socket client for external plugins.
9This module provides a high-performance client for communicating with
10external plugins over Unix domain sockets using length-prefixed protobuf
11messages.
13Examples:
14 Create and use a Unix socket plugin client:
16 >>> from mcpgateway.plugins.framework.external.unix.client import UnixSocketExternalPlugin
17 >>> from mcpgateway.plugins.framework.models import PluginConfig, UnixSocketClientConfig
19 >>> config = PluginConfig(
20 ... name="MyPlugin",
21 ... kind="external",
22 ... hooks=["tool_pre_invoke"],
23 ... unix_socket=UnixSocketClientConfig(path="/tmp/plugin.sock"),
24 ... )
25 >>> plugin = UnixSocketExternalPlugin(config)
26 >>> # await plugin.initialize()
27 >>> # result = await plugin.invoke_hook(hook_type, payload, context)
28"""
30# pylint: disable=no-member,no-name-in-module
32# Standard
33import asyncio
34import logging
35from typing import Any, Optional
37# Third-Party
38from google.protobuf import json_format
39from google.protobuf.struct_pb2 import Struct
41# First-Party
42from mcpgateway.plugins.framework.base import Plugin
43from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError
44from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2
45from mcpgateway.plugins.framework.external.proto_convert import pydantic_context_to_proto, update_pydantic_context_from_proto
46from mcpgateway.plugins.framework.external.unix.protocol import read_message, write_message_async
47from mcpgateway.plugins.framework.hooks.registry import get_hook_registry
48from mcpgateway.plugins.framework.models import PluginConfig, PluginContext, PluginErrorModel, PluginResult
50logger = logging.getLogger(__name__)
53class UnixSocketExternalPlugin(Plugin):
54 """External plugin client using raw Unix domain sockets.
56 This client provides high-performance IPC for local plugins using
57 length-prefixed protobuf messages. It includes automatic reconnection
58 with configurable retry logic.
60 Attributes:
61 config: The plugin configuration.
63 Examples:
64 >>> from mcpgateway.plugins.framework.models import PluginConfig, UnixSocketClientConfig
65 >>> config = PluginConfig(
66 ... name="TestPlugin",
67 ... kind="external",
68 ... hooks=["tool_pre_invoke"],
69 ... unix_socket=UnixSocketClientConfig(path="/tmp/test.sock"),
70 ... )
71 >>> plugin = UnixSocketExternalPlugin(config)
72 >>> plugin.name
73 'TestPlugin'
74 """
76 def __init__(self, config: PluginConfig) -> None:
77 """Initialize the Unix socket plugin client.
79 Args:
80 config: The plugin configuration with unix_socket settings.
82 Raises:
83 PluginError: If unix_socket configuration is missing.
84 """
85 super().__init__(config)
87 if not config.unix_socket:
88 raise PluginError(error=PluginErrorModel(message="The unix_socket section must be defined for Unix socket plugin", plugin_name=config.name))
90 self._socket_path = config.unix_socket.path
91 self._reconnect_attempts = config.unix_socket.reconnect_attempts
92 self._reconnect_delay = config.unix_socket.reconnect_delay
93 self._timeout = config.unix_socket.timeout
95 self._reader: Optional[asyncio.StreamReader] = None
96 self._writer: Optional[asyncio.StreamWriter] = None
97 self._connected = False
98 self._lock = asyncio.Lock()
100 @property
101 def connected(self) -> bool:
102 """Check if the client is connected.
104 Returns:
105 bool: True if connected and writer is active, False otherwise.
106 """
107 return self._connected and self._writer is not None and not self._writer.is_closing()
109 async def _connect(self) -> None:
110 """Establish connection to the Unix socket server.
112 Raises:
113 PluginError: If connection fails.
114 """
115 try:
116 self._reader, self._writer = await asyncio.open_unix_connection(self._socket_path)
117 self._connected = True
118 logger.debug("Connected to Unix socket: %s", self._socket_path)
119 except OSError as e:
120 self._connected = False
121 raise PluginError(error=PluginErrorModel(message=f"Failed to connect to {self._socket_path}: {e}", plugin_name=self.name)) from e
123 async def _disconnect(self) -> None:
124 """Close the connection."""
125 if self._writer:
126 try:
127 self._writer.close()
128 await self._writer.wait_closed()
129 except Exception: # nosec B110 - cleanup code, exceptions should not propagate
130 pass
131 self._writer = None
132 self._reader = None
133 self._connected = False
135 async def _reconnect(self) -> None:
136 """Attempt to reconnect with retry logic.
138 Raises:
139 PluginError: If all reconnection attempts fail.
140 """
141 await self._disconnect()
143 last_error: Optional[Exception] = None
144 for attempt in range(1, self._reconnect_attempts + 1):
145 try:
146 logger.debug("Reconnection attempt %d/%d to %s", attempt, self._reconnect_attempts, self._socket_path)
147 await self._connect()
148 logger.info("Reconnected to %s on attempt %d", self._socket_path, attempt)
149 return
150 except PluginError as e:
151 last_error = e
152 if attempt < self._reconnect_attempts:
153 await asyncio.sleep(self._reconnect_delay * attempt) # Exponential backoff
155 raise PluginError(error=PluginErrorModel(message=f"Failed to reconnect after {self._reconnect_attempts} attempts: {last_error}", plugin_name=self.name))
157 async def _send_request(self, request: plugin_service_pb2.InvokeHookRequest) -> plugin_service_pb2.InvokeHookResponse:
158 """Send a request and receive response, with reconnection on failure.
160 Args:
161 request: The protobuf request to send.
163 Returns:
164 The protobuf response.
166 Raises:
167 PluginError: If sending fails after reconnection attempts.
168 """
169 request_bytes = request.SerializeToString()
171 async with self._lock:
172 for attempt in range(self._reconnect_attempts + 1):
173 try:
174 if not self.connected:
175 await self._reconnect()
177 # Send request
178 await write_message_async(self._writer, request_bytes)
180 # Read response
181 response_bytes = await read_message(self._reader, timeout=self._timeout)
183 # Parse response
184 response = plugin_service_pb2.InvokeHookResponse()
185 response.ParseFromString(response_bytes)
186 return response
188 except asyncio.TimeoutError as e:
189 logger.warning("Request timed out after %s seconds", self._timeout)
190 raise PluginError(error=PluginErrorModel(message=f"Request timed out after {self._timeout}s", plugin_name=self.name)) from e
192 except (OSError, asyncio.IncompleteReadError, BrokenPipeError) as e:
193 logger.warning("Connection error on attempt %d: %s", attempt + 1, e)
194 self._connected = False
196 if attempt < self._reconnect_attempts:
197 await asyncio.sleep(self._reconnect_delay * (attempt + 1))
198 continue
199 raise PluginError(error=PluginErrorModel(message=f"Request failed after {self._reconnect_attempts + 1} attempts: {e}", plugin_name=self.name)) from e
201 # Should not reach here
202 raise PluginError(error=PluginErrorModel(message="Unexpected state in _send_request", plugin_name=self.name))
204 async def initialize(self) -> None:
205 """Initialize the plugin client by connecting to the server.
207 This establishes the Unix socket connection and optionally
208 fetches the remote plugin configuration.
210 Raises:
211 PluginError: If initial connection fails.
212 """
213 logger.info("Initializing Unix socket plugin: %s -> %s", self.name, self._socket_path)
215 try:
216 await self._connect()
217 except PluginError:
218 raise
219 except Exception as e:
220 logger.exception(e)
221 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))
223 # Optionally fetch remote config to verify connection
224 try:
225 request = plugin_service_pb2.GetPluginConfigRequest(name=self.name)
226 request_bytes = request.SerializeToString()
228 await write_message_async(self._writer, request_bytes)
229 response_bytes = await read_message(self._reader, timeout=self._timeout)
231 response = plugin_service_pb2.GetPluginConfigResponse()
232 response.ParseFromString(response_bytes)
234 if response.found:
235 logger.debug("Remote plugin config verified for %s", self.name)
236 else:
237 logger.warning("Plugin %s not found on remote server", self.name)
239 except Exception as e:
240 logger.warning("Could not verify remote plugin config: %s", e)
241 # Continue anyway - the plugin might still work
243 logger.info("Unix socket plugin initialized: %s", self.name)
245 async def shutdown(self) -> None:
246 """Shutdown the plugin client and close the connection."""
247 logger.info("Shutting down Unix socket plugin: %s", self.name)
248 await self._disconnect()
250 async def invoke_hook(
251 self,
252 hook_type: str,
253 payload: Any,
254 context: PluginContext,
255 ) -> PluginResult:
256 """Invoke a plugin hook over the Unix socket connection.
258 Args:
259 hook_type: The type of hook to invoke (e.g., "tool_pre_invoke").
260 payload: The hook payload (will be serialized to protobuf Struct).
261 context: The plugin context.
263 Returns:
264 The plugin result.
266 Raises:
267 PluginError: If the request fails after retries or hook type is invalid.
268 """
269 # Get the result type from the global registry
270 registry = get_hook_registry()
271 result_type = registry.get_result_type(hook_type)
272 if not result_type:
273 raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name))
275 # Convert payload to Struct (still polymorphic)
276 payload_struct = Struct()
277 if hasattr(payload, "model_dump"):
278 json_format.ParseDict(payload.model_dump(), payload_struct)
279 else:
280 json_format.ParseDict(payload, payload_struct)
282 # Convert context to explicit proto message (faster than Struct)
283 context_proto = pydantic_context_to_proto(context)
285 # Build request
286 request = plugin_service_pb2.InvokeHookRequest(
287 hook_type=hook_type,
288 plugin_name=self.name,
289 payload=payload_struct,
290 context=context_proto,
291 )
293 try:
294 # Send request and get response
295 response = await self._send_request(request)
297 # Handle error response
298 if response.HasField("error") and response.error.message:
299 error = PluginErrorModel(
300 message=response.error.message,
301 plugin_name=response.error.plugin_name or self.name,
302 code=response.error.code,
303 mcp_error_code=response.error.mcp_error_code,
304 )
305 if response.error.HasField("details"):
306 error.details = json_format.MessageToDict(response.error.details)
307 raise PluginError(error=error)
309 # Update context if modified (using explicit proto message)
310 if response.HasField("context"):
311 update_pydantic_context_from_proto(context, response.context)
313 # Parse and return result
314 if response.HasField("result"):
315 result_dict = json_format.MessageToDict(response.result)
316 return result_type.model_validate(result_dict)
318 raise PluginError(
319 error=PluginErrorModel(
320 message="Received invalid response from Unix socket plugin server",
321 plugin_name=self.name,
322 )
323 )
325 except PluginError:
326 raise
327 except Exception as e:
328 logger.exception(e)
329 raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name))