Coverage for mcpgateway / services / grpc_service.py: 95%

229 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/grpc_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: MCP Gateway Contributors 

6 

7gRPC Service Management 

8 

9This module implements gRPC service management for the MCP Gateway. 

10It handles gRPC service registration, reflection-based discovery, listing, 

11retrieval, updates, activation toggling, and deletion. 

12""" 

13 

14# Standard 

15import asyncio 

16from datetime import datetime, timezone 

17from pathlib import Path 

18from typing import Any, Dict, List, Optional 

19 

20try: 

21 # Third-Party 

22 import grpc 

23 from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc 

24 

25 GRPC_AVAILABLE = True 

26except ImportError: 

27 GRPC_AVAILABLE = False 

28 # grpc module will not be used if not available 

29 grpc = None # type: ignore 

30 reflection_pb2 = None # type: ignore 

31 reflection_pb2_grpc = None # type: ignore 

32 

33# Third-Party 

34from sqlalchemy import and_, desc, select 

35from sqlalchemy.orm import Session 

36 

37# First-Party 

38from mcpgateway.db import GrpcService as DbGrpcService 

39from mcpgateway.schemas import GrpcServiceCreate, GrpcServiceRead, GrpcServiceUpdate 

40from mcpgateway.services.logging_service import LoggingService 

41from mcpgateway.services.team_management_service import TeamManagementService 

42 

43# Initialize logging 

44logging_service = LoggingService() 

45logger = logging_service.get_logger(__name__) 

46 

47 

48class GrpcServiceError(Exception): 

49 """Base class for gRPC service-related errors.""" 

50 

51 

52class GrpcServiceNotFoundError(GrpcServiceError): 

53 """Raised when a requested gRPC service is not found.""" 

54 

55 

56class GrpcServiceNameConflictError(GrpcServiceError): 

57 """Raised when a gRPC service name conflicts with an existing one.""" 

58 

59 def __init__(self, name: str, is_active: bool = True, service_id: Optional[str] = None): 

60 """Initialize the GrpcServiceNameConflictError. 

61 

62 Args: 

63 name: The conflicting gRPC service name 

64 is_active: Whether the conflicting service is currently active 

65 service_id: The ID of the conflicting service, if known 

66 """ 

67 self.name = name 

68 self.is_active = is_active 

69 self.service_id = service_id 

70 msg = f"gRPC service with name '{name}' already exists" 

71 if not is_active: 

72 msg += " (inactive)" 

73 if service_id: 

74 msg += f" (ID: {service_id})" 

75 super().__init__(msg) 

76 

77 

78class GrpcService: 

79 """Service for managing gRPC services with reflection-based discovery.""" 

80 

81 def __init__(self): 

82 """Initialize the gRPC service manager.""" 

83 

84 async def register_service( 

85 self, 

86 db: Session, 

87 service_data: GrpcServiceCreate, 

88 user_email: Optional[str] = None, 

89 metadata: Optional[Dict[str, Any]] = None, 

90 ) -> GrpcServiceRead: 

91 """Register a new gRPC service. 

92 

93 Args: 

94 db: Database session 

95 service_data: gRPC service creation data 

96 user_email: Email of the user creating the service 

97 metadata: Additional metadata (IP, user agent, etc.) 

98 

99 Returns: 

100 GrpcServiceRead: The created service 

101 

102 Raises: 

103 GrpcServiceNameConflictError: If service name already exists 

104 """ 

105 # Check for name conflicts 

106 existing = db.execute(select(DbGrpcService).where(DbGrpcService.name == service_data.name)).scalar_one_or_none() 

107 

108 if existing: 

109 raise GrpcServiceNameConflictError(name=service_data.name, is_active=existing.enabled, service_id=existing.id) 

110 

111 # Create service 

112 db_service = DbGrpcService( 

113 name=service_data.name, 

114 target=service_data.target, 

115 description=service_data.description, 

116 reflection_enabled=service_data.reflection_enabled, 

117 tls_enabled=service_data.tls_enabled, 

118 tls_cert_path=service_data.tls_cert_path, 

119 tls_key_path=service_data.tls_key_path, 

120 grpc_metadata=service_data.grpc_metadata or {}, 

121 tags=service_data.tags or [], 

122 team_id=service_data.team_id, 

123 owner_email=user_email or service_data.owner_email, 

124 visibility=service_data.visibility, 

125 created_at=datetime.now(timezone.utc), 

126 updated_at=datetime.now(timezone.utc), 

127 ) 

128 

129 # Set audit metadata if provided 

130 if metadata: 

131 db_service.created_by = user_email 

132 db_service.created_from_ip = metadata.get("ip") 

133 db_service.created_via = metadata.get("via") 

134 db_service.created_user_agent = metadata.get("user_agent") 

135 

136 db.add(db_service) 

137 db.commit() 

138 db.refresh(db_service) 

139 

140 logger.info(f"Registered gRPC service: {db_service.name} (target: {db_service.target})") 

141 

142 # Perform initial reflection if enabled 

143 if db_service.reflection_enabled: 

144 try: 

145 await self._perform_reflection(db, db_service) 

146 except Exception as e: 

147 logger.warning(f"Initial reflection failed for {db_service.name}: {e}") 

148 

149 return GrpcServiceRead.model_validate(db_service) 

150 

151 async def list_services( 

152 self, 

153 db: Session, 

154 include_inactive: bool = False, 

155 user_email: Optional[str] = None, 

156 team_id: Optional[str] = None, 

157 ) -> List[GrpcServiceRead]: 

158 """List gRPC services with optional filtering. 

159 

160 Args: 

161 db: Database session 

162 include_inactive: Include disabled services 

163 user_email: Filter by user email for team access control 

164 team_id: Filter by team ID 

165 

166 Returns: 

167 List of gRPC services 

168 """ 

169 query = select(DbGrpcService) 

170 

171 # Apply team filtering 

172 if user_email and team_id: 

173 team_service = TeamManagementService(db) 

174 team_filter = await team_service.build_team_filter_clause(DbGrpcService, user_email, team_id) # pylint: disable=no-member 

175 if team_filter is not None: 

176 query = query.where(team_filter) 

177 elif team_id: 

178 query = query.where(DbGrpcService.team_id == team_id) 

179 

180 # Apply active filter 

181 if not include_inactive: 

182 query = query.where(DbGrpcService.enabled.is_(True)) # pylint: disable=singleton-comparison 

183 

184 query = query.order_by(desc(DbGrpcService.created_at)) 

185 

186 services = db.execute(query).scalars().all() 

187 return [GrpcServiceRead.model_validate(svc) for svc in services] 

188 

189 async def get_service( 

190 self, 

191 db: Session, 

192 service_id: str, 

193 user_email: Optional[str] = None, 

194 ) -> GrpcServiceRead: 

195 """Get a specific gRPC service by ID. 

196 

197 Args: 

198 db: Database session 

199 service_id: Service ID 

200 user_email: Email for team access control 

201 

202 Returns: 

203 The gRPC service 

204 

205 Raises: 

206 GrpcServiceNotFoundError: If service not found or access denied 

207 """ 

208 query = select(DbGrpcService).where(DbGrpcService.id == service_id) 

209 

210 # Apply team access control 

211 if user_email: 

212 team_service = TeamManagementService(db) 

213 team_filter = await team_service.build_team_filter_clause(DbGrpcService, user_email, None) # pylint: disable=no-member 

214 if team_filter is not None: 214 ↛ 217line 214 didn't jump to line 217 because the condition on line 214 was always true

215 query = query.where(team_filter) 

216 

217 service = db.execute(query).scalar_one_or_none() 

218 

219 if not service: 

220 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

221 

222 return GrpcServiceRead.model_validate(service) 

223 

224 async def update_service( 

225 self, 

226 db: Session, 

227 service_id: str, 

228 service_data: GrpcServiceUpdate, 

229 user_email: Optional[str] = None, 

230 metadata: Optional[Dict[str, Any]] = None, 

231 ) -> GrpcServiceRead: 

232 """Update an existing gRPC service. 

233 

234 Args: 

235 db: Database session 

236 service_id: Service ID to update 

237 service_data: Update data 

238 user_email: Email of user performing update 

239 metadata: Audit metadata 

240 

241 Returns: 

242 Updated service 

243 

244 Raises: 

245 GrpcServiceNotFoundError: If service not found 

246 GrpcServiceNameConflictError: If new name conflicts 

247 """ 

248 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none() 

249 

250 if not service: 

251 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

252 

253 # Check name conflict if name is being changed 

254 if service_data.name and service_data.name != service.name: 

255 existing = db.execute(select(DbGrpcService).where(and_(DbGrpcService.name == service_data.name, DbGrpcService.id != service_id))).scalar_one_or_none() 

256 

257 if existing: 257 ↛ 261line 257 didn't jump to line 261 because the condition on line 257 was always true

258 raise GrpcServiceNameConflictError(name=service_data.name, is_active=existing.enabled, service_id=existing.id) 

259 

260 # Update fields 

261 update_data = service_data.model_dump(exclude_unset=True) 

262 for field, value in update_data.items(): 

263 setattr(service, field, value) 

264 

265 service.updated_at = datetime.now(timezone.utc) 

266 

267 # Set audit metadata 

268 if metadata and user_email: 

269 service.modified_by = user_email 

270 service.modified_from_ip = metadata.get("ip") 

271 service.modified_via = metadata.get("via") 

272 service.modified_user_agent = metadata.get("user_agent") 

273 

274 service.version += 1 

275 

276 db.commit() 

277 db.refresh(service) 

278 

279 logger.info(f"Updated gRPC service: {service.name}") 

280 

281 return GrpcServiceRead.model_validate(service) 

282 

283 async def set_service_state( 

284 self, 

285 db: Session, 

286 service_id: str, 

287 activate: bool, 

288 ) -> GrpcServiceRead: 

289 """Set a gRPC service's enabled status. 

290 

291 Args: 

292 db: Database session 

293 service_id: Service ID 

294 activate: True to enable, False to disable 

295 

296 Returns: 

297 Updated service 

298 

299 Raises: 

300 GrpcServiceNotFoundError: If service not found 

301 """ 

302 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none() 

303 

304 if not service: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

306 

307 service.enabled = activate 

308 service.updated_at = datetime.now(timezone.utc) 

309 

310 db.commit() 

311 db.refresh(service) 

312 

313 action = "activated" if activate else "deactivated" 

314 logger.info(f"gRPC service {service.name} {action}") 

315 

316 return GrpcServiceRead.model_validate(service) 

317 

318 async def delete_service( 

319 self, 

320 db: Session, 

321 service_id: str, 

322 ) -> None: 

323 """Delete a gRPC service. 

324 

325 Args: 

326 db: Database session 

327 service_id: Service ID to delete 

328 

329 Raises: 

330 GrpcServiceNotFoundError: If service not found 

331 """ 

332 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none() 

333 

334 if not service: 

335 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

336 

337 db.delete(service) 

338 db.commit() 

339 

340 logger.info(f"Deleted gRPC service: {service.name}") 

341 

342 async def reflect_service( 

343 self, 

344 db: Session, 

345 service_id: str, 

346 ) -> GrpcServiceRead: 

347 """Trigger reflection on a gRPC service to discover services and methods. 

348 

349 Args: 

350 db: Database session 

351 service_id: Service ID 

352 

353 Returns: 

354 Updated service with reflection results 

355 

356 Raises: 

357 GrpcServiceNotFoundError: If service not found 

358 GrpcServiceError: If reflection fails 

359 """ 

360 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none() 

361 

362 if not service: 

363 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

364 

365 try: 

366 await self._perform_reflection(db, service) 

367 logger.info(f"Reflection completed for {service.name}: {service.service_count} services, {service.method_count} methods") 

368 except Exception as e: 

369 logger.error(f"Reflection failed for {service.name}: {e}") 

370 service.reachable = False 

371 db.commit() 

372 raise GrpcServiceError(f"Reflection failed: {str(e)}") 

373 

374 return GrpcServiceRead.model_validate(service) 

375 

376 async def get_service_methods( 

377 self, 

378 db: Session, 

379 service_id: str, 

380 ) -> List[Dict[str, Any]]: 

381 """Get the list of methods for a gRPC service. 

382 

383 Args: 

384 db: Database session 

385 service_id: Service ID 

386 

387 Returns: 

388 List of method descriptors 

389 

390 Raises: 

391 GrpcServiceNotFoundError: If service not found 

392 """ 

393 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none() 

394 

395 if not service: 

396 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

397 

398 methods = [] 

399 discovered = service.discovered_services or {} 

400 

401 for service_name, service_desc in discovered.items(): 

402 for method in service_desc.get("methods", []): 

403 methods.append( 

404 { 

405 "service": service_name, 

406 "method": method["name"], 

407 "full_name": f"{service_name}.{method['name']}", 

408 "input_type": method.get("input_type"), 

409 "output_type": method.get("output_type"), 

410 "client_streaming": method.get("client_streaming", False), 

411 "server_streaming": method.get("server_streaming", False), 

412 } 

413 ) 

414 

415 return methods 

416 

417 async def _perform_reflection( 

418 self, 

419 db: Session, 

420 service: DbGrpcService, 

421 ) -> None: 

422 """Perform gRPC server reflection to discover services. 

423 

424 Args: 

425 db: Database session 

426 service: GrpcService model instance 

427 

428 Raises: 

429 GrpcServiceError: If TLS certificate files not found 

430 Exception: If reflection or connection fails 

431 """ 

432 # Create gRPC channel 

433 if service.tls_enabled: 

434 if service.tls_cert_path and service.tls_key_path: 

435 # Load TLS certificates 

436 try: 

437 cert = await asyncio.to_thread(Path(service.tls_cert_path).read_bytes) 

438 key = await asyncio.to_thread(Path(service.tls_key_path).read_bytes) 

439 credentials = grpc.ssl_channel_credentials(root_certificates=cert, private_key=key) 

440 except FileNotFoundError as e: 

441 raise GrpcServiceError(f"TLS certificate or key file not found: {e}") 

442 else: 

443 # Use default system certificates 

444 credentials = grpc.ssl_channel_credentials() 

445 

446 channel = grpc.secure_channel(service.target, credentials) 

447 else: 

448 channel = grpc.insecure_channel(service.target) 

449 

450 try: # pylint: disable=too-many-nested-blocks 

451 # Import here to avoid circular dependency 

452 # Third-Party 

453 from google.protobuf.descriptor_pb2 import FileDescriptorProto # pylint: disable=import-outside-toplevel,no-name-in-module 

454 

455 # Create reflection stub 

456 stub = reflection_pb2_grpc.ServerReflectionStub(channel) 

457 

458 # List services 

459 request = reflection_pb2.ServerReflectionRequest(list_services="") # pylint: disable=no-member 

460 

461 response = stub.ServerReflectionInfo(iter([request])) 

462 

463 service_names = [] 

464 for resp in response: 

465 if resp.HasField("list_services_response"): 465 ↛ 464line 465 didn't jump to line 464 because the condition on line 465 was always true

466 for svc in resp.list_services_response.service: 

467 service_name = svc.name 

468 # Skip reflection service itself 

469 if "ServerReflection" in service_name: 

470 continue 

471 service_names.append(service_name) 

472 

473 # Get detailed information for each service 

474 discovered_services = {} 

475 service_count = 0 

476 method_count = 0 

477 

478 for service_name in service_names: 

479 try: 

480 # Request file descriptor containing this service 

481 file_request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=service_name) # pylint: disable=no-member 

482 

483 file_response = stub.ServerReflectionInfo(iter([file_request])) 

484 

485 for resp in file_response: 

486 if resp.HasField("file_descriptor_response"): 486 ↛ 485line 486 didn't jump to line 485 because the condition on line 486 was always true

487 # Process file descriptors 

488 for file_desc_proto_bytes in resp.file_descriptor_response.file_descriptor_proto: 

489 file_desc_proto = FileDescriptorProto() 

490 file_desc_proto.ParseFromString(file_desc_proto_bytes) 

491 

492 # Extract service and method information 

493 for service_desc in file_desc_proto.service: 

494 if service_desc.name in service_name or service_name.endswith(service_desc.name): 494 ↛ 493line 494 didn't jump to line 493 because the condition on line 494 was always true

495 full_service_name = f"{file_desc_proto.package}.{service_desc.name}" if file_desc_proto.package else service_desc.name 

496 

497 methods = [] 

498 for method_desc in service_desc.method: 

499 methods.append( 

500 { 

501 "name": method_desc.name, 

502 "input_type": method_desc.input_type, 

503 "output_type": method_desc.output_type, 

504 "client_streaming": method_desc.client_streaming, 

505 "server_streaming": method_desc.server_streaming, 

506 } 

507 ) 

508 method_count += 1 

509 

510 discovered_services[full_service_name] = { 

511 "name": full_service_name, 

512 "methods": methods, 

513 "package": file_desc_proto.package, 

514 } 

515 service_count += 1 

516 

517 except Exception as detail_error: 

518 logger.warning(f"Failed to get details for {service_name}: {detail_error}") 

519 # Add basic info even if detailed discovery fails 

520 discovered_services[service_name] = { 

521 "name": service_name, 

522 "methods": [], 

523 } 

524 service_count += 1 

525 

526 service.discovered_services = discovered_services 

527 service.service_count = service_count 

528 service.method_count = method_count 

529 service.last_reflection = datetime.now(timezone.utc) 

530 service.reachable = True 

531 

532 db.commit() 

533 

534 except Exception as e: 

535 logger.error(f"Reflection error for {service.target}: {e}") 

536 service.reachable = False 

537 db.commit() 

538 raise 

539 

540 finally: 

541 channel.close() 

542 

543 async def invoke_method( 

544 self, 

545 db: Session, 

546 service_id: str, 

547 method_name: str, 

548 request_data: Dict[str, Any], 

549 ) -> Dict[str, Any]: 

550 """Invoke a gRPC method on a registered service. 

551 

552 Args: 

553 db: Database session 

554 service_id: Service ID 

555 method_name: Full method name (service.Method) 

556 request_data: JSON request data 

557 

558 Returns: 

559 JSON response data 

560 

561 Raises: 

562 GrpcServiceNotFoundError: If service not found 

563 GrpcServiceError: If invocation fails 

564 """ 

565 service = db.execute(select(DbGrpcService).where(DbGrpcService.id == service_id)).scalar_one_or_none() 

566 

567 if not service: 

568 raise GrpcServiceNotFoundError(f"gRPC service with ID '{service_id}' not found") 

569 

570 if not service.enabled: 

571 raise GrpcServiceError(f"Service '{service.name}' is disabled") 

572 

573 # Import here to avoid circular dependency 

574 # First-Party 

575 from mcpgateway.translate_grpc import GrpcEndpoint # pylint: disable=import-outside-toplevel 

576 

577 # Parse method name (service.Method format) 

578 if "." not in method_name: 

579 raise GrpcServiceError(f"Invalid method name '{method_name}', expected 'service.Method' format") 

580 

581 parts = method_name.rsplit(".", 1) 

582 service_name = ".".join(parts[:-1]) if len(parts) > 1 else parts[0] 

583 method = parts[-1] 

584 

585 # Create endpoint and invoke 

586 endpoint = GrpcEndpoint( 

587 target=service.target, 

588 reflection_enabled=False, # Assume already discovered 

589 tls_enabled=service.tls_enabled, 

590 tls_cert_path=service.tls_cert_path, 

591 tls_key_path=service.tls_key_path, 

592 metadata=service.grpc_metadata or {}, 

593 ) 

594 

595 try: 

596 # Start connection 

597 await endpoint.start() 

598 

599 # If we have stored service info, use it 

600 if service.discovered_services: 600 ↛ 604line 600 didn't jump to line 604 because the condition on line 600 was always true

601 endpoint._services = service.discovered_services # pylint: disable=protected-access 

602 

603 # Invoke method 

604 response = await endpoint.invoke(service_name, method, request_data) 

605 

606 return response 

607 

608 except Exception as e: 

609 logger.error(f"Failed to invoke {method_name} on {service.name}: {e}") 

610 raise GrpcServiceError(f"Method invocation failed: {e}") 

611 

612 finally: 

613 await endpoint.close()