Coverage for mcpgateway / translate_grpc.py: 100%

224 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/translate_grpc.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: MCP Gateway Contributors 

6 

7gRPC to MCP Translation Module 

8 

9This module provides gRPC to MCP protocol translation capabilities. 

10It enables exposing gRPC services as MCP tools via HTTP/SSE endpoints 

11using automatic service discovery through gRPC server reflection. 

12""" 

13 

14# Standard 

15import asyncio 

16from pathlib import Path 

17from typing import Any, AsyncGenerator, Dict, List, Optional 

18 

19try: 

20 # Third-Party 

21 from google.protobuf import descriptor_pool, json_format, message_factory 

22 from google.protobuf.descriptor_pb2 import FileDescriptorProto # pylint: disable=no-name-in-module 

23 import grpc 

24 from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc # pylint: disable=no-member 

25 

26 GRPC_AVAILABLE = True 

27except ImportError: 

28 GRPC_AVAILABLE = False 

29 # Placeholder values for when grpc is not available 

30 descriptor_pool = None # type: ignore 

31 json_format = None # type: ignore 

32 message_factory = None # type: ignore 

33 FileDescriptorProto = None # type: ignore 

34 grpc = None # type: ignore 

35 reflection_pb2 = None # type: ignore 

36 reflection_pb2_grpc = None # type: ignore 

37 

38# First-Party 

39from mcpgateway.services.logging_service import LoggingService 

40 

41# Initialize logging 

42logging_service = LoggingService() 

43logger = logging_service.get_logger(__name__) 

44 

45 

46PROTO_TO_JSON_TYPE_MAP = { 

47 1: "number", # TYPE_DOUBLE 

48 2: "number", # TYPE_FLOAT 

49 3: "integer", # TYPE_INT64 

50 4: "integer", # TYPE_UINT64 

51 5: "integer", # TYPE_INT32 

52 8: "boolean", # TYPE_BOOL 

53 9: "string", # TYPE_STRING 

54 12: "string", # TYPE_BYTES (base64) 

55 13: "integer", # TYPE_UINT32 

56 14: "string", # TYPE_ENUM 

57} 

58 

59 

60class GrpcEndpoint: 

61 """Wrapper around a gRPC channel with reflection-based introspection.""" 

62 

63 def __init__( 

64 self, 

65 target: str, 

66 reflection_enabled: bool = True, 

67 tls_enabled: bool = False, 

68 tls_cert_path: Optional[str] = None, 

69 tls_key_path: Optional[str] = None, 

70 metadata: Optional[Dict[str, str]] = None, 

71 ): 

72 """Initialize gRPC endpoint. 

73 

74 Args: 

75 target: gRPC server address (host:port) 

76 reflection_enabled: Enable server reflection for discovery 

77 tls_enabled: Use TLS for connection 

78 tls_cert_path: Path to TLS certificate 

79 tls_key_path: Path to TLS key 

80 metadata: gRPC metadata headers 

81 """ 

82 self._target = target 

83 self._reflection_enabled = reflection_enabled 

84 self._tls_enabled = tls_enabled 

85 self._tls_cert_path = tls_cert_path 

86 self._tls_key_path = tls_key_path 

87 self._metadata = metadata or {} 

88 self._channel: Optional[grpc.Channel] = None 

89 self._services: Dict[str, Any] = {} 

90 self._descriptors: Dict[str, Any] = {} 

91 self._pool = descriptor_pool.Default() 

92 self._factory = message_factory.MessageFactory() 

93 

94 async def start(self) -> None: 

95 """Initialize gRPC channel and perform reflection if enabled.""" 

96 logger.info(f"Starting gRPC endpoint connection to {self._target}") 

97 

98 # Create channel 

99 if self._tls_enabled: 

100 if self._tls_cert_path and self._tls_key_path: 

101 cert = await asyncio.to_thread(Path(self._tls_cert_path).read_bytes) 

102 key = await asyncio.to_thread(Path(self._tls_key_path).read_bytes) 

103 credentials = grpc.ssl_channel_credentials(root_certificates=cert, private_key=key) 

104 self._channel = grpc.secure_channel(self._target, credentials) 

105 else: 

106 credentials = grpc.ssl_channel_credentials() 

107 self._channel = grpc.secure_channel(self._target, credentials) 

108 else: 

109 self._channel = grpc.insecure_channel(self._target) 

110 

111 # Perform reflection if enabled 

112 if self._reflection_enabled: 

113 await self._discover_services() 

114 

115 async def _discover_services(self) -> None: 

116 """Use gRPC reflection to discover services and methods. 

117 

118 Raises: 

119 Exception: If service discovery fails 

120 """ 

121 logger.info(f"Discovering services on {self._target} via reflection") 

122 

123 try: 

124 stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) 

125 

126 # List all services 

127 request = reflection_pb2.ServerReflectionRequest(list_services="") # pylint: disable=no-member 

128 

129 response = stub.ServerReflectionInfo(iter([request])) 

130 

131 service_names = [] 

132 for resp in response: 

133 if resp.HasField("list_services_response"): 

134 for svc in resp.list_services_response.service: 

135 service_name = svc.name 

136 # Skip reflection service itself 

137 if "ServerReflection" in service_name: 

138 continue 

139 service_names.append(service_name) 

140 logger.debug(f"Discovered service: {service_name}") 

141 

142 # Get file descriptors for each service 

143 for service_name in service_names: 

144 await self._discover_service_details(stub, service_name) 

145 

146 logger.info(f"Discovered {len(self._services)} gRPC services") 

147 

148 except Exception as e: 

149 logger.error(f"Service discovery failed: {e}") 

150 raise 

151 

152 async def _discover_service_details(self, stub, service_name: str) -> None: 

153 """Discover detailed information about a service including methods and message types. 

154 

155 Args: 

156 stub: gRPC reflection stub 

157 service_name: Name of the service to discover 

158 """ 

159 try: # pylint: disable=too-many-nested-blocks 

160 # Request file descriptor containing this service 

161 request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=service_name) # pylint: disable=no-member 

162 

163 response = stub.ServerReflectionInfo(iter([request])) 

164 

165 for resp in response: 

166 if resp.HasField("file_descriptor_response"): 

167 # Process all file descriptors 

168 for file_desc_proto_bytes in resp.file_descriptor_response.file_descriptor_proto: 

169 file_desc_proto = FileDescriptorProto() 

170 file_desc_proto.ParseFromString(file_desc_proto_bytes) 

171 

172 # Add to pool (ignore if already exists) 

173 try: 

174 self._pool.Add(file_desc_proto) 

175 except Exception as e: # pylint: disable=broad-except 

176 # Descriptor already in pool, safe to skip 

177 logger.debug(f"Descriptor already in pool: {e}") 

178 

179 # Extract service and method information 

180 for service_desc in file_desc_proto.service: 

181 if service_desc.name in service_name or service_name.endswith(service_desc.name): 

182 full_service_name = f"{file_desc_proto.package}.{service_desc.name}" if file_desc_proto.package else service_desc.name 

183 

184 methods = [] 

185 for method_desc in service_desc.method: 

186 methods.append( 

187 { 

188 "name": method_desc.name, 

189 "input_type": method_desc.input_type, 

190 "output_type": method_desc.output_type, 

191 "client_streaming": method_desc.client_streaming, 

192 "server_streaming": method_desc.server_streaming, 

193 } 

194 ) 

195 

196 self._services[full_service_name] = { 

197 "name": full_service_name, 

198 "methods": methods, 

199 "package": file_desc_proto.package, 

200 } 

201 

202 # Store descriptors for this service 

203 self._descriptors[full_service_name] = file_desc_proto 

204 

205 logger.debug(f"Service {full_service_name} has {len(methods)} methods") 

206 

207 except Exception as e: 

208 logger.warning(f"Failed to get details for {service_name}: {e}") 

209 # Still add basic service info even if details fail 

210 self._services[service_name] = { 

211 "name": service_name, 

212 "methods": [], 

213 } 

214 

215 async def invoke( 

216 self, 

217 service: str, 

218 method: str, 

219 request_data: Dict[str, Any], 

220 ) -> Dict[str, Any]: 

221 """Invoke a gRPC method with JSON request data. 

222 

223 Args: 

224 service: Service name 

225 method: Method name 

226 request_data: JSON request data 

227 

228 Returns: 

229 JSON response data 

230 

231 Raises: 

232 ValueError: If service or method not found 

233 Exception: If invocation fails 

234 """ 

235 logger.debug(f"Invoking {service}.{method}") 

236 

237 # Get method info 

238 if service not in self._services: 

239 raise ValueError(f"Service {service} not found") 

240 

241 method_info = None 

242 for m in self._services[service]["methods"]: 

243 if m["name"] == method: 

244 method_info = m 

245 break 

246 

247 if not method_info: 

248 raise ValueError(f"Method {method} not found in service {service}") 

249 

250 if method_info["client_streaming"] or method_info["server_streaming"]: 

251 raise ValueError(f"Method {method} is streaming, use invoke_streaming instead") 

252 

253 # Get message descriptors from pool 

254 input_type = method_info["input_type"].lstrip(".") 

255 output_type = method_info["output_type"].lstrip(".") 

256 

257 try: 

258 input_desc = self._pool.FindMessageTypeByName(input_type) 

259 output_desc = self._pool.FindMessageTypeByName(output_type) 

260 except KeyError as e: 

261 raise ValueError(f"Message type not found in descriptor pool: {e}") 

262 

263 # Create message classes 

264 # pylint: disable=no-member 

265 request_class = self._factory.GetPrototype(input_desc) 

266 response_class = self._factory.GetPrototype(output_desc) 

267 

268 # Convert JSON to protobuf message 

269 request_msg = json_format.ParseDict(request_data, request_class()) 

270 

271 # Create generic stub and invoke 

272 channel = self._channel 

273 method_path = f"/{service}/{method}" 

274 

275 # Use generic_stub for dynamic invocation 

276 response_msg = await asyncio.get_event_loop().run_in_executor( 

277 None, channel.unary_unary(method_path, request_serializer=request_msg.SerializeToString, response_deserializer=response_class.FromString), request_msg 

278 ) 

279 

280 # Convert protobuf response to JSON 

281 # pylint: disable=unexpected-keyword-arg 

282 response_dict = json_format.MessageToDict(response_msg, preserving_proto_field_name=True, including_default_value_fields=True) 

283 

284 logger.debug(f"Successfully invoked {service}.{method}") 

285 return response_dict 

286 

287 async def invoke_streaming( 

288 self, 

289 service: str, 

290 method: str, 

291 request_data: Dict[str, Any], 

292 ) -> AsyncGenerator[Dict[str, Any], None]: 

293 """Invoke a server-streaming gRPC method. 

294 

295 Args: 

296 service: Service name 

297 method: Method name 

298 request_data: JSON request data 

299 

300 Yields: 

301 JSON response chunks 

302 

303 Raises: 

304 ValueError: If service or method not found or not streaming 

305 grpc.RpcError: If streaming RPC fails 

306 """ 

307 logger.debug(f"Invoking streaming {service}.{method}") 

308 

309 # Get method info 

310 if service not in self._services: 

311 raise ValueError(f"Service {service} not found") 

312 

313 method_info = None 

314 for m in self._services[service]["methods"]: 

315 if m["name"] == method: 

316 method_info = m 

317 break 

318 

319 if not method_info: 

320 raise ValueError(f"Method {method} not found in service {service}") 

321 

322 if not method_info["server_streaming"]: 

323 raise ValueError(f"Method {method} is not server-streaming") 

324 

325 if method_info["client_streaming"]: 

326 raise ValueError("Client streaming not yet supported") 

327 

328 # Get message descriptors from pool 

329 input_type = method_info["input_type"].lstrip(".") 

330 output_type = method_info["output_type"].lstrip(".") 

331 

332 try: 

333 input_desc = self._pool.FindMessageTypeByName(input_type) 

334 output_desc = self._pool.FindMessageTypeByName(output_type) 

335 except KeyError as e: 

336 raise ValueError(f"Message type not found in descriptor pool: {e}") 

337 

338 # Create message classes 

339 # pylint: disable=no-member 

340 request_class = self._factory.GetPrototype(input_desc) 

341 response_class = self._factory.GetPrototype(output_desc) 

342 

343 # Convert JSON to protobuf message 

344 request_msg = json_format.ParseDict(request_data, request_class()) 

345 

346 # Create streaming call 

347 channel = self._channel 

348 method_path = f"/{service}/{method}" 

349 

350 stream_call = channel.unary_stream(method_path, request_serializer=request_msg.SerializeToString, response_deserializer=response_class.FromString)(request_msg) 

351 

352 # Yield responses as they arrive 

353 try: 

354 for response_msg in stream_call: 

355 # pylint: disable=unexpected-keyword-arg 

356 response_dict = json_format.MessageToDict(response_msg, preserving_proto_field_name=True, including_default_value_fields=True) 

357 yield response_dict 

358 except grpc.RpcError as e: 

359 logger.error(f"Streaming RPC error: {e}") 

360 raise 

361 

362 logger.debug(f"Streaming complete for {service}.{method}") 

363 

364 async def close(self) -> None: 

365 """Close the gRPC channel.""" 

366 if self._channel: 

367 self._channel.close() 

368 logger.info(f"Closed gRPC connection to {self._target}") 

369 

370 def get_services(self) -> List[str]: 

371 """Get list of discovered service names. 

372 

373 Returns: 

374 List of service names 

375 """ 

376 return list(self._services.keys()) 

377 

378 def get_methods(self, service: str) -> List[str]: 

379 """Get list of methods for a service. 

380 

381 Args: 

382 service: Service name 

383 

384 Returns: 

385 List of method names 

386 """ 

387 if service in self._services: 

388 return [m["name"] for m in self._services[service].get("methods", [])] 

389 return [] 

390 

391 

392class GrpcToMcpTranslator: 

393 """Translates between gRPC and MCP protocols.""" 

394 

395 def __init__(self, endpoint: GrpcEndpoint): 

396 """Initialize translator. 

397 

398 Args: 

399 endpoint: gRPC endpoint to translate 

400 """ 

401 self._endpoint = endpoint 

402 

403 def grpc_service_to_mcp_server(self, service_name: str) -> Dict[str, Any]: 

404 """Convert a gRPC service to an MCP virtual server definition. 

405 

406 Args: 

407 service_name: gRPC service name 

408 

409 Returns: 

410 MCP server definition 

411 """ 

412 return { 

413 "name": service_name, 

414 "description": f"gRPC service: {service_name}", 

415 "transport": ["sse", "http"], 

416 "tools": self.grpc_methods_to_mcp_tools(service_name), 

417 } 

418 

419 def grpc_methods_to_mcp_tools(self, service_name: str) -> List[Dict[str, Any]]: 

420 """Convert gRPC methods to MCP tool definitions. 

421 

422 Args: 

423 service_name: gRPC service name 

424 

425 Returns: 

426 List of MCP tool definitions 

427 """ 

428 # pylint: disable=protected-access 

429 if service_name not in self._endpoint._services: 

430 return [] 

431 

432 service_info = self._endpoint._services[service_name] 

433 tools = [] 

434 

435 for method_info in service_info.get("methods", []): 

436 method_name = method_info["name"] 

437 input_type = method_info["input_type"].lstrip(".") 

438 

439 # Try to get input schema from descriptor 

440 try: 

441 input_desc = self._endpoint._pool.FindMessageTypeByName(input_type) 

442 input_schema = self.protobuf_to_json_schema(input_desc) 

443 except KeyError: 

444 # Fallback to generic schema if descriptor not found 

445 input_schema = {"type": "object", "properties": {}} 

446 

447 tools.append({"name": f"{service_name}.{method_name}", "description": f"gRPC method {service_name}.{method_name}", "inputSchema": input_schema}) 

448 

449 return tools 

450 

451 def protobuf_to_json_schema(self, message_descriptor: Any) -> Dict[str, Any]: 

452 """Convert protobuf message descriptor to JSON schema. 

453 

454 Args: 

455 message_descriptor: Protobuf message descriptor 

456 

457 Returns: 

458 JSON schema 

459 """ 

460 schema = {"type": "object", "properties": {}, "required": []} 

461 

462 # Iterate over fields in the message 

463 for field in message_descriptor.fields: 

464 field_name = field.name 

465 field_schema = self._protobuf_field_to_json_schema(field) 

466 schema["properties"][field_name] = field_schema 

467 

468 # Add to required if field is required (proto2/proto3 handling) 

469 if hasattr(field, "label") and field.label == 2: # LABEL_REQUIRED 

470 schema["required"].append(field_name) 

471 

472 return schema 

473 

474 def _protobuf_field_to_json_schema(self, field: Any) -> Dict[str, Any]: 

475 """Convert a protobuf field to JSON schema type. 

476 

477 Args: 

478 field: Protobuf field descriptor 

479 

480 Returns: 

481 JSON schema for the field 

482 """ 

483 # Map protobuf types to JSON schema types 

484 type_map = { 

485 1: "number", # TYPE_DOUBLE 

486 2: "number", # TYPE_FLOAT 

487 3: "integer", # TYPE_INT64 

488 4: "integer", # TYPE_UINT64 

489 5: "integer", # TYPE_INT32 

490 6: "integer", # TYPE_FIXED64 

491 7: "integer", # TYPE_FIXED32 

492 8: "boolean", # TYPE_BOOL 

493 9: "string", # TYPE_STRING 

494 11: "object", # TYPE_MESSAGE 

495 12: "string", # TYPE_BYTES (base64) 

496 13: "integer", # TYPE_UINT32 

497 14: "string", # TYPE_ENUM 

498 15: "integer", # TYPE_SFIXED32 

499 16: "integer", # TYPE_SFIXED64 

500 17: "integer", # TYPE_SINT32 

501 18: "integer", # TYPE_SINT64 

502 } 

503 

504 field_type = type_map.get(field.type, "string") 

505 

506 # Handle repeated fields 

507 if hasattr(field, "label") and field.label == 3: # LABEL_REPEATED 

508 return {"type": "array", "items": {"type": field_type}} 

509 

510 # Handle message types (nested objects) 

511 if field.type == 11: # TYPE_MESSAGE 

512 try: 

513 nested_desc = field.message_type 

514 return self.protobuf_to_json_schema(nested_desc) 

515 except Exception: 

516 return {"type": "object"} 

517 

518 return {"type": field_type} 

519 

520 

521# Utility functions for CLI usage 

522 

523 

524async def expose_grpc_via_sse( 

525 target: str, 

526 port: int = 9000, 

527 tls_enabled: bool = False, 

528 tls_cert: Optional[str] = None, 

529 tls_key: Optional[str] = None, 

530 metadata: Optional[Dict[str, str]] = None, 

531) -> None: 

532 """Expose a gRPC service via SSE/HTTP endpoints. 

533 

534 Args: 

535 target: gRPC server address (host:port) 

536 port: HTTP port to listen on 

537 tls_enabled: Use TLS for gRPC connection 

538 tls_cert: TLS certificate path 

539 tls_key: TLS key path 

540 metadata: gRPC metadata headers 

541 """ 

542 logger.info(f"Exposing gRPC service {target} via SSE on port {port}") 

543 

544 endpoint = GrpcEndpoint( 

545 target=target, 

546 reflection_enabled=True, 

547 tls_enabled=tls_enabled, 

548 tls_cert_path=tls_cert, 

549 tls_key_path=tls_key, 

550 metadata=metadata, 

551 ) 

552 

553 try: 

554 await endpoint.start() 

555 

556 logger.info(f"gRPC service exposed. Discovered services: {endpoint.get_services()}") 

557 logger.info("To expose via HTTP/SSE, register this service in the gateway admin UI") 

558 logger.info(f" Target: {target}") 

559 logger.info(f" Discovered: {len(endpoint.get_services())} services") 

560 

561 # Keep endpoint connection alive 

562 # Note: For full HTTP/SSE exposure, register the service via the gateway admin API 

563 # which will make it accessible through the existing multi-protocol server infrastructure 

564 while True: 

565 await asyncio.sleep(1) 

566 

567 except KeyboardInterrupt: 

568 logger.info("Shutting down...") 

569 finally: 

570 await endpoint.close()