Coverage for mcpgateway / services / grpc_service.py: 95%
229 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/grpc_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: MCP Gateway Contributors
7gRPC Service Management
9This module implements gRPC service management for the MCP Gateway.
10It handles gRPC service registration, reflection-based discovery, listing,
11retrieval, updates, activation toggling, and deletion.
12"""
14# Standard
15import asyncio
16from datetime import datetime, timezone
17from pathlib import Path
18from typing import Any, Dict, List, Optional
20try:
21 # Third-Party
22 import grpc
23 from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc
25 GRPC_AVAILABLE = True
26except ImportError:
27 GRPC_AVAILABLE = False
28 # grpc module will not be used if not available
29 grpc = None # type: ignore
30 reflection_pb2 = None # type: ignore
31 reflection_pb2_grpc = None # type: ignore
33# Third-Party
34from sqlalchemy import and_, desc, select
35from sqlalchemy.orm import Session
37# First-Party
38from mcpgateway.db import GrpcService as DbGrpcService
39from mcpgateway.schemas import GrpcServiceCreate, GrpcServiceRead, GrpcServiceUpdate
40from mcpgateway.services.logging_service import LoggingService
41from mcpgateway.services.team_management_service import TeamManagementService
43# Initialize logging
44logging_service = LoggingService()
45logger = logging_service.get_logger(__name__)
48class GrpcServiceError(Exception):
49 """Base class for gRPC service-related errors."""
52class GrpcServiceNotFoundError(GrpcServiceError):
53 """Raised when a requested gRPC service is not found."""
56class GrpcServiceNameConflictError(GrpcServiceError):
57 """Raised when a gRPC service name conflicts with an existing one."""
59 def __init__(self, name: str, is_active: bool = True, service_id: Optional[str] = None):
60 """Initialize the GrpcServiceNameConflictError.
62 Args:
63 name: The conflicting gRPC service name
64 is_active: Whether the conflicting service is currently active
65 service_id: The ID of the conflicting service, if known
66 """
67 self.name = name
68 self.is_active = is_active
69 self.service_id = service_id
70 msg = f"gRPC service with name '{name}' already exists"
71 if not is_active:
72 msg += " (inactive)"
73 if service_id:
74 msg += f" (ID: {service_id})"
75 super().__init__(msg)
78class GrpcService:
79 """Service for managing gRPC services with reflection-based discovery."""
81 def __init__(self):
82 """Initialize the gRPC service manager."""
84 async def register_service(
85 self,
86 db: Session,
87 service_data: GrpcServiceCreate,
88 user_email: Optional[str] = None,
89 metadata: Optional[Dict[str, Any]] = None,
90 ) -> GrpcServiceRead:
91 """Register a new gRPC service.
93 Args:
94 db: Database session
95 service_data: gRPC service creation data
96 user_email: Email of the user creating the service
97 metadata: Additional metadata (IP, user agent, etc.)
99 Returns:
100 GrpcServiceRead: The created service
102 Raises:
103 GrpcServiceNameConflictError: If service name already exists
104 """
105 # Check for name conflicts
106 existing = db.execute(select(DbGrpcService).where(DbGrpcService.name == service_data.name)).scalar_one_or_none()
108 if existing:
109 raise GrpcServiceNameConflictError(name=service_data.name, is_active=existing.enabled, service_id=existing.id)
111 # Create service
112 db_service = DbGrpcService(
113 name=service_data.name,
114 target=service_data.target,
115 description=service_data.description,
116 reflection_enabled=service_data.reflection_enabled,
117 tls_enabled=service_data.tls_enabled,
118 tls_cert_path=service_data.tls_cert_path,
119 tls_key_path=service_data.tls_key_path,
120 grpc_metadata=service_data.grpc_metadata or {},
121 tags=service_data.tags or [],
122 team_id=service_data.team_id,
123 owner_email=user_email or service_data.owner_email,
124 visibility=service_data.visibility,
125 created_at=datetime.now(timezone.utc),
126 updated_at=datetime.now(timezone.utc),
127 )
129 # Set audit metadata if provided
130 if metadata:
131 db_service.created_by = user_email
132 db_service.created_from_ip = metadata.get("ip")
133 db_service.created_via = metadata.get("via")
134 db_service.created_user_agent = metadata.get("user_agent")
136 db.add(db_service)
137 db.commit()
138 db.refresh(db_service)
140 logger.info(f"Registered gRPC service: {db_service.name} (target: {db_service.target})")
142 # Perform initial reflection if enabled
143 if db_service.reflection_enabled:
144 try:
145 await self._perform_reflection(db, db_service)
146 except Exception as e:
147 logger.warning(f"Initial reflection failed for {db_service.name}: {e}")
149 return GrpcServiceRead.model_validate(db_service)
151 async def list_services(
152 self,
153 db: Session,
154 include_inactive: bool = False,
155 user_email: Optional[str] = None,
156 team_id: Optional[str] = None,
157 ) -> List[GrpcServiceRead]:
158 """List gRPC services with optional filtering.
160 Args:
161 db: Database session
162 include_inactive: Include disabled services
163 user_email: Filter by user email for team access control
164 team_id: Filter by team ID
166 Returns:
167 List of gRPC services
168 """
169 query = select(DbGrpcService)
171 # Apply team filtering
172 if user_email and team_id:
173 team_service = TeamManagementService(db)
174 team_filter = await team_service.build_team_filter_clause(DbGrpcService, user_email, team_id) # pylint: disable=no-member
175 if team_filter is not None:
176 query = query.where(team_filter)
177 elif team_id:
178 query = query.where(DbGrpcService.team_id == team_id)
180 # Apply active filter
181 if not include_inactive:
182 query = query.where(DbGrpcService.enabled.is_(True)) # pylint: disable=singleton-comparison
184 query = query.order_by(desc(DbGrpcService.created_at))
186 services = db.execute(query).scalars().all()
187 return [GrpcServiceRead.model_validate(svc) for svc in services]
189 async def get_service(
190 self,
191 db: Session,
192 service_id: str,
193 user_email: Optional[str] = None,
194 ) -> GrpcServiceRead:
195 """Get a specific gRPC service by ID.
197 Args:
198 db: Database session
199 service_id: Service ID
200 user_email: Email for team access control
202 Returns:
203 The gRPC service
205 Raises:
206 GrpcServiceNotFoundError: If service not found or access denied
207 """
208 query = select(DbGrpcService).where(DbGrpcService.id == service_id)
210 # Apply team access control
211 if user_email:
212 team_service = TeamManagementService(db)
213 team_filter = await team_service.build_team_filter_clause(DbGrpcService, user_email, None) # pylint: disable=no-member
214 if team_filter is not None: 214 ↛ 217line 214 didn't jump to line 217 because the condition on line 214 was always true
215 query = query.where(team_filter)
217 service = db.execute(query).scalar_one_or_none()
219 if not service:
220 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
222 return GrpcServiceRead.model_validate(service)
224 async def update_service(
225 self,
226 db: Session,
227 service_id: str,
228 service_data: GrpcServiceUpdate,
229 user_email: Optional[str] = None,
230 metadata: Optional[Dict[str, Any]] = None,
231 ) -> GrpcServiceRead:
232 """Update an existing gRPC service.
234 Args:
235 db: Database session
236 service_id: Service ID to update
237 service_data: Update data
238 user_email: Email of user performing update
239 metadata: Audit metadata
241 Returns:
242 Updated service
244 Raises:
245 GrpcServiceNotFoundError: If service not found
246 GrpcServiceNameConflictError: If new name conflicts
247 """
248 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none()
250 if not service:
251 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
253 # Check name conflict if name is being changed
254 if service_data.name and service_data.name != service.name:
255 existing = db.execute(select(DbGrpcService).where(and_(DbGrpcService.name == service_data.name, DbGrpcService.id != service_id))).scalar_one_or_none()
257 if existing: 257 ↛ 261line 257 didn't jump to line 261 because the condition on line 257 was always true
258 raise GrpcServiceNameConflictError(name=service_data.name, is_active=existing.enabled, service_id=existing.id)
260 # Update fields
261 update_data = service_data.model_dump(exclude_unset=True)
262 for field, value in update_data.items():
263 setattr(service, field, value)
265 service.updated_at = datetime.now(timezone.utc)
267 # Set audit metadata
268 if metadata and user_email:
269 service.modified_by = user_email
270 service.modified_from_ip = metadata.get("ip")
271 service.modified_via = metadata.get("via")
272 service.modified_user_agent = metadata.get("user_agent")
274 service.version += 1
276 db.commit()
277 db.refresh(service)
279 logger.info(f"Updated gRPC service: {service.name}")
281 return GrpcServiceRead.model_validate(service)
283 async def set_service_state(
284 self,
285 db: Session,
286 service_id: str,
287 activate: bool,
288 ) -> GrpcServiceRead:
289 """Set a gRPC service's enabled status.
291 Args:
292 db: Database session
293 service_id: Service ID
294 activate: True to enable, False to disable
296 Returns:
297 Updated service
299 Raises:
300 GrpcServiceNotFoundError: If service not found
301 """
302 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none()
304 if not service: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
307 service.enabled = activate
308 service.updated_at = datetime.now(timezone.utc)
310 db.commit()
311 db.refresh(service)
313 action = "activated" if activate else "deactivated"
314 logger.info(f"gRPC service {service.name} {action}")
316 return GrpcServiceRead.model_validate(service)
318 async def delete_service(
319 self,
320 db: Session,
321 service_id: str,
322 ) -> None:
323 """Delete a gRPC service.
325 Args:
326 db: Database session
327 service_id: Service ID to delete
329 Raises:
330 GrpcServiceNotFoundError: If service not found
331 """
332 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none()
334 if not service:
335 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
337 db.delete(service)
338 db.commit()
340 logger.info(f"Deleted gRPC service: {service.name}")
342 async def reflect_service(
343 self,
344 db: Session,
345 service_id: str,
346 ) -> GrpcServiceRead:
347 """Trigger reflection on a gRPC service to discover services and methods.
349 Args:
350 db: Database session
351 service_id: Service ID
353 Returns:
354 Updated service with reflection results
356 Raises:
357 GrpcServiceNotFoundError: If service not found
358 GrpcServiceError: If reflection fails
359 """
360 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none()
362 if not service:
363 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
365 try:
366 await self._perform_reflection(db, service)
367 logger.info(f"Reflection completed for {service.name}: {service.service_count} services, {service.method_count} methods")
368 except Exception as e:
369 logger.error(f"Reflection failed for {service.name}: {e}")
370 service.reachable = False
371 db.commit()
372 raise GrpcServiceError(f"Reflection failed: {str(e)}")
374 return GrpcServiceRead.model_validate(service)
376 async def get_service_methods(
377 self,
378 db: Session,
379 service_id: str,
380 ) -> List[Dict[str, Any]]:
381 """Get the list of methods for a gRPC service.
383 Args:
384 db: Database session
385 service_id: Service ID
387 Returns:
388 List of method descriptors
390 Raises:
391 GrpcServiceNotFoundError: If service not found
392 """
393 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none()
395 if not service:
396 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
398 methods = []
399 discovered = service.discovered_services or {}
401 for service_name, service_desc in discovered.items():
402 for method in service_desc.get("methods", []):
403 methods.append(
404 {
405 "service": service_name,
406 "method": method["name"],
407 "full_name": f"{service_name}.{method['name']}",
408 "input_type": method.get("input_type"),
409 "output_type": method.get("output_type"),
410 "client_streaming": method.get("client_streaming", False),
411 "server_streaming": method.get("server_streaming", False),
412 }
413 )
415 return methods
417 async def _perform_reflection(
418 self,
419 db: Session,
420 service: DbGrpcService,
421 ) -> None:
422 """Perform gRPC server reflection to discover services.
424 Args:
425 db: Database session
426 service: GrpcService model instance
428 Raises:
429 GrpcServiceError: If TLS certificate files not found
430 Exception: If reflection or connection fails
431 """
432 # Create gRPC channel
433 if service.tls_enabled:
434 if service.tls_cert_path and service.tls_key_path:
435 # Load TLS certificates
436 try:
437 cert = await asyncio.to_thread(Path(service.tls_cert_path).read_bytes)
438 key = await asyncio.to_thread(Path(service.tls_key_path).read_bytes)
439 credentials = grpc.ssl_channel_credentials(root_certificates=cert, private_key=key)
440 except FileNotFoundError as e:
441 raise GrpcServiceError(f"TLS certificate or key file not found: {e}")
442 else:
443 # Use default system certificates
444 credentials = grpc.ssl_channel_credentials()
446 channel = grpc.secure_channel(service.target, credentials)
447 else:
448 channel = grpc.insecure_channel(service.target)
450 try: # pylint: disable=too-many-nested-blocks
451 # Import here to avoid circular dependency
452 # Third-Party
453 from google.protobuf.descriptor_pb2 import FileDescriptorProto # pylint: disable=import-outside-toplevel,no-name-in-module
455 # Create reflection stub
456 stub = reflection_pb2_grpc.ServerReflectionStub(channel)
458 # List services
459 request = reflection_pb2.ServerReflectionRequest(list_services="") # pylint: disable=no-member
461 response = stub.ServerReflectionInfo(iter([request]))
463 service_names = []
464 for resp in response:
465 if resp.HasField("list_services_response"): 465 ↛ 464line 465 didn't jump to line 464 because the condition on line 465 was always true
466 for svc in resp.list_services_response.service:
467 service_name = svc.name
468 # Skip reflection service itself
469 if "ServerReflection" in service_name:
470 continue
471 service_names.append(service_name)
473 # Get detailed information for each service
474 discovered_services = {}
475 service_count = 0
476 method_count = 0
478 for service_name in service_names:
479 try:
480 # Request file descriptor containing this service
481 file_request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=service_name) # pylint: disable=no-member
483 file_response = stub.ServerReflectionInfo(iter([file_request]))
485 for resp in file_response:
486 if resp.HasField("file_descriptor_response"): 486 ↛ 485line 486 didn't jump to line 485 because the condition on line 486 was always true
487 # Process file descriptors
488 for file_desc_proto_bytes in resp.file_descriptor_response.file_descriptor_proto:
489 file_desc_proto = FileDescriptorProto()
490 file_desc_proto.ParseFromString(file_desc_proto_bytes)
492 # Extract service and method information
493 for service_desc in file_desc_proto.service:
494 if service_desc.name in service_name or service_name.endswith(service_desc.name): 494 ↛ 493line 494 didn't jump to line 493 because the condition on line 494 was always true
495 full_service_name = f"{file_desc_proto.package}.{service_desc.name}" if file_desc_proto.package else service_desc.name
497 methods = []
498 for method_desc in service_desc.method:
499 methods.append(
500 {
501 "name": method_desc.name,
502 "input_type": method_desc.input_type,
503 "output_type": method_desc.output_type,
504 "client_streaming": method_desc.client_streaming,
505 "server_streaming": method_desc.server_streaming,
506 }
507 )
508 method_count += 1
510 discovered_services[full_service_name] = {
511 "name": full_service_name,
512 "methods": methods,
513 "package": file_desc_proto.package,
514 }
515 service_count += 1
517 except Exception as detail_error:
518 logger.warning(f"Failed to get details for {service_name}: {detail_error}")
519 # Add basic info even if detailed discovery fails
520 discovered_services[service_name] = {
521 "name": service_name,
522 "methods": [],
523 }
524 service_count += 1
526 service.discovered_services = discovered_services
527 service.service_count = service_count
528 service.method_count = method_count
529 service.last_reflection = datetime.now(timezone.utc)
530 service.reachable = True
532 db.commit()
534 except Exception as e:
535 logger.error(f"Reflection error for {service.target}: {e}")
536 service.reachable = False
537 db.commit()
538 raise
540 finally:
541 channel.close()
543 async def invoke_method(
544 self,
545 db: Session,
546 service_id: str,
547 method_name: str,
548 request_data: Dict[str, Any],
549 ) -> Dict[str, Any]:
550 """Invoke a gRPC method on a registered service.
552 Args:
553 db: Database session
554 service_id: Service ID
555 method_name: Full method name (service.Method)
556 request_data: JSON request data
558 Returns:
559 JSON response data
561 Raises:
562 GrpcServiceNotFoundError: If service not found
563 GrpcServiceError: If invocation fails
564 """
565 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none()
567 if not service:
568 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found")
570 if not service.enabled:
571 raise GrpcServiceError(f"Service '{service.name}' is disabled")
573 # Import here to avoid circular dependency
574 # First-Party
575 from mcpgateway.translate_grpc import GrpcEndpoint # pylint: disable=import-outside-toplevel
577 # Parse method name (service.Method format)
578 if "." not in method_name:
579 raise GrpcServiceError(f"Invalid method name '{method_name}', expected 'service.Method' format")
581 parts = method_name.rsplit(".", 1)
582 service_name = ".".join(parts[:-1]) if len(parts) > 1 else parts[0]
583 method = parts[-1]
585 # Create endpoint and invoke
586 endpoint = GrpcEndpoint(
587 target=service.target,
588 reflection_enabled=False, # Assume already discovered
589 tls_enabled=service.tls_enabled,
590 tls_cert_path=service.tls_cert_path,
591 tls_key_path=service.tls_key_path,
592 metadata=service.grpc_metadata or {},
593 )
595 try:
596 # Start connection
597 await endpoint.start()
599 # If we have stored service info, use it
600 if service.discovered_services: 600 ↛ 604line 600 didn't jump to line 604 because the condition on line 600 was always true
601 endpoint._services = service.discovered_services # pylint: disable=protected-access
603 # Invoke method
604 response = await endpoint.invoke(service_name, method, request_data)
606 return response
608 except Exception as e:
609 logger.error(f"Failed to invoke {method_name} on {service.name}: {e}")
610 raise GrpcServiceError(f"Method invocation failed: {e}")
612 finally:
613 await endpoint.close()