Coverage for mcpgateway / plugins / framework / external / unix / server / server.py: 97%
179 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/external/unix/server/server.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor
7Unix socket server for external plugins.
9This module provides a high-performance server that handles plugin requests
10over Unix domain sockets using length-prefixed protobuf messages.
12Examples:
13 Run the server:
15 >>> import asyncio
16 >>> from mcpgateway.plugins.framework.external.unix.server.server import UnixSocketPluginServer
18 >>> async def main():
19 ... server = UnixSocketPluginServer(
20 ... config_path="plugins/config.yaml",
21 ... socket_path="/tmp/plugin.sock",
22 ... )
23 ... await server.start()
24 ... # Server runs until stopped
25 ... await server.stop()
27 >>> # asyncio.run(main())
28"""
30# pylint: disable=no-member,no-name-in-module
32# Standard
33import asyncio
34import logging
35import os
36import signal
37from typing import Optional
39# Third-Party
40from google.protobuf import json_format
41from google.protobuf.struct_pb2 import Struct
43# First-Party
44from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2
45from mcpgateway.plugins.framework.external.mcp.server.server import ExternalPluginServer
46from mcpgateway.plugins.framework.external.proto_convert import (
47 proto_context_to_pydantic,
48 pydantic_context_to_proto,
49)
50from mcpgateway.plugins.framework.external.unix.protocol import ProtocolError, read_message, write_message_async
51from mcpgateway.plugins.framework.models import PluginContext
53logger = logging.getLogger(__name__)
56class UnixSocketPluginServer:
57 """Unix socket server for handling external plugin requests.
59 This server listens on a Unix domain socket and handles plugin
60 requests using length-prefixed protobuf messages. It wraps the
61 ExternalPluginServer for actual plugin execution.
63 Attributes:
64 socket_path: Path to the Unix socket file.
66 Examples:
67 >>> server = UnixSocketPluginServer(
68 ... config_path="plugins/config.yaml",
69 ... socket_path="/tmp/test.sock",
70 ... )
71 >>> server.socket_path
72 '/tmp/test.sock'
73 """
75 def __init__(
76 self,
77 config_path: str,
78 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default
79 ) -> None:
80 """Initialize the Unix socket server.
82 Args:
83 config_path: Path to the plugin configuration file.
84 socket_path: Path for the Unix socket file.
85 """
86 self._config_path = config_path
87 self._socket_path = socket_path
88 self._plugin_server: Optional[ExternalPluginServer] = None
89 self._server: Optional[asyncio.Server] = None
90 self._running = False
92 @property
93 def socket_path(self) -> str:
94 """Get the socket path.
96 Returns:
97 str: The Unix socket file path.
98 """
99 return self._socket_path
101 @property
102 def running(self) -> bool:
103 """Check if the server is running.
105 Returns:
106 bool: True if the server is running, False otherwise.
107 """
108 return self._running
110 async def _handle_client(
111 self,
112 reader: asyncio.StreamReader,
113 writer: asyncio.StreamWriter,
114 ) -> None:
115 """Handle a client connection.
117 Args:
118 reader: The stream reader for the client.
119 writer: The stream writer for the client.
120 """
121 peer = writer.get_extra_info("peername") or "unknown"
122 logger.debug("Client connected: %s", peer)
124 try:
125 while self._running:
126 try:
127 # Read request with timeout
128 data = await read_message(reader, timeout=300.0) # 5 min timeout
129 except asyncio.TimeoutError:
130 logger.debug("Client %s timed out", peer)
131 break
132 except asyncio.IncompleteReadError:
133 # Client disconnected
134 break
135 except ProtocolError as e:
136 logger.warning("Protocol error from %s: %s", peer, e)
137 break
139 # Determine message type and handle
140 response_bytes = await self._handle_message(data)
142 # Send response
143 try:
144 await write_message_async(writer, response_bytes)
145 except (OSError, BrokenPipeError):
146 logger.debug("Client %s disconnected during write", peer)
147 break
149 except Exception as e:
150 logger.exception("Error handling client %s: %s", peer, e)
151 finally:
152 logger.debug("Client disconnected: %s", peer)
153 try:
154 writer.close()
155 await writer.wait_closed()
156 except Exception: # nosec B110 - cleanup code, exceptions should not propagate
157 pass
159 async def _handle_message(self, data: bytes) -> bytes:
160 """Handle a single message and return the response.
162 Args:
163 data: The raw message bytes.
165 Returns:
166 The serialized response bytes.
167 """
168 # Try to parse as InvokeHookRequest first (most common)
169 try:
170 request = plugin_service_pb2.InvokeHookRequest()
171 request.ParseFromString(data)
173 if request.hook_type and request.plugin_name:
174 return await self._handle_invoke_hook(request)
175 except Exception: # nosec B110 - protobuf parse attempt, try next message type
176 pass
178 # Try GetPluginConfigRequest
179 try:
180 request = plugin_service_pb2.GetPluginConfigRequest()
181 request.ParseFromString(data)
183 if request.name:
184 return await self._handle_get_plugin_config(request)
185 except Exception: # nosec B110 - protobuf parse attempt, try next message type
186 pass
188 # Try GetPluginConfigsRequest
189 try:
190 request = plugin_service_pb2.GetPluginConfigsRequest()
191 request.ParseFromString(data)
192 # This request has no required fields, so check if data is minimal
193 if len(data) <= 2: # Empty or near-empty message
194 return await self._handle_get_plugin_configs(request)
195 except Exception: # nosec B110 - protobuf parse attempt, fall through to error
196 pass
198 # Unknown message type
199 logger.warning("Unknown message type, length=%d", len(data))
200 error_response = plugin_service_pb2.InvokeHookResponse()
201 error_response.error.message = "Unknown message type"
202 error_response.error.code = "UNKNOWN_MESSAGE"
203 return error_response.SerializeToString()
205 async def _handle_invoke_hook(
206 self,
207 request: plugin_service_pb2.InvokeHookRequest,
208 ) -> bytes:
209 """Handle an InvokeHook request.
211 Args:
212 request: The InvokeHookRequest.
214 Returns:
215 Serialized InvokeHookResponse.
216 """
217 response = plugin_service_pb2.InvokeHookResponse(plugin_name=request.plugin_name)
219 try:
220 # Convert payload to dict (still polymorphic)
221 payload_dict = json_format.MessageToDict(request.payload)
223 # Convert explicit PluginContext proto directly to Pydantic
224 context_pydantic = proto_context_to_pydantic(request.context)
226 # Invoke the hook (passing Pydantic context directly, no dict conversion)
227 result = await self._plugin_server.invoke_hook(
228 hook_type=request.hook_type,
229 plugin_name=request.plugin_name,
230 payload=payload_dict,
231 context=context_pydantic,
232 )
234 # Build response
235 if "error" in result:
236 error_obj = result["error"]
237 if hasattr(error_obj, "model_dump"):
238 error_dict = error_obj.model_dump()
239 else:
240 error_dict = error_obj
242 response.error.message = error_dict.get("message", "Unknown error")
243 response.error.plugin_name = error_dict.get("plugin_name", "unknown")
244 response.error.code = error_dict.get("code", "")
245 response.error.mcp_error_code = error_dict.get("mcp_error_code", -32603)
246 else:
247 if "result" in result:
248 json_format.ParseDict(result["result"], response.result)
249 if "context" in result:
250 ctx = result["context"]
251 # Handle both Pydantic (optimized path) and dict (MCP compat)
252 if isinstance(ctx, PluginContext):
253 response.context.CopyFrom(pydantic_context_to_proto(ctx))
254 else:
255 updated_context = PluginContext.model_validate(ctx)
256 response.context.CopyFrom(pydantic_context_to_proto(updated_context))
258 except Exception as e:
259 logger.exception("Error invoking hook: %s", e)
260 response.error.message = str(e)
261 response.error.code = "INTERNAL_ERROR"
262 response.error.mcp_error_code = -32603
264 return response.SerializeToString()
266 async def _handle_get_plugin_config(
267 self,
268 request: plugin_service_pb2.GetPluginConfigRequest,
269 ) -> bytes:
270 """Handle a GetPluginConfig request.
272 Args:
273 request: The GetPluginConfigRequest.
275 Returns:
276 Serialized GetPluginConfigResponse.
277 """
278 response = plugin_service_pb2.GetPluginConfigResponse()
280 try:
281 config = await self._plugin_server.get_plugin_config(request.name)
283 if config:
284 response.found = True
285 json_format.ParseDict(config, response.config)
286 else:
287 response.found = False
289 except Exception as e:
290 logger.exception("Error getting plugin config: %s", e)
291 response.found = False
293 return response.SerializeToString()
295 async def _handle_get_plugin_configs(
296 self,
297 _request: plugin_service_pb2.GetPluginConfigsRequest,
298 ) -> bytes:
299 """Handle a GetPluginConfigs request.
301 Args:
302 _request: The GetPluginConfigsRequest (unused, included for API consistency).
304 Returns:
305 Serialized GetPluginConfigsResponse.
306 """
307 response = plugin_service_pb2.GetPluginConfigsResponse()
309 try:
310 configs = await self._plugin_server.get_plugin_configs()
312 for config in configs:
313 config_struct = Struct()
314 json_format.ParseDict(config, config_struct)
315 response.configs.append(config_struct)
317 except Exception as e:
318 logger.exception("Error getting plugin configs: %s", e)
320 return response.SerializeToString()
322 async def start(self) -> None:
323 """Start the Unix socket server.
325 This initializes the plugin server and starts listening for
326 connections on the Unix socket.
327 """
328 logger.info("Starting Unix socket plugin server on %s", self._socket_path)
330 # Clean up old socket file
331 if os.path.exists(self._socket_path):
332 os.unlink(self._socket_path)
334 # Initialize the plugin server
335 self._plugin_server = ExternalPluginServer(config_path=self._config_path)
336 await self._plugin_server.initialize()
338 # Create the Unix socket server
339 self._server = await asyncio.start_unix_server(
340 self._handle_client,
341 path=self._socket_path,
342 )
344 # Set restrictive permissions on the socket file (owner read/write only)
345 if os.path.exists(self._socket_path):
346 os.chmod(self._socket_path, 0o600)
348 self._running = True
349 logger.info("Unix socket plugin server started on %s", self._socket_path)
351 async def serve_forever(self) -> None:
352 """Serve requests until stopped.
354 Raises:
355 RuntimeError: If the server has not been started.
356 """
357 if not self._server:
358 raise RuntimeError("Server not started. Call start() first.")
360 async with self._server:
361 await self._server.serve_forever()
363 async def stop(self) -> None:
364 """Stop the Unix socket server."""
365 logger.info("Stopping Unix socket plugin server")
366 self._running = False
368 if self._server:
369 self._server.close()
370 await self._server.wait_closed()
371 self._server = None
373 if self._plugin_server:
374 await self._plugin_server.shutdown()
375 self._plugin_server = None
377 # Clean up socket file
378 if os.path.exists(self._socket_path):
379 try:
380 os.unlink(self._socket_path)
381 except OSError:
382 pass
384 logger.info("Unix socket plugin server stopped")
387async def run_server(
388 config_path: str,
389 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default
390) -> None:
391 """Run the Unix socket server until interrupted.
393 Args:
394 config_path: Path to the plugin configuration file.
395 socket_path: Path for the Unix socket file.
396 """
397 server = UnixSocketPluginServer(config_path=config_path, socket_path=socket_path)
399 # Set up signal handlers
400 loop = asyncio.get_running_loop()
401 stop_event = asyncio.Event()
403 def signal_handler() -> None:
404 """Handle SIGINT/SIGTERM by setting the stop event."""
405 logger.info("Received shutdown signal")
406 stop_event.set()
408 for sig in (signal.SIGINT, signal.SIGTERM):
409 loop.add_signal_handler(sig, signal_handler)
411 await server.start()
413 # Signal ready (for parent process coordination)
414 print("READY", flush=True)
416 # Wait for shutdown signal
417 await stop_event.wait()
419 await server.stop()