Coverage for mcpgateway / translate_grpc.py: 100%
224 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/translate_grpc.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: MCP Gateway Contributors
7gRPC to MCP Translation Module
9This module provides gRPC to MCP protocol translation capabilities.
10It enables exposing gRPC services as MCP tools via HTTP/SSE endpoints
11using automatic service discovery through gRPC server reflection.
12"""
14# Standard
15import asyncio
16from pathlib import Path
17from typing import Any, AsyncGenerator, Dict, List, Optional
19try:
20 # Third-Party
21 from google.protobuf import descriptor_pool, json_format, message_factory
22 from google.protobuf.descriptor_pb2 import FileDescriptorProto # pylint: disable=no-name-in-module
23 import grpc
24 from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc # pylint: disable=no-member
26 GRPC_AVAILABLE = True
27except ImportError:
28 GRPC_AVAILABLE = False
29 # Placeholder values for when grpc is not available
30 descriptor_pool = None # type: ignore
31 json_format = None # type: ignore
32 message_factory = None # type: ignore
33 FileDescriptorProto = None # type: ignore
34 grpc = None # type: ignore
35 reflection_pb2 = None # type: ignore
36 reflection_pb2_grpc = None # type: ignore
38# First-Party
39from mcpgateway.services.logging_service import LoggingService
41# Initialize logging
42logging_service = LoggingService()
43logger = logging_service.get_logger(__name__)
46PROTO_TO_JSON_TYPE_MAP = {
47 1: "number", # TYPE_DOUBLE
48 2: "number", # TYPE_FLOAT
49 3: "integer", # TYPE_INT64
50 4: "integer", # TYPE_UINT64
51 5: "integer", # TYPE_INT32
52 8: "boolean", # TYPE_BOOL
53 9: "string", # TYPE_STRING
54 12: "string", # TYPE_BYTES (base64)
55 13: "integer", # TYPE_UINT32
56 14: "string", # TYPE_ENUM
57}
60class GrpcEndpoint:
61 """Wrapper around a gRPC channel with reflection-based introspection."""
63 def __init__(
64 self,
65 target: str,
66 reflection_enabled: bool = True,
67 tls_enabled: bool = False,
68 tls_cert_path: Optional[str] = None,
69 tls_key_path: Optional[str] = None,
70 metadata: Optional[Dict[str, str]] = None,
71 ):
72 """Initialize gRPC endpoint.
74 Args:
75 target: gRPC server address (host:port)
76 reflection_enabled: Enable server reflection for discovery
77 tls_enabled: Use TLS for connection
78 tls_cert_path: Path to TLS certificate
79 tls_key_path: Path to TLS key
80 metadata: gRPC metadata headers
81 """
82 self._target = target
83 self._reflection_enabled = reflection_enabled
84 self._tls_enabled = tls_enabled
85 self._tls_cert_path = tls_cert_path
86 self._tls_key_path = tls_key_path
87 self._metadata = metadata or {}
88 self._channel: Optional[grpc.Channel] = None
89 self._services: Dict[str, Any] = {}
90 self._descriptors: Dict[str, Any] = {}
91 self._pool = descriptor_pool.Default()
92 self._factory = message_factory.MessageFactory()
94 async def start(self) -> None:
95 """Initialize gRPC channel and perform reflection if enabled."""
96 logger.info(f"Starting gRPC endpoint connection to {self._target}")
98 # Create channel
99 if self._tls_enabled:
100 if self._tls_cert_path and self._tls_key_path:
101 cert = await asyncio.to_thread(Path(self._tls_cert_path).read_bytes)
102 key = await asyncio.to_thread(Path(self._tls_key_path).read_bytes)
103 credentials = grpc.ssl_channel_credentials(root_certificates=cert, private_key=key)
104 self._channel = grpc.secure_channel(self._target, credentials)
105 else:
106 credentials = grpc.ssl_channel_credentials()
107 self._channel = grpc.secure_channel(self._target, credentials)
108 else:
109 self._channel = grpc.insecure_channel(self._target)
111 # Perform reflection if enabled
112 if self._reflection_enabled:
113 await self._discover_services()
115 async def _discover_services(self) -> None:
116 """Use gRPC reflection to discover services and methods.
118 Raises:
119 Exception: If service discovery fails
120 """
121 logger.info(f"Discovering services on {self._target} via reflection")
123 try:
124 stub = reflection_pb2_grpc.ServerReflectionStub(self._channel)
126 # List all services
127 request = reflection_pb2.ServerReflectionRequest(list_services="") # pylint: disable=no-member
129 response = stub.ServerReflectionInfo(iter([request]))
131 service_names = []
132 for resp in response:
133 if resp.HasField("list_services_response"):
134 for svc in resp.list_services_response.service:
135 service_name = svc.name
136 # Skip reflection service itself
137 if "ServerReflection" in service_name:
138 continue
139 service_names.append(service_name)
140 logger.debug(f"Discovered service: {service_name}")
142 # Get file descriptors for each service
143 for service_name in service_names:
144 await self._discover_service_details(stub, service_name)
146 logger.info(f"Discovered {len(self._services)} gRPC services")
148 except Exception as e:
149 logger.error(f"Service discovery failed: {e}")
150 raise
152 async def _discover_service_details(self, stub, service_name: str) -> None:
153 """Discover detailed information about a service including methods and message types.
155 Args:
156 stub: gRPC reflection stub
157 service_name: Name of the service to discover
158 """
159 try: # pylint: disable=too-many-nested-blocks
160 # Request file descriptor containing this service
161 request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=service_name) # pylint: disable=no-member
163 response = stub.ServerReflectionInfo(iter([request]))
165 for resp in response:
166 if resp.HasField("file_descriptor_response"):
167 # Process all file descriptors
168 for file_desc_proto_bytes in resp.file_descriptor_response.file_descriptor_proto:
169 file_desc_proto = FileDescriptorProto()
170 file_desc_proto.ParseFromString(file_desc_proto_bytes)
172 # Add to pool (ignore if already exists)
173 try:
174 self._pool.Add(file_desc_proto)
175 except Exception as e: # pylint: disable=broad-except
176 # Descriptor already in pool, safe to skip
177 logger.debug(f"Descriptor already in pool: {e}")
179 # Extract service and method information
180 for service_desc in file_desc_proto.service:
181 if service_desc.name in service_name or service_name.endswith(service_desc.name):
182 full_service_name = f"{file_desc_proto.package}.{service_desc.name}" if file_desc_proto.package else service_desc.name
184 methods = []
185 for method_desc in service_desc.method:
186 methods.append(
187 {
188 "name": method_desc.name,
189 "input_type": method_desc.input_type,
190 "output_type": method_desc.output_type,
191 "client_streaming": method_desc.client_streaming,
192 "server_streaming": method_desc.server_streaming,
193 }
194 )
196 self._services[full_service_name] = {
197 "name": full_service_name,
198 "methods": methods,
199 "package": file_desc_proto.package,
200 }
202 # Store descriptors for this service
203 self._descriptors[full_service_name] = file_desc_proto
205 logger.debug(f"Service {full_service_name} has {len(methods)} methods")
207 except Exception as e:
208 logger.warning(f"Failed to get details for {service_name}: {e}")
209 # Still add basic service info even if details fail
210 self._services[service_name] = {
211 "name": service_name,
212 "methods": [],
213 }
215 async def invoke(
216 self,
217 service: str,
218 method: str,
219 request_data: Dict[str, Any],
220 ) -> Dict[str, Any]:
221 """Invoke a gRPC method with JSON request data.
223 Args:
224 service: Service name
225 method: Method name
226 request_data: JSON request data
228 Returns:
229 JSON response data
231 Raises:
232 ValueError: If service or method not found
233 Exception: If invocation fails
234 """
235 logger.debug(f"Invoking {service}.{method}")
237 # Get method info
238 if service not in self._services:
239 raise ValueError(f"Service {service} not found")
241 method_info = None
242 for m in self._services[service]["methods"]:
243 if m["name"] == method:
244 method_info = m
245 break
247 if not method_info:
248 raise ValueError(f"Method {method} not found in service {service}")
250 if method_info["client_streaming"] or method_info["server_streaming"]:
251 raise ValueError(f"Method {method} is streaming, use invoke_streaming instead")
253 # Get message descriptors from pool
254 input_type = method_info["input_type"].lstrip(".")
255 output_type = method_info["output_type"].lstrip(".")
257 try:
258 input_desc = self._pool.FindMessageTypeByName(input_type)
259 output_desc = self._pool.FindMessageTypeByName(output_type)
260 except KeyError as e:
261 raise ValueError(f"Message type not found in descriptor pool: {e}")
263 # Create message classes
264 # pylint: disable=no-member
265 request_class = self._factory.GetPrototype(input_desc)
266 response_class = self._factory.GetPrototype(output_desc)
268 # Convert JSON to protobuf message
269 request_msg = json_format.ParseDict(request_data, request_class())
271 # Create generic stub and invoke
272 channel = self._channel
273 method_path = f"/{service}/{method}"
275 # Use generic_stub for dynamic invocation
276 response_msg = await asyncio.get_event_loop().run_in_executor(
277 None, channel.unary_unary(method_path, request_serializer=request_msg.SerializeToString, response_deserializer=response_class.FromString), request_msg
278 )
280 # Convert protobuf response to JSON
281 # pylint: disable=unexpected-keyword-arg
282 response_dict = json_format.MessageToDict(response_msg, preserving_proto_field_name=True, including_default_value_fields=True)
284 logger.debug(f"Successfully invoked {service}.{method}")
285 return response_dict
287 async def invoke_streaming(
288 self,
289 service: str,
290 method: str,
291 request_data: Dict[str, Any],
292 ) -> AsyncGenerator[Dict[str, Any], None]:
293 """Invoke a server-streaming gRPC method.
295 Args:
296 service: Service name
297 method: Method name
298 request_data: JSON request data
300 Yields:
301 JSON response chunks
303 Raises:
304 ValueError: If service or method not found or not streaming
305 grpc.RpcError: If streaming RPC fails
306 """
307 logger.debug(f"Invoking streaming {service}.{method}")
309 # Get method info
310 if service not in self._services:
311 raise ValueError(f"Service {service} not found")
313 method_info = None
314 for m in self._services[service]["methods"]:
315 if m["name"] == method:
316 method_info = m
317 break
319 if not method_info:
320 raise ValueError(f"Method {method} not found in service {service}")
322 if not method_info["server_streaming"]:
323 raise ValueError(f"Method {method} is not server-streaming")
325 if method_info["client_streaming"]:
326 raise ValueError("Client streaming not yet supported")
328 # Get message descriptors from pool
329 input_type = method_info["input_type"].lstrip(".")
330 output_type = method_info["output_type"].lstrip(".")
332 try:
333 input_desc = self._pool.FindMessageTypeByName(input_type)
334 output_desc = self._pool.FindMessageTypeByName(output_type)
335 except KeyError as e:
336 raise ValueError(f"Message type not found in descriptor pool: {e}")
338 # Create message classes
339 # pylint: disable=no-member
340 request_class = self._factory.GetPrototype(input_desc)
341 response_class = self._factory.GetPrototype(output_desc)
343 # Convert JSON to protobuf message
344 request_msg = json_format.ParseDict(request_data, request_class())
346 # Create streaming call
347 channel = self._channel
348 method_path = f"/{service}/{method}"
350 stream_call = channel.unary_stream(method_path, request_serializer=request_msg.SerializeToString, response_deserializer=response_class.FromString)(request_msg)
352 # Yield responses as they arrive
353 try:
354 for response_msg in stream_call:
355 # pylint: disable=unexpected-keyword-arg
356 response_dict = json_format.MessageToDict(response_msg, preserving_proto_field_name=True, including_default_value_fields=True)
357 yield response_dict
358 except grpc.RpcError as e:
359 logger.error(f"Streaming RPC error: {e}")
360 raise
362 logger.debug(f"Streaming complete for {service}.{method}")
364 async def close(self) -> None:
365 """Close the gRPC channel."""
366 if self._channel:
367 self._channel.close()
368 logger.info(f"Closed gRPC connection to {self._target}")
370 def get_services(self) -> List[str]:
371 """Get list of discovered service names.
373 Returns:
374 List of service names
375 """
376 return list(self._services.keys())
378 def get_methods(self, service: str) -> List[str]:
379 """Get list of methods for a service.
381 Args:
382 service: Service name
384 Returns:
385 List of method names
386 """
387 if service in self._services:
388 return [m["name"] for m in self._services[service].get("methods", [])]
389 return []
392class GrpcToMcpTranslator:
393 """Translates between gRPC and MCP protocols."""
395 def __init__(self, endpoint: GrpcEndpoint):
396 """Initialize translator.
398 Args:
399 endpoint: gRPC endpoint to translate
400 """
401 self._endpoint = endpoint
403 def grpc_service_to_mcp_server(self, service_name: str) -> Dict[str, Any]:
404 """Convert a gRPC service to an MCP virtual server definition.
406 Args:
407 service_name: gRPC service name
409 Returns:
410 MCP server definition
411 """
412 return {
413 "name": service_name,
414 "description": f"gRPC service: {service_name}",
415 "transport": ["sse", "http"],
416 "tools": self.grpc_methods_to_mcp_tools(service_name),
417 }
419 def grpc_methods_to_mcp_tools(self, service_name: str) -> List[Dict[str, Any]]:
420 """Convert gRPC methods to MCP tool definitions.
422 Args:
423 service_name: gRPC service name
425 Returns:
426 List of MCP tool definitions
427 """
428 # pylint: disable=protected-access
429 if service_name not in self._endpoint._services:
430 return []
432 service_info = self._endpoint._services[service_name]
433 tools = []
435 for method_info in service_info.get("methods", []):
436 method_name = method_info["name"]
437 input_type = method_info["input_type"].lstrip(".")
439 # Try to get input schema from descriptor
440 try:
441 input_desc = self._endpoint._pool.FindMessageTypeByName(input_type)
442 input_schema = self.protobuf_to_json_schema(input_desc)
443 except KeyError:
444 # Fallback to generic schema if descriptor not found
445 input_schema = {"type": "object", "properties": {}}
447 tools.append({"name": f"{service_name}.{method_name}", "description": f"gRPC method {service_name}.{method_name}", "inputSchema": input_schema})
449 return tools
451 def protobuf_to_json_schema(self, message_descriptor: Any) -> Dict[str, Any]:
452 """Convert protobuf message descriptor to JSON schema.
454 Args:
455 message_descriptor: Protobuf message descriptor
457 Returns:
458 JSON schema
459 """
460 schema = {"type": "object", "properties": {}, "required": []}
462 # Iterate over fields in the message
463 for field in message_descriptor.fields:
464 field_name = field.name
465 field_schema = self._protobuf_field_to_json_schema(field)
466 schema["properties"][field_name] = field_schema
468 # Add to required if field is required (proto2/proto3 handling)
469 if hasattr(field, "label") and field.label == 2: # LABEL_REQUIRED
470 schema["required"].append(field_name)
472 return schema
474 def _protobuf_field_to_json_schema(self, field: Any) -> Dict[str, Any]:
475 """Convert a protobuf field to JSON schema type.
477 Args:
478 field: Protobuf field descriptor
480 Returns:
481 JSON schema for the field
482 """
483 # Map protobuf types to JSON schema types
484 type_map = {
485 1: "number", # TYPE_DOUBLE
486 2: "number", # TYPE_FLOAT
487 3: "integer", # TYPE_INT64
488 4: "integer", # TYPE_UINT64
489 5: "integer", # TYPE_INT32
490 6: "integer", # TYPE_FIXED64
491 7: "integer", # TYPE_FIXED32
492 8: "boolean", # TYPE_BOOL
493 9: "string", # TYPE_STRING
494 11: "object", # TYPE_MESSAGE
495 12: "string", # TYPE_BYTES (base64)
496 13: "integer", # TYPE_UINT32
497 14: "string", # TYPE_ENUM
498 15: "integer", # TYPE_SFIXED32
499 16: "integer", # TYPE_SFIXED64
500 17: "integer", # TYPE_SINT32
501 18: "integer", # TYPE_SINT64
502 }
504 field_type = type_map.get(field.type, "string")
506 # Handle repeated fields
507 if hasattr(field, "label") and field.label == 3: # LABEL_REPEATED
508 return {"type": "array", "items": {"type": field_type}}
510 # Handle message types (nested objects)
511 if field.type == 11: # TYPE_MESSAGE
512 try:
513 nested_desc = field.message_type
514 return self.protobuf_to_json_schema(nested_desc)
515 except Exception:
516 return {"type": "object"}
518 return {"type": field_type}
521# Utility functions for CLI usage
524async def expose_grpc_via_sse(
525 target: str,
526 port: int = 9000,
527 tls_enabled: bool = False,
528 tls_cert: Optional[str] = None,
529 tls_key: Optional[str] = None,
530 metadata: Optional[Dict[str, str]] = None,
531) -> None:
532 """Expose a gRPC service via SSE/HTTP endpoints.
534 Args:
535 target: gRPC server address (host:port)
536 port: HTTP port to listen on
537 tls_enabled: Use TLS for gRPC connection
538 tls_cert: TLS certificate path
539 tls_key: TLS key path
540 metadata: gRPC metadata headers
541 """
542 logger.info(f"Exposing gRPC service {target} via SSE on port {port}")
544 endpoint = GrpcEndpoint(
545 target=target,
546 reflection_enabled=True,
547 tls_enabled=tls_enabled,
548 tls_cert_path=tls_cert,
549 tls_key_path=tls_key,
550 metadata=metadata,
551 )
553 try:
554 await endpoint.start()
556 logger.info(f"gRPC service exposed. Discovered services: {endpoint.get_services()}")
557 logger.info("To expose via HTTP/SSE, register this service in the gateway admin UI")
558 logger.info(f" Target: {target}")
559 logger.info(f" Discovered: {len(endpoint.get_services())} services")
561 # Keep endpoint connection alive
562 # Note: For full HTTP/SSE exposure, register the service via the gateway admin API
563 # which will make it accessible through the existing multi-protocol server infrastructure
564 while True:
565 await asyncio.sleep(1)
567 except KeyboardInterrupt:
568 logger.info("Shutting down...")
569 finally:
570 await endpoint.close()