Coverage for mcpgateway / plugins / framework / external / unix / server / server.py: 97%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/external/unix/server/server.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor 

6 

7Unix socket server for external plugins. 

8 

9This module provides a high-performance server that handles plugin requests 

10over Unix domain sockets using length-prefixed protobuf messages. 

11 

12Examples: 

13 Run the server: 

14 

15 >>> import asyncio 

16 >>> from mcpgateway.plugins.framework.external.unix.server.server import UnixSocketPluginServer 

17 

18 >>> async def main(): 

19 ... server = UnixSocketPluginServer( 

20 ... config_path="plugins/config.yaml", 

21 ... socket_path="/tmp/plugin.sock", 

22 ... ) 

23 ... await server.start() 

24 ... # Server runs until stopped 

25 ... await server.stop() 

26 

27 >>> # asyncio.run(main()) 

28""" 

29 

30# pylint: disable=no-member,no-name-in-module 

31 

32# Standard 

33import asyncio 

34import logging 

35import os 

36import signal 

37from typing import Optional 

38 

39# Third-Party 

40from google.protobuf import json_format 

41from google.protobuf.struct_pb2 import Struct 

42 

43# First-Party 

44from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2 

45from mcpgateway.plugins.framework.external.mcp.server.server import ExternalPluginServer 

46from mcpgateway.plugins.framework.external.proto_convert import ( 

47 proto_context_to_pydantic, 

48 pydantic_context_to_proto, 

49) 

50from mcpgateway.plugins.framework.external.unix.protocol import ProtocolError, read_message, write_message_async 

51from mcpgateway.plugins.framework.models import PluginContext 

52 

53logger = logging.getLogger(__name__) 

54 

55 

56class UnixSocketPluginServer: 

57 """Unix socket server for handling external plugin requests. 

58 

59 This server listens on a Unix domain socket and handles plugin 

60 requests using length-prefixed protobuf messages. It wraps the 

61 ExternalPluginServer for actual plugin execution. 

62 

63 Attributes: 

64 socket_path: Path to the Unix socket file. 

65 

66 Examples: 

67 >>> server = UnixSocketPluginServer( 

68 ... config_path="plugins/config.yaml", 

69 ... socket_path="/tmp/test.sock", 

70 ... ) 

71 >>> server.socket_path 

72 '/tmp/test.sock' 

73 """ 

74 

75 def __init__( 

76 self, 

77 config_path: str, 

78 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default 

79 ) -> None: 

80 """Initialize the Unix socket server. 

81 

82 Args: 

83 config_path: Path to the plugin configuration file. 

84 socket_path: Path for the Unix socket file. 

85 """ 

86 self._config_path = config_path 

87 self._socket_path = socket_path 

88 self._plugin_server: Optional[ExternalPluginServer] = None 

89 self._server: Optional[asyncio.Server] = None 

90 self._running = False 

91 

92 @property 

93 def socket_path(self) -> str: 

94 """Get the socket path. 

95 

96 Returns: 

97 str: The Unix socket file path. 

98 """ 

99 return self._socket_path 

100 

101 @property 

102 def running(self) -> bool: 

103 """Check if the server is running. 

104 

105 Returns: 

106 bool: True if the server is running, False otherwise. 

107 """ 

108 return self._running 

109 

110 async def _handle_client( 

111 self, 

112 reader: asyncio.StreamReader, 

113 writer: asyncio.StreamWriter, 

114 ) -> None: 

115 """Handle a client connection. 

116 

117 Args: 

118 reader: The stream reader for the client. 

119 writer: The stream writer for the client. 

120 """ 

121 peer = writer.get_extra_info("peername") or "unknown" 

122 logger.debug("Client connected: %s", peer) 

123 

124 try: 

125 while self._running: 

126 try: 

127 # Read request with timeout 

128 data = await read_message(reader, timeout=300.0) # 5 min timeout 

129 except asyncio.TimeoutError: 

130 logger.debug("Client %s timed out", peer) 

131 break 

132 except asyncio.IncompleteReadError: 

133 # Client disconnected 

134 break 

135 except ProtocolError as e: 

136 logger.warning("Protocol error from %s: %s", peer, e) 

137 break 

138 

139 # Determine message type and handle 

140 response_bytes = await self._handle_message(data) 

141 

142 # Send response 

143 try: 

144 await write_message_async(writer, response_bytes) 

145 except (OSError, BrokenPipeError): 

146 logger.debug("Client %s disconnected during write", peer) 

147 break 

148 

149 except Exception as e: 

150 logger.exception("Error handling client %s: %s", peer, e) 

151 finally: 

152 logger.debug("Client disconnected: %s", peer) 

153 try: 

154 writer.close() 

155 await writer.wait_closed() 

156 except Exception: # nosec B110 - cleanup code, exceptions should not propagate 

157 pass 

158 

159 async def _handle_message(self, data: bytes) -> bytes: 

160 """Handle a single message and return the response. 

161 

162 Args: 

163 data: The raw message bytes. 

164 

165 Returns: 

166 The serialized response bytes. 

167 """ 

168 # Try to parse as InvokeHookRequest first (most common) 

169 try: 

170 request = plugin_service_pb2.InvokeHookRequest() 

171 request.ParseFromString(data) 

172 

173 if request.hook_type and request.plugin_name: 

174 return await self._handle_invoke_hook(request) 

175 except Exception: # nosec B110 - protobuf parse attempt, try next message type 

176 pass 

177 

178 # Try GetPluginConfigRequest 

179 try: 

180 request = plugin_service_pb2.GetPluginConfigRequest() 

181 request.ParseFromString(data) 

182 

183 if request.name: 

184 return await self._handle_get_plugin_config(request) 

185 except Exception: # nosec B110 - protobuf parse attempt, try next message type 

186 pass 

187 

188 # Try GetPluginConfigsRequest 

189 try: 

190 request = plugin_service_pb2.GetPluginConfigsRequest() 

191 request.ParseFromString(data) 

192 # This request has no required fields, so check if data is minimal 

193 if len(data) <= 2: # Empty or near-empty message 

194 return await self._handle_get_plugin_configs(request) 

195 except Exception: # nosec B110 - protobuf parse attempt, fall through to error 

196 pass 

197 

198 # Unknown message type 

199 logger.warning("Unknown message type, length=%d", len(data)) 

200 error_response = plugin_service_pb2.InvokeHookResponse() 

201 error_response.error.message = "Unknown message type" 

202 error_response.error.code = "UNKNOWN_MESSAGE" 

203 return error_response.SerializeToString() 

204 

205 async def _handle_invoke_hook( 

206 self, 

207 request: plugin_service_pb2.InvokeHookRequest, 

208 ) -> bytes: 

209 """Handle an InvokeHook request. 

210 

211 Args: 

212 request: The InvokeHookRequest. 

213 

214 Returns: 

215 Serialized InvokeHookResponse. 

216 """ 

217 response = plugin_service_pb2.InvokeHookResponse(plugin_name=request.plugin_name) 

218 

219 try: 

220 # Convert payload to dict (still polymorphic) 

221 payload_dict = json_format.MessageToDict(request.payload) 

222 

223 # Convert explicit PluginContext proto directly to Pydantic 

224 context_pydantic = proto_context_to_pydantic(request.context) 

225 

226 # Invoke the hook (passing Pydantic context directly, no dict conversion) 

227 result = await self._plugin_server.invoke_hook( 

228 hook_type=request.hook_type, 

229 plugin_name=request.plugin_name, 

230 payload=payload_dict, 

231 context=context_pydantic, 

232 ) 

233 

234 # Build response 

235 if "error" in result: 

236 error_obj = result["error"] 

237 if hasattr(error_obj, "model_dump"): 

238 error_dict = error_obj.model_dump() 

239 else: 

240 error_dict = error_obj 

241 

242 response.error.message = error_dict.get("message", "Unknown error") 

243 response.error.plugin_name = error_dict.get("plugin_name", "unknown") 

244 response.error.code = error_dict.get("code", "") 

245 response.error.mcp_error_code = error_dict.get("mcp_error_code", -32603) 

246 else: 

247 if "result" in result: 

248 json_format.ParseDict(result["result"], response.result) 

249 if "context" in result: 

250 ctx = result["context"] 

251 # Handle both Pydantic (optimized path) and dict (MCP compat) 

252 if isinstance(ctx, PluginContext): 

253 response.context.CopyFrom(pydantic_context_to_proto(ctx)) 

254 else: 

255 updated_context = PluginContext.model_validate(ctx) 

256 response.context.CopyFrom(pydantic_context_to_proto(updated_context)) 

257 

258 except Exception as e: 

259 logger.exception("Error invoking hook: %s", e) 

260 response.error.message = str(e) 

261 response.error.code = "INTERNAL_ERROR" 

262 response.error.mcp_error_code = -32603 

263 

264 return response.SerializeToString() 

265 

266 async def _handle_get_plugin_config( 

267 self, 

268 request: plugin_service_pb2.GetPluginConfigRequest, 

269 ) -> bytes: 

270 """Handle a GetPluginConfig request. 

271 

272 Args: 

273 request: The GetPluginConfigRequest. 

274 

275 Returns: 

276 Serialized GetPluginConfigResponse. 

277 """ 

278 response = plugin_service_pb2.GetPluginConfigResponse() 

279 

280 try: 

281 config = await self._plugin_server.get_plugin_config(request.name) 

282 

283 if config: 

284 response.found = True 

285 json_format.ParseDict(config, response.config) 

286 else: 

287 response.found = False 

288 

289 except Exception as e: 

290 logger.exception("Error getting plugin config: %s", e) 

291 response.found = False 

292 

293 return response.SerializeToString() 

294 

295 async def _handle_get_plugin_configs( 

296 self, 

297 _request: plugin_service_pb2.GetPluginConfigsRequest, 

298 ) -> bytes: 

299 """Handle a GetPluginConfigs request. 

300 

301 Args: 

302 _request: The GetPluginConfigsRequest (unused, included for API consistency). 

303 

304 Returns: 

305 Serialized GetPluginConfigsResponse. 

306 """ 

307 response = plugin_service_pb2.GetPluginConfigsResponse() 

308 

309 try: 

310 configs = await self._plugin_server.get_plugin_configs() 

311 

312 for config in configs: 

313 config_struct = Struct() 

314 json_format.ParseDict(config, config_struct) 

315 response.configs.append(config_struct) 

316 

317 except Exception as e: 

318 logger.exception("Error getting plugin configs: %s", e) 

319 

320 return response.SerializeToString() 

321 

322 async def start(self) -> None: 

323 """Start the Unix socket server. 

324 

325 This initializes the plugin server and starts listening for 

326 connections on the Unix socket. 

327 """ 

328 logger.info("Starting Unix socket plugin server on %s", self._socket_path) 

329 

330 # Clean up old socket file 

331 if os.path.exists(self._socket_path): 

332 os.unlink(self._socket_path) 

333 

334 # Initialize the plugin server 

335 self._plugin_server = ExternalPluginServer(config_path=self._config_path) 

336 await self._plugin_server.initialize() 

337 

338 # Create the Unix socket server 

339 self._server = await asyncio.start_unix_server( 

340 self._handle_client, 

341 path=self._socket_path, 

342 ) 

343 

344 # Set restrictive permissions on the socket file (owner read/write only) 

345 if os.path.exists(self._socket_path): 

346 os.chmod(self._socket_path, 0o600) 

347 

348 self._running = True 

349 logger.info("Unix socket plugin server started on %s", self._socket_path) 

350 

351 async def serve_forever(self) -> None: 

352 """Serve requests until stopped. 

353 

354 Raises: 

355 RuntimeError: If the server has not been started. 

356 """ 

357 if not self._server: 

358 raise RuntimeError("Server not started. Call start() first.") 

359 

360 async with self._server: 

361 await self._server.serve_forever() 

362 

363 async def stop(self) -> None: 

364 """Stop the Unix socket server.""" 

365 logger.info("Stopping Unix socket plugin server") 

366 self._running = False 

367 

368 if self._server: 

369 self._server.close() 

370 await self._server.wait_closed() 

371 self._server = None 

372 

373 if self._plugin_server: 

374 await self._plugin_server.shutdown() 

375 self._plugin_server = None 

376 

377 # Clean up socket file 

378 if os.path.exists(self._socket_path): 

379 try: 

380 os.unlink(self._socket_path) 

381 except OSError: 

382 pass 

383 

384 logger.info("Unix socket plugin server stopped") 

385 

386 

387async def run_server( 

388 config_path: str, 

389 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default 

390) -> None: 

391 """Run the Unix socket server until interrupted. 

392 

393 Args: 

394 config_path: Path to the plugin configuration file. 

395 socket_path: Path for the Unix socket file. 

396 """ 

397 server = UnixSocketPluginServer(config_path=config_path, socket_path=socket_path) 

398 

399 # Set up signal handlers 

400 loop = asyncio.get_running_loop() 

401 stop_event = asyncio.Event() 

402 

403 def signal_handler() -> None: 

404 """Handle SIGINT/SIGTERM by setting the stop event.""" 

405 logger.info("Received shutdown signal") 

406 stop_event.set() 

407 

408 for sig in (signal.SIGINT, signal.SIGTERM): 

409 loop.add_signal_handler(sig, signal_handler) 

410 

411 await server.start() 

412 

413 # Signal ready (for parent process coordination) 

414 print("READY", flush=True) 

415 

416 # Wait for shutdown signal 

417 await stop_event.wait() 

418 

419 await server.stop()