Coverage for mcpgateway / plugins / framework / external / unix / server / server.py: 95%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/external/unix/server/server.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor 

6 

7Unix socket server for external plugins. 

8 

9This module provides a high-performance server that handles plugin requests 

10over Unix domain sockets using length-prefixed protobuf messages. 

11 

12Examples: 

13 Run the server: 

14 

15 >>> import asyncio 

16 >>> from mcpgateway.plugins.framework.external.unix.server.server import UnixSocketPluginServer 

17 

18 >>> async def main(): 

19 ... server = UnixSocketPluginServer( 

20 ... config_path="plugins/config.yaml", 

21 ... socket_path="/tmp/plugin.sock", 

22 ... ) 

23 ... await server.start() 

24 ... # Server runs until stopped 

25 ... await server.stop() 

26 

27 >>> # asyncio.run(main()) 

28""" 

29# pylint: disable=no-member,no-name-in-module 

30 

31# Standard 

32import asyncio 

33import logging 

34import os 

35import signal 

36from typing import Optional 

37 

38# Third-Party 

39from google.protobuf import json_format 

40from google.protobuf.struct_pb2 import Struct 

41 

42# First-Party 

43from mcpgateway.plugins.framework.external.grpc.proto import plugin_service_pb2 

44from mcpgateway.plugins.framework.external.mcp.server.server import ExternalPluginServer 

45from mcpgateway.plugins.framework.external.proto_convert import ( 

46 proto_context_to_pydantic, 

47 pydantic_context_to_proto, 

48) 

49from mcpgateway.plugins.framework.external.unix.protocol import ProtocolError, read_message, write_message_async 

50from mcpgateway.plugins.framework.models import PluginContext 

51 

52logger = logging.getLogger(__name__) 

53 

54 

55class UnixSocketPluginServer: 

56 """Unix socket server for handling external plugin requests. 

57 

58 This server listens on a Unix domain socket and handles plugin 

59 requests using length-prefixed protobuf messages. It wraps the 

60 ExternalPluginServer for actual plugin execution. 

61 

62 Attributes: 

63 socket_path: Path to the Unix socket file. 

64 

65 Examples: 

66 >>> server = UnixSocketPluginServer( 

67 ... config_path="plugins/config.yaml", 

68 ... socket_path="/tmp/test.sock", 

69 ... ) 

70 >>> server.socket_path 

71 '/tmp/test.sock' 

72 """ 

73 

74 def __init__( 

75 self, 

76 config_path: str, 

77 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default 

78 ) -> None: 

79 """Initialize the Unix socket server. 

80 

81 Args: 

82 config_path: Path to the plugin configuration file. 

83 socket_path: Path for the Unix socket file. 

84 """ 

85 self._config_path = config_path 

86 self._socket_path = socket_path 

87 self._plugin_server: Optional[ExternalPluginServer] = None 

88 self._server: Optional[asyncio.Server] = None 

89 self._running = False 

90 

91 @property 

92 def socket_path(self) -> str: 

93 """Get the socket path. 

94 

95 Returns: 

96 str: The Unix socket file path. 

97 """ 

98 return self._socket_path 

99 

100 @property 

101 def running(self) -> bool: 

102 """Check if the server is running. 

103 

104 Returns: 

105 bool: True if the server is running, False otherwise. 

106 """ 

107 return self._running 

108 

109 async def _handle_client( 

110 self, 

111 reader: asyncio.StreamReader, 

112 writer: asyncio.StreamWriter, 

113 ) -> None: 

114 """Handle a client connection. 

115 

116 Args: 

117 reader: The stream reader for the client. 

118 writer: The stream writer for the client. 

119 """ 

120 peer = writer.get_extra_info("peername") or "unknown" 

121 logger.debug("Client connected: %s", peer) 

122 

123 try: 

124 while self._running: 124 ↛ 151line 124 didn't jump to line 151 because the condition on line 124 was always true

125 try: 

126 # Read request with timeout 

127 data = await read_message(reader, timeout=300.0) # 5 min timeout 

128 except asyncio.TimeoutError: 

129 logger.debug("Client %s timed out", peer) 

130 break 

131 except asyncio.IncompleteReadError: 

132 # Client disconnected 

133 break 

134 except ProtocolError as e: 

135 logger.warning("Protocol error from %s: %s", peer, e) 

136 break 

137 

138 # Determine message type and handle 

139 response_bytes = await self._handle_message(data) 

140 

141 # Send response 

142 try: 

143 await write_message_async(writer, response_bytes) 

144 except (OSError, BrokenPipeError): 

145 logger.debug("Client %s disconnected during write", peer) 

146 break 

147 

148 except Exception as e: 

149 logger.exception("Error handling client %s: %s", peer, e) 

150 finally: 

151 logger.debug("Client disconnected: %s", peer) 

152 try: 

153 writer.close() 

154 await writer.wait_closed() 

155 except Exception: # nosec B110 - cleanup code, exceptions should not propagate 

156 pass 

157 

158 async def _handle_message(self, data: bytes) -> bytes: 

159 """Handle a single message and return the response. 

160 

161 Args: 

162 data: The raw message bytes. 

163 

164 Returns: 

165 The serialized response bytes. 

166 """ 

167 # Try to parse as InvokeHookRequest first (most common) 

168 try: 

169 request = plugin_service_pb2.InvokeHookRequest() 

170 request.ParseFromString(data) 

171 

172 if request.hook_type and request.plugin_name: 

173 return await self._handle_invoke_hook(request) 

174 except Exception: # nosec B110 - protobuf parse attempt, try next message type 

175 pass 

176 

177 # Try GetPluginConfigRequest 

178 try: 

179 request = plugin_service_pb2.GetPluginConfigRequest() 

180 request.ParseFromString(data) 

181 

182 if request.name: 

183 return await self._handle_get_plugin_config(request) 

184 except Exception: # nosec B110 - protobuf parse attempt, try next message type 

185 pass 

186 

187 # Try GetPluginConfigsRequest 

188 try: 

189 request = plugin_service_pb2.GetPluginConfigsRequest() 

190 request.ParseFromString(data) 

191 # This request has no required fields, so check if data is minimal 

192 if len(data) <= 2: # Empty or near-empty message 192 ↛ 198line 192 didn't jump to line 198 because the condition on line 192 was always true

193 return await self._handle_get_plugin_configs(request) 

194 except Exception: # nosec B110 - protobuf parse attempt, fall through to error 

195 pass 

196 

197 # Unknown message type 

198 logger.warning("Unknown message type, length=%d", len(data)) 

199 error_response = plugin_service_pb2.InvokeHookResponse() 

200 error_response.error.message = "Unknown message type" 

201 error_response.error.code = "UNKNOWN_MESSAGE" 

202 return error_response.SerializeToString() 

203 

204 async def _handle_invoke_hook( 

205 self, 

206 request: plugin_service_pb2.InvokeHookRequest, 

207 ) -> bytes: 

208 """Handle an InvokeHook request. 

209 

210 Args: 

211 request: The InvokeHookRequest. 

212 

213 Returns: 

214 Serialized InvokeHookResponse. 

215 """ 

216 response = plugin_service_pb2.InvokeHookResponse(plugin_name=request.plugin_name) 

217 

218 try: 

219 # Convert payload to dict (still polymorphic) 

220 payload_dict = json_format.MessageToDict(request.payload) 

221 

222 # Convert explicit PluginContext proto directly to Pydantic 

223 context_pydantic = proto_context_to_pydantic(request.context) 

224 

225 # Invoke the hook (passing Pydantic context directly, no dict conversion) 

226 result = await self._plugin_server.invoke_hook( 

227 hook_type=request.hook_type, 

228 plugin_name=request.plugin_name, 

229 payload=payload_dict, 

230 context=context_pydantic, 

231 ) 

232 

233 # Build response 

234 if "error" in result: 

235 error_obj = result["error"] 

236 if hasattr(error_obj, "model_dump"): 

237 error_dict = error_obj.model_dump() 

238 else: 

239 error_dict = error_obj 

240 

241 response.error.message = error_dict.get("message", "Unknown error") 

242 response.error.plugin_name = error_dict.get("plugin_name", "unknown") 

243 response.error.code = error_dict.get("code", "") 

244 response.error.mcp_error_code = error_dict.get("mcp_error_code", -32603) 

245 else: 

246 if "result" in result: 246 ↛ 248line 246 didn't jump to line 248 because the condition on line 246 was always true

247 json_format.ParseDict(result["result"], response.result) 

248 if "context" in result: 

249 ctx = result["context"] 

250 # Handle both Pydantic (optimized path) and dict (MCP compat) 

251 if isinstance(ctx, PluginContext): 

252 response.context.CopyFrom(pydantic_context_to_proto(ctx)) 

253 else: 

254 updated_context = PluginContext.model_validate(ctx) 

255 response.context.CopyFrom(pydantic_context_to_proto(updated_context)) 

256 

257 except Exception as e: 

258 logger.exception("Error invoking hook: %s", e) 

259 response.error.message = str(e) 

260 response.error.code = "INTERNAL_ERROR" 

261 response.error.mcp_error_code = -32603 

262 

263 return response.SerializeToString() 

264 

265 async def _handle_get_plugin_config( 

266 self, 

267 request: plugin_service_pb2.GetPluginConfigRequest, 

268 ) -> bytes: 

269 """Handle a GetPluginConfig request. 

270 

271 Args: 

272 request: The GetPluginConfigRequest. 

273 

274 Returns: 

275 Serialized GetPluginConfigResponse. 

276 """ 

277 response = plugin_service_pb2.GetPluginConfigResponse() 

278 

279 try: 

280 config = await self._plugin_server.get_plugin_config(request.name) 

281 

282 if config: 

283 response.found = True 

284 json_format.ParseDict(config, response.config) 

285 else: 

286 response.found = False 

287 

288 except Exception as e: 

289 logger.exception("Error getting plugin config: %s", e) 

290 response.found = False 

291 

292 return response.SerializeToString() 

293 

294 async def _handle_get_plugin_configs( 

295 self, 

296 _request: plugin_service_pb2.GetPluginConfigsRequest, 

297 ) -> bytes: 

298 """Handle a GetPluginConfigs request. 

299 

300 Args: 

301 _request: The GetPluginConfigsRequest (unused, included for API consistency). 

302 

303 Returns: 

304 Serialized GetPluginConfigsResponse. 

305 """ 

306 response = plugin_service_pb2.GetPluginConfigsResponse() 

307 

308 try: 

309 configs = await self._plugin_server.get_plugin_configs() 

310 

311 for config in configs: 

312 config_struct = Struct() 

313 json_format.ParseDict(config, config_struct) 

314 response.configs.append(config_struct) 

315 

316 except Exception as e: 

317 logger.exception("Error getting plugin configs: %s", e) 

318 

319 return response.SerializeToString() 

320 

321 async def start(self) -> None: 

322 """Start the Unix socket server. 

323 

324 This initializes the plugin server and starts listening for 

325 connections on the Unix socket. 

326 """ 

327 logger.info("Starting Unix socket plugin server on %s", self._socket_path) 

328 

329 # Clean up old socket file 

330 if os.path.exists(self._socket_path): 

331 os.unlink(self._socket_path) 

332 

333 # Initialize the plugin server 

334 self._plugin_server = ExternalPluginServer(config_path=self._config_path) 

335 await self._plugin_server.initialize() 

336 

337 # Create the Unix socket server 

338 self._server = await asyncio.start_unix_server( 

339 self._handle_client, 

340 path=self._socket_path, 

341 ) 

342 

343 # Set restrictive permissions on the socket file (owner read/write only) 

344 if os.path.exists(self._socket_path): 344 ↛ 347line 344 didn't jump to line 347 because the condition on line 344 was always true

345 os.chmod(self._socket_path, 0o600) 

346 

347 self._running = True 

348 logger.info("Unix socket plugin server started on %s", self._socket_path) 

349 

350 async def serve_forever(self) -> None: 

351 """Serve requests until stopped. 

352 

353 Raises: 

354 RuntimeError: If the server has not been started. 

355 """ 

356 if not self._server: 356 ↛ 359line 356 didn't jump to line 359 because the condition on line 356 was always true

357 raise RuntimeError("Server not started. Call start() first.") 

358 

359 async with self._server: 

360 await self._server.serve_forever() 

361 

362 async def stop(self) -> None: 

363 """Stop the Unix socket server.""" 

364 logger.info("Stopping Unix socket plugin server") 

365 self._running = False 

366 

367 if self._server: 

368 self._server.close() 

369 await self._server.wait_closed() 

370 self._server = None 

371 

372 if self._plugin_server: 

373 await self._plugin_server.shutdown() 

374 self._plugin_server = None 

375 

376 # Clean up socket file 

377 if os.path.exists(self._socket_path): 

378 try: 

379 os.unlink(self._socket_path) 

380 except OSError: 

381 pass 

382 

383 logger.info("Unix socket plugin server stopped") 

384 

385 

386async def run_server( 

387 config_path: str, 

388 socket_path: str = "/tmp/mcpgateway-plugins.sock", # nosec B108 - configurable default 

389) -> None: 

390 """Run the Unix socket server until interrupted. 

391 

392 Args: 

393 config_path: Path to the plugin configuration file. 

394 socket_path: Path for the Unix socket file. 

395 """ 

396 server = UnixSocketPluginServer(config_path=config_path, socket_path=socket_path) 

397 

398 # Set up signal handlers 

399 loop = asyncio.get_running_loop() 

400 stop_event = asyncio.Event() 

401 

402 def signal_handler() -> None: 

403 """Handle SIGINT/SIGTERM by setting the stop event.""" 

404 logger.info("Received shutdown signal") 

405 stop_event.set() 

406 

407 for sig in (signal.SIGINT, signal.SIGTERM): 

408 loop.add_signal_handler(sig, signal_handler) 

409 

410 await server.start() 

411 

412 # Signal ready (for parent process coordination) 

413 print("READY", flush=True) 

414 

415 # Wait for shutdown signal 

416 await stop_event.wait() 

417 

418 await server.stop()