Coverage for mcpgateway / plugins / framework / external / unix / server / server.py: 95%
179 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/external/unix/server/server.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor
7Unix socket server for external plugins.
9This module provides a high-performance server that handles plugin requests
10over Unix domain sockets using length-prefixed protobuf messages.
12Examples:
13 Run the server:
15 >>> import asyncio
16 >>> from mcpgateway.plugins.framework.external.unix.server.server import UnixSocketPluginServer
18 >>> async def main():
19 ... server = UnixSocketPluginServer(
20 ... config_path="plugins/config.yaml",
21 ... socket_path="/tmp/plugin.sock",
22 ... )
23 ... await server.start()
24 ... # Server runs until stopped
25 ... await server.stop()
27 >>> # asyncio.run(main())
28"""
29# pylint: disable=no-member,no-name-in-module
31# Standard
32import asyncio
33import logging
34import os
35import signal
36from typing import Optional
38# Third-Party
39from google.protobuf import json_format
40from google.protobuf.struct_pb2 import Struct
42# First-Party
43from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2
44from mcpgateway.plugins.framework.external.mcp.server.server import ExternalPluginServer
45from mcpgateway.plugins.framework.external.proto_convert import (
46 proto_context_to_pydantic,
47 pydantic_context_to_proto,
48)
49from mcpgateway.plugins.framework.external.unix.protocol import ProtocolError, read_message, write_message_async
50from mcpgateway.plugins.framework.models import PluginContext
52logger = logging.getLogger(__name__)
55class UnixSocketPluginServer:
56 """Unix socket server for handling external plugin requests.
58 This server listens on a Unix domain socket and handles plugin
59 requests using length-prefixed protobuf messages. It wraps the
60 ExternalPluginServer for actual plugin execution.
62 Attributes:
63 socket_path: Path to the Unix socket file.
65 Examples:
66 >>> server = UnixSocketPluginServer(
67 ... config_path="plugins/config.yaml",
68 ... socket_path="/tmp/test.sock",
69 ... )
70 >>> server.socket_path
71 '/tmp/test.sock'
72 """
74 def __init__(
75 self,
76 config_path: str,
77 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default
78 ) -> None:
79 """Initialize the Unix socket server.
81 Args:
82 config_path: Path to the plugin configuration file.
83 socket_path: Path for the Unix socket file.
84 """
85 self._config_path = config_path
86 self._socket_path = socket_path
87 self._plugin_server: Optional[ExternalPluginServer] = None
88 self._server: Optional[asyncio.Server] = None
89 self._running = False
91 @property
92 def socket_path(self) -> str:
93 """Get the socket path.
95 Returns:
96 str: The Unix socket file path.
97 """
98 return self._socket_path
100 @property
101 def running(self) -> bool:
102 """Check if the server is running.
104 Returns:
105 bool: True if the server is running, False otherwise.
106 """
107 return self._running
109 async def _handle_client(
110 self,
111 reader: asyncio.StreamReader,
112 writer: asyncio.StreamWriter,
113 ) -> None:
114 """Handle a client connection.
116 Args:
117 reader: The stream reader for the client.
118 writer: The stream writer for the client.
119 """
120 peer = writer.get_extra_info("peername") or "unknown"
121 logger.debug("Client connected: %s", peer)
123 try:
124 while self._running: 124 ↛ 151line 124 didn't jump to line 151 because the condition on line 124 was always true
125 try:
126 # Read request with timeout
127 data = await read_message(reader, timeout=300.0) # 5 min timeout
128 except asyncio.TimeoutError:
129 logger.debug("Client %s timed out", peer)
130 break
131 except asyncio.IncompleteReadError:
132 # Client disconnected
133 break
134 except ProtocolError as e:
135 logger.warning("Protocol error from %s: %s", peer, e)
136 break
138 # Determine message type and handle
139 response_bytes = await self._handle_message(data)
141 # Send response
142 try:
143 await write_message_async(writer, response_bytes)
144 except (OSError, BrokenPipeError):
145 logger.debug("Client %s disconnected during write", peer)
146 break
148 except Exception as e:
149 logger.exception("Error handling client %s: %s", peer, e)
150 finally:
151 logger.debug("Client disconnected: %s", peer)
152 try:
153 writer.close()
154 await writer.wait_closed()
155 except Exception: # nosec B110 - cleanup code, exceptions should not propagate
156 pass
158 async def _handle_message(self, data: bytes) -> bytes:
159 """Handle a single message and return the response.
161 Args:
162 data: The raw message bytes.
164 Returns:
165 The serialized response bytes.
166 """
167 # Try to parse as InvokeHookRequest first (most common)
168 try:
169 request = plugin_service_pb2.InvokeHookRequest()
170 request.ParseFromString(data)
172 if request.hook_type and request.plugin_name:
173 return await self._handle_invoke_hook(request)
174 except Exception: # nosec B110 - protobuf parse attempt, try next message type
175 pass
177 # Try GetPluginConfigRequest
178 try:
179 request = plugin_service_pb2.GetPluginConfigRequest()
180 request.ParseFromString(data)
182 if request.name:
183 return await self._handle_get_plugin_config(request)
184 except Exception: # nosec B110 - protobuf parse attempt, try next message type
185 pass
187 # Try GetPluginConfigsRequest
188 try:
189 request = plugin_service_pb2.GetPluginConfigsRequest()
190 request.ParseFromString(data)
191 # This request has no required fields, so check if data is minimal
192 if len(data) <= 2: # Empty or near-empty message 192 ↛ 198line 192 didn't jump to line 198 because the condition on line 192 was always true
193 return await self._handle_get_plugin_configs(request)
194 except Exception: # nosec B110 - protobuf parse attempt, fall through to error
195 pass
197 # Unknown message type
198 logger.warning("Unknown message type, length=%d", len(data))
199 error_response = plugin_service_pb2.InvokeHookResponse()
200 error_response.error.message = "Unknown message type"
201 error_response.error.code = "UNKNOWN_MESSAGE"
202 return error_response.SerializeToString()
204 async def _handle_invoke_hook(
205 self,
206 request: plugin_service_pb2.InvokeHookRequest,
207 ) -> bytes:
208 """Handle an InvokeHook request.
210 Args:
211 request: The InvokeHookRequest.
213 Returns:
214 Serialized InvokeHookResponse.
215 """
216 response = plugin_service_pb2.InvokeHookResponse(plugin_name=request.plugin_name)
218 try:
219 # Convert payload to dict (still polymorphic)
220 payload_dict = json_format.MessageToDict(request.payload)
222 # Convert explicit PluginContext proto directly to Pydantic
223 context_pydantic = proto_context_to_pydantic(request.context)
225 # Invoke the hook (passing Pydantic context directly, no dict conversion)
226 result = await self._plugin_server.invoke_hook(
227 hook_type=request.hook_type,
228 plugin_name=request.plugin_name,
229 payload=payload_dict,
230 context=context_pydantic,
231 )
233 # Build response
234 if "error" in result:
235 error_obj = result["error"]
236 if hasattr(error_obj, "model_dump"):
237 error_dict = error_obj.model_dump()
238 else:
239 error_dict = error_obj
241 response.error.message = error_dict.get("message", "Unknown error")
242 response.error.plugin_name = error_dict.get("plugin_name", "unknown")
243 response.error.code = error_dict.get("code", "")
244 response.error.mcp_error_code = error_dict.get("mcp_error_code", -32603)
245 else:
246 if "result" in result: 246 ↛ 248line 246 didn't jump to line 248 because the condition on line 246 was always true
247 json_format.ParseDict(result["result"], response.result)
248 if "context" in result:
249 ctx = result["context"]
250 # Handle both Pydantic (optimized path) and dict (MCP compat)
251 if isinstance(ctx, PluginContext):
252 response.context.CopyFrom(pydantic_context_to_proto(ctx))
253 else:
254 updated_context = PluginContext.model_validate(ctx)
255 response.context.CopyFrom(pydantic_context_to_proto(updated_context))
257 except Exception as e:
258 logger.exception("Error invoking hook: %s", e)
259 response.error.message = str(e)
260 response.error.code = "INTERNAL_ERROR"
261 response.error.mcp_error_code = -32603
263 return response.SerializeToString()
265 async def _handle_get_plugin_config(
266 self,
267 request: plugin_service_pb2.GetPluginConfigRequest,
268 ) -> bytes:
269 """Handle a GetPluginConfig request.
271 Args:
272 request: The GetPluginConfigRequest.
274 Returns:
275 Serialized GetPluginConfigResponse.
276 """
277 response = plugin_service_pb2.GetPluginConfigResponse()
279 try:
280 config = await self._plugin_server.get_plugin_config(request.name)
282 if config:
283 response.found = True
284 json_format.ParseDict(config, response.config)
285 else:
286 response.found = False
288 except Exception as e:
289 logger.exception("Error getting plugin config: %s", e)
290 response.found = False
292 return response.SerializeToString()
294 async def _handle_get_plugin_configs(
295 self,
296 _request: plugin_service_pb2.GetPluginConfigsRequest,
297 ) -> bytes:
298 """Handle a GetPluginConfigs request.
300 Args:
301 _request: The GetPluginConfigsRequest (unused, included for API consistency).
303 Returns:
304 Serialized GetPluginConfigsResponse.
305 """
306 response = plugin_service_pb2.GetPluginConfigsResponse()
308 try:
309 configs = await self._plugin_server.get_plugin_configs()
311 for config in configs:
312 config_struct = Struct()
313 json_format.ParseDict(config, config_struct)
314 response.configs.append(config_struct)
316 except Exception as e:
317 logger.exception("Error getting plugin configs: %s", e)
319 return response.SerializeToString()
321 async def start(self) -> None:
322 """Start the Unix socket server.
324 This initializes the plugin server and starts listening for
325 connections on the Unix socket.
326 """
327 logger.info("Starting Unix socket plugin server on %s", self._socket_path)
329 # Clean up old socket file
330 if os.path.exists(self._socket_path):
331 os.unlink(self._socket_path)
333 # Initialize the plugin server
334 self._plugin_server = ExternalPluginServer(config_path=self._config_path)
335 await self._plugin_server.initialize()
337 # Create the Unix socket server
338 self._server = await asyncio.start_unix_server(
339 self._handle_client,
340 path=self._socket_path,
341 )
343 # Set restrictive permissions on the socket file (owner read/write only)
344 if os.path.exists(self._socket_path): 344 ↛ 347line 344 didn't jump to line 347 because the condition on line 344 was always true
345 os.chmod(self._socket_path, 0o600)
347 self._running = True
348 logger.info("Unix socket plugin server started on %s", self._socket_path)
350 async def serve_forever(self) -> None:
351 """Serve requests until stopped.
353 Raises:
354 RuntimeError: If the server has not been started.
355 """
356 if not self._server: 356 ↛ 359line 356 didn't jump to line 359 because the condition on line 356 was always true
357 raise RuntimeError("Server not started. Call start() first.")
359 async with self._server:
360 await self._server.serve_forever()
362 async def stop(self) -> None:
363 """Stop the Unix socket server."""
364 logger.info("Stopping Unix socket plugin server")
365 self._running = False
367 if self._server:
368 self._server.close()
369 await self._server.wait_closed()
370 self._server = None
372 if self._plugin_server:
373 await self._plugin_server.shutdown()
374 self._plugin_server = None
376 # Clean up socket file
377 if os.path.exists(self._socket_path):
378 try:
379 os.unlink(self._socket_path)
380 except OSError:
381 pass
383 logger.info("Unix socket plugin server stopped")
386async def run_server(
387 config_path: str,
388 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default
389) -> None:
390 """Run the Unix socket server until interrupted.
392 Args:
393 config_path: Path to the plugin configuration file.
394 socket_path: Path for the Unix socket file.
395 """
396 server = UnixSocketPluginServer(config_path=config_path, socket_path=socket_path)
398 # Set up signal handlers
399 loop = asyncio.get_running_loop()
400 stop_event = asyncio.Event()
402 def signal_handler() -> None:
403 """Handle SIGINT/SIGTERM by setting the stop event."""
404 logger.info("Received shutdown signal")
405 stop_event.set()
407 for sig in (signal.SIGINT, signal.SIGTERM):
408 loop.add_signal_handler(sig, signal_handler)
410 await server.start()
412 # Signal ready (for parent process coordination)
413 print("READY", flush=True)
415 # Wait for shutdown signal
416 await stop_event.wait()
418 await server.stop()