Coverage for mcpgateway / services / gateway_service.py: 93%

2306 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2# pylint: disable=import-outside-toplevel,no-name-in-module 

3"""Location: ./mcpgateway/services/gateway_service.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Gateway Service Implementation. 

9This module implements gateway federation according to the MCP specification. 

10It handles: 

11- Gateway discovery and registration 

12- Capability aggregation 

13- Health monitoring 

14- Active/inactive gateway management 

15 

16Examples: 

17 >>> from mcpgateway.services.gateway_service import GatewayService, GatewayError 

18 >>> service = GatewayService() 

19 >>> isinstance(service, GatewayService) 

20 True 

21 >>> hasattr(service, '_active_gateways') 

22 True 

23 >>> isinstance(service._active_gateways, set) 

24 True 

25 

26 Test error classes: 

27 >>> error = GatewayError("Test error") 

28 >>> str(error) 

29 'Test error' 

30 >>> isinstance(error, Exception) 

31 True 

32 

33 >>> conflict_error = GatewayNameConflictError("test_gw") 

34 >>> "test_gw" in str(conflict_error) 

35 True 

36 >>> conflict_error.enabled 

37 True 

38 >>> 

39 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

40 >>> import asyncio 

41 >>> asyncio.run(service._http_client.aclose()) 

42""" 

43 

44# Standard 

45import asyncio 

46import binascii 

47from datetime import datetime, timezone 

48import logging 

49import mimetypes 

50import os 

51import ssl 

52import tempfile 

53import time 

54from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union 

55from urllib.parse import urlparse, urlunparse 

56import uuid 

57 

58# Third-Party 

59import anyio 

60from filelock import FileLock, Timeout 

61import httpx 

62from mcp import ClientSession 

63from mcp.client.sse import sse_client 

64from mcp.client.streamable_http import streamablehttp_client 

65from pydantic import ValidationError 

66from sqlalchemy import and_, delete, desc, or_, select, update 

67from sqlalchemy.exc import IntegrityError 

68from sqlalchemy.orm import joinedload, selectinload, Session 

69 

70try: 

71 # Third-Party - check if redis is available 

72 # Third-Party 

73 import redis.asyncio as _aioredis # noqa: F401 # pylint: disable=unused-import 

74 

75 REDIS_AVAILABLE = True 

76 del _aioredis # Only needed for availability check 

77except ImportError: 

78 REDIS_AVAILABLE = False 

79 logging.info("Redis is not utilized in this environment.") 

80 

81# First-Party 

82from mcpgateway.common.validators import SecurityValidator 

83from mcpgateway.config import settings 

84from mcpgateway.db import EmailTeam as DbEmailTeam 

85from mcpgateway.db import EmailTeamMember as DbEmailTeamMember 

86from mcpgateway.db import fresh_db_session 

87from mcpgateway.db import Gateway as DbGateway 

88from mcpgateway.db import get_for_update 

89from mcpgateway.db import Prompt as DbPrompt 

90from mcpgateway.db import PromptMetric 

91from mcpgateway.db import Resource as DbResource 

92from mcpgateway.db import ResourceMetric, ResourceSubscription, server_prompt_association, server_resource_association, server_tool_association, SessionLocal 

93from mcpgateway.db import Tool as DbTool 

94from mcpgateway.db import ToolMetric 

95from mcpgateway.observability import create_span, set_span_attribute, set_span_error 

96from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate 

97 

98# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks 

99from mcpgateway.services.audit_trail_service import get_audit_trail_service 

100from mcpgateway.services.base_service import BaseService 

101from mcpgateway.services.encryption_service import get_encryption_service, protect_oauth_config_for_storage 

102from mcpgateway.services.event_service import EventService 

103from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout, get_isolated_http_client 

104from mcpgateway.services.logging_service import LoggingService 

105from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, register_gateway_capabilities_for_notifications, TransportType 

106from mcpgateway.services.oauth_manager import OAuthManager 

107from mcpgateway.services.structured_logger import get_structured_logger 

108from mcpgateway.services.team_management_service import TeamManagementService 

109from mcpgateway.utils.create_slug import slugify 

110from mcpgateway.utils.display_name import generate_display_name 

111from mcpgateway.utils.pagination import unified_paginate 

112from mcpgateway.utils.passthrough_headers import get_passthrough_headers 

113from mcpgateway.utils.redis_client import get_redis_client 

114from mcpgateway.utils.retry_manager import ResilientHttpClient 

115from mcpgateway.utils.services_auth import decode_auth, encode_auth 

116from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr 

117from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context 

118from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging 

119from mcpgateway.utils.validate_signature import validate_signature 

120from mcpgateway.validation.tags import validate_tags_field 

121 

122 

123def _resolve_tool_title(tool) -> Optional[str]: 

124 """Resolve the display title for a tool per MCP spec precedence. 

125 

126 MCP 2025-11-25: "Display name precedence order is: title, 

127 annotations.title, then name." 

128 

129 1. ``tool.title`` — top-level ``BaseMetadata`` field (canonical). 

130 2. ``tool.annotations.title`` — ``ToolAnnotations`` (legacy fallback). 

131 3. ``None`` if neither is available (caller may fall back to ``name``). 

132 

133 All return paths are guarded with ``isinstance(str)`` so the function 

134 never leaks non-string values from mock objects or malformed payloads. 

135 

136 Args: 

137 tool: An object representing a tool. It may define a top-level 

138 ``title`` attribute and/or an ``annotations`` attribute 

139 (``ToolAnnotations`` model or ``dict``). 

140 

141 Returns: 

142 Optional[str]: The resolved title string if found, otherwise None. 

143 

144 Examples: 

145 >>> class Tool: 

146 ... def __init__(self, title=None, annotations=None): 

147 ... self.title = title 

148 ... self.annotations = annotations 

149 ... 

150 >>> # 1. top-level title takes precedence 

151 >>> tool = Tool(title="Top Level", annotations={"title": "Annotated"}) 

152 >>> _resolve_tool_title(tool) 

153 'Top Level' 

154 

155 >>> # 2. Fallback to annotations.title 

156 >>> tool = Tool(annotations={"title": "Annotated"}) 

157 >>> _resolve_tool_title(tool) 

158 'Annotated' 

159 

160 >>> # 3. No title available 

161 >>> tool = Tool() 

162 >>> _resolve_tool_title(tool) is None 

163 True 

164 

165 >>> # 4. annotations is not a dict 

166 >>> tool = Tool(title="Top Level", annotations="invalid") 

167 >>> _resolve_tool_title(tool) 

168 'Top Level' 

169 """ 

170 # MCP spec: "Display name precedence order is: title, annotations.title, then name." 

171 title = getattr(tool, "title", None) 

172 if isinstance(title, str): 

173 return title 

174 annotations = getattr(tool, "annotations", None) 

175 if annotations is not None: 

176 if isinstance(annotations, dict): 

177 ann_title = annotations.get("title") 

178 else: 

179 ann_title = getattr(annotations, "title", None) 

180 if isinstance(ann_title, str): 

181 return ann_title 

182 return None 

183 

184 

185# Cache import (lazy to avoid circular dependencies) 

186_REGISTRY_CACHE = None 

187_TOOL_LOOKUP_CACHE = None 

188 

189 

190def _get_registry_cache(): 

191 """Get registry cache singleton lazily. 

192 

193 Returns: 

194 RegistryCache instance. 

195 """ 

196 global _REGISTRY_CACHE # pylint: disable=global-statement 

197 if _REGISTRY_CACHE is None: 

198 # First-Party 

199 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel 

200 

201 _REGISTRY_CACHE = registry_cache 

202 return _REGISTRY_CACHE 

203 

204 

205def _get_tool_lookup_cache(): 

206 """Get tool lookup cache singleton lazily. 

207 

208 Returns: 

209 ToolLookupCache instance. 

210 """ 

211 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement 

212 if _TOOL_LOOKUP_CACHE is None: 

213 # First-Party 

214 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

215 

216 _TOOL_LOOKUP_CACHE = tool_lookup_cache 

217 return _TOOL_LOOKUP_CACHE 

218 

219 

220# Initialize logging service first 

221logging_service = LoggingService() 

222logger = logging_service.get_logger(__name__) 

223 

224# Initialize structured logger and audit trail for gateway operations 

225structured_logger = get_structured_logger("gateway_service") 

226audit_trail = get_audit_trail_service() 

227 

228 

229GW_FAILURE_THRESHOLD = settings.unhealthy_threshold 

230GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval 

231 

232 

233class GatewayError(Exception): 

234 """Base class for gateway-related errors. 

235 

236 Examples: 

237 >>> error = GatewayError("Test error") 

238 >>> str(error) 

239 'Test error' 

240 >>> isinstance(error, Exception) 

241 True 

242 """ 

243 

244 

245class GatewayNotFoundError(GatewayError): 

246 """Raised when a requested gateway is not found. 

247 

248 Examples: 

249 >>> error = GatewayNotFoundError("Gateway not found") 

250 >>> str(error) 

251 'Gateway not found' 

252 >>> isinstance(error, GatewayError) 

253 True 

254 """ 

255 

256 

257class GatewayNameConflictError(GatewayError): 

258 """Raised when a gateway name conflicts with existing (active or inactive) gateway. 

259 

260 Args: 

261 name: The conflicting gateway name 

262 enabled: Whether the existing gateway is enabled 

263 gateway_id: ID of the existing gateway if available 

264 visibility: The visibility of the gateway ("public" or "team"). 

265 

266 Examples: 

267 >>> error = GatewayNameConflictError("test_gateway") 

268 >>> str(error) 

269 'Public Gateway already exists with name: test_gateway' 

270 >>> error.name 

271 'test_gateway' 

272 >>> error.enabled 

273 True 

274 >>> error.gateway_id is None 

275 True 

276 

277 >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123) 

278 >>> str(error_inactive) 

279 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)' 

280 >>> error_inactive.enabled 

281 False 

282 >>> error_inactive.gateway_id 

283 123 

284 """ 

285 

286 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"): 

287 """Initialize the error with gateway information. 

288 

289 Args: 

290 name: The conflicting gateway name 

291 enabled: Whether the existing gateway is enabled 

292 gateway_id: ID of the existing gateway if available 

293 visibility: The visibility of the gateway ("public" or "team"). 

294 """ 

295 self.name = name 

296 self.enabled = enabled 

297 self.gateway_id = gateway_id 

298 if visibility == "team": 

299 vis_label = "Team-level" 

300 else: 

301 vis_label = "Public" 

302 message = f"{vis_label} Gateway already exists with name: {name}" 

303 if not enabled: 

304 message += f" (currently inactive, ID: {gateway_id})" 

305 super().__init__(message) 

306 

307 

308class GatewayDuplicateConflictError(GatewayError): 

309 """Raised when a gateway conflicts with an existing gateway (same URL + credentials). 

310 

311 This error is raised when attempting to register a gateway with a URL and 

312 authentication credentials that already exist within the same scope: 

313 - Public: Global uniqueness required across all public gateways. 

314 - Team: Uniqueness required within the same team. 

315 - Private: Uniqueness required for the same user, a user cannot have two private gateways with the same URL and credentials. 

316 

317 Args: 

318 duplicate_gateway: The existing conflicting gateway (DbGateway instance). 

319 

320 Examples: 

321 >>> # Public gateway conflict with the same URL and basic auth 

322 >>> existing_gw = DbGateway(url="https://api.example.com", id="abc-123", enabled=True, visibility="public", team_id=None, name="API Gateway", owner_email="alice@example.com") 

323 >>> error = GatewayDuplicateConflictError( 

324 ... duplicate_gateway=existing_gw 

325 ... ) 

326 >>> str(error) 

327 'The Server already exists in Public scope (Name: API Gateway, Status: active)' 

328 

329 >>> # Team gateway conflict with the same URL and OAuth credentials 

330 >>> team_gw = DbGateway(url="https://api.example.com", id="def-456", enabled=False, visibility="team", team_id="engineering-team", name="API Gateway", owner_email="bob@example.com") 

331 >>> error = GatewayDuplicateConflictError( 

332 ... duplicate_gateway=team_gw 

333 ... ) 

334 >>> str(error) 

335 'The Server already exists in your Team (Name: API Gateway, Status: inactive). You may want to re-enable the existing gateway instead.' 

336 

337 >>> # Private gateway conflict (same user cannot have two gateways with the same URL) 

338 >>> private_gw = DbGateway(url="https://api.example.com", id="ghi-789", enabled=True, visibility="private", team_id="none", name="API Gateway", owner_email="charlie@example.com") 

339 >>> error = GatewayDuplicateConflictError( 

340 ... duplicate_gateway=private_gw 

341 ... ) 

342 >>> str(error) 

343 'The Server already exists in "private" scope (Name: API Gateway, Status: active)' 

344 """ 

345 

346 def __init__( 

347 self, 

348 duplicate_gateway: "DbGateway", 

349 ): 

350 """Initialize the error with gateway information. 

351 

352 Args: 

353 duplicate_gateway: The existing conflicting gateway (DbGateway instance) 

354 """ 

355 self.duplicate_gateway = duplicate_gateway 

356 self.url = duplicate_gateway.url 

357 self.gateway_id = duplicate_gateway.id 

358 self.enabled = duplicate_gateway.enabled 

359 self.visibility = duplicate_gateway.visibility 

360 self.team_id = duplicate_gateway.team_id 

361 self.name = duplicate_gateway.name 

362 

363 # Build scope description 

364 if self.visibility == "public": 

365 scope_desc = "Public scope" 

366 elif self.visibility == "team" and self.team_id: 

367 scope_desc = "your Team" 

368 else: 

369 scope_desc = f'"{self.visibility}" scope' 

370 

371 # Build status description 

372 status = "active" if self.enabled else "inactive" 

373 

374 # Construct error message 

375 message = f"The Server already exists in {scope_desc} " f"(Name: {self.name}, Status: {status})" 

376 

377 # Add helpful hint for inactive gateways 

378 if not self.enabled: 

379 message += ". You may want to re-enable the existing gateway instead." 

380 

381 super().__init__(message) 

382 

383 

384class GatewayConnectionError(GatewayError): 

385 """Raised when gateway connection fails. 

386 

387 Examples: 

388 >>> error = GatewayConnectionError("Connection failed") 

389 >>> str(error) 

390 'Connection failed' 

391 >>> isinstance(error, GatewayError) 

392 True 

393 """ 

394 

395 

396class OAuthToolValidationError(GatewayConnectionError): 

397 """Raised when tool validation fails during OAuth-driven fetch.""" 

398 

399 

400def _validate_gateway_team_assignment(db: Session, user_email: Optional[str], target_team_id: Optional[str]) -> None: 

401 """Validate team assignment for gateway updates. 

402 

403 Args: 

404 db: Database session used for membership checks. 

405 user_email: Requesting user email. When omitted, ownership checks are skipped. 

406 target_team_id: Team identifier to validate. 

407 

408 Raises: 

409 ValueError: If team does not exist or caller lacks ownership. 

410 """ 

411 if not target_team_id: 

412 raise ValueError("Cannot set visibility to 'team' without a team_id") 

413 

414 team = db.query(DbEmailTeam).filter(DbEmailTeam.id == target_team_id).first() 

415 if not team: 

416 raise ValueError(f"Team {target_team_id} not found") 

417 

418 if not user_email: 

419 return 

420 

421 membership = ( 

422 db.query(DbEmailTeamMember) 

423 .filter(DbEmailTeamMember.team_id == target_team_id, DbEmailTeamMember.user_email == user_email, DbEmailTeamMember.is_active, DbEmailTeamMember.role == "owner") 

424 .first() 

425 ) 

426 if not membership: 

427 raise ValueError("User membership in team not sufficient for this update.") 

428 

429 

430class GatewayService(BaseService): # pylint: disable=too-many-instance-attributes 

431 """Service for managing federated gateways. 

432 

433 Handles: 

434 - Gateway registration and health checks 

435 - Capability negotiation 

436 - Federation events 

437 - Active/inactive status management 

438 """ 

439 

440 _visibility_model_cls = DbGateway 

441 

442 def __init__(self) -> None: 

443 """Initialize the gateway service. 

444 

445 Examples: 

446 >>> from mcpgateway.services.gateway_service import GatewayService 

447 >>> from mcpgateway.services.event_service import EventService 

448 >>> from mcpgateway.utils.retry_manager import ResilientHttpClient 

449 >>> from mcpgateway.services.tool_service import ToolService 

450 >>> service = GatewayService() 

451 >>> isinstance(service._event_service, EventService) 

452 True 

453 >>> isinstance(service._http_client, ResilientHttpClient) 

454 True 

455 >>> service._health_check_interval == GW_HEALTH_CHECK_INTERVAL 

456 True 

457 >>> service._health_check_task is None 

458 True 

459 >>> isinstance(service._active_gateways, set) 

460 True 

461 >>> len(service._active_gateways) 

462 0 

463 >>> service._stream_response is None 

464 True 

465 >>> isinstance(service._pending_responses, dict) 

466 True 

467 >>> len(service._pending_responses) 

468 0 

469 >>> isinstance(service.tool_service, ToolService) 

470 True 

471 >>> isinstance(service._gateway_failure_counts, dict) 

472 True 

473 >>> len(service._gateway_failure_counts) 

474 0 

475 >>> hasattr(service, 'redis_url') 

476 True 

477 >>> 

478 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

479 >>> import asyncio 

480 >>> asyncio.run(service._http_client.aclose()) 

481 """ 

482 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) 

483 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL 

484 self._health_check_task: Optional[asyncio.Task] = None 

485 self._active_gateways: Set[str] = set() # Track active gateway URLs 

486 self._stream_response = None 

487 self._pending_responses = {} 

488 # Hot/cold server classification service (initialized in initialize()) 

489 self._classification_service: Optional[Any] = None 

490 # Prefer using the globally-initialized singletons from the service modules 

491 # so events propagate via their initialized EventService/Redis clients. 

492 # Import lazily and fall back to creating local instances when the module-level 

493 # __getattr__ singletons are not yet available (e.g. circular import during 

494 # Gunicorn --preload). 

495 # First-Party 

496 try: 

497 # First-Party 

498 from mcpgateway.services.prompt_service import prompt_service 

499 except ImportError: 

500 # First-Party 

501 from mcpgateway.services.prompt_service import PromptService 

502 

503 prompt_service = PromptService() 

504 try: 

505 # First-Party 

506 from mcpgateway.services.resource_service import resource_service 

507 except ImportError: 

508 # First-Party 

509 from mcpgateway.services.resource_service import ResourceService 

510 

511 resource_service = ResourceService() 

512 try: 

513 # First-Party 

514 from mcpgateway.services.tool_service import tool_service 

515 except ImportError: 

516 # First-Party 

517 from mcpgateway.services.tool_service import ToolService 

518 

519 tool_service = ToolService() 

520 

521 self.tool_service = tool_service 

522 self.prompt_service = prompt_service 

523 self.resource_service = resource_service 

524 self._gateway_failure_counts: dict[str, int] = {} 

525 self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3"))) 

526 self._event_service = EventService(channel_name="mcpgateway:gateway_events") 

527 

528 # Per-gateway refresh locks to prevent concurrent refreshes for the same gateway 

529 self._refresh_locks: Dict[str, asyncio.Lock] = {} 

530 

531 # For health checks, we determine the leader instance. 

532 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None 

533 

534 # Initialize optional Redis client holder (set in initialize()) 

535 self._redis_client: Optional[Any] = None 

536 

537 # Leader election settings from config 

538 if self.redis_url and REDIS_AVAILABLE: 

539 self._instance_id = str(uuid.uuid4()) # Unique ID for this process 

540 self._leader_key = settings.redis_leader_key 

541 self._leader_ttl = settings.redis_leader_ttl 

542 self._leader_heartbeat_interval = settings.redis_leader_heartbeat_interval 

543 self._leader_heartbeat_task: Optional[asyncio.Task] = None 

544 self._follower_election_task: Optional[asyncio.Task] = None 

545 

546 # Log instance mapping for debugging 

547 logger.info(f"Instance started: instance_id={self._instance_id}, port={settings.port}, pid={os.getpid()}") 

548 

549 # Always initialize file lock as fallback (used if Redis connection fails at runtime) 

550 if settings.cache_type != "none": 

551 temp_dir = tempfile.gettempdir() 

552 user_path = os.path.normpath(settings.filelock_name) 

553 if os.path.isabs(user_path): 

554 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep) 

555 full_path = os.path.join(temp_dir, user_path) 

556 self._lock_path = full_path.replace("\\", "/") 

557 self._file_lock = FileLock(self._lock_path) 

558 

559 @staticmethod 

560 def normalize_url(url: str) -> str: 

561 """ 

562 Normalize a URL by ensuring it's properly formatted. 

563 

564 Special handling for localhost to prevent duplicates: 

565 - Converts 127.0.0.1 to localhost for consistency 

566 - Preserves all other domain names as-is for CDN/load balancer support 

567 

568 Args: 

569 url (str): The URL to normalize. 

570 

571 Returns: 

572 str: The normalized URL. 

573 

574 Examples: 

575 >>> GatewayService.normalize_url('http://localhost:8080/path') 

576 'http://localhost:8080/path' 

577 >>> GatewayService.normalize_url('http://127.0.0.1:8080/path') 

578 'http://localhost:8080/path' 

579 >>> GatewayService.normalize_url('https://example.com/api') 

580 'https://example.com/api' 

581 """ 

582 parsed = urlparse(url) 

583 hostname = parsed.hostname 

584 

585 # Special case: normalize 127.0.0.1 to localhost to prevent duplicates 

586 # but preserve all other domains as-is for CDN/load balancer support 

587 if hostname == "127.0.0.1": 

588 netloc = "localhost" 

589 if parsed.port: 

590 netloc += f":{parsed.port}" 

591 normalized = parsed._replace(netloc=netloc) 

592 return str(urlunparse(normalized)) 

593 

594 # For all other URLs, preserve the domain name 

595 return url 

596 

597 @staticmethod 

598 async def _encrypt_client_key(client_key: Optional[str]) -> Optional[str]: 

599 """Encrypt a client private key for storage. 

600 

601 Args: 

602 client_key: Plaintext client private key or None. 

603 

604 Returns: 

605 Encrypted client key or None if input is None/empty. 

606 """ 

607 if not client_key: 

608 return None 

609 encryption = get_encryption_service(settings.auth_encryption_secret) 

610 if encryption.is_encrypted(client_key): 

611 return client_key 

612 return await encryption.encrypt_secret_async(client_key) 

613 

614 def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext: 

615 """Create an SSL context with the provided CA certificate. 

616 

617 Uses caching to avoid repeated SSL context creation for the same certificate. 

618 

619 Args: 

620 ca_certificate: CA certificate in PEM format 

621 

622 Returns: 

623 ssl.SSLContext: Configured SSL context 

624 """ 

625 return get_cached_ssl_context(ca_certificate) 

626 

627 async def initialize(self) -> None: 

628 """Initialize the service and start health check if this instance is the leader. 

629 

630 Raises: 

631 ConnectionError: When redis ping fails 

632 """ 

633 logger.info("Initializing gateway service") 

634 

635 # Initialize event service with shared Redis client 

636 await self._event_service.initialize() 

637 

638 # NOTE: We intentionally do NOT create a long-lived DB session here. 

639 # Health checks use fresh_db_session() only when DB access is actually needed, 

640 # avoiding holding connections during HTTP calls to MCP servers. 

641 

642 user_email = settings.platform_admin_email 

643 

644 # Get shared Redis client from factory 

645 if self.redis_url and REDIS_AVAILABLE: 

646 self._redis_client = await get_redis_client() 

647 

648 if self._redis_client: 

649 # Check if Redis is available (ping already done by factory, but verify) 

650 try: 

651 await self._redis_client.ping() 

652 except Exception as e: 

653 raise ConnectionError(f"Redis ping failed: {e}") from e 

654 

655 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) 

656 if is_leader: 

657 logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.") 

658 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

659 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat()) 

660 else: 

661 # Did not acquire leadership - start follower election loop 

662 logger.info("Did not acquire leadership. Starting follower election loop.") 

663 self._follower_election_task = asyncio.create_task(self._run_follower_election(user_email)) 

664 else: 

665 # No Redis available - always create the health check task in filelock mode 

666 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

667 

668 # Initialize hot/cold classification service (if enabled) 

669 if settings.hot_cold_classification_enabled: 

670 # First-Party 

671 from mcpgateway.services.server_classification_service import ServerClassificationService 

672 

673 self._classification_service = ServerClassificationService(redis_client=self._redis_client) 

674 await self._classification_service.start() 

675 logger.info("Hot/cold classification service initialized") 

676 

677 async def shutdown(self) -> None: 

678 """Shutdown the service. 

679 

680 Examples: 

681 >>> service = GatewayService() 

682 >>> # Mock internal components 

683 >>> from unittest.mock import AsyncMock 

684 >>> service._event_service = AsyncMock() 

685 >>> service._active_gateways = {'test_gw'} 

686 >>> import asyncio 

687 >>> asyncio.run(service.shutdown()) 

688 >>> # Verify event service shutdown was called 

689 >>> service._event_service.shutdown.assert_awaited_once() 

690 >>> len(service._active_gateways) 

691 0 

692 """ 

693 # Cancel follower election FIRST to prevent it from spawning new 

694 # health-check / heartbeat tasks while we are tearing down. 

695 if getattr(self, "_follower_election_task", None): 

696 self._follower_election_task.cancel() 

697 try: 

698 await self._follower_election_task 

699 except asyncio.CancelledError: 

700 pass 

701 

702 # Now safe to cancel health-check and heartbeat (handles may have been 

703 # overwritten by follower election just before cancellation — that is fine, 

704 # we always cancel whichever task the attribute currently points to). 

705 if self._health_check_task: 

706 self._health_check_task.cancel() 

707 try: 

708 await self._health_check_task 

709 except asyncio.CancelledError: 

710 pass 

711 

712 # Stop classification service 

713 if self._classification_service: 

714 await self._classification_service.stop() 

715 logger.info("Classification service stopped") 

716 

717 # Cancel leader heartbeat task if running 

718 if getattr(self, "_leader_heartbeat_task", None): 

719 self._leader_heartbeat_task.cancel() 

720 try: 

721 await self._leader_heartbeat_task 

722 except asyncio.CancelledError: 

723 pass 

724 

725 # Release Redis leadership atomically if we hold it 

726 if self._redis_client: 

727 try: 

728 # Lua script for atomic check-and-delete (only delete if we own the key) 

729 release_script = """ 

730 if redis.call("get", KEYS[1]) == ARGV[1] then 

731 return redis.call("del", KEYS[1]) 

732 else 

733 return 0 

734 end 

735 """ 

736 result = await self._redis_client.eval(release_script, 1, self._leader_key, self._instance_id) 

737 if result: 

738 logger.info("Released Redis leadership on shutdown") 

739 except Exception as e: 

740 logger.warning(f"Failed to release Redis leader key on shutdown: {e}") 

741 

742 await self._http_client.aclose() 

743 await self._event_service.shutdown() 

744 self._active_gateways.clear() 

745 logger.info("Gateway service shutdown complete") 

746 

747 def _check_gateway_uniqueness( 

748 self, 

749 db: Session, 

750 url: str, 

751 auth_value: Optional[Dict[str, str]], 

752 oauth_config: Optional[Dict[str, Any]], 

753 team_id: Optional[str], 

754 owner_email: str, 

755 visibility: str, 

756 gateway_id: Optional[str] = None, 

757 ) -> Optional[DbGateway]: 

758 """ 

759 Check if a gateway with the same URL and credentials already exists. 

760 

761 Args: 

762 db: Database session 

763 url: Gateway URL (normalized) 

764 auth_value: Decoded auth_value dict (not encrypted) 

765 oauth_config: OAuth configuration dict 

766 team_id: Team ID for team-scoped gateways 

767 owner_email: Email of the gateway owner 

768 visibility: Gateway visibility (public/team/private) 

769 gateway_id: Optional gateway ID to exclude from check (for updates) 

770 

771 Returns: 

772 DbGateway if duplicate found, None otherwise 

773 """ 

774 # Build base query based on visibility 

775 if visibility == "public": 

776 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "public") 

777 elif visibility == "team" and team_id: 

778 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "team", DbGateway.team_id == team_id) 

779 elif visibility == "private": 

780 # Check for duplicates within the same user's private gateways 

781 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "private", DbGateway.owner_email == owner_email) # Scoped to same user 

782 else: 

783 return None 

784 

785 # Exclude current gateway if updating 

786 if gateway_id: 

787 query = query.filter(DbGateway.id != gateway_id) 

788 

789 existing_gateways = query.all() 

790 

791 # Check each existing gateway 

792 for existing in existing_gateways: 

793 # Case 1: Both have OAuth config 

794 if oauth_config and existing.oauth_config: 

795 # Compare OAuth configs (exclude dynamic fields like tokens) 

796 existing_oauth = existing.oauth_config or {} 

797 new_oauth = oauth_config or {} 

798 

799 # Compare key OAuth fields 

800 oauth_keys = ["grant_type", "client_id", "authorization_url", "token_url", "scope"] 

801 if all(existing_oauth.get(k) == new_oauth.get(k) for k in oauth_keys): 

802 return existing # Duplicate OAuth config found 

803 

804 # Case 2: Both have auth_value (need to decrypt and compare) 

805 elif auth_value and existing.auth_value: 

806 

807 try: 

808 # Decrypt existing auth_value 

809 if isinstance(existing.auth_value, str): 

810 existing_decoded = decode_auth(existing.auth_value) 

811 

812 elif isinstance(existing.auth_value, dict): 

813 existing_decoded = existing.auth_value 

814 

815 else: 

816 continue 

817 

818 # Compare decoded auth values 

819 if auth_value == existing_decoded: 

820 return existing # Duplicate credentials found 

821 except Exception as e: 

822 logger.warning(f"Failed to decode auth_value for comparison: {e}") 

823 continue 

824 

825 # Case 3: Both have no auth (URL only, not allowed) 

826 elif not auth_value and not oauth_config and not existing.auth_value and not existing.oauth_config: 

827 return existing # Duplicate URL without credentials 

828 

829 return None # No duplicate found 

830 

831 async def register_gateway( 

832 self, 

833 db: Session, 

834 gateway: GatewayCreate, 

835 created_by: Optional[str] = None, 

836 created_from_ip: Optional[str] = None, 

837 created_via: Optional[str] = None, 

838 created_user_agent: Optional[str] = None, 

839 team_id: Optional[str] = None, 

840 owner_email: Optional[str] = None, 

841 visibility: Optional[str] = None, 

842 initialize_timeout: Optional[float] = None, 

843 ) -> GatewayRead: 

844 """Register a new gateway. 

845 

846 Args: 

847 db: Database session 

848 gateway: Gateway creation schema 

849 created_by: Username who created this gateway 

850 created_from_ip: IP address of creator 

851 created_via: Creation method (ui, api, federation) 

852 created_user_agent: User agent of creation request 

853 team_id (Optional[str]): Team ID to assign the gateway to. 

854 owner_email (Optional[str]): Email of the user who owns this gateway. 

855 visibility (Optional[str]): Gateway visibility level (private, team, public). 

856 initialize_timeout (Optional[float]): Timeout in seconds for gateway initialization. 

857 

858 Returns: 

859 Created gateway information 

860 

861 Raises: 

862 GatewayNameConflictError: If gateway name already exists 

863 GatewayConnectionError: If there was an error connecting to the gateway 

864 ValueError: If required values are missing 

865 RuntimeError: If there is an error during processing that is not covered by other exceptions 

866 IntegrityError: If there is a database integrity error 

867 BaseException: If an unexpected error occurs 

868 

869 Examples: 

870 >>> from mcpgateway.services.gateway_service import GatewayService 

871 >>> from unittest.mock import MagicMock 

872 >>> service = GatewayService() 

873 >>> db = MagicMock() 

874 >>> gateway = MagicMock() 

875 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

876 >>> db.add = MagicMock() 

877 >>> db.commit = MagicMock() 

878 >>> db.refresh = MagicMock() 

879 >>> service._notify_gateway_added = MagicMock() 

880 >>> import asyncio 

881 >>> try: 

882 ... asyncio.run(service.register_gateway(db, gateway)) 

883 ... except Exception: 

884 ... pass 

885 >>> 

886 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

887 >>> asyncio.run(service._http_client.aclose()) 

888 """ 

889 visibility = "public" if visibility not in ("private", "team", "public") else visibility 

890 try: 

891 # # Check for name conflicts (both active and inactive) 

892 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none() 

893 

894 # if existing_gateway: 

895 # raise GatewayNameConflictError( 

896 # gateway.name, 

897 # enabled=existing_gateway.enabled, 

898 # gateway_id=existing_gateway.id, 

899 # ) 

900 # Check for existing gateway with the same slug and visibility 

901 slug_name = slugify(gateway.name) 

902 if visibility.lower() == "public": 

903 # Check for existing public gateway with the same slug (row-locked) 

904 existing_gateway = get_for_update( 

905 db, 

906 DbGateway, 

907 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "public"), 

908 ) 

909 if existing_gateway: 

910 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility) 

911 elif visibility.lower() == "team" and team_id: 

912 # Check for existing team gateway with the same slug (row-locked) 

913 existing_gateway = get_for_update( 

914 db, 

915 DbGateway, 

916 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "team", DbGateway.team_id == team_id), 

917 ) 

918 if existing_gateway: 

919 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility) 

920 

921 # Normalize the gateway URL 

922 normalized_url = self.normalize_url(str(gateway.url)) 

923 

924 decoded_auth_value = None 

925 if gateway.auth_value: 

926 if isinstance(gateway.auth_value, str): 

927 try: 

928 decoded_auth_value = decode_auth(gateway.auth_value) 

929 except Exception as e: 

930 logger.warning(f"Failed to decode provided auth_value: {e}") 

931 decoded_auth_value = None 

932 elif isinstance(gateway.auth_value, dict): 

933 decoded_auth_value = gateway.auth_value 

934 

935 # Check for duplicate gateway 

936 if not gateway.one_time_auth: 

937 duplicate_gateway = self._check_gateway_uniqueness( 

938 db=db, url=normalized_url, auth_value=decoded_auth_value, oauth_config=gateway.oauth_config, team_id=team_id, owner_email=owner_email, visibility=visibility 

939 ) 

940 

941 if duplicate_gateway: 

942 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway) 

943 

944 # Prevent URL-only gateways (no auth at all) 

945 # if not decoded_auth_value and not gateway.oauth_config: 

946 # raise ValueError( 

947 # f"Gateway with URL '{normalized_url}' must have either auth_value or oauth_config. " 

948 # "URL-only gateways are not allowed." 

949 # ) 

950 

951 auth_type = getattr(gateway, "auth_type", None) 

952 # Support multiple custom headers 

953 auth_value = getattr(gateway, "auth_value", {}) 

954 authentication_headers: Optional[Dict[str, str]] = None 

955 

956 # Handle query_param auth - encrypt and prepare for storage 

957 auth_query_params_encrypted: Optional[Dict[str, str]] = None 

958 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

959 init_url = normalized_url # URL to use for initialization 

960 

961 if auth_type == "query_param": 

962 # Extract and encrypt query param auth 

963 param_key = getattr(gateway, "auth_query_param_key", None) 

964 param_value = getattr(gateway, "auth_query_param_value", None) 

965 if param_key and param_value: 

966 # Get the actual secret value 

967 if hasattr(param_value, "get_secret_value"): 

968 raw_value = param_value.get_secret_value() 

969 else: 

970 raw_value = str(param_value) 

971 # Encrypt for storage 

972 encrypted_value = encode_auth({param_key: raw_value}) 

973 auth_query_params_encrypted = {param_key: encrypted_value} 

974 auth_query_params_decrypted = {param_key: raw_value} 

975 # Append query params to URL for initialization 

976 init_url = apply_query_param_auth(normalized_url, auth_query_params_decrypted) 

977 # Query param auth doesn't use auth_value 

978 auth_value = None 

979 authentication_headers = None 

980 

981 elif hasattr(gateway, "auth_headers") and gateway.auth_headers: 

982 # Convert list of {key, value} to dict 

983 header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")} 

984 auth_value = header_dict # store plain dict, consistent with update path and DB column type 

985 authentication_headers = {str(k): str(v) for k, v in header_dict.items()} 

986 

987 elif isinstance(auth_value, str) and auth_value: 

988 # Decode persisted auth for initialization 

989 decoded = decode_auth(auth_value) 

990 authentication_headers = {str(k): str(v) for k, v in decoded.items()} 

991 else: 

992 authentication_headers = None 

993 

994 oauth_config = await protect_oauth_config_for_storage(getattr(gateway, "oauth_config", None)) 

995 ca_certificate = getattr(gateway, "ca_certificate", None) 

996 init_client_cert = getattr(gateway, "client_cert", None) 

997 init_client_key = getattr(gateway, "client_key", None) 

998 

999 # Check if gateway is in direct_proxy mode 

1000 gateway_mode = getattr(gateway, "gateway_mode", "cache") 

1001 

1002 if gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled: 

1003 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.") 

1004 

1005 if initialize_timeout is not None: 

1006 try: 

1007 capabilities, tools, resources, prompts = await asyncio.wait_for( 

1008 self._initialize_gateway( 

1009 init_url, # URL with query params if applicable 

1010 authentication_headers, 

1011 gateway.transport, 

1012 auth_type, 

1013 oauth_config, 

1014 ca_certificate, 

1015 auth_query_params=auth_query_params_decrypted, 

1016 client_cert=init_client_cert, 

1017 client_key=init_client_key, 

1018 ), 

1019 timeout=initialize_timeout, 

1020 ) 

1021 except asyncio.TimeoutError as exc: 

1022 sanitized = sanitize_url_for_logging(init_url, auth_query_params_decrypted) 

1023 raise GatewayConnectionError(f"Gateway initialization timed out after {initialize_timeout}s for {sanitized}") from exc 

1024 else: 

1025 capabilities, tools, resources, prompts = await self._initialize_gateway( 

1026 init_url, # URL with query params if applicable 

1027 authentication_headers, 

1028 gateway.transport, 

1029 auth_type, 

1030 oauth_config, 

1031 ca_certificate, 

1032 auth_query_params=auth_query_params_decrypted, 

1033 client_cert=init_client_cert, 

1034 client_key=init_client_key, 

1035 ) 

1036 

1037 if gateway.one_time_auth: 

1038 # For one-time auth, clear auth_type and auth_value after initialization 

1039 auth_type = "one_time_auth" 

1040 auth_value = None 

1041 oauth_config = None 

1042 

1043 # DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before 

1044 # storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and 

1045 # receives the plain dict directly (see assignment above). 

1046 tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value 

1047 

1048 tools = [ 

1049 DbTool( 

1050 original_name=tool.name, 

1051 custom_name=tool.name, 

1052 custom_name_slug=slugify(tool.name), 

1053 display_name=generate_display_name(tool.name), 

1054 title=_resolve_tool_title(tool), 

1055 url=normalized_url, 

1056 original_description=tool.description, 

1057 description=tool.description, 

1058 integration_type="MCP", # Gateway-discovered tools are MCP type 

1059 request_type=tool.request_type, 

1060 headers=tool.headers, 

1061 input_schema=tool.input_schema, 

1062 output_schema=tool.output_schema, 

1063 annotations=tool.annotations, 

1064 jsonpath_filter=tool.jsonpath_filter, 

1065 auth_type=auth_type, 

1066 auth_value=tool_auth_value, 

1067 # Federation metadata 

1068 created_by=created_by or "system", 

1069 created_from_ip=created_from_ip, 

1070 created_via="federation", # These are federated tools 

1071 created_user_agent=created_user_agent, 

1072 federation_source=gateway.name, 

1073 version=1, 

1074 # Inherit team assignment from gateway 

1075 team_id=team_id, 

1076 owner_email=owner_email, 

1077 visibility=visibility, 

1078 ) 

1079 for tool in tools 

1080 ] 

1081 

1082 # Create resource DB models with upsert logic for ORPHANED resources only 

1083 # Query for existing ORPHANED resources (gateway_id IS NULL or points to non-existent gateway) 

1084 # with same (team_id, owner_email, uri) to handle resources left behind from incomplete 

1085 # gateway deletions (e.g., issue #2341 crash scenarios). 

1086 # We only update orphaned resources - resources belonging to active gateways are not touched. 

1087 resource_uris = [r.uri for r in resources] 

1088 effective_owner = owner_email or created_by 

1089 

1090 # Build lookup map: (team_id, owner_email, uri) -> orphaned DbResource 

1091 # We query all resources matching our URIs, then filter to orphaned ones in Python 

1092 # to handle per-resource team/owner overrides correctly 

1093 orphaned_resources_map: Dict[tuple, DbResource] = {} 

1094 if resource_uris: 

1095 try: 

1096 # Get valid gateway IDs to identify orphaned resources 

1097 valid_gateway_ids = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all()) 

1098 candidate_resources = db.execute(select(DbResource).where(DbResource.uri.in_(resource_uris))).scalars().all() 

1099 for res in candidate_resources: 

1100 # Only consider orphaned resources (no gateway or gateway doesn't exist) 

1101 is_orphaned = res.gateway_id is None or res.gateway_id not in valid_gateway_ids 

1102 if is_orphaned: 

1103 key = (res.team_id, res.owner_email, res.uri) 

1104 orphaned_resources_map[key] = res 

1105 if orphaned_resources_map: 

1106 logger.info(f"Found {len(orphaned_resources_map)} orphaned resources to reassign for gateway {SecurityValidator.sanitize_log_message(gateway.name)}") 

1107 except Exception as e: 

1108 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new resources 

1109 # This is conservative - we won't accidentally reassign resources from active gateways 

1110 logger.debug(f"Orphan resource detection skipped: {e}") 

1111 

1112 db_resources = [] 

1113 for r in resources: 

1114 mime_type = mimetypes.guess_type(r.uri)[0] or ("text/plain" if isinstance(r.content, str) else "application/octet-stream") 

1115 r_team_id = getattr(r, "team_id", None) or team_id 

1116 r_owner_email = getattr(r, "owner_email", None) or effective_owner 

1117 r_visibility = getattr(r, "visibility", None) or visibility 

1118 

1119 # Check if there's an orphaned resource with matching unique key 

1120 lookup_key = (r_team_id, r_owner_email, r.uri) 

1121 if lookup_key in orphaned_resources_map: 

1122 # Update orphaned resource - reassign to new gateway 

1123 existing = orphaned_resources_map[lookup_key] 

1124 existing.name = r.name 

1125 existing.description = r.description 

1126 existing.mime_type = mime_type 

1127 existing.uri_template = r.uri_template or None 

1128 existing.text_content = r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None 

1129 existing.binary_content = ( 

1130 r.content.encode() if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else r.content if isinstance(r.content, bytes) else None 

1131 ) 

1132 existing.size = len(r.content) if r.content else 0 

1133 existing.title = getattr(r, "title", None) 

1134 existing.tags = getattr(r, "tags", []) or [] 

1135 existing.federation_source = gateway.name 

1136 existing.modified_by = created_by 

1137 existing.modified_from_ip = created_from_ip 

1138 existing.modified_via = "federation" 

1139 existing.modified_user_agent = created_user_agent 

1140 existing.updated_at = datetime.now(timezone.utc) 

1141 existing.visibility = r_visibility 

1142 # Note: gateway_id will be set when gateway is created (relationship) 

1143 db_resources.append(existing) 

1144 else: 

1145 # Create new resource 

1146 db_resources.append( 

1147 DbResource( 

1148 uri=r.uri, 

1149 name=r.name, 

1150 title=getattr(r, "title", None), 

1151 description=r.description, 

1152 mime_type=mime_type, 

1153 uri_template=r.uri_template or None, 

1154 text_content=r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None, 

1155 binary_content=( 

1156 r.content.encode() 

1157 if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) 

1158 else r.content if isinstance(r.content, bytes) else None 

1159 ), 

1160 size=len(r.content) if r.content else 0, 

1161 tags=getattr(r, "tags", []) or [], 

1162 created_by=created_by or "system", 

1163 created_from_ip=created_from_ip, 

1164 created_via="federation", 

1165 created_user_agent=created_user_agent, 

1166 import_batch_id=None, 

1167 federation_source=gateway.name, 

1168 version=1, 

1169 team_id=r_team_id, 

1170 owner_email=r_owner_email, 

1171 visibility=r_visibility, 

1172 ) 

1173 ) 

1174 

1175 # Create prompt DB models with upsert logic for ORPHANED prompts only 

1176 # Query for existing ORPHANED prompts (gateway_id IS NULL or points to non-existent gateway) 

1177 # with same (team_id, owner_email, name) to handle prompts left behind from incomplete 

1178 # gateway deletions. We only update orphaned prompts - prompts belonging to active gateways are not touched. 

1179 prompt_names = [p.name for p in prompts] 

1180 

1181 # Build lookup map: (team_id, owner_email, name) -> orphaned DbPrompt 

1182 orphaned_prompts_map: Dict[tuple, DbPrompt] = {} 

1183 if prompt_names: 

1184 try: 

1185 # Get valid gateway IDs to identify orphaned prompts 

1186 valid_gateway_ids_for_prompts = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all()) 

1187 candidate_prompts = db.execute(select(DbPrompt).where(DbPrompt.name.in_(prompt_names))).scalars().all() 

1188 for pmt in candidate_prompts: 

1189 # Only consider orphaned prompts (no gateway or gateway doesn't exist) 

1190 is_orphaned = pmt.gateway_id is None or pmt.gateway_id not in valid_gateway_ids_for_prompts 

1191 if is_orphaned: 

1192 key = (pmt.team_id, pmt.owner_email, pmt.name) 

1193 orphaned_prompts_map[key] = pmt 

1194 if orphaned_prompts_map: 

1195 logger.info(f"Found {len(orphaned_prompts_map)} orphaned prompts to reassign for gateway {SecurityValidator.sanitize_log_message(gateway.name)}") 

1196 except Exception as e: 

1197 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new prompts 

1198 logger.debug(f"Orphan prompt detection skipped: {e}") 

1199 

1200 db_prompts = [] 

1201 for prompt in prompts: 

1202 # Prompts inherit team/owner from gateway (no per-prompt overrides) 

1203 p_team_id = team_id 

1204 p_owner_email = owner_email or effective_owner 

1205 

1206 # Check if there's an orphaned prompt with matching unique key 

1207 lookup_key = (p_team_id, p_owner_email, prompt.name) 

1208 if lookup_key in orphaned_prompts_map: 

1209 # Update orphaned prompt - reassign to new gateway 

1210 existing = orphaned_prompts_map[lookup_key] 

1211 existing.original_name = prompt.name 

1212 existing.custom_name = prompt.name 

1213 existing.display_name = prompt.name 

1214 existing.title = getattr(prompt, "title", None) 

1215 existing.description = prompt.description 

1216 existing.template = prompt.template if hasattr(prompt, "template") else "" 

1217 existing.argument_schema = self._build_prompt_argument_schema(prompt) 

1218 existing.federation_source = gateway.name 

1219 existing.modified_by = created_by 

1220 existing.modified_from_ip = created_from_ip 

1221 existing.modified_via = "federation" 

1222 existing.modified_user_agent = created_user_agent 

1223 existing.updated_at = datetime.now(timezone.utc) 

1224 existing.visibility = visibility 

1225 # Note: gateway_id will be set when gateway is created (relationship) 

1226 db_prompts.append(existing) 

1227 else: 

1228 # Create new prompt 

1229 db_prompts.append( 

1230 DbPrompt( 

1231 name=prompt.name, 

1232 original_name=prompt.name, 

1233 custom_name=prompt.name, 

1234 display_name=prompt.name, 

1235 title=getattr(prompt, "title", None), 

1236 description=prompt.description, 

1237 template=prompt.template if hasattr(prompt, "template") else "", 

1238 argument_schema=self._build_prompt_argument_schema(prompt), 

1239 # Federation metadata 

1240 created_by=created_by or "system", 

1241 created_from_ip=created_from_ip, 

1242 created_via="federation", # These are federated prompts 

1243 created_user_agent=created_user_agent, 

1244 federation_source=gateway.name, 

1245 version=1, 

1246 # Inherit team assignment from gateway 

1247 team_id=team_id, 

1248 owner_email=owner_email, 

1249 visibility=visibility, 

1250 ) 

1251 ) 

1252 

1253 # Create DB model 

1254 db_gateway = DbGateway( 

1255 name=gateway.name, 

1256 slug=slug_name, 

1257 url=normalized_url, 

1258 description=gateway.description, 

1259 tags=gateway.tags or [], 

1260 transport=gateway.transport, 

1261 capabilities=capabilities, 

1262 last_seen=datetime.now(timezone.utc), 

1263 auth_type=auth_type, 

1264 auth_value=auth_value, 

1265 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth 

1266 oauth_config=oauth_config, 

1267 passthrough_headers=gateway.passthrough_headers, 

1268 tools=tools, 

1269 resources=db_resources, 

1270 prompts=db_prompts, 

1271 # Gateway metadata 

1272 created_by=created_by, 

1273 created_from_ip=created_from_ip, 

1274 created_via=created_via or "api", 

1275 created_user_agent=created_user_agent, 

1276 version=1, 

1277 # Team scoping fields 

1278 team_id=team_id, 

1279 owner_email=owner_email, 

1280 visibility=visibility, 

1281 ca_certificate=gateway.ca_certificate, 

1282 ca_certificate_sig=gateway.ca_certificate_sig, 

1283 signing_algorithm=gateway.signing_algorithm, 

1284 # mTLS client certificate/key 

1285 client_cert=getattr(gateway, "client_cert", None), 

1286 client_key=await self._encrypt_client_key(getattr(gateway, "client_key", None)), 

1287 # Gateway mode configuration 

1288 gateway_mode=gateway_mode, 

1289 ) 

1290 

1291 # Add to DB and commit immediately so tools/resources/prompts are visible 

1292 # to other workers before the HTTP response reaches the client. 

1293 # Without this, clients issuing follow-up requests (e.g., manual refresh) 

1294 # can hit a different worker that hasn't seen the uncommitted data yet. 

1295 db.add(db_gateway) 

1296 db.commit() 

1297 db.refresh(db_gateway) 

1298 

1299 # Update tracking 

1300 self._active_gateways.add(db_gateway.url) 

1301 

1302 # Notify subscribers 

1303 await self._notify_gateway_added(db_gateway) 

1304 

1305 # Invalidate caches so other workers see the new gateway and its tools/resources/prompts 

1306 cache = _get_registry_cache() 

1307 await cache.invalidate_gateways() 

1308 await cache.invalidate_tools() 

1309 await cache.invalidate_resources() 

1310 await cache.invalidate_prompts() 

1311 tool_lookup_cache = _get_tool_lookup_cache() 

1312 await tool_lookup_cache.invalidate_gateway(str(db_gateway.id)) 

1313 # First-Party 

1314 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

1315 

1316 await admin_stats_cache.invalidate_tags() 

1317 

1318 # Invalidate loopback passthrough cache when a new gateway has passthrough headers (#3640) 

1319 if gateway.passthrough_headers: 

1320 # First-Party 

1321 from mcpgateway.utils.passthrough_headers import invalidate_passthrough_header_caches # pylint: disable=import-outside-toplevel 

1322 

1323 invalidate_passthrough_header_caches() 

1324 

1325 logger.info(f"Registered gateway: {SecurityValidator.sanitize_log_message(gateway.name)}") 

1326 

1327 # Structured logging: Audit trail for gateway creation 

1328 audit_trail.log_action( 

1329 user_id=created_by or "system", 

1330 action="create_gateway", 

1331 resource_type="gateway", 

1332 resource_id=str(db_gateway.id), 

1333 resource_name=db_gateway.name, 

1334 user_email=owner_email, 

1335 team_id=team_id, 

1336 client_ip=created_from_ip, 

1337 user_agent=created_user_agent, 

1338 new_values={ 

1339 "name": db_gateway.name, 

1340 "url": db_gateway.url, 

1341 "visibility": visibility, 

1342 "transport": db_gateway.transport, 

1343 "tools_count": len(tools), 

1344 "resources_count": len(db_resources), 

1345 "prompts_count": len(db_prompts), 

1346 }, 

1347 context={ 

1348 "created_via": created_via, 

1349 }, 

1350 db=db, 

1351 ) 

1352 

1353 # Structured logging: Log successful gateway creation 

1354 structured_logger.log( 

1355 level="INFO", 

1356 message="Gateway created successfully", 

1357 event_type="gateway_created", 

1358 component="gateway_service", 

1359 user_id=created_by, 

1360 user_email=owner_email, 

1361 team_id=team_id, 

1362 resource_type="gateway", 

1363 resource_id=str(db_gateway.id), 

1364 custom_fields={ 

1365 "gateway_name": db_gateway.name, 

1366 "gateway_url": normalized_url, 

1367 "visibility": visibility, 

1368 "transport": db_gateway.transport, 

1369 }, 

1370 ) 

1371 

1372 return self.convert_gateway_to_read(db_gateway) 

1373 except* GatewayConnectionError as ge: # pragma: no mutate 

1374 if TYPE_CHECKING: 

1375 ge: ExceptionGroup[GatewayConnectionError] 

1376 logger.error(f"GatewayConnectionError in group: {ge.exceptions}") 

1377 db.rollback() 

1378 

1379 structured_logger.log( 

1380 level="ERROR", 

1381 message="Gateway creation failed due to connection error", 

1382 event_type="gateway_creation_failed", 

1383 component="gateway_service", 

1384 user_id=created_by, 

1385 user_email=owner_email, 

1386 error=ge.exceptions[0], 

1387 custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)}, 

1388 ) 

1389 raise ge.exceptions[0] 

1390 except* GatewayNameConflictError as gnce: # pragma: no mutate 

1391 if TYPE_CHECKING: 

1392 gnce: ExceptionGroup[GatewayNameConflictError] 

1393 logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") 

1394 db.rollback() 

1395 

1396 structured_logger.log( 

1397 level="WARNING", 

1398 message="Gateway creation failed due to name conflict", 

1399 event_type="gateway_name_conflict", 

1400 component="gateway_service", 

1401 user_id=created_by, 

1402 user_email=owner_email, 

1403 custom_fields={"gateway_name": gateway.name, "visibility": visibility}, 

1404 ) 

1405 raise gnce.exceptions[0] 

1406 except* GatewayDuplicateConflictError as guce: # pragma: no mutate 

1407 if TYPE_CHECKING: 

1408 guce: ExceptionGroup[GatewayDuplicateConflictError] 

1409 logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}") 

1410 db.rollback() 

1411 

1412 structured_logger.log( 

1413 level="WARNING", 

1414 message="Gateway creation failed due to duplicate", 

1415 event_type="gateway_duplicate_conflict", 

1416 component="gateway_service", 

1417 user_id=created_by, 

1418 user_email=owner_email, 

1419 custom_fields={"gateway_name": gateway.name}, 

1420 ) 

1421 raise guce.exceptions[0] 

1422 except* ValueError as ve: # pragma: no mutate 

1423 if TYPE_CHECKING: 

1424 ve: ExceptionGroup[ValueError] 

1425 logger.error(f"ValueErrors in group: {ve.exceptions}") 

1426 db.rollback() 

1427 

1428 structured_logger.log( 

1429 level="ERROR", 

1430 message="Gateway creation failed due to validation error", 

1431 event_type="gateway_creation_failed", 

1432 component="gateway_service", 

1433 user_id=created_by, 

1434 user_email=owner_email, 

1435 error=ve.exceptions[0], 

1436 custom_fields={"gateway_name": gateway.name}, 

1437 ) 

1438 raise ve.exceptions[0] 

1439 except* RuntimeError as re: # pragma: no mutate 

1440 if TYPE_CHECKING: 

1441 re: ExceptionGroup[RuntimeError] 

1442 logger.error(f"RuntimeErrors in group: {re.exceptions}") 

1443 db.rollback() 

1444 

1445 structured_logger.log( 

1446 level="ERROR", 

1447 message="Gateway creation failed due to runtime error", 

1448 event_type="gateway_creation_failed", 

1449 component="gateway_service", 

1450 user_id=created_by, 

1451 user_email=owner_email, 

1452 error=re.exceptions[0], 

1453 custom_fields={"gateway_name": gateway.name}, 

1454 ) 

1455 raise re.exceptions[0] 

1456 except* IntegrityError as ie: # pragma: no mutate 

1457 if TYPE_CHECKING: 

1458 ie: ExceptionGroup[IntegrityError] 

1459 logger.error(f"IntegrityErrors in group: {ie.exceptions}") 

1460 db.rollback() 

1461 

1462 structured_logger.log( 

1463 level="ERROR", 

1464 message="Gateway creation failed due to database integrity error", 

1465 event_type="gateway_creation_failed", 

1466 component="gateway_service", 

1467 user_id=created_by, 

1468 user_email=owner_email, 

1469 error=ie.exceptions[0], 

1470 custom_fields={"gateway_name": gateway.name}, 

1471 ) 

1472 raise ie.exceptions[0] 

1473 except* BaseException as other: # catches every other sub-exception # pragma: no mutate 

1474 if TYPE_CHECKING: 

1475 other: ExceptionGroup[Exception] 

1476 logger.error(f"Other grouped errors: {other.exceptions}") 

1477 db.rollback() 

1478 raise other.exceptions[0] 

1479 

1480 async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_email: str) -> Dict[str, Any]: 

1481 """Fetch tools from MCP server after OAuth completion for Authorization Code flow. 

1482 

1483 Args: 

1484 db: Database session 

1485 gateway_id: ID of the gateway to fetch tools for 

1486 app_user_email: ContextForge user email for token retrieval 

1487 

1488 Returns: 

1489 Dict containing capabilities, tools, resources, and prompts 

1490 

1491 Raises: 

1492 GatewayConnectionError: If connection or OAuth fails 

1493 """ 

1494 try: 

1495 # Get the gateway with eager loading for sync operations to avoid N+1 queries 

1496 gateway = db.execute( 

1497 select(DbGateway) 

1498 .options( 

1499 selectinload(DbGateway.tools), 

1500 selectinload(DbGateway.resources), 

1501 selectinload(DbGateway.prompts), 

1502 joinedload(DbGateway.email_team), 

1503 ) 

1504 .where(DbGateway.id == gateway_id) 

1505 ).scalar_one_or_none() 

1506 

1507 if not gateway: 

1508 raise ValueError(f"Gateway {gateway_id} not found") 

1509 

1510 if not gateway.oauth_config: 

1511 raise ValueError(f"Gateway {gateway_id} has no OAuth configuration") 

1512 

1513 grant_type = gateway.oauth_config.get("grant_type") 

1514 if grant_type != "authorization_code": 

1515 raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow") 

1516 

1517 # Get OAuth tokens for this gateway 

1518 # First-Party 

1519 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

1520 

1521 token_storage = TokenStorageService(db) 

1522 

1523 # Get user-specific OAuth token 

1524 if not app_user_email: 

1525 raise GatewayConnectionError(f"User authentication required for OAuth gateway {gateway.name}") 

1526 

1527 access_token = await token_storage.get_user_token(gateway.id, app_user_email) 

1528 

1529 if not access_token: 

1530 raise GatewayConnectionError( 

1531 f"No OAuth tokens found for user {app_user_email} on gateway {gateway.name}. Please complete the OAuth authorization flow first at /oauth/authorize/{gateway.id}" 

1532 ) 

1533 

1534 # Debug: Check if token was decrypted 

1535 if access_token.startswith("Z0FBQUFBQm"): # Encrypted tokens start with this 

1536 logger.error("OAuth token decryption may have failed before gateway initialization") 

1537 else: 

1538 logger.info("Using decrypted OAuth token for gateway %s", gateway.name) 

1539 

1540 # Now connect to MCP server with the access token 

1541 authentication = {"Authorization": f"Bearer {access_token}"} 

1542 

1543 # Use the existing connection logic 

1544 # Note: For OAuth servers, skip validation since we already validated via OAuth flow 

1545 if gateway.transport.upper() == "SSE": 

1546 capabilities, tools, resources, prompts = await self._connect_to_sse_server_without_validation(gateway.url, authentication) 

1547 elif gateway.transport.upper() == "STREAMABLEHTTP": 

1548 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication) 

1549 else: 

1550 raise ValueError(f"Unsupported transport type: {gateway.transport}") 

1551 

1552 # Handle tools, resources, and prompts using helper methods 

1553 tools_to_add = self._update_or_create_tools(db, tools, gateway, "oauth") 

1554 resources_to_add = self._update_or_create_resources(db, resources, gateway, "oauth") 

1555 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "oauth") 

1556 

1557 # Clean up items that are no longer available from the gateway 

1558 new_tool_names = [tool.name for tool in tools] 

1559 new_resource_uris = [resource.uri for resource in resources] 

1560 new_prompt_names = [prompt.name for prompt in prompts] 

1561 

1562 # Count items before cleanup for logging 

1563 

1564 # Bulk delete tools that are no longer available from the gateway 

1565 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

1566 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

1567 if stale_tool_ids: 

1568 # Delete child records first to avoid FK constraint violations 

1569 for i in range(0, len(stale_tool_ids), 500): 

1570 chunk = stale_tool_ids[i : i + 500] 

1571 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

1572 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

1573 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

1574 

1575 # Bulk delete resources that are no longer available from the gateway 

1576 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

1577 if stale_resource_ids: 

1578 # Delete child records first to avoid FK constraint violations 

1579 for i in range(0, len(stale_resource_ids), 500): 

1580 chunk = stale_resource_ids[i : i + 500] 

1581 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

1582 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

1583 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

1584 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

1585 

1586 # Bulk delete prompts that are no longer available from the gateway 

1587 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

1588 if stale_prompt_ids: 

1589 # Delete child records first to avoid FK constraint violations 

1590 for i in range(0, len(stale_prompt_ids), 500): 

1591 chunk = stale_prompt_ids[i : i + 500] 

1592 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

1593 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

1594 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

1595 

1596 # Expire gateway to clear cached relationships after bulk deletes 

1597 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

1598 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

1599 db.expire(gateway) 

1600 

1601 # Update gateway relationships to reflect deletions 

1602 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] 

1603 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] 

1604 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] 

1605 

1606 # Log cleanup results 

1607 tools_removed = len(stale_tool_ids) 

1608 resources_removed = len(stale_resource_ids) 

1609 prompts_removed = len(stale_prompt_ids) 

1610 

1611 if tools_removed > 0: 

1612 logger.info(f"Removed {tools_removed} tools no longer available from gateway") 

1613 if resources_removed > 0: 

1614 logger.info(f"Removed {resources_removed} resources no longer available from gateway") 

1615 if prompts_removed > 0: 

1616 logger.info(f"Removed {prompts_removed} prompts no longer available from gateway") 

1617 

1618 # Update gateway capabilities and last_seen 

1619 gateway.capabilities = capabilities 

1620 gateway.last_seen = datetime.now(timezone.utc) 

1621 

1622 # Register capabilities for notification-driven actions 

1623 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

1624 

1625 # Add new items to DB in chunks to prevent lock escalation 

1626 items_added = 0 

1627 chunk_size = 50 

1628 

1629 if tools_to_add: 

1630 for i in range(0, len(tools_to_add), chunk_size): 

1631 chunk = tools_to_add[i : i + chunk_size] 

1632 db.add_all(chunk) 

1633 db.flush() # Flush each chunk to avoid excessive memory usage 

1634 items_added += len(tools_to_add) 

1635 logger.info(f"Added {len(tools_to_add)} new tools to database") 

1636 

1637 if resources_to_add: 

1638 for i in range(0, len(resources_to_add), chunk_size): 

1639 chunk = resources_to_add[i : i + chunk_size] 

1640 db.add_all(chunk) 

1641 db.flush() 

1642 items_added += len(resources_to_add) 

1643 logger.info(f"Added {len(resources_to_add)} new resources to database") 

1644 

1645 if prompts_to_add: 

1646 for i in range(0, len(prompts_to_add), chunk_size): 

1647 chunk = prompts_to_add[i : i + chunk_size] 

1648 db.add_all(chunk) 

1649 db.flush() 

1650 items_added += len(prompts_to_add) 

1651 logger.info(f"Added {len(prompts_to_add)} new prompts to database") 

1652 

1653 if items_added > 0: 

1654 db.commit() 

1655 logger.info(f"Total {items_added} new items added to database") 

1656 else: 

1657 logger.info("No new items to add to database") 

1658 # Still commit to save any updates to existing items 

1659 db.commit() 

1660 

1661 cache = _get_registry_cache() 

1662 await cache.invalidate_tools() 

1663 await cache.invalidate_resources() 

1664 await cache.invalidate_prompts() 

1665 tool_lookup_cache = _get_tool_lookup_cache() 

1666 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

1667 # Also invalidate tags cache since tool/resource tags may have changed 

1668 # First-Party 

1669 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

1670 

1671 await admin_stats_cache.invalidate_tags() 

1672 

1673 return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts} 

1674 

1675 except GatewayConnectionError as gce: 

1676 db.rollback() 

1677 # Surface validation or depth-related failures directly to the user 

1678 logger.error(f"GatewayConnectionError during OAuth fetch for {SecurityValidator.sanitize_log_message(gateway_id)}: {gce}") 

1679 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}") 

1680 except Exception as e: 

1681 db.rollback() 

1682 logger.error(f"Failed to fetch tools after OAuth for gateway {SecurityValidator.sanitize_log_message(gateway_id)}: {e}") 

1683 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}") 

1684 

1685 async def list_gateways( 

1686 self, 

1687 db: Session, 

1688 include_inactive: bool = False, 

1689 tags: Optional[List[str]] = None, 

1690 cursor: Optional[str] = None, 

1691 limit: Optional[int] = None, 

1692 page: Optional[int] = None, 

1693 per_page: Optional[int] = None, 

1694 user_email: Optional[str] = None, 

1695 team_id: Optional[str] = None, 

1696 visibility: Optional[str] = None, 

1697 token_teams: Optional[List[str]] = None, 

1698 ) -> Union[tuple[List[GatewayRead], Optional[str]], Dict[str, Any]]: 

1699 """List all registered gateways with cursor pagination and optional team filtering. 

1700 

1701 Args: 

1702 db: Database session 

1703 include_inactive: Whether to include inactive gateways 

1704 tags (Optional[List[str]]): Filter resources by tags. If provided, only resources with at least one matching tag will be returned. 

1705 cursor: Cursor for pagination (encoded last created_at and id). 

1706 limit: Maximum number of gateways to return. None for default, 0 for unlimited. 

1707 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor. 

1708 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size. 

1709 user_email: Email of user for team-based access control. None for no access control. 

1710 team_id: Optional team ID to filter by specific team (requires user_email). 

1711 visibility: Optional visibility filter (private, team, public) (requires user_email). 

1712 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only). 

1713 

1714 Returns: 

1715 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}} 

1716 If cursor is provided or neither: tuple of (list of GatewayRead objects, next_cursor). 

1717 

1718 Examples: 

1719 >>> from mcpgateway.services.gateway_service import GatewayService 

1720 >>> from unittest.mock import MagicMock, AsyncMock, patch 

1721 >>> from mcpgateway.schemas import GatewayRead 

1722 >>> import asyncio 

1723 >>> service = GatewayService() 

1724 >>> db = MagicMock() 

1725 >>> gateway_obj = MagicMock() 

1726 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj] 

1727 >>> gateway_read_obj = MagicMock(spec=GatewayRead) 

1728 >>> service.convert_gateway_to_read = MagicMock(return_value=gateway_read_obj) 

1729 >>> # Mock the cache to bypass caching logic 

1730 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory: 

1731 ... mock_cache = MagicMock() 

1732 ... mock_cache.get = AsyncMock(return_value=None) 

1733 ... mock_cache.set = AsyncMock(return_value=None) 

1734 ... mock_cache.hash_filters = MagicMock(return_value="hash") 

1735 ... mock_cache_factory.return_value = mock_cache 

1736 ... gateways, cursor = asyncio.run(service.list_gateways(db)) 

1737 ... gateways == [gateway_read_obj] and cursor is None 

1738 True 

1739 

1740 >>> # Test empty result 

1741 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

1742 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory: 

1743 ... mock_cache = MagicMock() 

1744 ... mock_cache.get = AsyncMock(return_value=None) 

1745 ... mock_cache.set = AsyncMock(return_value=None) 

1746 ... mock_cache.hash_filters = MagicMock(return_value="hash") 

1747 ... mock_cache_factory.return_value = mock_cache 

1748 ... empty_result, cursor = asyncio.run(service.list_gateways(db)) 

1749 ... empty_result == [] and cursor is None 

1750 True 

1751 >>> 

1752 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

1753 >>> asyncio.run(service._http_client.aclose()) 

1754 """ 

1755 # Check cache for first page only - only for public-only queries (no user/team filtering) 

1756 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped 

1757 cache = _get_registry_cache() 

1758 is_public_only = token_teams is not None and len(token_teams) == 0 

1759 use_cache = cursor is None and user_email is None and page is None and is_public_only 

1760 if use_cache: 

1761 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None, visibility=visibility) 

1762 cached = await cache.get("gateways", filters_hash) 

1763 if cached is not None: 

1764 # Reconstruct GatewayRead objects from cached dicts 

1765 # SECURITY: Always apply .masked() to ensure stale cache entries don't leak credentials 

1766 cached_gateways = [GatewayRead.model_validate(g).masked() for g in cached["gateways"]] 

1767 return (cached_gateways, cached.get("next_cursor")) 

1768 

1769 # Build base query with ordering 

1770 query = select(DbGateway).options(joinedload(DbGateway.email_team)).order_by(desc(DbGateway.created_at), desc(DbGateway.id)) 

1771 

1772 # Apply active/inactive filter 

1773 if not include_inactive: 

1774 query = query.where(DbGateway.enabled) 

1775 

1776 query = await self._apply_access_control(query, db, user_email, token_teams, team_id) 

1777 

1778 if visibility: 

1779 query = query.where(DbGateway.visibility == visibility) 

1780 

1781 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats) 

1782 if tags: 

1783 query = query.where(json_contains_tag_expr(db, DbGateway.tags, tags, match_any=True)) 

1784 # Use unified pagination helper - handles both page and cursor pagination 

1785 pag_result = await unified_paginate( 

1786 db=db, 

1787 query=query, 

1788 page=page, 

1789 per_page=per_page, 

1790 cursor=cursor, 

1791 limit=limit, 

1792 base_url="/admin/gateways", # Used for page-based links 

1793 query_params={"include_inactive": include_inactive} if include_inactive else {}, 

1794 ) 

1795 

1796 next_cursor = None 

1797 # Extract gateways based on pagination type 

1798 if page is not None: 

1799 # Page-based: pag_result is a dict 

1800 gateways_db = pag_result["data"] 

1801 else: 

1802 # Cursor-based: pag_result is a tuple 

1803 gateways_db, next_cursor = pag_result 

1804 

1805 db.commit() # Release transaction to avoid idle-in-transaction 

1806 

1807 # Convert to GatewayRead (common for both pagination types) 

1808 result = [] 

1809 for s in gateways_db: 

1810 try: 

1811 result.append(self.convert_gateway_to_read(s)) 

1812 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e: 

1813 logger.exception(f"Failed to convert gateway {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}") 

1814 # Continue with remaining gateways instead of failing completely 

1815 

1816 # Return appropriate format based on pagination type 

1817 if page is not None: 

1818 # Page-based format 

1819 return { 

1820 "data": result, 

1821 "pagination": pag_result["pagination"], 

1822 "links": pag_result["links"], 

1823 } 

1824 

1825 # Cursor-based format 

1826 

1827 # Cache first page results - only for public-only queries (no user/team filtering) 

1828 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped 

1829 if cursor is None and user_email is None and is_public_only: 

1830 try: 

1831 cache_data = {"gateways": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor} 

1832 await cache.set("gateways", cache_data, filters_hash) 

1833 except AttributeError: 

1834 pass # Skip caching if result objects don't support model_dump (e.g., in doctests) 

1835 

1836 return (result, next_cursor) 

1837 

1838 async def list_gateways_for_user( 

1839 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100 

1840 ) -> List[GatewayRead]: 

1841 """ 

1842 DEPRECATED: Use list_gateways() with user_email parameter instead. 

1843 

1844 This method is maintained for backward compatibility but is no longer used. 

1845 New code should call list_gateways() with user_email, team_id, and visibility parameters. 

1846 

1847 List gateways user has access to with team filtering. 

1848 

1849 Args: 

1850 db: Database session 

1851 user_email: Email of the user requesting gateways 

1852 team_id: Optional team ID to filter by specific team 

1853 visibility: Optional visibility filter (private, team, public) 

1854 include_inactive: Whether to include inactive gateways 

1855 skip: Number of gateways to skip for pagination 

1856 limit: Maximum number of gateways to return 

1857 

1858 Returns: 

1859 List[GatewayRead]: Gateways the user has access to 

1860 """ 

1861 # Build query following existing patterns from list_gateways() 

1862 team_service = TeamManagementService(db) 

1863 user_teams = await team_service.get_user_teams(user_email) 

1864 team_ids = [team.id for team in user_teams] 

1865 

1866 # Use joinedload to eager load email_team relationship (avoids N+1 queries) 

1867 query = select(DbGateway).options(joinedload(DbGateway.email_team)) 

1868 

1869 # Apply active/inactive filter 

1870 if not include_inactive: 

1871 query = query.where(DbGateway.enabled.is_(True)) 

1872 

1873 if team_id: 

1874 if team_id not in team_ids: 

1875 return [] # No access to team 

1876 

1877 access_conditions = [] 

1878 # Filter by specific team 

1879 

1880 # Team-owned gateways (team-scoped gateways) 

1881 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"]))) 

1882 

1883 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email)) 

1884 

1885 # Also include global public gateways (no team_id) so public gateways are visible regardless of selected team 

1886 access_conditions.append(DbGateway.visibility == "public") 

1887 

1888 query = query.where(or_(*access_conditions)) 

1889 else: 

1890 # Get user's accessible teams 

1891 # Build access conditions following existing patterns 

1892 access_conditions = [] 

1893 # 1. User's personal resources (owner_email matches) 

1894 access_conditions.append(DbGateway.owner_email == user_email) 

1895 # 2. Team resources where user is member 

1896 if team_ids: 

1897 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"]))) 

1898 # 3. Public resources (if visibility allows) 

1899 access_conditions.append(DbGateway.visibility == "public") 

1900 

1901 query = query.where(or_(*access_conditions)) 

1902 

1903 # Apply visibility filter if specified 

1904 if visibility: 

1905 query = query.where(DbGateway.visibility == visibility) 

1906 

1907 # Apply pagination following existing patterns 

1908 query = query.offset(skip).limit(limit) 

1909 

1910 gateways = db.execute(query).scalars().all() 

1911 

1912 db.commit() # Release transaction to avoid idle-in-transaction 

1913 

1914 # Team names are loaded via joinedload(DbGateway.email_team) 

1915 result = [] 

1916 for g in gateways: 

1917 logger.info(f"Gateway: {SecurityValidator.sanitize_log_message(g.team_id)}, Team: {g.team}") 

1918 result.append(self.convert_gateway_to_read(g)) 

1919 return result 

1920 

1921 async def update_gateway( 

1922 self, 

1923 db: Session, 

1924 gateway_id: str, 

1925 gateway_update: GatewayUpdate, 

1926 modified_by: Optional[str] = None, 

1927 modified_from_ip: Optional[str] = None, 

1928 modified_via: Optional[str] = None, 

1929 modified_user_agent: Optional[str] = None, 

1930 include_inactive: bool = True, 

1931 user_email: Optional[str] = None, 

1932 ) -> Optional[GatewayRead]: 

1933 """Update a gateway. 

1934 

1935 Args: 

1936 db: Database session 

1937 gateway_id: Gateway ID to update 

1938 gateway_update: Updated gateway data 

1939 modified_by: Username of the person modifying the gateway 

1940 modified_from_ip: IP address where the modification request originated 

1941 modified_via: Source of modification (ui/api/import) 

1942 modified_user_agent: User agent string from the modification request 

1943 include_inactive: Whether to include inactive gateways 

1944 user_email: Email of user performing update (for ownership check) 

1945 

1946 Returns: 

1947 Updated gateway information 

1948 

1949 Raises: 

1950 GatewayNotFoundError: If gateway not found 

1951 PermissionError: If user doesn't own the gateway 

1952 GatewayError: For other update errors 

1953 GatewayNameConflictError: If gateway name conflict occurs 

1954 IntegrityError: If there is a database integrity error 

1955 ValidationError: If validation fails 

1956 """ 

1957 try: # pylint: disable=too-many-nested-blocks 

1958 # Acquire row lock and eager-load relationships while locked so 

1959 # concurrent updates are serialized on Postgres. 

1960 gateway = get_for_update( 

1961 db, 

1962 DbGateway, 

1963 gateway_id, 

1964 options=[ 

1965 selectinload(DbGateway.tools), 

1966 selectinload(DbGateway.resources), 

1967 selectinload(DbGateway.prompts), 

1968 selectinload(DbGateway.email_team), # Use selectinload to avoid locking email_teams 

1969 ], 

1970 ) 

1971 if not gateway: 

1972 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

1973 

1974 # Check ownership if user_email provided 

1975 if user_email: 

1976 # First-Party 

1977 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

1978 

1979 permission_service = PermissionService(db) 

1980 if not await permission_service.check_resource_ownership(user_email, gateway): 

1981 raise PermissionError("Only the owner can update this gateway") 

1982 

1983 if gateway.enabled or include_inactive: 

1984 # Check for name conflicts if name is being changed 

1985 if gateway_update.name is not None and gateway_update.name != gateway.name: 

1986 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none() 

1987 

1988 # if existing_gateway: 

1989 # raise GatewayNameConflictError( 

1990 # gateway_update.name, 

1991 # enabled=existing_gateway.enabled, 

1992 # gateway_id=existing_gateway.id, 

1993 # ) 

1994 # Check for existing gateway with the same slug and visibility 

1995 new_slug = slugify(gateway_update.name) 

1996 if gateway_update.visibility is not None: 

1997 vis = gateway_update.visibility 

1998 else: 

1999 vis = gateway.visibility 

2000 if vis == "public": 

2001 # Check for existing public gateway with the same slug (row-locked) 

2002 existing_gateway = get_for_update( 

2003 db, 

2004 DbGateway, 

2005 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "public", DbGateway.id != gateway_id), 

2006 ) 

2007 if existing_gateway: 

2008 raise GatewayNameConflictError( 

2009 new_slug, 

2010 enabled=existing_gateway.enabled, 

2011 gateway_id=existing_gateway.id, 

2012 visibility=existing_gateway.visibility, 

2013 ) 

2014 elif vis == "team" and gateway.team_id: 

2015 # Check for existing team gateway with the same slug (row-locked) 

2016 existing_gateway = get_for_update( 

2017 db, 

2018 DbGateway, 

2019 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "team", DbGateway.team_id == gateway.team_id, DbGateway.id != gateway_id), 

2020 ) 

2021 if existing_gateway: 

2022 raise GatewayNameConflictError( 

2023 new_slug, 

2024 enabled=existing_gateway.enabled, 

2025 gateway_id=existing_gateway.id, 

2026 visibility=existing_gateway.visibility, 

2027 ) 

2028 # Check for existing gateway with the same URL and visibility 

2029 normalized_url = "" 

2030 if gateway_update.url is not None: 

2031 normalized_url = self.normalize_url(str(gateway_update.url)) 

2032 else: 

2033 normalized_url = None 

2034 

2035 # Prepare decoded auth_value for uniqueness check 

2036 decoded_auth_value = None 

2037 if gateway_update.auth_value: 

2038 if isinstance(gateway_update.auth_value, str): 

2039 try: 

2040 decoded_auth_value = decode_auth(gateway_update.auth_value) 

2041 except Exception as e: 

2042 logger.warning(f"Failed to decode provided auth_value: {e}") 

2043 elif isinstance(gateway_update.auth_value, dict): 

2044 decoded_auth_value = gateway_update.auth_value 

2045 

2046 # Determine final values for uniqueness check 

2047 final_auth_value = decoded_auth_value if gateway_update.auth_value is not None else (decode_auth(gateway.auth_value) if isinstance(gateway.auth_value, str) else gateway.auth_value) 

2048 final_oauth_config = gateway_update.oauth_config if gateway_update.oauth_config is not None else gateway.oauth_config 

2049 final_visibility = gateway_update.visibility if gateway_update.visibility is not None else gateway.visibility 

2050 

2051 # Check for duplicates with updated credentials 

2052 if not gateway_update.one_time_auth: 

2053 duplicate_gateway = self._check_gateway_uniqueness( 

2054 db=db, 

2055 url=normalized_url, 

2056 auth_value=final_auth_value, 

2057 oauth_config=final_oauth_config, 

2058 team_id=gateway.team_id, 

2059 visibility=final_visibility, 

2060 gateway_id=gateway_id, # Exclude current gateway from check 

2061 owner_email=user_email, 

2062 ) 

2063 

2064 if duplicate_gateway: 

2065 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway) 

2066 

2067 # FIX for Issue #1025: Determine if URL actually changed before we update it 

2068 # We need this early because we update gateway.url below, and need to know 

2069 # if it actually changed to decide whether to re-fetch tools 

2070 # tools/resoures/prompts are need to be re-fetched not only if URL changed , in case any update like authentication and visibility changed 

2071 # url_changed = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != gateway.url 

2072 

2073 # Save original values BEFORE updating for change detection checks later 

2074 original_url = gateway.url 

2075 original_auth_type = gateway.auth_type 

2076 

2077 # Update fields if provided 

2078 if gateway_update.name is not None: 

2079 gateway.name = gateway_update.name 

2080 gateway.slug = slugify(gateway_update.name) 

2081 if gateway_update.url is not None: 

2082 # Normalize the updated URL 

2083 gateway.url = self.normalize_url(str(gateway_update.url)) 

2084 if gateway_update.description is not None: 

2085 gateway.description = gateway_update.description 

2086 if gateway_update.transport is not None: 

2087 gateway.transport = gateway_update.transport 

2088 if gateway_update.tags is not None: 

2089 gateway.tags = gateway_update.tags 

2090 if gateway_update.visibility is not None: 

2091 old_visibility = gateway.visibility 

2092 # Validate visibility transitions 

2093 if gateway_update.visibility == "team": 

2094 target_team_id = gateway_update.team_id if gateway_update.team_id is not None else gateway.team_id 

2095 _validate_gateway_team_assignment(db, user_email, target_team_id) 

2096 gateway.visibility = gateway_update.visibility 

2097 # Propagate visibility to all linked items immediately so it 

2098 # takes effect even when the upstream server is unreachable 

2099 # and _initialize_gateway fails. 

2100 # Only update items that inherited the old gateway visibility; 

2101 # preserve per-item overrides (e.g. a resource set to "team" 

2102 # while the gateway was "public"). 

2103 for tool in gateway.tools: 

2104 if tool.visibility == old_visibility: 

2105 tool.visibility = gateway.visibility 

2106 for resource in gateway.resources: 

2107 if resource.visibility == old_visibility: 

2108 resource.visibility = gateway.visibility 

2109 for prompt in gateway.prompts: 

2110 if prompt.visibility == old_visibility: 

2111 prompt.visibility = gateway.visibility 

2112 if gateway_update.passthrough_headers is not None: 

2113 if isinstance(gateway_update.passthrough_headers, list): 

2114 gateway.passthrough_headers = gateway_update.passthrough_headers 

2115 else: 

2116 if isinstance(gateway_update.passthrough_headers, str): 

2117 parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()] 

2118 gateway.passthrough_headers = parsed 

2119 else: 

2120 raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string") 

2121 

2122 logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}") 

2123 

2124 # Update team assignment if provided, validating ownership 

2125 if gateway_update.team_id is not None: 

2126 if gateway_update.team_id != gateway.team_id: 

2127 _validate_gateway_team_assignment(db, user_email, gateway_update.team_id) 

2128 gateway.team_id = gateway_update.team_id 

2129 

2130 # Update CA certificate fields if provided 

2131 if getattr(gateway_update, "ca_certificate", None) is not None: 

2132 gateway.ca_certificate = gateway_update.ca_certificate 

2133 if getattr(gateway_update, "ca_certificate_sig", None) is not None: 

2134 gateway.ca_certificate_sig = gateway_update.ca_certificate_sig 

2135 if getattr(gateway_update, "signing_algorithm", None) is not None: 

2136 gateway.signing_algorithm = gateway_update.signing_algorithm 

2137 

2138 # Update mTLS client certificate/key if provided 

2139 if getattr(gateway_update, "client_cert", None) is not None: 

2140 gateway.client_cert = gateway_update.client_cert 

2141 if getattr(gateway_update, "client_key", None) is not None: 

2142 if gateway_update.client_key == settings.masked_auth_value: 

2143 pass # Preserve existing encrypted value 

2144 else: 

2145 gateway.client_key = await self._encrypt_client_key(gateway_update.client_key) 

2146 

2147 # Only update auth_type if explicitly provided in the update 

2148 if gateway_update.auth_type is not None: 

2149 gateway.auth_type = gateway_update.auth_type 

2150 

2151 # If auth_type is empty, update the auth_value too 

2152 if gateway_update.auth_type == "": 

2153 gateway.auth_value = cast(Any, "") 

2154 

2155 # Clear auth_query_params when switching away from query_param auth 

2156 if original_auth_type == "query_param" and gateway_update.auth_type != "query_param": 

2157 gateway.auth_query_params = None 

2158 logger.debug(f"Cleared auth_query_params for gateway {SecurityValidator.sanitize_log_message(gateway.id)} (switched from query_param to {gateway_update.auth_type})") 

2159 

2160 # if auth_type is not None and only then check auth_value 

2161 # Handle OAuth configuration updates 

2162 if gateway_update.oauth_config is not None: 

2163 gateway.oauth_config = await protect_oauth_config_for_storage(gateway_update.oauth_config, existing_oauth_config=gateway.oauth_config) 

2164 

2165 # Handle auth_value updates (both existing and new auth values) 

2166 token = gateway_update.auth_token 

2167 password = gateway_update.auth_password 

2168 header_value = gateway_update.auth_header_value 

2169 

2170 # Support multiple custom headers on update 

2171 if hasattr(gateway_update, "auth_headers") and gateway_update.auth_headers: 

2172 existing_auth_raw = getattr(gateway, "auth_value", {}) or {} 

2173 if isinstance(existing_auth_raw, str): 

2174 try: 

2175 existing_auth = decode_auth(existing_auth_raw) 

2176 except Exception: 

2177 existing_auth = {} 

2178 elif isinstance(existing_auth_raw, dict): 

2179 existing_auth = existing_auth_raw 

2180 else: 

2181 existing_auth = {} 

2182 

2183 header_dict: Dict[str, str] = {} 

2184 for header in gateway_update.auth_headers: 

2185 key = header.get("key") 

2186 if not key: 

2187 continue 

2188 value = header.get("value", "") 

2189 if value == settings.masked_auth_value and key in existing_auth: 

2190 header_dict[key] = existing_auth[key] 

2191 else: 

2192 header_dict[key] = value 

2193 gateway.auth_value = header_dict # Store as dict for DB JSON field 

2194 elif settings.masked_auth_value not in (token, password, header_value): 

2195 # Check if values differ from existing ones or if setting for first time 

2196 decoded_auth = decode_auth(gateway_update.auth_value) if gateway_update.auth_value else {} 

2197 current_auth = getattr(gateway, "auth_value", {}) or {} 

2198 if current_auth != decoded_auth: 

2199 gateway.auth_value = decoded_auth 

2200 

2201 # Handle query_param auth updates with service-layer enforcement 

2202 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

2203 init_url = gateway.url 

2204 

2205 # Check if updating to query_param auth or updating existing query_param credentials 

2206 # Use original_auth_type since gateway.auth_type may have been updated already 

2207 is_switching_to_queryparam = gateway_update.auth_type == "query_param" and original_auth_type != "query_param" 

2208 is_updating_queryparam_creds = original_auth_type == "query_param" and (gateway_update.auth_query_param_key is not None or gateway_update.auth_query_param_value is not None) 

2209 is_url_changing = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != original_url 

2210 

2211 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"): 

2212 # Service-layer enforcement: Check feature flag 

2213 if not settings.insecure_allow_queryparam_auth: 

2214 # Grandfather clause: Allow updates to existing query_param gateways 

2215 # unless they're trying to change credentials 

2216 if is_switching_to_queryparam or is_updating_queryparam_creds: 

2217 raise ValueError("Query parameter authentication is disabled. " + "Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.") 

2218 

2219 # Service-layer enforcement: Check host allowlist 

2220 if settings.insecure_queryparam_auth_allowed_hosts: 

2221 check_url = str(gateway_update.url) if gateway_update.url else gateway.url 

2222 parsed = urlparse(check_url) 

2223 hostname = (parsed.hostname or "").lower() 

2224 if hostname not in settings.insecure_queryparam_auth_allowed_hosts: 

2225 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts) 

2226 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. Allowed: {allowed}") 

2227 

2228 # Process query_param auth credentials 

2229 param_key = getattr(gateway_update, "auth_query_param_key", None) or (next(iter(gateway.auth_query_params.keys()), None) if gateway.auth_query_params else None) 

2230 param_value = getattr(gateway_update, "auth_query_param_value", None) 

2231 

2232 # Get raw value from SecretStr if applicable 

2233 raw_value: Optional[str] = None 

2234 if param_value: 

2235 if hasattr(param_value, "get_secret_value"): 

2236 raw_value = param_value.get_secret_value() 

2237 else: 

2238 raw_value = str(param_value) 

2239 

2240 # Check if the value is the masked placeholder - if so, keep existing value 

2241 is_masked_placeholder = raw_value == settings.masked_auth_value 

2242 

2243 if param_key: 

2244 if raw_value and not is_masked_placeholder: 

2245 # New value provided - encrypt for storage 

2246 encrypted_value = encode_auth({param_key: raw_value}) 

2247 gateway.auth_query_params = {param_key: encrypted_value} 

2248 auth_query_params_decrypted = {param_key: raw_value} 

2249 elif gateway.auth_query_params: 

2250 # Use existing encrypted value 

2251 existing_encrypted = gateway.auth_query_params.get(param_key, "") 

2252 if existing_encrypted: 

2253 decrypted = decode_auth(existing_encrypted) 

2254 auth_query_params_decrypted = {param_key: decrypted.get(param_key, "")} 

2255 

2256 # Append query params to URL for initialization 

2257 if auth_query_params_decrypted: 

2258 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2259 

2260 # Update auth_type if switching 

2261 if is_switching_to_queryparam: 

2262 gateway.auth_type = "query_param" 

2263 gateway.auth_value = None # Query param auth doesn't use auth_value 

2264 

2265 elif gateway.auth_type == "query_param" and gateway.auth_query_params: 

2266 # Existing query_param gateway without credential changes - decrypt for init 

2267 first_key = next(iter(gateway.auth_query_params.keys()), None) 

2268 if first_key: 

2269 encrypted_value = gateway.auth_query_params.get(first_key, "") 

2270 if encrypted_value: 

2271 decrypted = decode_auth(encrypted_value) 

2272 auth_query_params_decrypted = {first_key: decrypted.get(first_key, "")} 

2273 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2274 

2275 # Try to reinitialize connection if URL actually changed 

2276 # if url_changed: 

2277 # Initialize empty lists in case initialization fails 

2278 tools_to_add = [] 

2279 resources_to_add = [] 

2280 prompts_to_add = [] 

2281 reinit_succeeded = False 

2282 

2283 try: 

2284 ca_certificate = getattr(gateway, "ca_certificate", None) 

2285 update_client_cert = getattr(gateway, "client_cert", None) 

2286 update_client_key = getattr(gateway, "client_key", None) 

2287 # Decrypt client_key for initialization (stored encrypted) 

2288 if update_client_key: 

2289 try: 

2290 _enc = get_encryption_service(settings.auth_encryption_secret) 

2291 update_client_key = _enc.decrypt_secret_or_plaintext(update_client_key) 

2292 except Exception: 

2293 logger.debug("client_key decryption skipped during gateway re-init") 

2294 capabilities, tools, resources, prompts = await self._initialize_gateway( 

2295 init_url, 

2296 gateway.auth_value, 

2297 gateway.transport, 

2298 gateway.auth_type, 

2299 gateway.oauth_config, 

2300 ca_certificate, 

2301 auth_query_params=auth_query_params_decrypted, 

2302 client_cert=update_client_cert, 

2303 client_key=update_client_key, 

2304 ) 

2305 new_tool_names = [tool.name for tool in tools] 

2306 new_resource_uris = [resource.uri for resource in resources] 

2307 new_prompt_names = [prompt.name for prompt in prompts] 

2308 

2309 if gateway_update.one_time_auth: 

2310 # For one-time auth, clear auth_type and auth_value after initialization 

2311 gateway.auth_type = "one_time_auth" 

2312 gateway.auth_value = None 

2313 gateway.oauth_config = None 

2314 

2315 # Update tools using helper method — only propagate visibility 

2316 # when the user explicitly changed it in this request 

2317 _vis_changed = gateway_update.visibility is not None 

2318 tools_to_add = self._update_or_create_tools(db, tools, gateway, "update", update_visibility=_vis_changed) 

2319 

2320 # Update resources using helper method 

2321 resources_to_add = self._update_or_create_resources(db, resources, gateway, "update", update_visibility=_vis_changed) 

2322 

2323 # Update prompts using helper method 

2324 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "update", update_visibility=_vis_changed) 

2325 

2326 # Log newly added items 

2327 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add) 

2328 if items_added > 0: 

2329 if tools_to_add: 

2330 logger.info(f"Added {len(tools_to_add)} new tools during gateway update") 

2331 if resources_to_add: 

2332 logger.info(f"Added {len(resources_to_add)} new resources during gateway update") 

2333 if prompts_to_add: 

2334 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway update") 

2335 logger.info(f"Total {items_added} new items added during gateway update") 

2336 

2337 # Count items before cleanup for logging 

2338 

2339 # Bulk delete tools that are no longer available from the gateway 

2340 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2341 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

2342 if stale_tool_ids: 

2343 # Delete child records first to avoid FK constraint violations 

2344 for i in range(0, len(stale_tool_ids), 500): 

2345 chunk = stale_tool_ids[i : i + 500] 

2346 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2347 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2348 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2349 

2350 # Bulk delete resources that are no longer available from the gateway 

2351 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

2352 if stale_resource_ids: 

2353 # Delete child records first to avoid FK constraint violations 

2354 for i in range(0, len(stale_resource_ids), 500): 

2355 chunk = stale_resource_ids[i : i + 500] 

2356 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2357 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2358 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2359 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2360 

2361 # Bulk delete prompts that are no longer available from the gateway 

2362 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

2363 if stale_prompt_ids: 

2364 # Delete child records first to avoid FK constraint violations 

2365 for i in range(0, len(stale_prompt_ids), 500): 

2366 chunk = stale_prompt_ids[i : i + 500] 

2367 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2368 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2369 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2370 

2371 # Expire gateway to clear cached relationships after bulk deletes 

2372 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

2373 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

2374 db.expire(gateway) 

2375 

2376 gateway.capabilities = capabilities 

2377 

2378 # Register capabilities for notification-driven actions 

2379 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

2380 

2381 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

2382 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows 

2383 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows 

2384 

2385 # Log cleanup results 

2386 tools_removed = len(stale_tool_ids) 

2387 resources_removed = len(stale_resource_ids) 

2388 prompts_removed = len(stale_prompt_ids) 

2389 

2390 if tools_removed > 0: 

2391 logger.info(f"Removed {tools_removed} tools no longer available during gateway update") 

2392 if resources_removed > 0: 

2393 logger.info(f"Removed {resources_removed} resources no longer available during gateway update") 

2394 if prompts_removed > 0: 

2395 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway update") 

2396 

2397 gateway.last_seen = datetime.now(timezone.utc) 

2398 

2399 # Add new items to database session in chunks to prevent lock escalation 

2400 chunk_size = 50 

2401 

2402 if tools_to_add: 

2403 for i in range(0, len(tools_to_add), chunk_size): 

2404 chunk = tools_to_add[i : i + chunk_size] 

2405 db.add_all(chunk) 

2406 db.flush() 

2407 if resources_to_add: 

2408 for i in range(0, len(resources_to_add), chunk_size): 

2409 chunk = resources_to_add[i : i + chunk_size] 

2410 db.add_all(chunk) 

2411 db.flush() 

2412 if prompts_to_add: 

2413 for i in range(0, len(prompts_to_add), chunk_size): 

2414 chunk = prompts_to_add[i : i + chunk_size] 

2415 db.add_all(chunk) 

2416 db.flush() 

2417 

2418 # Update tracking with new URL 

2419 self._active_gateways.discard(gateway.url) 

2420 self._active_gateways.add(gateway.url) 

2421 reinit_succeeded = True 

2422 except Exception as e: 

2423 logger.warning(f"Failed to initialize updated gateway: {e}") 

2424 reinit_succeeded = False 

2425 

2426 # Update tags if provided 

2427 if gateway_update.tags is not None: 

2428 gateway.tags = gateway_update.tags 

2429 

2430 # Update gateway_mode if provided 

2431 if hasattr(gateway_update, "gateway_mode") and gateway_update.gateway_mode is not None: 

2432 if gateway_update.gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled: 

2433 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.") 

2434 gateway.gateway_mode = gateway_update.gateway_mode 

2435 

2436 # Update metadata fields 

2437 gateway.updated_at = datetime.now(timezone.utc) 

2438 if modified_by: 

2439 gateway.modified_by = modified_by 

2440 if modified_from_ip: 

2441 gateway.modified_from_ip = modified_from_ip 

2442 if modified_via: 

2443 gateway.modified_via = modified_via 

2444 if modified_user_agent: 

2445 gateway.modified_user_agent = modified_user_agent 

2446 if hasattr(gateway, "version") and gateway.version is not None: 

2447 gateway.version = gateway.version + 1 

2448 else: 

2449 gateway.version = 1 

2450 

2451 db.commit() 

2452 db.refresh(gateway) 

2453 

2454 # Invalidate cache after successful update 

2455 cache = _get_registry_cache() 

2456 await cache.invalidate_gateways() 

2457 tool_lookup_cache = _get_tool_lookup_cache() 

2458 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

2459 # Also invalidate tags cache since gateway tags may have changed 

2460 # First-Party 

2461 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

2462 

2463 await admin_stats_cache.invalidate_tags() 

2464 

2465 # Advance hot/cold poll schedule only after successful tool re-init 

2466 if reinit_succeeded and self._classification_service and gateway.url: 

2467 try: 

2468 await self._classification_service.mark_poll_completed(gateway.url, "tool_discovery", gateway_id=str(gateway.id)) 

2469 except Exception as poll_ts_err: 

2470 logger.debug(f"Best-effort tool_discovery poll timestamp update failed: {poll_ts_err}") 

2471 

2472 # Invalidate loopback passthrough cache when gateway headers change (#3640) 

2473 if gateway_update.passthrough_headers is not None: 

2474 # First-Party 

2475 from mcpgateway.utils.passthrough_headers import invalidate_passthrough_header_caches # pylint: disable=import-outside-toplevel 

2476 

2477 invalidate_passthrough_header_caches() 

2478 

2479 # Notify subscribers 

2480 await self._notify_gateway_updated(gateway) 

2481 

2482 logger.info(f"Updated gateway: {SecurityValidator.sanitize_log_message(gateway.name)}") 

2483 

2484 # Structured logging: Audit trail for gateway update 

2485 audit_trail.log_action( 

2486 user_id=user_email or modified_by or "system", 

2487 action="update_gateway", 

2488 resource_type="gateway", 

2489 resource_id=str(gateway.id), 

2490 resource_name=gateway.name, 

2491 user_email=user_email, 

2492 team_id=gateway.team_id, 

2493 client_ip=modified_from_ip, 

2494 user_agent=modified_user_agent, 

2495 new_values={ 

2496 "name": gateway.name, 

2497 "url": gateway.url, 

2498 "version": gateway.version, 

2499 }, 

2500 context={ 

2501 "modified_via": modified_via, 

2502 }, 

2503 db=db, 

2504 ) 

2505 

2506 # Structured logging: Log successful gateway update 

2507 structured_logger.log( 

2508 level="INFO", 

2509 message="Gateway updated successfully", 

2510 event_type="gateway_updated", 

2511 component="gateway_service", 

2512 user_id=modified_by, 

2513 user_email=user_email, 

2514 team_id=gateway.team_id, 

2515 resource_type="gateway", 

2516 resource_id=str(gateway.id), 

2517 custom_fields={ 

2518 "gateway_name": gateway.name, 

2519 "version": gateway.version, 

2520 }, 

2521 ) 

2522 

2523 return self.convert_gateway_to_read(gateway) 

2524 # Gateway is inactive and include_inactive is False → skip update, return None 

2525 return None 

2526 except GatewayNameConflictError as ge: 

2527 logger.error(f"GatewayNameConflictError in group: {ge}") 

2528 db.rollback() 

2529 

2530 structured_logger.log( 

2531 level="WARNING", 

2532 message="Gateway update failed due to name conflict", 

2533 event_type="gateway_name_conflict", 

2534 component="gateway_service", 

2535 user_email=user_email, 

2536 resource_type="gateway", 

2537 resource_id=gateway_id, 

2538 error=ge, 

2539 ) 

2540 raise ge 

2541 except GatewayNotFoundError as gnfe: 

2542 logger.error(f"GatewayNotFoundError: {gnfe}") 

2543 db.rollback() 

2544 

2545 structured_logger.log( 

2546 level="ERROR", 

2547 message="Gateway update failed - gateway not found", 

2548 event_type="gateway_not_found", 

2549 component="gateway_service", 

2550 user_email=user_email, 

2551 resource_type="gateway", 

2552 resource_id=gateway_id, 

2553 error=gnfe, 

2554 ) 

2555 raise gnfe 

2556 except IntegrityError as ie: 

2557 logger.error(f"IntegrityErrors in group: {ie}") 

2558 db.rollback() 

2559 

2560 structured_logger.log( 

2561 level="ERROR", 

2562 message="Gateway update failed due to database integrity error", 

2563 event_type="gateway_update_failed", 

2564 component="gateway_service", 

2565 user_email=user_email, 

2566 resource_type="gateway", 

2567 resource_id=gateway_id, 

2568 error=ie, 

2569 ) 

2570 raise ie 

2571 except PermissionError as pe: 

2572 db.rollback() 

2573 

2574 structured_logger.log( 

2575 level="WARNING", 

2576 message="Gateway update failed due to permission error", 

2577 event_type="gateway_update_permission_denied", 

2578 component="gateway_service", 

2579 user_email=user_email, 

2580 resource_type="gateway", 

2581 resource_id=gateway_id, 

2582 error=pe, 

2583 ) 

2584 raise 

2585 except Exception as e: 

2586 db.rollback() 

2587 

2588 structured_logger.log( 

2589 level="ERROR", 

2590 message="Gateway update failed", 

2591 event_type="gateway_update_failed", 

2592 component="gateway_service", 

2593 user_email=user_email, 

2594 resource_type="gateway", 

2595 resource_id=gateway_id, 

2596 error=e, 

2597 ) 

2598 raise GatewayError(f"Failed to update gateway: {str(e)}") 

2599 

2600 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead: 

2601 """Get a gateway by its ID. 

2602 

2603 Args: 

2604 db: Database session 

2605 gateway_id: Gateway ID 

2606 include_inactive: Whether to include inactive gateways 

2607 

2608 Returns: 

2609 GatewayRead object 

2610 

2611 Raises: 

2612 GatewayNotFoundError: If the gateway is not found 

2613 

2614 Examples: 

2615 >>> from unittest.mock import MagicMock 

2616 >>> from mcpgateway.schemas import GatewayRead 

2617 >>> service = GatewayService() 

2618 >>> db = MagicMock() 

2619 >>> gateway_mock = MagicMock() 

2620 >>> gateway_mock.enabled = True 

2621 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock 

2622 >>> mocked_gateway_read = MagicMock() 

2623 >>> mocked_gateway_read.masked.return_value = 'gateway_read' 

2624 >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read) 

2625 >>> import asyncio 

2626 >>> result = asyncio.run(service.get_gateway(db, 'gateway_id')) 

2627 >>> result == 'gateway_read' 

2628 True 

2629 

2630 >>> # Test with inactive gateway but include_inactive=True 

2631 >>> gateway_mock.enabled = False 

2632 >>> result_inactive = asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=True)) 

2633 >>> result_inactive == 'gateway_read' 

2634 True 

2635 

2636 >>> # Test gateway not found 

2637 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

2638 >>> try: 

2639 ... asyncio.run(service.get_gateway(db, 'missing_id')) 

2640 ... except GatewayNotFoundError as e: 

2641 ... 'Gateway not found: missing_id' in str(e) 

2642 True 

2643 

2644 >>> # Test inactive gateway with include_inactive=False 

2645 >>> gateway_mock.enabled = False 

2646 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock 

2647 >>> try: 

2648 ... asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=False)) 

2649 ... except GatewayNotFoundError as e: 

2650 ... 'Gateway not found: gateway_id' in str(e) 

2651 True 

2652 >>> 

2653 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

2654 >>> asyncio.run(service._http_client.aclose()) 

2655 """ 

2656 # Use eager loading to avoid N+1 queries for relationships and team name 

2657 gateway = db.execute( 

2658 select(DbGateway) 

2659 .options( 

2660 selectinload(DbGateway.tools), 

2661 selectinload(DbGateway.resources), 

2662 selectinload(DbGateway.prompts), 

2663 joinedload(DbGateway.email_team), 

2664 ) 

2665 .where(DbGateway.id == gateway_id) 

2666 ).scalar_one_or_none() 

2667 

2668 if not gateway: 

2669 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2670 

2671 if gateway.enabled or include_inactive: 

2672 # Structured logging: Log gateway view 

2673 structured_logger.log( 

2674 level="INFO", 

2675 message="Gateway retrieved successfully", 

2676 event_type="gateway_viewed", 

2677 component="gateway_service", 

2678 team_id=getattr(gateway, "team_id", None), 

2679 resource_type="gateway", 

2680 resource_id=str(gateway.id), 

2681 custom_fields={ 

2682 "gateway_name": gateway.name, 

2683 "gateway_url": gateway.url, 

2684 "include_inactive": include_inactive, 

2685 }, 

2686 ) 

2687 

2688 return self.convert_gateway_to_read(gateway) 

2689 

2690 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2691 

2692 async def set_gateway_state(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False, user_email: Optional[str] = None) -> GatewayRead: 

2693 """ 

2694 Set the activation status of a gateway. 

2695 

2696 Args: 

2697 db: Database session 

2698 gateway_id: Gateway ID 

2699 activate: True to activate, False to deactivate 

2700 reachable: Whether the gateway is reachable 

2701 only_update_reachable: Only update reachable status 

2702 user_email: Optional[str] The email of the user to check if the user has permission to modify. 

2703 

2704 Returns: 

2705 The updated GatewayRead object 

2706 

2707 Raises: 

2708 GatewayNotFoundError: If the gateway is not found 

2709 GatewayError: For other errors 

2710 PermissionError: If user doesn't own the agent. 

2711 """ 

2712 try: 

2713 # Eager-load collections for the gateway. Note: we don't use FOR UPDATE 

2714 # here because _initialize_gateway does network I/O, and holding a row 

2715 # lock during network calls would block other operations and risk timeouts. 

2716 gateway = db.execute( 

2717 select(DbGateway) 

2718 .options( 

2719 selectinload(DbGateway.tools), 

2720 selectinload(DbGateway.resources), 

2721 selectinload(DbGateway.prompts), 

2722 joinedload(DbGateway.email_team), 

2723 ) 

2724 .where(DbGateway.id == gateway_id) 

2725 ).scalar_one_or_none() 

2726 if not gateway: 

2727 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2728 

2729 if user_email: 

2730 # First-Party 

2731 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2732 

2733 permission_service = PermissionService(db) 

2734 if not await permission_service.check_resource_ownership(user_email, gateway): 

2735 raise PermissionError("Only the owner can activate the gateway" if activate else "Only the owner can deactivate the gateway") 

2736 

2737 # Update status if it's different 

2738 if (gateway.enabled != activate) or (gateway.reachable != reachable): 

2739 gateway.enabled = activate 

2740 gateway.reachable = reachable 

2741 gateway.updated_at = datetime.now(timezone.utc) 

2742 # Update tracking 

2743 if activate and reachable: 

2744 self._active_gateways.add(gateway.url) 

2745 

2746 # Initialize empty lists in case initialization fails 

2747 tools_to_add = [] 

2748 resources_to_add = [] 

2749 prompts_to_add = [] 

2750 

2751 # Try to initialize if activating 

2752 try: 

2753 # Handle query_param auth - decrypt and apply to URL 

2754 init_url = gateway.url 

2755 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

2756 if gateway.auth_type == "query_param" and gateway.auth_query_params: 

2757 auth_query_params_decrypted = {} 

2758 for param_key, encrypted_value in gateway.auth_query_params.items(): 

2759 if encrypted_value: 

2760 try: 

2761 decrypted = decode_auth(encrypted_value) 

2762 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

2763 except Exception: 

2764 logger.debug(f"Failed to decrypt query param '{param_key}' for gateway activation") 

2765 if auth_query_params_decrypted: 

2766 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2767 

2768 act_client_cert = getattr(gateway, "client_cert", None) 

2769 act_client_key = getattr(gateway, "client_key", None) 

2770 if act_client_key: 

2771 try: 

2772 _enc = get_encryption_service(settings.auth_encryption_secret) 

2773 act_client_key = _enc.decrypt_secret_or_plaintext(act_client_key) 

2774 except Exception: 

2775 logger.debug("client_key decryption skipped during gateway activation") 

2776 capabilities, tools, resources, prompts = await self._initialize_gateway( 

2777 init_url, 

2778 gateway.auth_value, 

2779 gateway.transport, 

2780 gateway.auth_type, 

2781 gateway.oauth_config, 

2782 auth_query_params=auth_query_params_decrypted, 

2783 oauth_auto_fetch_tool_flag=True, 

2784 client_cert=act_client_cert, 

2785 client_key=act_client_key, 

2786 ) 

2787 new_tool_names = [tool.name for tool in tools] 

2788 new_resource_uris = [resource.uri for resource in resources] 

2789 new_prompt_names = [prompt.name for prompt in prompts] 

2790 

2791 # Update tools, resources, and prompts using helper methods 

2792 tools_to_add = self._update_or_create_tools(db, tools, gateway, "rediscovery") 

2793 resources_to_add = self._update_or_create_resources(db, resources, gateway, "rediscovery") 

2794 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "rediscovery") 

2795 

2796 # Log newly added items 

2797 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add) 

2798 if items_added > 0: 

2799 if tools_to_add: 

2800 logger.info(f"Added {len(tools_to_add)} new tools during gateway reactivation") 

2801 if resources_to_add: 

2802 logger.info(f"Added {len(resources_to_add)} new resources during gateway reactivation") 

2803 if prompts_to_add: 

2804 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway reactivation") 

2805 logger.info(f"Total {items_added} new items added during gateway reactivation") 

2806 

2807 # Count items before cleanup for logging 

2808 

2809 # For authorization_code OAuth gateways, empty responses may indicate 

2810 # a missing auth token rather than genuine removal of all items. 

2811 # Skip stale cleanup to prevent destructive deletion of tools, 

2812 # resources, prompts, and their virtual server associations. 

2813 # Mirrors the guard in _refresh_gateway_tools_resources_prompts. 

2814 is_auth_code_gateway = gateway.oauth_config and isinstance(gateway.oauth_config, dict) and gateway.oauth_config.get("grant_type") == "authorization_code" 

2815 skip_stale_cleanup = not tools and not resources and not prompts and is_auth_code_gateway 

2816 if skip_stale_cleanup: 

2817 logger.debug(f"Empty response from auth_code gateway {gateway.name} during reactivation, preserving existing items") 

2818 

2819 # Bulk delete tools that are no longer available from the gateway 

2820 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2821 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] if not skip_stale_cleanup else [] 

2822 if stale_tool_ids: 

2823 # Delete child records first to avoid FK constraint violations 

2824 for i in range(0, len(stale_tool_ids), 500): 

2825 chunk = stale_tool_ids[i : i + 500] 

2826 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2827 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2828 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2829 

2830 # Bulk delete resources that are no longer available from the gateway 

2831 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] if not skip_stale_cleanup else [] 

2832 if stale_resource_ids: 

2833 # Delete child records first to avoid FK constraint violations 

2834 for i in range(0, len(stale_resource_ids), 500): 

2835 chunk = stale_resource_ids[i : i + 500] 

2836 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2837 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2838 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2839 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2840 

2841 # Bulk delete prompts that are no longer available from the gateway 

2842 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] if not skip_stale_cleanup else [] 

2843 if stale_prompt_ids: 

2844 # Delete child records first to avoid FK constraint violations 

2845 for i in range(0, len(stale_prompt_ids), 500): 

2846 chunk = stale_prompt_ids[i : i + 500] 

2847 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2848 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2849 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2850 

2851 # Expire gateway to clear cached relationships after bulk deletes 

2852 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

2853 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

2854 db.expire(gateway) 

2855 

2856 gateway.capabilities = capabilities 

2857 

2858 # Register capabilities for notification-driven actions 

2859 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

2860 

2861 if not skip_stale_cleanup: 

2862 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

2863 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows 

2864 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows 

2865 

2866 # Log cleanup results 

2867 tools_removed = len(stale_tool_ids) 

2868 resources_removed = len(stale_resource_ids) 

2869 prompts_removed = len(stale_prompt_ids) 

2870 

2871 if tools_removed > 0: 

2872 logger.info(f"Removed {tools_removed} tools no longer available during gateway reactivation") 

2873 if resources_removed > 0: 

2874 logger.info(f"Removed {resources_removed} resources no longer available during gateway reactivation") 

2875 if prompts_removed > 0: 

2876 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway reactivation") 

2877 

2878 gateway.last_seen = datetime.now(timezone.utc) 

2879 

2880 # Add new items to database session in chunks to prevent lock escalation 

2881 chunk_size = 50 

2882 

2883 if tools_to_add: 

2884 for i in range(0, len(tools_to_add), chunk_size): 

2885 chunk = tools_to_add[i : i + chunk_size] 

2886 db.add_all(chunk) 

2887 db.flush() 

2888 if resources_to_add: 

2889 for i in range(0, len(resources_to_add), chunk_size): 

2890 chunk = resources_to_add[i : i + chunk_size] 

2891 db.add_all(chunk) 

2892 db.flush() 

2893 if prompts_to_add: 

2894 for i in range(0, len(prompts_to_add), chunk_size): 

2895 chunk = prompts_to_add[i : i + chunk_size] 

2896 db.add_all(chunk) 

2897 db.flush() 

2898 except Exception as e: 

2899 logger.warning(f"Failed to initialize reactivated gateway: {e}") 

2900 else: 

2901 self._active_gateways.discard(gateway.url) 

2902 

2903 db.commit() 

2904 db.refresh(gateway) 

2905 

2906 # Invalidate cache after status change 

2907 cache = _get_registry_cache() 

2908 await cache.invalidate_gateways() 

2909 

2910 # Notify Subscribers 

2911 if not gateway.enabled: 

2912 # Inactive 

2913 await self._notify_gateway_deactivated(gateway) 

2914 elif gateway.enabled and not gateway.reachable: 

2915 # Offline (Enabled but Unreachable) 

2916 await self._notify_gateway_offline(gateway) 

2917 else: 

2918 # Active (Enabled and Reachable) 

2919 await self._notify_gateway_activated(gateway) 

2920 

2921 # Bulk update tools - single UPDATE statement instead of N FOR UPDATE locks 

2922 # This prevents lock contention under high concurrent load 

2923 now = datetime.now(timezone.utc) 

2924 if only_update_reachable: 

2925 # Only update reachable status, keep enabled as-is 

2926 tools_result = db.execute(update(DbTool).where(DbTool.gateway_id == gateway_id).where(DbTool.reachable != reachable).values(reachable=reachable, updated_at=now)) 

2927 else: 

2928 # Update both enabled and reachable 

2929 tools_result = db.execute( 

2930 update(DbTool) 

2931 .where(DbTool.gateway_id == gateway_id) 

2932 .where(or_(DbTool.enabled != activate, DbTool.reachable != reachable)) 

2933 .values(enabled=activate, reachable=reachable, updated_at=now) 

2934 ) 

2935 tools_updated = tools_result.rowcount 

2936 

2937 # Commit tool updates 

2938 if tools_updated > 0: 

2939 db.commit() 

2940 

2941 # Invalidate tools cache once after bulk update 

2942 if tools_updated > 0: 

2943 await cache.invalidate_tools() 

2944 tool_lookup_cache = _get_tool_lookup_cache() 

2945 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

2946 

2947 # Bulk update prompts when gateway is deactivated/activated (skip for reachability-only updates) 

2948 prompts_updated = 0 

2949 if not only_update_reachable: 

2950 prompts_result = db.execute(update(DbPrompt).where(DbPrompt.gateway_id == gateway_id).where(DbPrompt.enabled != activate).values(enabled=activate, updated_at=now)) 

2951 prompts_updated = prompts_result.rowcount 

2952 if prompts_updated > 0: 

2953 db.commit() 

2954 await cache.invalidate_prompts() 

2955 

2956 # Bulk update resources when gateway is deactivated/activated (skip for reachability-only updates) 

2957 resources_updated = 0 

2958 if not only_update_reachable: 

2959 resources_result = db.execute(update(DbResource).where(DbResource.gateway_id == gateway_id).where(DbResource.enabled != activate).values(enabled=activate, updated_at=now)) 

2960 resources_updated = resources_result.rowcount 

2961 if resources_updated > 0: 

2962 db.commit() 

2963 await cache.invalidate_resources() 

2964 

2965 logger.debug(f"Gateway {SecurityValidator.sanitize_log_message(gateway.name)} bulk state update: {tools_updated} tools, {prompts_updated} prompts, {resources_updated} resources") 

2966 

2967 logger.info(f"Gateway status: {SecurityValidator.sanitize_log_message(gateway.name)} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") 

2968 

2969 # Structured logging: Audit trail for gateway state change 

2970 audit_trail.log_action( 

2971 user_id=user_email or "system", 

2972 action="set_gateway_state", 

2973 resource_type="gateway", 

2974 resource_id=str(gateway.id), 

2975 resource_name=gateway.name, 

2976 user_email=user_email, 

2977 team_id=gateway.team_id, 

2978 new_values={ 

2979 "enabled": gateway.enabled, 

2980 "reachable": gateway.reachable, 

2981 }, 

2982 context={ 

2983 "action": "activate" if activate else "deactivate", 

2984 "only_update_reachable": only_update_reachable, 

2985 }, 

2986 db=db, 

2987 ) 

2988 

2989 # Structured logging: Log successful gateway state change 

2990 structured_logger.log( 

2991 level="INFO", 

2992 message=f"Gateway {'activated' if activate else 'deactivated'} successfully", 

2993 event_type="gateway_state_changed", 

2994 component="gateway_service", 

2995 user_email=user_email, 

2996 team_id=gateway.team_id, 

2997 resource_type="gateway", 

2998 resource_id=str(gateway.id), 

2999 custom_fields={ 

3000 "gateway_name": gateway.name, 

3001 "enabled": gateway.enabled, 

3002 "reachable": gateway.reachable, 

3003 }, 

3004 ) 

3005 

3006 return self.convert_gateway_to_read(gateway) 

3007 

3008 except PermissionError as e: 

3009 db.rollback() 

3010 

3011 # Structured logging: Log permission error 

3012 structured_logger.log( 

3013 level="WARNING", 

3014 message="Gateway state change failed due to permission error", 

3015 event_type="gateway_state_change_permission_denied", 

3016 component="gateway_service", 

3017 user_email=user_email, 

3018 resource_type="gateway", 

3019 resource_id=gateway_id, 

3020 error=e, 

3021 ) 

3022 raise e 

3023 except Exception as e: 

3024 db.rollback() 

3025 

3026 # Structured logging: Log generic gateway state change failure 

3027 structured_logger.log( 

3028 level="ERROR", 

3029 message="Gateway state change failed", 

3030 event_type="gateway_state_change_failed", 

3031 component="gateway_service", 

3032 user_email=user_email, 

3033 resource_type="gateway", 

3034 resource_id=gateway_id, 

3035 error=e, 

3036 ) 

3037 raise GatewayError(f"Failed to set gateway state: {str(e)}") 

3038 

3039 async def _notify_gateway_updated(self, gateway: DbGateway) -> None: 

3040 """ 

3041 Notify subscribers of gateway update. 

3042 

3043 Args: 

3044 gateway: Gateway to update 

3045 """ 

3046 event = { 

3047 "type": "gateway_updated", 

3048 "data": { 

3049 "id": gateway.id, 

3050 "name": gateway.name, 

3051 "url": gateway.url, 

3052 "description": gateway.description, 

3053 "enabled": gateway.enabled, 

3054 }, 

3055 "timestamp": datetime.now(timezone.utc).isoformat(), 

3056 } 

3057 await self._publish_event(event) 

3058 

3059 async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optional[str] = None) -> None: 

3060 """ 

3061 Delete a gateway by its ID. 

3062 

3063 Args: 

3064 db: Database session 

3065 gateway_id: Gateway ID 

3066 user_email: Email of user performing deletion (for ownership check) 

3067 

3068 Raises: 

3069 GatewayNotFoundError: If the gateway is not found 

3070 PermissionError: If user doesn't own the gateway 

3071 GatewayError: For other deletion errors 

3072 

3073 Examples: 

3074 >>> from mcpgateway.services.gateway_service import GatewayService 

3075 >>> from unittest.mock import MagicMock 

3076 >>> service = GatewayService() 

3077 >>> db = MagicMock() 

3078 >>> gateway = MagicMock() 

3079 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway 

3080 >>> db.delete = MagicMock() 

3081 >>> db.commit = MagicMock() 

3082 >>> service._notify_gateway_deleted = MagicMock() 

3083 >>> import asyncio 

3084 >>> try: 

3085 ... asyncio.run(service.delete_gateway(db, 'gateway_id', 'user@example.com')) 

3086 ... except Exception: 

3087 ... pass 

3088 >>> 

3089 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

3090 >>> asyncio.run(service._http_client.aclose()) 

3091 """ 

3092 try: 

3093 # Find gateway with eager loading for deletion to avoid N+1 queries 

3094 gateway = db.execute( 

3095 select(DbGateway) 

3096 .options( 

3097 selectinload(DbGateway.tools), 

3098 selectinload(DbGateway.resources), 

3099 selectinload(DbGateway.prompts), 

3100 joinedload(DbGateway.email_team), 

3101 ) 

3102 .where(DbGateway.id == gateway_id) 

3103 ).scalar_one_or_none() 

3104 

3105 if not gateway: 

3106 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

3107 

3108 # Check ownership if user_email provided 

3109 if user_email: 

3110 # First-Party 

3111 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

3112 

3113 permission_service = PermissionService(db) 

3114 if not await permission_service.check_resource_ownership(user_email, gateway): 

3115 raise PermissionError("Only the owner can delete this gateway") 

3116 

3117 # Store gateway info for notification before deletion 

3118 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} 

3119 gateway_name = gateway.name 

3120 gateway_team_id = gateway.team_id 

3121 gateway_url = gateway.url # Store URL before expiring the object 

3122 

3123 # Manually delete children first to avoid FK constraint violations 

3124 # (passive_deletes=True means ORM won't auto-cascade, we must do it explicitly) 

3125 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

3126 tool_ids = [t.id for t in gateway.tools] 

3127 resource_ids = [r.id for r in gateway.resources] 

3128 prompt_ids = [p.id for p in gateway.prompts] 

3129 

3130 # Delete tool children and tools 

3131 if tool_ids: 

3132 for i in range(0, len(tool_ids), 500): 

3133 chunk = tool_ids[i : i + 500] 

3134 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

3135 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

3136 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

3137 

3138 # Delete resource children and resources 

3139 if resource_ids: 

3140 for i in range(0, len(resource_ids), 500): 

3141 chunk = resource_ids[i : i + 500] 

3142 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

3143 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

3144 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

3145 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

3146 

3147 # Delete prompt children and prompts 

3148 if prompt_ids: 

3149 for i in range(0, len(prompt_ids), 500): 

3150 chunk = prompt_ids[i : i + 500] 

3151 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

3152 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

3153 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

3154 

3155 # Expire gateway to clear cached relationships after bulk deletes 

3156 db.expire(gateway) 

3157 

3158 # Use DELETE with rowcount check for database-agnostic atomic delete 

3159 stmt = delete(DbGateway).where(DbGateway.id == gateway_id) 

3160 result = db.execute(stmt) 

3161 if result.rowcount == 0: 

3162 # Gateway was already deleted by another concurrent request 

3163 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

3164 

3165 db.commit() 

3166 

3167 # Invalidate cache after successful deletion 

3168 cache = _get_registry_cache() 

3169 await cache.invalidate_gateways() 

3170 tool_lookup_cache = _get_tool_lookup_cache() 

3171 await tool_lookup_cache.invalidate_gateway(str(gateway_id)) 

3172 # Also invalidate tags cache since gateway tags may have changed 

3173 # First-Party 

3174 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

3175 

3176 await admin_stats_cache.invalidate_tags() 

3177 

3178 # Invalidate loopback passthrough cache when a gateway is deleted (#3640) 

3179 # First-Party 

3180 from mcpgateway.utils.passthrough_headers import invalidate_passthrough_header_caches # pylint: disable=import-outside-toplevel 

3181 

3182 invalidate_passthrough_header_caches() 

3183 

3184 # Update tracking 

3185 self._active_gateways.discard(gateway_url) 

3186 

3187 # Notify subscribers 

3188 await self._notify_gateway_deleted(gateway_info) 

3189 

3190 logger.info(f"Permanently deleted gateway: {gateway_name}") 

3191 

3192 # Structured logging: Audit trail for gateway deletion 

3193 audit_trail.log_action( 

3194 user_id=user_email or "system", 

3195 action="delete_gateway", 

3196 resource_type="gateway", 

3197 resource_id=str(gateway_info["id"]), 

3198 resource_name=gateway_name, 

3199 user_email=user_email, 

3200 team_id=gateway_team_id, 

3201 old_values={ 

3202 "name": gateway_name, 

3203 "url": gateway_info["url"], 

3204 }, 

3205 db=db, 

3206 ) 

3207 

3208 # Structured logging: Log successful gateway deletion 

3209 structured_logger.log( 

3210 level="INFO", 

3211 message="Gateway deleted successfully", 

3212 event_type="gateway_deleted", 

3213 component="gateway_service", 

3214 user_email=user_email, 

3215 team_id=gateway_team_id, 

3216 resource_type="gateway", 

3217 resource_id=str(gateway_info["id"]), 

3218 custom_fields={ 

3219 "gateway_name": gateway_name, 

3220 "gateway_url": gateway_info["url"], 

3221 }, 

3222 ) 

3223 

3224 except PermissionError as pe: 

3225 db.rollback() 

3226 

3227 # Structured logging: Log permission error 

3228 structured_logger.log( 

3229 level="WARNING", 

3230 message="Gateway deletion failed due to permission error", 

3231 event_type="gateway_delete_permission_denied", 

3232 component="gateway_service", 

3233 user_email=user_email, 

3234 resource_type="gateway", 

3235 resource_id=gateway_id, 

3236 error=pe, 

3237 ) 

3238 raise 

3239 except Exception as e: 

3240 db.rollback() 

3241 

3242 # Structured logging: Log generic gateway deletion failure 

3243 structured_logger.log( 

3244 level="ERROR", 

3245 message="Gateway deletion failed", 

3246 event_type="gateway_deletion_failed", 

3247 component="gateway_service", 

3248 user_email=user_email, 

3249 resource_type="gateway", 

3250 resource_id=gateway_id, 

3251 error=e, 

3252 ) 

3253 raise GatewayError(f"Failed to delete gateway: {str(e)}") 

3254 

3255 async def _handle_gateway_failure(self, gateway: DbGateway) -> None: 

3256 """Tracks and handles gateway failures during health checks. 

3257 If the failure count exceeds the threshold, the gateway is deactivated. 

3258 

3259 Args: 

3260 gateway: The gateway object that failed its health check. 

3261 

3262 Returns: 

3263 None 

3264 

3265 Examples: 

3266 >>> from mcpgateway.services.gateway_service import GatewayService 

3267 >>> service = GatewayService() 

3268 >>> gateway = type('Gateway', (), { 

3269 ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True 

3270 ... })() 

3271 >>> service._gateway_failure_counts = {} 

3272 >>> import asyncio 

3273 >>> # Test failure counting 

3274 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS 

3275 >>> service._gateway_failure_counts['gw1'] >= 1 

3276 True 

3277 

3278 >>> # Test disabled gateway (no action) 

3279 >>> gateway.enabled = False 

3280 >>> old_count = service._gateway_failure_counts.get('gw1', 0) 

3281 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS 

3282 >>> service._gateway_failure_counts.get('gw1', 0) == old_count 

3283 True 

3284 """ 

3285 if GW_FAILURE_THRESHOLD == -1: 

3286 return # Gateway failure action disabled 

3287 

3288 if not gateway.enabled: 

3289 return # No action needed for inactive gateways 

3290 

3291 if not gateway.reachable: 

3292 return # No action needed for unreachable gateways 

3293 

3294 count = self._gateway_failure_counts.get(gateway.id, 0) + 1 

3295 self._gateway_failure_counts[gateway.id] = count 

3296 

3297 logger.warning(f"Gateway {SecurityValidator.sanitize_log_message(gateway.name)} failed health check {count} time(s).") 

3298 

3299 if count >= GW_FAILURE_THRESHOLD: 

3300 logger.error(f"Gateway {SecurityValidator.sanitize_log_message(gateway.name)} failed {GW_FAILURE_THRESHOLD} times. Deactivating...") 

3301 with cast(Any, SessionLocal)() as db: 

3302 await self.set_gateway_state(db, gateway.id, activate=True, reachable=False, only_update_reachable=True) 

3303 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation 

3304 

3305 async def check_health_of_gateways(self, gateways: List[DbGateway], user_email: Optional[str] = None) -> bool: 

3306 """Check health of a batch of gateways. 

3307 

3308 Performs an asynchronous health-check for each gateway in `gateways` using 

3309 an Async HTTP client. The function handles different authentication 

3310 modes (OAuth client_credentials and authorization_code, and non-OAuth 

3311 auth headers). When a gateway uses the authorization_code flow, the 

3312 optional `user_email` is used to look up stored user tokens with 

3313 fresh_db_session(). On individual failures the service will record the 

3314 failure and call internal failure handling which may mark a gateway 

3315 unreachable or deactivate it after repeated failures. If a previously 

3316 unreachable gateway becomes healthy again the service will attempt to 

3317 update its reachable status. 

3318 

3319 NOTE: This method intentionally does NOT take a db parameter. 

3320 DB access uses fresh_db_session() only when needed, avoiding holding 

3321 connections during HTTP calls to MCP servers. 

3322 

3323 Args: 

3324 gateways: List of DbGateway objects to check. 

3325 user_email: Optional MCP gateway user email used to retrieve 

3326 stored OAuth tokens for gateways using the 

3327 "authorization_code" grant type. If not provided, authorization 

3328 code flows that require a user token will be treated as failed. 

3329 

3330 Returns: 

3331 bool: True when the health-check batch completes. This return 

3332 value indicates completion of the checks, not that every gateway 

3333 was healthy. Individual gateway failures are handled internally 

3334 (via _handle_gateway_failure and status updates). 

3335 

3336 Examples: 

3337 >>> from mcpgateway.services.gateway_service import GatewayService 

3338 >>> from unittest.mock import MagicMock 

3339 >>> service = GatewayService() 

3340 >>> gateways = [MagicMock()] 

3341 >>> gateways[0].ca_certificate = None 

3342 >>> import asyncio 

3343 >>> result = asyncio.run(service.check_health_of_gateways(gateways)) 

3344 >>> isinstance(result, bool) 

3345 True 

3346 

3347 >>> # Test empty gateway list 

3348 >>> empty_result = asyncio.run(service.check_health_of_gateways([])) 

3349 >>> empty_result 

3350 True 

3351 

3352 >>> # Test multiple gateways (basic smoke) 

3353 >>> multiple_gateways = [MagicMock(), MagicMock(), MagicMock()] 

3354 >>> for i, gw in enumerate(multiple_gateways): 

3355 ... gw.name = f"gateway_{i}" 

3356 ... gw.url = f"http://gateway{i}.example.com" 

3357 ... gw.transport = "SSE" 

3358 ... gw.enabled = True 

3359 ... gw.reachable = True 

3360 ... gw.auth_value = {} 

3361 ... gw.ca_certificate = None 

3362 >>> multi_result = asyncio.run(service.check_health_of_gateways(multiple_gateways)) 

3363 >>> isinstance(multi_result, bool) 

3364 True 

3365 """ 

3366 start_time = time.monotonic() 

3367 concurrency_limit = min(settings.max_concurrent_health_checks, max(10, os.cpu_count() * 5)) # adaptive concurrency 

3368 semaphore = asyncio.Semaphore(concurrency_limit) 

3369 

3370 async def limited_check(gateway: DbGateway): 

3371 """ 

3372 Checks the health of a single gateway while respecting a concurrency limit. 

3373 

3374 This function checks the health of the given database gateway, ensuring that 

3375 the number of concurrent checks does not exceed a predefined limit. The check 

3376 is performed asynchronously and uses a semaphore to manage concurrency. 

3377 

3378 Args: 

3379 gateway (DbGateway): The database gateway whose health is to be checked. 

3380 

3381 Raises: 

3382 Any exceptions raised during the health check will be propagated to the caller. 

3383 """ 

3384 async with semaphore: 

3385 try: 

3386 await asyncio.wait_for( 

3387 self._check_single_gateway_health(gateway, user_email), 

3388 timeout=settings.gateway_health_check_timeout, 

3389 ) 

3390 except asyncio.TimeoutError: 

3391 logger.warning(f"Gateway {getattr(gateway, 'name', 'unknown')} health check timed out after {settings.gateway_health_check_timeout}s") 

3392 # Treat timeout as a failed health check 

3393 await self._handle_gateway_failure(gateway) 

3394 

3395 # Create trace span for health check batch 

3396 with create_span("gateway.health_check_batch", {"gateway.count": len(gateways), "check.type": "health"}) as batch_span: 

3397 # Chunk processing to avoid overload 

3398 if not gateways: 

3399 return True 

3400 chunk_size = concurrency_limit 

3401 for i in range(0, len(gateways), chunk_size): 

3402 # batch will be a sublist of gateways from index i to i + chunk_size 

3403 batch = gateways[i : i + chunk_size] 

3404 

3405 # Each task is a health check for a gateway in the batch, excluding those with auth_type == "one_time_auth" 

3406 tasks = [limited_check(gw) for gw in batch if gw.auth_type != "one_time_auth"] 

3407 

3408 # Execute all health checks concurrently 

3409 await asyncio.gather(*tasks, return_exceptions=True) 

3410 await asyncio.sleep(0.05) # small pause prevents network saturation 

3411 

3412 elapsed = time.monotonic() - start_time 

3413 

3414 if batch_span: 

3415 set_span_attribute(batch_span, "check.duration_ms", int(elapsed * 1000)) 

3416 set_span_attribute(batch_span, "check.completed", True) 

3417 

3418 logger.debug(f"Health check batch completed for {len(gateways)} gateways in {elapsed:.2f}s") 

3419 

3420 return True 

3421 

3422 async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Optional[str] = None) -> None: 

3423 """Check health of a single gateway. 

3424 

3425 NOTE: This method intentionally does NOT take a db parameter. 

3426 DB access uses fresh_db_session() only when needed, avoiding holding 

3427 connections during HTTP calls to MCP servers. 

3428 

3429 Args: 

3430 gateway: Gateway to check (may be detached from session) 

3431 user_email: Optional user email for OAuth token lookup 

3432 """ 

3433 # Extract gateway data upfront (gateway may be detached from session) 

3434 gateway_id = gateway.id 

3435 gateway_name = gateway.name 

3436 gateway_url = gateway.url 

3437 gateway_transport = gateway.transport 

3438 gateway_enabled = gateway.enabled 

3439 gateway_reachable = gateway.reachable 

3440 gateway_ca_certificate = gateway.ca_certificate 

3441 gateway_ca_certificate_sig = gateway.ca_certificate_sig 

3442 gateway_auth_type = gateway.auth_type 

3443 gateway_oauth_config = gateway.oauth_config 

3444 gateway_auth_value = gateway.auth_value 

3445 gateway_auth_query_params = gateway.auth_query_params 

3446 health_client_cert = getattr(gateway, "client_cert", None) 

3447 health_client_key = getattr(gateway, "client_key", None) 

3448 

3449 # Handle query_param auth - decrypt and apply to URL for health check 

3450 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

3451 # Preserve the base URL (without auth query params) for classification lookups. 

3452 # Classification uses Gateway.url from the DB, so poll-state keys must match. 

3453 gateway_base_url = gateway_url 

3454 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

3455 auth_query_params_decrypted = {} 

3456 for param_key, encrypted_value in gateway_auth_query_params.items(): 

3457 if encrypted_value: 

3458 try: 

3459 decrypted = decode_auth(encrypted_value) 

3460 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

3461 except Exception: 

3462 logger.debug(f"Failed to decrypt query param '{param_key}' for health check") 

3463 if auth_query_params_decrypted: 

3464 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) 

3465 

3466 # Sanitize URL for logging/telemetry (redacts sensitive query params) 

3467 gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted) 

3468 

3469 # NOTE: Health checks always run regardless of hot/cold classification. 

3470 # Classification only gates auto-refresh (tool discovery), not health monitoring. 

3471 # Skipping health checks would blind the gateway to outages on cold servers. 

3472 

3473 # Create span for individual gateway health check 

3474 with create_span( 

3475 "gateway.health_check", 

3476 { 

3477 "gateway.name": gateway_name, 

3478 "gateway.id": str(gateway_id), 

3479 "gateway.url": gateway_url_sanitized, 

3480 "gateway.transport": gateway_transport, 

3481 "gateway.enabled": gateway_enabled, 

3482 "http.method": "GET", 

3483 "http.url": gateway_url_sanitized, 

3484 }, 

3485 ) as span: 

3486 valid = False 

3487 if gateway_ca_certificate: 

3488 if settings.enable_ed25519_signing: 

3489 public_key_pem = settings.ed25519_public_key 

3490 valid = validate_signature(gateway_ca_certificate.encode(), gateway_ca_certificate_sig, public_key_pem) 

3491 else: 

3492 valid = True 

3493 

3494 # Decrypt client_key for health check mTLS 

3495 _hc_client_key = health_client_key 

3496 if _hc_client_key: 

3497 try: 

3498 _enc = get_encryption_service(settings.auth_encryption_secret) 

3499 _hc_client_key = _enc.decrypt_secret_or_plaintext(_hc_client_key) 

3500 except Exception: 

3501 logger.debug("client_key decryption skipped during health check") 

3502 

3503 if gateway_url and gateway_url.lower().startswith("http://"): 

3504 ssl_context = None 

3505 elif valid and gateway_ca_certificate: 

3506 ssl_context = get_cached_ssl_context(gateway_ca_certificate, client_cert=health_client_cert, client_key=_hc_client_key) 

3507 else: 

3508 ssl_context = None 

3509 

3510 def get_httpx_client_factory( 

3511 headers: dict[str, str] | None = None, 

3512 timeout: httpx.Timeout | None = None, 

3513 auth: httpx.Auth | None = None, 

3514 ) -> httpx.AsyncClient: 

3515 """Factory function to create httpx.AsyncClient with optional CA certificate. 

3516 

3517 Args: 

3518 headers: Optional headers for the client 

3519 timeout: Optional timeout for the client 

3520 auth: Optional auth for the client 

3521 

3522 Returns: 

3523 httpx.AsyncClient: Configured HTTPX async client 

3524 """ 

3525 return httpx.AsyncClient( 

3526 verify=ssl_context if ssl_context else get_default_verify(), 

3527 follow_redirects=True, 

3528 headers=headers, 

3529 timeout=timeout if timeout else get_http_timeout(), 

3530 auth=auth, 

3531 limits=httpx.Limits( 

3532 max_connections=settings.httpx_max_connections, 

3533 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

3534 keepalive_expiry=settings.httpx_keepalive_expiry, 

3535 ), 

3536 ) 

3537 

3538 # Use isolated client for gateway health checks (each gateway may have custom CA cert) 

3539 # Use admin timeout for health checks (fail fast, don't wait 120s for slow upstreams) 

3540 # Pass ssl_context if present, otherwise let get_isolated_http_client use skip_ssl_verify setting 

3541 async with get_isolated_http_client(timeout=settings.httpx_admin_read_timeout, verify=ssl_context) as client: 

3542 logger.debug(f"Checking health of gateway: {gateway_name} ({gateway_url_sanitized})") 

3543 try: 

3544 # Handle different authentication types 

3545 headers = {} 

3546 

3547 if gateway_auth_type == "oauth" and gateway_oauth_config: 

3548 grant_type = gateway_oauth_config.get("grant_type", "client_credentials") 

3549 

3550 if grant_type == "authorization_code": 

3551 # For Authorization Code flow, try to get stored tokens 

3552 try: 

3553 # First-Party 

3554 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

3555 

3556 # Use fresh session for OAuth token lookup 

3557 with fresh_db_session() as token_db: 

3558 token_storage = TokenStorageService(token_db) 

3559 

3560 # Get user-specific OAuth token 

3561 if not user_email: 

3562 if span: 

3563 set_span_attribute(span, "health.status", "unhealthy") 

3564 set_span_error(span, "User email required for OAuth token") 

3565 await self._handle_gateway_failure(gateway) 

3566 return 

3567 

3568 access_token = await token_storage.get_user_token(gateway_id, user_email) 

3569 

3570 if access_token: 

3571 headers["Authorization"] = f"Bearer {access_token}" 

3572 else: 

3573 if span: 

3574 set_span_attribute(span, "health.status", "unhealthy") 

3575 set_span_error(span, "No valid OAuth token for user") 

3576 await self._handle_gateway_failure(gateway) 

3577 return 

3578 except Exception as e: 

3579 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}") 

3580 if span: 

3581 set_span_attribute(span, "health.status", "unhealthy") 

3582 set_span_error(span, "Failed to obtain stored OAuth token") 

3583 await self._handle_gateway_failure(gateway) 

3584 return 

3585 else: 

3586 # For Client Credentials flow, get token directly 

3587 try: 

3588 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config) 

3589 headers["Authorization"] = f"Bearer {access_token}" 

3590 except Exception as e: 

3591 if span: 

3592 set_span_attribute(span, "health.status", "unhealthy") 

3593 set_span_error(span, e) 

3594 await self._handle_gateway_failure(gateway) 

3595 return 

3596 else: 

3597 # Handle non-OAuth authentication (existing logic) 

3598 auth_data = gateway_auth_value or {} 

3599 if isinstance(auth_data, str): 

3600 headers = decode_auth(auth_data) 

3601 elif isinstance(auth_data, dict): 

3602 headers = {str(k): str(v) for k, v in auth_data.items()} 

3603 else: 

3604 headers = {} 

3605 

3606 # Perform the GET and raise on 4xx/5xx 

3607 if (gateway_transport).lower() == "sse": 

3608 timeout = httpx.Timeout(settings.health_check_timeout) 

3609 async with client.stream("GET", gateway_url, headers=headers, timeout=timeout) as response: 

3610 # This will raise immediately if status is 4xx/5xx 

3611 response.raise_for_status() 

3612 if span: 

3613 set_span_attribute(span, "http.status_code", response.status_code) 

3614 elif (gateway_transport).lower() == "streamablehttp": 

3615 # Use session pool if enabled for faster health checks 

3616 use_pool = False 

3617 pool = None 

3618 if settings.mcp_session_pool_enabled: 

3619 try: 

3620 pool = get_mcp_session_pool() 

3621 use_pool = True 

3622 except RuntimeError: 

3623 # Pool not initialized (e.g., in tests), fall back to per-call sessions 

3624 pass 

3625 

3626 if use_pool and pool is not None: 

3627 # Health checks are system operations, not user-driven. 

3628 # Use system identity to isolate from user sessions. 

3629 async with pool.session( 

3630 url=gateway_url, 

3631 headers=headers, 

3632 transport_type=TransportType.STREAMABLE_HTTP, 

3633 httpx_client_factory=get_httpx_client_factory, 

3634 user_identity="_system_health_check", 

3635 gateway_id=gateway_id, 

3636 ) as pooled: 

3637 # Optional explicit RPC verification (off by default for performance). 

3638 # Pool's internal staleness check handles health via _validate_session. 

3639 if settings.mcp_session_pool_explicit_health_rpc: 

3640 with anyio.fail_after(settings.health_check_timeout): 

3641 await pooled.session.list_tools() 

3642 else: 

3643 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout, httpx_client_factory=get_httpx_client_factory) as ( 

3644 read_stream, 

3645 write_stream, 

3646 _get_session_id, 

3647 ): 

3648 async with ClientSession(read_stream, write_stream) as session: 

3649 # Initialize the session 

3650 response = await session.initialize() 

3651 

3652 # Reactivate gateway if it was previously inactive and health check passed now 

3653 if gateway_enabled and not gateway_reachable: 

3654 logger.info(f"Reactivating gateway: {gateway_name}, as it is healthy now") 

3655 with cast(Any, SessionLocal)() as status_db: 

3656 await self.set_gateway_state(status_db, gateway_id, activate=True, reachable=True, only_update_reachable=True) 

3657 

3658 # Update last_seen with fresh session (gateway object is detached) 

3659 try: 

3660 with fresh_db_session() as update_db: 

3661 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

3662 if db_gateway: 

3663 db_gateway.last_seen = datetime.now(timezone.utc) 

3664 update_db.commit() 

3665 except Exception as update_error: 

3666 logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}") 

3667 

3668 # Auto-refresh tools/resources/prompts if enabled 

3669 should_auto_refresh = False 

3670 if settings.auto_refresh_servers: 

3671 # Hot/cold classification: Check if this server should have tools refreshed now 

3672 if self._classification_service: 

3673 try: 

3674 should_auto_refresh = await self._classification_service.should_poll_server(gateway_base_url, "tool_discovery", gateway_id=str(gateway_id)) 

3675 if not should_auto_refresh: 

3676 logger.debug(f"Skipping auto-refresh for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification") 

3677 except Exception as e: 

3678 # Fail open: proceed with auto-refresh if classification check fails 

3679 logger.warning(f"Classification check failed for {SecurityValidator.sanitize_log_message(gateway_name)}, proceeding with auto-refresh (fail-open): {e}") 

3680 should_auto_refresh = True 

3681 else: 

3682 should_auto_refresh = True 

3683 

3684 if should_auto_refresh: 

3685 try: 

3686 # Throttling: Check if refresh is needed based on last_refresh_at 

3687 refresh_needed = True 

3688 if gateway.last_refresh_at: 

3689 # Default to config value if configured interval is missing 

3690 

3691 last_refresh = gateway.last_refresh_at 

3692 if last_refresh.tzinfo is None: 

3693 last_refresh = last_refresh.replace(tzinfo=timezone.utc) 

3694 

3695 # Use per-gateway interval if set, otherwise fall back to global default 

3696 refresh_interval = getattr(settings, "gateway_auto_refresh_interval", 300) 

3697 if gateway.refresh_interval_seconds is not None: 

3698 refresh_interval = gateway.refresh_interval_seconds 

3699 

3700 time_since_refresh = (datetime.now(timezone.utc) - last_refresh).total_seconds() 

3701 

3702 if time_since_refresh < refresh_interval: 

3703 refresh_needed = False 

3704 logger.debug(f"Skipping auto-refresh for {gateway_name}: last refreshed {int(time_since_refresh)}s ago") 

3705 

3706 if refresh_needed: 

3707 # Locking: Try to acquire lock to avoid conflict with manual refresh 

3708 lock = self._get_refresh_lock(gateway_id) 

3709 if not lock.locked(): 

3710 # Acquire lock to prevent concurrent manual refresh 

3711 async with lock: 

3712 await self._refresh_gateway_tools_resources_prompts( 

3713 gateway_id=gateway_id, 

3714 _user_email=user_email, 

3715 created_via="health_check", 

3716 pre_auth_headers=headers if headers else None, 

3717 gateway=gateway, 

3718 ) 

3719 # mark_poll_completed is called inside _refresh_gateway_tools_resources_prompts 

3720 else: 

3721 logger.debug(f"Skipping auto-refresh for {gateway_name}: lock held (likely manual refresh in progress)") 

3722 except Exception as refresh_error: 

3723 logger.warning(f"Failed to refresh tools for gateway {gateway_name}: {refresh_error}") 

3724 

3725 if span: 

3726 set_span_attribute(span, "health.status", "healthy") 

3727 set_span_attribute(span, "success", True) 

3728 

3729 except Exception as e: 

3730 if span: 

3731 set_span_attribute(span, "health.status", "unhealthy") 

3732 set_span_error(span, e) 

3733 

3734 # Set the logger as debug as this check happens for each interval 

3735 logger.debug(f"Health check failed for gateway {gateway_name}: {e}") 

3736 await self._handle_gateway_failure(gateway) 

3737 

3738 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]: 

3739 """ 

3740 Aggregate capabilities across all gateways. 

3741 

3742 Args: 

3743 db: Database session 

3744 

3745 Returns: 

3746 Dictionary of aggregated capabilities 

3747 

3748 Examples: 

3749 >>> from mcpgateway.services.gateway_service import GatewayService 

3750 >>> from unittest.mock import MagicMock 

3751 >>> service = GatewayService() 

3752 >>> db = MagicMock() 

3753 >>> gateway_mock = MagicMock() 

3754 >>> gateway_mock.capabilities = {"tools": {"listChanged": True}, "custom": {"feature": True}} 

3755 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_mock] 

3756 >>> import asyncio 

3757 >>> result = asyncio.run(service.aggregate_capabilities(db)) 

3758 >>> isinstance(result, dict) 

3759 True 

3760 >>> 'prompts' in result 

3761 True 

3762 >>> 'resources' in result 

3763 True 

3764 >>> 'tools' in result 

3765 True 

3766 >>> 'logging' in result 

3767 True 

3768 >>> result['prompts']['listChanged'] 

3769 True 

3770 >>> result['resources']['subscribe'] 

3771 True 

3772 >>> result['resources']['listChanged'] 

3773 True 

3774 >>> result['tools']['listChanged'] 

3775 True 

3776 >>> isinstance(result['logging'], dict) 

3777 True 

3778 

3779 >>> # Test with no gateways 

3780 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

3781 >>> empty_result = asyncio.run(service.aggregate_capabilities(db)) 

3782 >>> isinstance(empty_result, dict) 

3783 True 

3784 >>> 'tools' in empty_result 

3785 True 

3786 

3787 >>> # Test capability merging 

3788 >>> gateway1 = MagicMock() 

3789 >>> gateway1.capabilities = {"tools": {"feature1": True}} 

3790 >>> gateway2 = MagicMock() 

3791 >>> gateway2.capabilities = {"tools": {"feature2": True}} 

3792 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway1, gateway2] 

3793 >>> merged_result = asyncio.run(service.aggregate_capabilities(db)) 

3794 >>> merged_result['tools']['listChanged'] # Default capability 

3795 True 

3796 """ 

3797 capabilities = { 

3798 "prompts": {"listChanged": True}, 

3799 "resources": {"subscribe": True, "listChanged": True}, 

3800 "tools": {"listChanged": True}, 

3801 "logging": {}, 

3802 } 

3803 

3804 # Get all active gateways 

3805 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

3806 

3807 # Combine capabilities 

3808 for gateway in gateways: 

3809 if gateway.capabilities: 

3810 for key, value in gateway.capabilities.items(): 

3811 if key not in capabilities: 

3812 capabilities[key] = value 

3813 elif isinstance(value, dict): 

3814 capabilities[key].update(value) 

3815 

3816 return capabilities 

3817 

3818 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

3819 """Subscribe to gateway events. 

3820 

3821 Creates a new event queue and subscribes to gateway events. Events are 

3822 yielded as they are published. The subscription is automatically cleaned 

3823 up when the generator is closed or goes out of scope. 

3824 

3825 Yields: 

3826 Dict[str, Any]: Gateway event messages with 'type', 'data', and 'timestamp' fields 

3827 

3828 Examples: 

3829 >>> service = GatewayService() 

3830 >>> import asyncio 

3831 >>> from unittest.mock import MagicMock 

3832 >>> # Create a mock async generator for the event service 

3833 >>> async def mock_event_gen(): 

3834 ... yield {"type": "test_event", "data": "payload"} 

3835 >>> 

3836 >>> # Mock the event service to return our generator 

3837 >>> service._event_service = MagicMock() 

3838 >>> service._event_service.subscribe_events.return_value = mock_event_gen() 

3839 >>> 

3840 >>> # Test the subscription 

3841 >>> async def test_sub(): 

3842 ... async for event in service.subscribe_events(): 

3843 ... return event 

3844 >>> 

3845 >>> result = asyncio.run(test_sub()) 

3846 >>> result 

3847 {'type': 'test_event', 'data': 'payload'} 

3848 """ 

3849 async for event in self._event_service.subscribe_events(): 

3850 yield event 

3851 

3852 async def _initialize_gateway( 

3853 self, 

3854 url: str, 

3855 authentication: Optional[Dict[str, str]] = None, 

3856 transport: str = "SSE", 

3857 auth_type: Optional[str] = None, 

3858 oauth_config: Optional[Dict[str, Any]] = None, 

3859 ca_certificate: Optional[bytes] = None, 

3860 pre_auth_headers: Optional[Dict[str, str]] = None, 

3861 include_resources: bool = True, 

3862 include_prompts: bool = True, 

3863 auth_query_params: Optional[Dict[str, str]] = None, 

3864 oauth_auto_fetch_tool_flag: Optional[bool] = False, 

3865 client_cert: Optional[str] = None, 

3866 client_key: Optional[str] = None, 

3867 ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: 

3868 """Initialize connection to a gateway and retrieve its capabilities. 

3869 

3870 Connects to an MCP gateway using the specified transport protocol, 

3871 performs the MCP handshake, and retrieves capabilities, tools, 

3872 resources, and prompts from the gateway. 

3873 

3874 Args: 

3875 url: Gateway URL to connect to 

3876 authentication: Optional authentication headers for the connection 

3877 transport: Transport protocol - "SSE" or "StreamableHTTP" 

3878 auth_type: Authentication type - "basic", "bearer", "authheaders", "oauth", "query_param" or None 

3879 oauth_config: OAuth configuration if auth_type is "oauth" 

3880 ca_certificate: CA certificate for SSL verification 

3881 pre_auth_headers: Pre-authenticated headers to skip OAuth token fetch (for reuse) 

3882 include_resources: Whether to include resources in the fetch 

3883 include_prompts: Whether to include prompts in the fetch 

3884 auth_query_params: Query param names for URL sanitization in error logs (decrypted values) 

3885 oauth_auto_fetch_tool_flag: Whether to skip the early return for OAuth Authorization Code flow. 

3886 When False (default), auth_code gateways return empty lists immediately (for health checks). 

3887 When True, attempts to connect even for auth_code gateways (for activation after user authorization). 

3888 client_cert: Optional client certificate path or PEM for mTLS 

3889 client_key: Optional client private key path or PEM for mTLS 

3890 

3891 Returns: 

3892 tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: 

3893 Capabilities dictionary, list of ToolCreate objects, list of ResourceCreate objects, and list of PromptCreate objects 

3894 

3895 Raises: 

3896 GatewayConnectionError: If connection or initialization fails 

3897 

3898 Examples: 

3899 >>> service = GatewayService() 

3900 >>> # Test parameter validation 

3901 >>> import asyncio 

3902 >>> from unittest.mock import AsyncMock 

3903 >>> # Avoid opening a real SSE connection in doctests (it can leak anyio streams on failure paths) 

3904 >>> service.connect_to_sse_server = AsyncMock(side_effect=GatewayConnectionError("boom")) 

3905 >>> async def test_params(): 

3906 ... try: 

3907 ... await service._initialize_gateway("hello//") 

3908 ... except Exception as e: 

3909 ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e) 

3910 

3911 >>> asyncio.run(test_params()) 

3912 True 

3913 

3914 >>> # Test default parameters 

3915 >>> hasattr(service, '_initialize_gateway') 

3916 True 

3917 >>> import inspect 

3918 >>> sig = inspect.signature(service._initialize_gateway) 

3919 >>> sig.parameters['transport'].default 

3920 'SSE' 

3921 >>> sig.parameters['authentication'].default is None 

3922 True 

3923 >>> 

3924 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

3925 >>> asyncio.run(service._http_client.aclose()) 

3926 """ 

3927 try: 

3928 if authentication is None: 

3929 authentication = {} 

3930 

3931 # Use pre-authenticated headers if provided (avoids duplicate OAuth token fetch) 

3932 if pre_auth_headers: 

3933 authentication = pre_auth_headers 

3934 # Handle OAuth authentication 

3935 elif auth_type == "oauth" and oauth_config: 

3936 grant_type = oauth_config.get("grant_type", "client_credentials") 

3937 

3938 if grant_type == "authorization_code": 

3939 if not oauth_auto_fetch_tool_flag: 

3940 # For Authorization Code flow during health checks, we can't initialize immediately 

3941 # because we need user consent. Just store the configuration 

3942 # and let the user complete the OAuth flow later. 

3943 logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""") 

3944 # Don't try to get access token here - it will be obtained during tool invocation 

3945 authentication = {} 

3946 

3947 # Skip MCP server connection for Authorization Code flow 

3948 # Tools will be fetched after OAuth completion 

3949 return {}, [], [], [] 

3950 # When flag is True (activation), skip token fetch but try to connect 

3951 # This allows activation to proceed - actual auth happens during tool invocation 

3952 logger.debug("OAuth Authorization Code gateway activation - skipping token fetch") 

3953 elif grant_type == "client_credentials": 

3954 # For Client Credentials flow, we can get the token immediately 

3955 try: 

3956 logger.debug("Obtaining OAuth access token for Client Credentials flow") 

3957 access_token = await self.oauth_manager.get_access_token(oauth_config) 

3958 authentication = {"Authorization": f"Bearer {access_token}"} 

3959 except Exception as e: 

3960 logger.error(f"Failed to obtain OAuth access token: {e}") 

3961 raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}") 

3962 

3963 capabilities = {} 

3964 tools = [] 

3965 resources = [] 

3966 prompts = [] 

3967 if auth_type in ("basic", "bearer", "authheaders") and isinstance(authentication, str): 

3968 authentication = decode_auth(authentication) 

3969 if transport.lower() == "sse": 

3970 capabilities, tools, resources, prompts = await self.connect_to_sse_server( 

3971 url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params, client_cert=client_cert, client_key=client_key 

3972 ) 

3973 elif transport.lower() == "streamablehttp": 

3974 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server( 

3975 url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params, client_cert=client_cert, client_key=client_key 

3976 ) 

3977 

3978 return capabilities, tools, resources, prompts 

3979 except Exception as e: 

3980 

3981 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup 

3982 root_cause = e 

3983 if isinstance(e, BaseExceptionGroup): 

3984 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions: 

3985 root_cause = root_cause.exceptions[0] 

3986 sanitized_url = sanitize_url_for_logging(url, auth_query_params) 

3987 raw_error = str(root_cause) or type(root_cause).__name__ 

3988 sanitized_error = sanitize_exception_message(raw_error, auth_query_params) 

3989 logger.error(f"Gateway initialization failed for {sanitized_url}: {sanitized_error}", exc_info=True) 

3990 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: {sanitized_error}") 

3991 

3992 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]: 

3993 """Sync function for database operations (runs in thread). 

3994 

3995 Args: 

3996 include_inactive: Whether to include inactive gateways 

3997 

3998 Returns: 

3999 List[DbGateway]: List of active gateways 

4000 

4001 Examples: 

4002 >>> from unittest.mock import patch, MagicMock 

4003 >>> service = GatewayService() 

4004 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session: 

4005 ... mock_db = MagicMock() 

4006 ... mock_session.return_value.__enter__.return_value = mock_db 

4007 ... mock_db.execute.return_value.scalars.return_value.all.return_value = [] 

4008 ... result = service._get_gateways() 

4009 ... isinstance(result, list) 

4010 True 

4011 

4012 >>> # Test include_inactive parameter handling 

4013 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session: 

4014 ... mock_db = MagicMock() 

4015 ... mock_session.return_value.__enter__.return_value = mock_db 

4016 ... mock_db.execute.return_value.scalars.return_value.all.return_value = [] 

4017 ... result_active_only = service._get_gateways(include_inactive=False) 

4018 ... isinstance(result_active_only, list) 

4019 True 

4020 """ 

4021 with cast(Any, SessionLocal)() as db: 

4022 if include_inactive: 

4023 return db.execute(select(DbGateway)).scalars().all() 

4024 # Only return active gateways 

4025 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

4026 

4027 def get_first_gateway_by_url(self, db: Session, url: str, team_id: Optional[str] = None, include_inactive: bool = False) -> Optional[GatewayRead]: 

4028 """Return the first DbGateway matching the given URL and optional team_id. 

4029 

4030 This is a synchronous helper intended for use from request handlers where 

4031 a simple DB lookup is needed. It normalizes the provided URL similar to 

4032 how gateways are stored and matches by the `url` column. If team_id is 

4033 provided, it restricts the search to that team. 

4034 

4035 Args: 

4036 db: Database session to use for the query 

4037 url: Gateway base URL to match (will be normalized) 

4038 team_id: Optional team id to restrict search 

4039 include_inactive: Whether to include inactive gateways 

4040 

4041 Returns: 

4042 Optional[DbGateway]: First matching gateway or None 

4043 """ 

4044 query = select(DbGateway).where(DbGateway.url == url) 

4045 if not include_inactive: 

4046 query = query.where(DbGateway.enabled) 

4047 if team_id: 

4048 query = query.where(DbGateway.team_id == team_id) 

4049 result = db.execute(query).scalars().first() 

4050 # Wrap the DB object in the GatewayRead schema for consistency with 

4051 # other service methods. Return None if no match found. 

4052 if result is None: 

4053 return None 

4054 return self.convert_gateway_to_read(result) 

4055 

4056 async def _run_leader_heartbeat(self) -> None: 

4057 """Run leader heartbeat loop with Redis reconnection support. 

4058 

4059 Refreshes the leader key TTL every heartbeat interval. Exits and starts 

4060 follower election if leadership is lost or after consecutive failures. 

4061 """ 

4062 consecutive_failures = 0 

4063 max_failures = 3 

4064 

4065 while True: 

4066 try: 

4067 await asyncio.sleep(self._leader_heartbeat_interval) 

4068 

4069 if not self._redis_client: 

4070 logger.warning("Redis client unavailable in heartbeat") 

4071 consecutive_failures += 1 

4072 if consecutive_failures >= max_failures: 

4073 logger.error("Lost Redis connection, stopping heartbeat") 

4074 return 

4075 continue 

4076 

4077 # Check if we're still the leader 

4078 current_leader = await self._redis_client.get(self._leader_key) 

4079 if current_leader != self._instance_id: 

4080 logger.info("Lost Redis leadership, stopping heartbeat") 

4081 self._start_follower_election() 

4082 return 

4083 

4084 # Refresh the leader key TTL 

4085 await self._redis_client.expire(self._leader_key, self._leader_ttl) 

4086 logger.debug(f"Leader heartbeat: refreshed TTL to {self._leader_ttl}s") 

4087 consecutive_failures = 0 

4088 

4089 except Exception as e: 

4090 consecutive_failures += 1 

4091 logger.warning(f"Leader heartbeat error (failure {consecutive_failures}/{max_failures}): {e}") 

4092 if consecutive_failures >= max_failures: 

4093 logger.error("Too many consecutive heartbeat failures, starting follower election") 

4094 self._start_follower_election() 

4095 return 

4096 

4097 def _start_follower_election(self) -> None: 

4098 """Start a follower election task if one is not already running.""" 

4099 if self._follower_election_task is None or self._follower_election_task.done(): 

4100 self._follower_election_task = asyncio.create_task(self._run_follower_election(settings.platform_admin_email)) 

4101 

4102 async def _run_follower_election(self, user_email: str) -> None: 

4103 """Continuously attempt to acquire leadership when not the leader. 

4104 

4105 This runs on follower instances and polls Redis to claim leadership 

4106 when the current leader key expires or becomes available. 

4107 

4108 Args: 

4109 user_email: Email of the user for OAuth token lookup 

4110 """ 

4111 retry_interval = max(1, self._leader_ttl // 3) # Poll at 1/3 of TTL 

4112 

4113 while True: 

4114 try: 

4115 await asyncio.sleep(retry_interval) 

4116 

4117 if not self._redis_client: 

4118 logger.warning("Redis client unavailable, cannot attempt election.") 

4119 continue 

4120 

4121 # Attempt to acquire leadership 

4122 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) 

4123 

4124 if is_leader: 

4125 logger.info("Acquired Redis leadership via follower election. Starting health check and heartbeat.") 

4126 # Cancel stale tasks from a previous leadership period to prevent 

4127 # orphaned loops running alongside the new ones. 

4128 if self._health_check_task and not self._health_check_task.done(): 

4129 self._health_check_task.cancel() 

4130 if getattr(self, "_leader_heartbeat_task", None) and not self._leader_heartbeat_task.done(): 

4131 self._leader_heartbeat_task.cancel() 

4132 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

4133 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat()) 

4134 return # Exit follower loop, now running as leader 

4135 

4136 except Exception as e: 

4137 logger.warning(f"Follower election error: {e}", exc_info=True) 

4138 

4139 async def _run_health_checks(self, user_email: str) -> None: 

4140 """Run health checks periodically, 

4141 Uses Redis or FileLock - for multiple workers. 

4142 Uses simple health check for single worker mode. 

4143 

4144 NOTE: This method intentionally does NOT take a db parameter. 

4145 Health checks use fresh_db_session() only when DB access is needed, 

4146 avoiding holding connections during HTTP calls to MCP servers. 

4147 

4148 Args: 

4149 user_email: Email of the user for OAuth token lookup 

4150 

4151 Examples: 

4152 >>> service = GatewayService() 

4153 >>> service._health_check_interval = 0.1 # Short interval for testing 

4154 >>> service._redis_client = None 

4155 >>> import asyncio 

4156 >>> # Test that method exists and is callable 

4157 >>> callable(service._run_health_checks) 

4158 True 

4159 >>> # Test setup without actual execution (would run forever) 

4160 >>> hasattr(service, '_health_check_interval') 

4161 True 

4162 >>> service._health_check_interval == 0.1 

4163 True 

4164 """ 

4165 

4166 while True: 

4167 try: 

4168 if self._redis_client and settings.cache_type == "redis": 

4169 # Redis-based leader check (async, decode_responses=True returns strings) 

4170 # Note: Leader key TTL refresh is handled by _run_leader_heartbeat task 

4171 current_leader = await self._redis_client.get(self._leader_key) 

4172 if current_leader != self._instance_id: 

4173 return 

4174 

4175 # Run health checks 

4176 gateways = await asyncio.to_thread(self._get_gateways) 

4177 if gateways: 

4178 await self.check_health_of_gateways(gateways, user_email) 

4179 

4180 await asyncio.sleep(self._health_check_interval) 

4181 

4182 elif settings.cache_type == "none": 

4183 try: 

4184 # For single worker mode, run health checks directly 

4185 gateways = await asyncio.to_thread(self._get_gateways) 

4186 if gateways: 

4187 await self.check_health_of_gateways(gateways, user_email) 

4188 except Exception as e: 

4189 logger.error(f"Health check run failed: {str(e)}") 

4190 

4191 await asyncio.sleep(self._health_check_interval) 

4192 

4193 else: 

4194 # FileLock-based leader fallback 

4195 try: 

4196 self._file_lock.acquire(timeout=0) 

4197 logger.info("File lock acquired. Running health checks.") 

4198 

4199 while True: 

4200 gateways = await asyncio.to_thread(self._get_gateways) 

4201 if gateways: 

4202 await self.check_health_of_gateways(gateways, user_email) 

4203 await asyncio.sleep(self._health_check_interval) 

4204 

4205 except Timeout: 

4206 logger.debug("File lock already held. Retrying later.") 

4207 await asyncio.sleep(self._health_check_interval) 

4208 

4209 except Exception as e: 

4210 logger.error(f"FileLock health check failed: {str(e)}") 

4211 

4212 finally: 

4213 if self._file_lock.is_locked: 

4214 try: 

4215 self._file_lock.release() 

4216 logger.info("Released file lock.") 

4217 except Exception as e: 

4218 logger.warning(f"Failed to release file lock: {str(e)}") 

4219 

4220 except Exception as e: 

4221 logger.error(f"Unexpected error in health check loop: {str(e)}") 

4222 await asyncio.sleep(self._health_check_interval) 

4223 

4224 def _get_auth_headers(self) -> Dict[str, str]: 

4225 """Get default headers for gateway requests (no authentication). 

4226 

4227 SECURITY: This method intentionally does NOT include authentication credentials. 

4228 Each gateway should have its own auth_value configured. Never send this gateway's 

4229 admin credentials to remote servers. 

4230 

4231 Returns: 

4232 dict: Default headers without authentication 

4233 

4234 Examples: 

4235 >>> service = GatewayService() 

4236 >>> headers = service._get_auth_headers() 

4237 >>> isinstance(headers, dict) 

4238 True 

4239 >>> 'Content-Type' in headers 

4240 True 

4241 >>> headers['Content-Type'] 

4242 'application/json' 

4243 >>> 'Authorization' not in headers # No credentials leaked 

4244 True 

4245 """ 

4246 return {"Content-Type": "application/json"} 

4247 

4248 async def _notify_gateway_added(self, gateway: DbGateway) -> None: 

4249 """Notify subscribers of gateway addition. 

4250 

4251 Args: 

4252 gateway: Gateway to add 

4253 """ 

4254 event = { 

4255 "type": "gateway_added", 

4256 "data": { 

4257 "id": gateway.id, 

4258 "name": gateway.name, 

4259 "url": gateway.url, 

4260 "description": gateway.description, 

4261 "enabled": gateway.enabled, 

4262 }, 

4263 "timestamp": datetime.now(timezone.utc).isoformat(), 

4264 } 

4265 await self._publish_event(event) 

4266 

4267 async def _notify_gateway_activated(self, gateway: DbGateway) -> None: 

4268 """Notify subscribers of gateway activation. 

4269 

4270 Args: 

4271 gateway: Gateway to activate 

4272 """ 

4273 event = { 

4274 "type": "gateway_activated", 

4275 "data": { 

4276 "id": gateway.id, 

4277 "name": gateway.name, 

4278 "url": gateway.url, 

4279 "enabled": gateway.enabled, 

4280 "reachable": gateway.reachable, 

4281 }, 

4282 "timestamp": datetime.now(timezone.utc).isoformat(), 

4283 } 

4284 await self._publish_event(event) 

4285 

4286 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None: 

4287 """Notify subscribers of gateway deactivation. 

4288 

4289 Args: 

4290 gateway: Gateway database object 

4291 """ 

4292 event = { 

4293 "type": "gateway_deactivated", 

4294 "data": { 

4295 "id": gateway.id, 

4296 "name": gateway.name, 

4297 "url": gateway.url, 

4298 "enabled": gateway.enabled, 

4299 "reachable": gateway.reachable, 

4300 }, 

4301 "timestamp": datetime.now(timezone.utc).isoformat(), 

4302 } 

4303 await self._publish_event(event) 

4304 

4305 async def _notify_gateway_offline(self, gateway: DbGateway) -> None: 

4306 """ 

4307 Notify subscribers that gateway is offline (Enabled but Unreachable). 

4308 

4309 Args: 

4310 gateway: Gateway database object 

4311 """ 

4312 event = { 

4313 "type": "gateway_offline", 

4314 "data": { 

4315 "id": gateway.id, 

4316 "name": gateway.name, 

4317 "url": gateway.url, 

4318 "enabled": True, 

4319 "reachable": False, 

4320 }, 

4321 "timestamp": datetime.now(timezone.utc).isoformat(), 

4322 } 

4323 await self._publish_event(event) 

4324 

4325 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None: 

4326 """Notify subscribers of gateway deletion. 

4327 

4328 Args: 

4329 gateway_info: Dict containing information about gateway to delete 

4330 """ 

4331 event = { 

4332 "type": "gateway_deleted", 

4333 "data": gateway_info, 

4334 "timestamp": datetime.now(timezone.utc).isoformat(), 

4335 } 

4336 await self._publish_event(event) 

4337 

4338 async def _notify_gateway_removed(self, gateway: DbGateway) -> None: 

4339 """Notify subscribers of gateway removal (deactivation). 

4340 

4341 Args: 

4342 gateway: Gateway to remove 

4343 """ 

4344 event = { 

4345 "type": "gateway_removed", 

4346 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled}, 

4347 "timestamp": datetime.now(timezone.utc).isoformat(), 

4348 } 

4349 await self._publish_event(event) 

4350 

4351 def convert_gateway_to_read(self, gateway: DbGateway) -> GatewayRead: 

4352 """Convert a DbGateway instance to a GatewayRead Pydantic model. 

4353 

4354 Args: 

4355 gateway: Gateway database object 

4356 

4357 Returns: 

4358 GatewayRead: Pydantic model instance 

4359 """ 

4360 gateway_dict = gateway.__dict__.copy() 

4361 gateway_dict.pop("_sa_instance_state", None) 

4362 

4363 # Ensure auth_value is properly encoded 

4364 if isinstance(gateway.auth_value, dict): 

4365 gateway_dict["auth_value"] = encode_auth(gateway.auth_value) 

4366 

4367 if gateway.tags: 

4368 # Check tags are list of strings or list of Dict[str, str] 

4369 if isinstance(gateway.tags[0], str): 

4370 # Convert tags from List[str] to List[Dict[str, str]] for GatewayRead 

4371 gateway_dict["tags"] = validate_tags_field(gateway.tags) 

4372 else: 

4373 gateway_dict["tags"] = gateway.tags 

4374 else: 

4375 gateway_dict["tags"] = [] 

4376 

4377 # Include metadata fields 

4378 gateway_dict["created_by"] = getattr(gateway, "created_by", None) 

4379 gateway_dict["modified_by"] = getattr(gateway, "modified_by", None) 

4380 gateway_dict["created_at"] = getattr(gateway, "created_at", None) 

4381 gateway_dict["updated_at"] = getattr(gateway, "updated_at", None) 

4382 gateway_dict["version"] = getattr(gateway, "version", None) 

4383 gateway_dict["team"] = getattr(gateway, "team", None) 

4384 

4385 # Populate tool count from the eagerly-loaded tools relationship when available 

4386 tools_rel = gateway.__dict__.get("tools") 

4387 gateway_dict["tool_count"] = len(tools_rel) if tools_rel is not None else 0 

4388 

4389 return GatewayRead.model_validate(gateway_dict).masked() 

4390 

4391 def _create_db_tool( 

4392 self, 

4393 tool: ToolCreate, 

4394 gateway: DbGateway, 

4395 created_by: Optional[str] = None, 

4396 created_from_ip: Optional[str] = None, 

4397 created_via: Optional[str] = None, 

4398 created_user_agent: Optional[str] = None, 

4399 ) -> DbTool: 

4400 """Create a DbTool with consistent federation metadata across all scenarios. 

4401 

4402 Args: 

4403 tool: Tool creation schema 

4404 gateway: Gateway database object 

4405 created_by: Username who created/updated this tool 

4406 created_from_ip: IP address of creator 

4407 created_via: Creation method (ui, api, federation, rediscovery) 

4408 created_user_agent: User agent of creation request 

4409 

4410 Returns: 

4411 DbTool: Consistently configured database tool object 

4412 """ 

4413 return DbTool( 

4414 original_name=tool.name, 

4415 custom_name=tool.name, 

4416 custom_name_slug=slugify(tool.name), 

4417 display_name=generate_display_name(tool.name), 

4418 title=_resolve_tool_title(tool), 

4419 url=gateway.url, 

4420 original_description=tool.description, 

4421 description=tool.description, 

4422 integration_type="MCP", # Gateway-discovered tools are MCP type 

4423 request_type=tool.request_type, 

4424 headers=tool.headers, 

4425 input_schema=tool.input_schema, 

4426 annotations=tool.annotations, 

4427 jsonpath_filter=tool.jsonpath_filter, 

4428 auth_type=gateway.auth_type, 

4429 auth_value=encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value, 

4430 # Federation metadata - consistent across all scenarios 

4431 created_by=created_by or "system", 

4432 created_from_ip=created_from_ip, 

4433 created_via=created_via or "federation", 

4434 created_user_agent=created_user_agent, 

4435 federation_source=gateway.name, 

4436 version=1, 

4437 # Inherit team assignment from gateway; respect per-tool visibility if set 

4438 team_id=gateway.team_id, 

4439 owner_email=gateway.owner_email, 

4440 visibility=getattr(tool, "visibility", None) or gateway.visibility, 

4441 ) 

4442 

4443 def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGateway, created_via: str, update_visibility: bool = False) -> List[DbTool]: 

4444 """Helper to handle update-or-create logic for tools from MCP server. 

4445 

4446 Args: 

4447 db: Database session 

4448 tools: List of tools from MCP server 

4449 gateway: Gateway object 

4450 created_via: String indicating creation source ("oauth", "update", etc.) 

4451 update_visibility: Whether to propagate gateway visibility to existing tools 

4452 

4453 Returns: 

4454 List of new tools to be added to the database 

4455 """ 

4456 if not tools: 

4457 return [] 

4458 

4459 tools_to_add = [] 

4460 

4461 # Batch fetch all existing tools for this gateway 

4462 tool_names = [tool.name for tool in tools if tool is not None] 

4463 if not tool_names: 

4464 return [] 

4465 

4466 existing_tools_query = select(DbTool).where(DbTool.gateway_id == gateway.id, DbTool.original_name.in_(tool_names)) 

4467 existing_tools = db.execute(existing_tools_query).scalars().all() 

4468 existing_tools_map = {tool.original_name: tool for tool in existing_tools} 

4469 

4470 for tool in tools: 

4471 if tool is None: 

4472 logger.warning("Skipping None tool in tools list") 

4473 continue 

4474 

4475 try: 

4476 # Check if tool already exists for this gateway from the tools_map 

4477 existing_tool = existing_tools_map.get(tool.name) 

4478 if existing_tool: 

4479 # Update existing tool if there are changes 

4480 fields_to_update = False 

4481 

4482 # Check basic field changes 

4483 # Compare against original_description (upstream value) rather than description 

4484 # (which may have been customized by the user) 

4485 basic_fields_changed = ( 

4486 existing_tool.url != gateway.url 

4487 or existing_tool.original_description != tool.description 

4488 or existing_tool.integration_type != "MCP" 

4489 or existing_tool.request_type != tool.request_type 

4490 ) 

4491 

4492 # Check schema and configuration changes 

4493 schema_fields_changed = ( 

4494 existing_tool.headers != tool.headers 

4495 or existing_tool.input_schema != tool.input_schema 

4496 or existing_tool.output_schema != tool.output_schema 

4497 or existing_tool.jsonpath_filter != tool.jsonpath_filter 

4498 ) 

4499 

4500 # Check authentication and visibility changes. 

4501 # DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict). 

4502 # encode_auth() uses a random nonce, so comparing ciphertext would always 

4503 # differ even when the plaintext hasn't changed. Compare on decoded 

4504 # (plaintext) values instead, and only encode on the write path. 

4505 # If decoding fails (legacy/corrupt data), fall back to direct comparison. 

4506 try: 

4507 gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {}) 

4508 existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {} 

4509 auth_value_changed = existing_tool_auth_plain != gateway_auth_plain 

4510 except Exception: 

4511 gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value 

4512 auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value 

4513 

4514 upstream_tool_visibility = getattr(tool, "visibility", None) 

4515 auth_fields_changed = ( 

4516 existing_tool.auth_type != gateway.auth_type 

4517 or auth_value_changed 

4518 or (update_visibility and upstream_tool_visibility is not None and existing_tool.visibility != upstream_tool_visibility) 

4519 ) 

4520 

4521 title_changed = existing_tool.title != _resolve_tool_title(tool) 

4522 

4523 if basic_fields_changed or schema_fields_changed or auth_fields_changed or title_changed: 

4524 fields_to_update = True 

4525 if fields_to_update: 

4526 existing_tool.url = gateway.url 

4527 # Only overwrite user-facing description if it hasn't been customized 

4528 # (mirrors original_name/custom_name pattern) 

4529 if existing_tool.description == existing_tool.original_description: 

4530 existing_tool.description = tool.description 

4531 existing_tool.original_description = tool.description 

4532 existing_tool.integration_type = "MCP" 

4533 existing_tool.request_type = tool.request_type 

4534 existing_tool.headers = tool.headers 

4535 existing_tool.input_schema = tool.input_schema 

4536 existing_tool.output_schema = tool.output_schema 

4537 existing_tool.jsonpath_filter = tool.jsonpath_filter 

4538 existing_tool.title = _resolve_tool_title(tool) 

4539 existing_tool.auth_type = gateway.auth_type 

4540 existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value 

4541 if update_visibility and upstream_tool_visibility is not None: 

4542 existing_tool.visibility = upstream_tool_visibility 

4543 logger.debug(f"Updated existing tool: {tool.name}") 

4544 else: 

4545 # Create new tool if it doesn't exist 

4546 db_tool = self._create_db_tool( 

4547 tool=tool, 

4548 gateway=gateway, 

4549 created_by="system", 

4550 created_via=created_via, 

4551 ) 

4552 # Attach relationship to avoid NoneType during flush 

4553 db_tool.gateway = gateway 

4554 tools_to_add.append(db_tool) 

4555 logger.debug(f"Created new tool: {tool.name}") 

4556 except Exception as e: 

4557 logger.warning(f"Failed to process tool {getattr(tool, 'name', 'unknown')}: {e}") 

4558 continue 

4559 

4560 return tools_to_add 

4561 

4562 def _update_or_create_resources(self, db: Session, resources: List[Any], gateway: DbGateway, created_via: str, update_visibility: bool = False) -> List[DbResource]: 

4563 """Helper to handle update-or-create logic for resources from MCP server. 

4564 

4565 Args: 

4566 db: Database session 

4567 resources: List of resources from MCP server 

4568 gateway: Gateway object 

4569 created_via: String indicating creation source ("oauth", "update", etc.) 

4570 update_visibility: Whether to propagate gateway visibility to existing resources 

4571 

4572 Returns: 

4573 List of new resources to be added to the database 

4574 """ 

4575 if not resources: 

4576 return [] 

4577 

4578 resources_to_add = [] 

4579 

4580 # Batch fetch all existing resources for this gateway 

4581 resource_uris = [resource.uri for resource in resources if resource is not None] 

4582 if not resource_uris: 

4583 return [] 

4584 

4585 existing_resources_query = select(DbResource).where(DbResource.gateway_id == gateway.id, DbResource.uri.in_(resource_uris)) 

4586 existing_resources = db.execute(existing_resources_query).scalars().all() 

4587 existing_resources_map = {resource.uri: resource for resource in existing_resources} 

4588 

4589 for resource in resources: 

4590 if resource is None: 

4591 logger.warning("Skipping None resource in resources list") 

4592 continue 

4593 

4594 try: 

4595 # Check if resource already exists for this gateway from the resources_map 

4596 existing_resource = existing_resources_map.get(resource.uri) 

4597 

4598 if existing_resource: 

4599 # Update existing resource if there are changes 

4600 fields_to_update = False 

4601 

4602 upstream_visibility = getattr(resource, "visibility", None) 

4603 if ( 

4604 existing_resource.name != resource.name 

4605 or existing_resource.description != resource.description 

4606 or existing_resource.mime_type != resource.mime_type 

4607 or existing_resource.uri_template != resource.uri_template 

4608 or (update_visibility and upstream_visibility is not None and existing_resource.visibility != upstream_visibility) 

4609 or existing_resource.title != getattr(resource, "title", None) 

4610 ): 

4611 fields_to_update = True 

4612 

4613 if fields_to_update: 

4614 existing_resource.name = resource.name 

4615 existing_resource.description = resource.description 

4616 existing_resource.mime_type = resource.mime_type 

4617 existing_resource.uri_template = resource.uri_template 

4618 existing_resource.title = getattr(resource, "title", None) 

4619 if update_visibility and upstream_visibility is not None: 

4620 existing_resource.visibility = upstream_visibility 

4621 logger.debug(f"Updated existing resource: {resource.uri}") 

4622 else: 

4623 # Create new resource if it doesn't exist 

4624 db_resource = DbResource( 

4625 uri=resource.uri, 

4626 name=resource.name, 

4627 title=getattr(resource, "title", None), 

4628 description=resource.description, 

4629 mime_type=resource.mime_type, 

4630 uri_template=resource.uri_template, 

4631 gateway_id=gateway.id, 

4632 created_by="system", 

4633 created_via=created_via, 

4634 visibility=getattr(resource, "visibility", None) or gateway.visibility, 

4635 ) 

4636 resources_to_add.append(db_resource) 

4637 logger.debug(f"Created new resource: {resource.uri}") 

4638 except Exception as e: 

4639 logger.warning(f"Failed to process resource {getattr(resource, 'uri', 'unknown')}: {e}") 

4640 continue 

4641 

4642 return resources_to_add 

4643 

4644 @staticmethod 

4645 def _build_prompt_argument_schema(prompt: Any) -> Dict[str, Any]: 

4646 """Build a JSON-schema-compatible argument_schema dict from a PromptCreate's arguments list. 

4647 

4648 The MCP protocol's ``prompts/list`` response includes argument metadata 

4649 (name, description, required) on each prompt. This helper converts that 

4650 list into the internal ``argument_schema`` structure expected by 

4651 ``DbPrompt`` so that the UI and API can surface the arguments correctly. 

4652 

4653 Args: 

4654 prompt: A PromptCreate (or any object with an ``arguments`` attribute 

4655 whose items have ``name``, optional ``description``, and 

4656 optional ``required`` fields). 

4657 

4658 Returns: 

4659 Dict with ``type``, ``properties``, and ``required`` keys. 

4660 """ 

4661 schema: Dict[str, Any] = {"type": "object", "properties": {}, "required": []} 

4662 for arg in getattr(prompt, "arguments", []) or []: 

4663 prop: Dict[str, Any] = {"type": "string"} 

4664 if getattr(arg, "description", None): 

4665 prop["description"] = arg.description 

4666 schema["properties"][arg.name] = prop 

4667 if getattr(arg, "required", False): 

4668 schema["required"].append(arg.name) 

4669 return schema 

4670 

4671 def _update_or_create_prompts(self, db: Session, prompts: List[Any], gateway: DbGateway, created_via: str, update_visibility: bool = False) -> List[DbPrompt]: 

4672 """Helper to handle update-or-create logic for prompts from MCP server. 

4673 

4674 Args: 

4675 db: Database session 

4676 prompts: List of prompts from MCP server 

4677 gateway: Gateway object 

4678 created_via: String indicating creation source ("oauth", "update", etc.) 

4679 update_visibility: Whether to propagate gateway visibility to existing prompts 

4680 

4681 Returns: 

4682 List of new prompts to be added to the database 

4683 """ 

4684 if not prompts: 

4685 return [] 

4686 

4687 prompts_to_add = [] 

4688 

4689 # Batch fetch all existing prompts for this gateway 

4690 prompt_names = [prompt.name for prompt in prompts if prompt is not None] 

4691 if not prompt_names: 

4692 return [] 

4693 

4694 existing_prompts_query = select(DbPrompt).where(DbPrompt.gateway_id == gateway.id, DbPrompt.original_name.in_(prompt_names)) 

4695 existing_prompts = db.execute(existing_prompts_query).scalars().all() 

4696 existing_prompts_map = {prompt.original_name: prompt for prompt in existing_prompts} 

4697 

4698 for prompt in prompts: 

4699 if prompt is None: 

4700 logger.warning("Skipping None prompt in prompts list") 

4701 continue 

4702 

4703 try: 

4704 # Check if resource already exists for this gateway from the prompts_map 

4705 existing_prompt = existing_prompts_map.get(prompt.name) 

4706 

4707 if existing_prompt: 

4708 # Update existing prompt if there are changes 

4709 fields_to_update = False 

4710 

4711 new_argument_schema = self._build_prompt_argument_schema(prompt) 

4712 upstream_prompt_visibility = getattr(prompt, "visibility", None) 

4713 if ( 

4714 existing_prompt.description != prompt.description 

4715 or existing_prompt.template != (prompt.template if hasattr(prompt, "template") else "") 

4716 or (update_visibility and upstream_prompt_visibility is not None and existing_prompt.visibility != upstream_prompt_visibility) 

4717 or (existing_prompt.argument_schema or {}) != new_argument_schema 

4718 or existing_prompt.title != getattr(prompt, "title", None) 

4719 ): 

4720 fields_to_update = True 

4721 

4722 if fields_to_update: 

4723 existing_prompt.description = prompt.description 

4724 existing_prompt.template = prompt.template if hasattr(prompt, "template") else "" 

4725 existing_prompt.argument_schema = new_argument_schema 

4726 existing_prompt.title = getattr(prompt, "title", None) 

4727 if update_visibility and upstream_prompt_visibility is not None: 

4728 existing_prompt.visibility = upstream_prompt_visibility 

4729 logger.debug(f"Updated existing prompt: {prompt.name}") 

4730 else: 

4731 # Create new prompt if it doesn't exist 

4732 db_prompt = DbPrompt( 

4733 name=prompt.name, 

4734 original_name=prompt.name, 

4735 custom_name=prompt.name, 

4736 display_name=prompt.name, 

4737 title=getattr(prompt, "title", None), 

4738 description=prompt.description, 

4739 template=prompt.template if hasattr(prompt, "template") else "", 

4740 argument_schema=self._build_prompt_argument_schema(prompt), 

4741 gateway_id=gateway.id, 

4742 created_by="system", 

4743 created_via=created_via, 

4744 visibility=getattr(prompt, "visibility", None) or gateway.visibility, 

4745 ) 

4746 db_prompt.gateway = gateway 

4747 prompts_to_add.append(db_prompt) 

4748 logger.debug(f"Created new prompt: {prompt.name}") 

4749 except Exception as e: 

4750 logger.warning(f"Failed to process prompt {getattr(prompt, 'name', 'unknown')}: {e}") 

4751 continue 

4752 

4753 return prompts_to_add 

4754 

4755 async def _refresh_gateway_tools_resources_prompts( 

4756 self, 

4757 gateway_id: str, 

4758 _user_email: Optional[str] = None, 

4759 created_via: str = "health_check", 

4760 pre_auth_headers: Optional[Dict[str, str]] = None, 

4761 gateway: Optional[DbGateway] = None, 

4762 include_resources: bool = True, 

4763 include_prompts: bool = True, 

4764 ) -> Dict[str, int]: 

4765 """Refresh tools, resources, and prompts for a gateway during health checks. 

4766 

4767 Fetches the latest tools/resources/prompts from the MCP server and syncs 

4768 with the database (add new, update changed, remove stale). Only performs 

4769 DB operations if actual changes are detected. 

4770 

4771 This method uses fresh_db_session() internally to avoid holding 

4772 connections during HTTP calls to MCP servers. 

4773 

4774 Args: 

4775 gateway_id: ID of the gateway to refresh 

4776 _user_email: Optional user email for OAuth token lookup (unused currently) 

4777 created_via: String indicating creation source (default: "health_check") 

4778 pre_auth_headers: Pre-authenticated headers from health check to avoid duplicate OAuth token fetch 

4779 gateway: Optional DbGateway object to avoid redundant DB lookup 

4780 include_resources: Whether to include resources in the refresh 

4781 include_prompts: Whether to include prompts in the refresh 

4782 

4783 Returns: 

4784 Dict with counts: {tools_added, tools_removed, resources_added, 

4785 resources_removed, prompts_added, prompts_removed} 

4786 

4787 Examples: 

4788 >>> from mcpgateway.services.gateway_service import GatewayService 

4789 >>> from unittest.mock import patch, MagicMock, AsyncMock 

4790 >>> import asyncio 

4791 

4792 >>> # Test gateway not found returns empty result 

4793 >>> service = GatewayService() 

4794 >>> mock_session = MagicMock() 

4795 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = None 

4796 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4797 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4798 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4799 >>> result['tools_added'] == 0 and result['tools_removed'] == 0 

4800 True 

4801 >>> result['resources_added'] == 0 and result['resources_removed'] == 0 

4802 True 

4803 >>> result['success'] is True and result['error'] is None 

4804 True 

4805 

4806 >>> # Test disabled gateway returns empty result 

4807 >>> mock_gw = MagicMock() 

4808 >>> mock_gw.enabled = False 

4809 >>> mock_gw.reachable = True 

4810 >>> mock_gw.name = 'test_gw' 

4811 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = mock_gw 

4812 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4813 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4814 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4815 >>> result['tools_added'] 

4816 0 

4817 

4818 >>> # Test unreachable gateway returns empty result 

4819 >>> mock_gw.enabled = True 

4820 >>> mock_gw.reachable = False 

4821 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4822 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4823 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4824 >>> result['tools_added'] 

4825 0 

4826 

4827 >>> # Test method is async and callable 

4828 >>> import inspect 

4829 >>> inspect.iscoroutinefunction(service._refresh_gateway_tools_resources_prompts) 

4830 True 

4831 >>> 

4832 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

4833 >>> asyncio.run(service._http_client.aclose()) 

4834 """ 

4835 result = { 

4836 "tools_added": 0, 

4837 "tools_removed": 0, 

4838 "resources_added": 0, 

4839 "resources_removed": 0, 

4840 "prompts_added": 0, 

4841 "prompts_removed": 0, 

4842 "tools_updated": 0, 

4843 "resources_updated": 0, 

4844 "prompts_updated": 0, 

4845 "success": True, 

4846 "error": None, 

4847 "validation_errors": [], 

4848 } 

4849 

4850 # Fetch gateway metadata only (no relationships needed for MCP call) 

4851 # Use provided gateway object if available to save a DB call 

4852 gateway_name = None 

4853 gateway_url = None 

4854 gateway_transport = None 

4855 gateway_auth_type = None 

4856 gateway_auth_value = None 

4857 gateway_oauth_config = None 

4858 gateway_ca_certificate = None 

4859 gateway_auth_query_params = None 

4860 refresh_client_cert = None 

4861 refresh_client_key = None 

4862 

4863 if gateway: 

4864 if not gateway.enabled or not gateway.reachable: 

4865 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {SecurityValidator.sanitize_log_message(gateway.name)}") 

4866 return result 

4867 

4868 gateway_name = gateway.name 

4869 gateway_url = gateway.url 

4870 gateway_transport = gateway.transport 

4871 gateway_auth_type = gateway.auth_type 

4872 gateway_auth_value = gateway.auth_value 

4873 gateway_oauth_config = gateway.oauth_config 

4874 gateway_ca_certificate = gateway.ca_certificate 

4875 gateway_auth_query_params = gateway.auth_query_params 

4876 refresh_client_cert = getattr(gateway, "client_cert", None) 

4877 refresh_client_key = getattr(gateway, "client_key", None) 

4878 else: 

4879 with fresh_db_session() as db: 

4880 gateway_obj = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

4881 

4882 if not gateway_obj: 

4883 logger.warning(f"Gateway {SecurityValidator.sanitize_log_message(gateway_id)} not found for tool refresh") 

4884 return result 

4885 

4886 if not gateway_obj.enabled or not gateway_obj.reachable: 

4887 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway_obj.name}") 

4888 return result 

4889 

4890 # Extract metadata before session closes 

4891 gateway_name = gateway_obj.name 

4892 gateway_url = gateway_obj.url 

4893 gateway_transport = gateway_obj.transport 

4894 gateway_auth_type = gateway_obj.auth_type 

4895 gateway_auth_value = gateway_obj.auth_value 

4896 gateway_oauth_config = gateway_obj.oauth_config 

4897 gateway_ca_certificate = gateway_obj.ca_certificate 

4898 gateway_auth_query_params = gateway_obj.auth_query_params 

4899 refresh_client_cert = getattr(gateway_obj, "client_cert", None) 

4900 refresh_client_key = getattr(gateway_obj, "client_key", None) 

4901 

4902 # Preserve base URL before auth mutation for classification poll-state keys 

4903 gateway_base_url = gateway_url 

4904 

4905 # Handle query_param auth - decrypt and apply to URL for refresh 

4906 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

4907 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

4908 auth_query_params_decrypted = {} 

4909 for param_key, encrypted_value in gateway_auth_query_params.items(): 

4910 if encrypted_value: 

4911 try: 

4912 decrypted = decode_auth(encrypted_value) 

4913 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

4914 except Exception: 

4915 logger.debug(f"Failed to decrypt query param '{param_key}' for tool refresh") 

4916 if auth_query_params_decrypted: 

4917 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) 

4918 

4919 # Fetch tools/resources/prompts from MCP server (no DB connection held) 

4920 try: 

4921 # Decrypt client_key for refresh initialization 

4922 _refresh_key = refresh_client_key 

4923 if _refresh_key: 

4924 try: 

4925 _enc = get_encryption_service(settings.auth_encryption_secret) 

4926 _refresh_key = _enc.decrypt_secret_or_plaintext(_refresh_key) 

4927 except Exception: 

4928 logger.debug("client_key decryption skipped during gateway refresh") 

4929 _capabilities, tools, resources, prompts = await self._initialize_gateway( 

4930 url=gateway_url, 

4931 authentication=gateway_auth_value, 

4932 transport=gateway_transport, 

4933 auth_type=gateway_auth_type, 

4934 oauth_config=gateway_oauth_config, 

4935 ca_certificate=gateway_ca_certificate.encode() if gateway_ca_certificate else None, 

4936 pre_auth_headers=pre_auth_headers, 

4937 include_resources=include_resources, 

4938 include_prompts=include_prompts, 

4939 auth_query_params=auth_query_params_decrypted, 

4940 client_cert=refresh_client_cert, 

4941 client_key=_refresh_key, 

4942 ) 

4943 except Exception as e: 

4944 logger.warning(f"Failed to fetch tools from gateway {gateway_name}: {e}") 

4945 result["success"] = False 

4946 result["error"] = str(e) 

4947 return result 

4948 

4949 # For authorization_code OAuth gateways, empty responses may indicate incomplete auth flow 

4950 # Skip only if it's an auth_code gateway with no data (user may not have completed authorization) 

4951 is_auth_code_gateway = gateway_oauth_config and isinstance(gateway_oauth_config, dict) and gateway_oauth_config.get("grant_type") == "authorization_code" 

4952 if not tools and not resources and not prompts and is_auth_code_gateway: 

4953 logger.debug(f"No tools/resources/prompts returned from auth_code gateway {gateway_name} (user may not have authorized)") 

4954 return result 

4955 

4956 # For non-auth_code gateways, empty responses are legitimate and will clear stale items 

4957 

4958 # Update database with fresh session 

4959 with fresh_db_session() as db: 

4960 # Fetch gateway with relationships for update/comparison 

4961 gateway = db.execute( 

4962 select(DbGateway) 

4963 .options( 

4964 selectinload(DbGateway.tools), 

4965 selectinload(DbGateway.resources), 

4966 selectinload(DbGateway.prompts), 

4967 ) 

4968 .where(DbGateway.id == gateway_id) 

4969 ).scalar_one_or_none() 

4970 

4971 if not gateway: 

4972 result["success"] = False 

4973 result["error"] = f"Gateway {gateway_id} not found during refresh" 

4974 return result 

4975 

4976 new_tool_names = [tool.name for tool in tools] 

4977 new_resource_uris = [resource.uri for resource in resources] if include_resources else None 

4978 new_prompt_names = [prompt.name for prompt in prompts] if include_prompts else None 

4979 

4980 # Track dirty objects before update operations to count per-type updates 

4981 pending_tools_before = {obj for obj in db.dirty if isinstance(obj, DbTool)} 

4982 pending_resources_before = {obj for obj in db.dirty if isinstance(obj, DbResource)} 

4983 pending_prompts_before = {obj for obj in db.dirty if isinstance(obj, DbPrompt)} 

4984 

4985 # Update/create tools, resources, and prompts 

4986 tools_to_add = self._update_or_create_tools(db, tools, gateway, created_via) 

4987 resources_to_add = self._update_or_create_resources(db, resources, gateway, created_via) if include_resources else [] 

4988 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, created_via) if include_prompts else [] 

4989 

4990 # Count per-type updates 

4991 result["tools_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbTool)} - pending_tools_before) 

4992 result["resources_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbResource)} - pending_resources_before) 

4993 result["prompts_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbPrompt)} - pending_prompts_before) 

4994 

4995 # Only delete MCP-discovered items (not user-created entries) 

4996 # Excludes "api", "ui", None (legacy/user-created) to preserve user entries 

4997 mcp_created_via_values = {"MCP", "federation", "health_check", "manual_refresh", "oauth", "update"} 

4998 

4999 # Find and remove stale tools (only MCP-discovered ones) 

5000 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names and tool.created_via in mcp_created_via_values] 

5001 if stale_tool_ids: 

5002 for i in range(0, len(stale_tool_ids), 500): 

5003 chunk = stale_tool_ids[i : i + 500] 

5004 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

5005 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

5006 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

5007 result["tools_removed"] = len(stale_tool_ids) 

5008 

5009 # Find and remove stale resources (only MCP-discovered ones, only if resources were fetched) 

5010 stale_resource_ids = [] 

5011 if new_resource_uris is not None: 

5012 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris and resource.created_via in mcp_created_via_values] 

5013 if stale_resource_ids: 

5014 for i in range(0, len(stale_resource_ids), 500): 

5015 chunk = stale_resource_ids[i : i + 500] 

5016 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

5017 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

5018 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

5019 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

5020 result["resources_removed"] = len(stale_resource_ids) 

5021 

5022 # Find and remove stale prompts (only MCP-discovered ones, only if prompts were fetched) 

5023 stale_prompt_ids = [] 

5024 if new_prompt_names is not None: 

5025 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names and prompt.created_via in mcp_created_via_values] 

5026 if stale_prompt_ids: 

5027 for i in range(0, len(stale_prompt_ids), 500): 

5028 chunk = stale_prompt_ids[i : i + 500] 

5029 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

5030 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

5031 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

5032 result["prompts_removed"] = len(stale_prompt_ids) 

5033 

5034 # Expire gateway if stale items were deleted 

5035 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

5036 db.expire(gateway) 

5037 

5038 # Add new items in chunks 

5039 chunk_size = 50 

5040 if tools_to_add: 

5041 for i in range(0, len(tools_to_add), chunk_size): 

5042 chunk = tools_to_add[i : i + chunk_size] 

5043 db.add_all(chunk) 

5044 db.flush() 

5045 result["tools_added"] = len(tools_to_add) 

5046 

5047 if resources_to_add: 

5048 for i in range(0, len(resources_to_add), chunk_size): 

5049 chunk = resources_to_add[i : i + chunk_size] 

5050 db.add_all(chunk) 

5051 db.flush() 

5052 result["resources_added"] = len(resources_to_add) 

5053 

5054 if prompts_to_add: 

5055 for i in range(0, len(prompts_to_add), chunk_size): 

5056 chunk = prompts_to_add[i : i + chunk_size] 

5057 db.add_all(chunk) 

5058 db.flush() 

5059 result["prompts_added"] = len(prompts_to_add) 

5060 

5061 gateway.last_refresh_at = datetime.now(timezone.utc) 

5062 

5063 total_changes = ( 

5064 result["tools_added"] 

5065 + result["tools_removed"] 

5066 + result["tools_updated"] 

5067 + result["resources_added"] 

5068 + result["resources_removed"] 

5069 + result["resources_updated"] 

5070 + result["prompts_added"] 

5071 + result["prompts_removed"] 

5072 + result["prompts_updated"] 

5073 ) 

5074 

5075 has_changes = total_changes > 0 

5076 

5077 if has_changes: 

5078 db.commit() 

5079 logger.info( 

5080 f"Refreshed gateway {gateway_name}: " 

5081 f"tools(+{result['tools_added']}/-{result['tools_removed']}/~{result['tools_updated']}), " 

5082 f"resources(+{result['resources_added']}/-{result['resources_removed']}/~{result['resources_updated']}), " 

5083 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}/~{result['prompts_updated']})" 

5084 ) 

5085 

5086 # Invalidate caches per-type based on actual changes 

5087 cache = _get_registry_cache() 

5088 if result["tools_added"] > 0 or result["tools_removed"] > 0 or result["tools_updated"] > 0: 

5089 await cache.invalidate_tools() 

5090 if result["resources_added"] > 0 or result["resources_removed"] > 0 or result["resources_updated"] > 0: 

5091 await cache.invalidate_resources() 

5092 if result["prompts_added"] > 0 or result["prompts_removed"] > 0 or result["prompts_updated"] > 0: 

5093 await cache.invalidate_prompts() 

5094 

5095 # Invalidate tool lookup cache for this gateway 

5096 tool_lookup_cache = _get_tool_lookup_cache() 

5097 await tool_lookup_cache.invalidate_gateway(str(gateway_id)) 

5098 else: 

5099 db.commit() 

5100 logger.debug(f"No changes detected during refresh of gateway {gateway_name}") 

5101 

5102 # Advance poll schedule so hot/cold classification tracks the actual last refresh 

5103 # regardless of whether the refresh was triggered by health check, manual API, or registration. 

5104 # Use gateway_base_url (pre-auth) to match classification keys. 

5105 if self._classification_service and gateway_base_url: 

5106 try: 

5107 await self._classification_service.mark_poll_completed(gateway_base_url, "tool_discovery", gateway_id=str(gateway_id)) 

5108 except Exception as poll_ts_err: 

5109 logger.debug(f"Best-effort tool_discovery poll timestamp update failed: {poll_ts_err}") 

5110 

5111 return result 

5112 

5113 def _get_refresh_lock(self, gateway_id: str) -> asyncio.Lock: 

5114 """Get or create a per-gateway refresh lock. 

5115 

5116 This ensures only one refresh operation can run for a given gateway at a time. 

5117 

5118 Args: 

5119 gateway_id: ID of the gateway to get the lock for 

5120 

5121 Returns: 

5122 asyncio.Lock: The lock for the specified gateway 

5123 

5124 Examples: 

5125 >>> from mcpgateway.services.gateway_service import GatewayService 

5126 >>> service = GatewayService() 

5127 >>> lock1 = service._get_refresh_lock('gw-123') 

5128 >>> lock2 = service._get_refresh_lock('gw-123') 

5129 >>> lock1 is lock2 

5130 True 

5131 >>> lock3 = service._get_refresh_lock('gw-456') 

5132 >>> lock1 is lock3 

5133 False 

5134 """ 

5135 if gateway_id not in self._refresh_locks: 

5136 self._refresh_locks[gateway_id] = asyncio.Lock() 

5137 return self._refresh_locks[gateway_id] 

5138 

5139 async def refresh_gateway_manually( 

5140 self, 

5141 gateway_id: str, 

5142 include_resources: bool = True, 

5143 include_prompts: bool = True, 

5144 user_email: Optional[str] = None, 

5145 request_headers: Optional[Dict[str, str]] = None, 

5146 ) -> Dict[str, Any]: 

5147 """Manually trigger a refresh of tools/resources/prompts for a gateway. 

5148 

5149 This method provides a public API for triggering an immediate refresh 

5150 of a gateway's tools, resources, and prompts from its MCP server. 

5151 It includes concurrency control via per-gateway locking. 

5152 

5153 Args: 

5154 gateway_id: Gateway ID to refresh 

5155 include_resources: Whether to include resources in the refresh 

5156 include_prompts: Whether to include prompts in the refresh 

5157 user_email: Email of the user triggering the refresh 

5158 request_headers: Optional request headers for passthrough authentication 

5159 

5160 Returns: 

5161 Dict with counts: {tools_added, tools_updated, tools_removed, 

5162 resources_added, resources_updated, resources_removed, 

5163 prompts_added, prompts_updated, prompts_removed, 

5164 validation_errors, duration_ms, refreshed_at} 

5165 

5166 Raises: 

5167 GatewayNotFoundError: If the gateway does not exist 

5168 GatewayError: If another refresh is already in progress for this gateway 

5169 

5170 Examples: 

5171 >>> from mcpgateway.services.gateway_service import GatewayService 

5172 >>> from unittest.mock import patch, MagicMock, AsyncMock 

5173 >>> import asyncio 

5174 

5175 >>> # Test method is async 

5176 >>> service = GatewayService() 

5177 >>> import inspect 

5178 >>> inspect.iscoroutinefunction(service.refresh_gateway_manually) 

5179 True 

5180 """ 

5181 start_time = time.monotonic() 

5182 

5183 pre_auth_headers = {} 

5184 

5185 # Check if gateway exists before acquiring lock 

5186 with fresh_db_session() as db: 

5187 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

5188 if not gateway: 

5189 raise GatewayNotFoundError(f"Gateway with ID '{gateway_id}' not found") 

5190 gateway_name = gateway.name 

5191 

5192 # Get passthrough headers if request headers provided 

5193 if request_headers: 

5194 pre_auth_headers = get_passthrough_headers(request_headers, {}, db, gateway) 

5195 

5196 lock = self._get_refresh_lock(gateway_id) 

5197 

5198 # Check if lock is already held (concurrent refresh in progress) 

5199 if lock.locked(): 

5200 raise GatewayError(f"Refresh already in progress for gateway {gateway_name}") 

5201 

5202 async with lock: 

5203 logger.info(f"Starting manual refresh for gateway {gateway_name} (ID: {SecurityValidator.sanitize_log_message(gateway_id)})") 

5204 

5205 result = await self._refresh_gateway_tools_resources_prompts( 

5206 gateway_id=gateway_id, 

5207 _user_email=user_email, 

5208 created_via="manual_refresh", 

5209 pre_auth_headers=pre_auth_headers, 

5210 gateway=gateway, 

5211 include_resources=include_resources, 

5212 include_prompts=include_prompts, 

5213 ) 

5214 # Note: last_refresh_at is updated inside _refresh_gateway_tools_resources_prompts on success 

5215 

5216 result["duration_ms"] = (time.monotonic() - start_time) * 1000 

5217 result["refreshed_at"] = datetime.now(timezone.utc) 

5218 

5219 log_level = logging.INFO if result.get("success", True) else logging.WARNING 

5220 status_msg = "succeeded" if result.get("success", True) else f"failed: {result.get('error')}" 

5221 

5222 logger.log( 

5223 log_level, 

5224 f"Manual refresh for gateway {gateway_id} {status_msg}. Stats: " 

5225 f"tools(+{result['tools_added']}/-{result['tools_removed']}), " 

5226 f"resources(+{result['resources_added']}/-{result['resources_removed']}), " 

5227 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}) " 

5228 f"in {result['duration_ms']:.2f}ms", 

5229 ) 

5230 

5231 return result 

5232 

5233 async def _publish_event(self, event: Dict[str, Any]) -> None: 

5234 """Publish event to all subscribers. 

5235 

5236 Args: 

5237 event: event dictionary 

5238 

5239 Examples: 

5240 >>> import asyncio 

5241 >>> from unittest.mock import AsyncMock 

5242 >>> service = GatewayService() 

5243 >>> # Mock the underlying event service 

5244 >>> service._event_service = AsyncMock() 

5245 >>> test_event = {"type": "test", "data": {}} 

5246 >>> 

5247 >>> asyncio.run(service._publish_event(test_event)) 

5248 >>> 

5249 >>> # Verify the event was passed to the event service 

5250 >>> service._event_service.publish_event.assert_awaited_with(test_event) 

5251 """ 

5252 await self._event_service.publish_event(event) 

5253 

5254 def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> tuple[list[ToolCreate], list[str]]: 

5255 """Validate tools individually with richer logging and error aggregation. 

5256 

5257 Args: 

5258 tools: list of tool dicts 

5259 context: caller context, e.g. "oauth" to tailor errors/messages 

5260 

5261 Returns: 

5262 tuple[list[ToolCreate], list[str]]: Tuple of (valid tools, validation errors) 

5263 

5264 Raises: 

5265 OAuthToolValidationError: If all tools fail validation in OAuth context 

5266 GatewayConnectionError: If all tools fail validation in default context 

5267 """ 

5268 valid_tools: list[ToolCreate] = [] 

5269 validation_errors: list[str] = [] 

5270 

5271 for i, tool_dict in enumerate(tools): 

5272 tool_name = tool_dict.get("name", f"unknown_tool_{i}") 

5273 try: 

5274 logger.debug(f"Validating tool: {tool_name}") 

5275 validated_tool = ToolCreate.model_validate(tool_dict) 

5276 valid_tools.append(validated_tool) 

5277 logger.debug(f"Tool '{tool_name}' validated successfully") 

5278 except ValidationError as e: 

5279 error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}" 

5280 logger.error(error_msg) 

5281 logger.debug(f"Failed tool schema: {tool_dict}") 

5282 validation_errors.append(error_msg) 

5283 except ValueError as e: 

5284 if "JSON structure exceeds maximum depth" in str(e): 

5285 error_msg = f"Tool '{tool_name}' schema too deeply nested. " f"Current depth limit: {settings.validation_max_json_depth}" 

5286 logger.error(error_msg) 

5287 logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable") 

5288 else: 

5289 error_msg = f"ValueError for tool '{tool_name}': {str(e)}" 

5290 logger.error(error_msg) 

5291 validation_errors.append(error_msg) 

5292 except Exception as e: # pragma: no cover - defensive 

5293 error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}" 

5294 logger.error(error_msg, exc_info=True) 

5295 validation_errors.append(error_msg) 

5296 

5297 if validation_errors: 

5298 logger.warning(f"Tool validation completed with {len(validation_errors)} error(s). " f"Successfully validated {len(valid_tools)} tool(s).") 

5299 for err in validation_errors[:3]: 

5300 logger.debug(f"Validation error: {err}") 

5301 

5302 if not valid_tools and validation_errors: 

5303 if context == "oauth": 

5304 raise OAuthToolValidationError(f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}") 

5305 raise GatewayConnectionError(f"Failed to fetch tools: All {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}") 

5306 

5307 return valid_tools, validation_errors 

5308 

5309 async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None): 

5310 """Connect to an MCP server running with SSE transport, skipping URL validation. 

5311 

5312 This is used for OAuth-protected servers where we've already validated the token works. 

5313 

5314 Args: 

5315 server_url: The URL of the SSE MCP server to connect to. 

5316 authentication: Optional dictionary containing authentication headers. 

5317 

5318 Returns: 

5319 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5320 """ 

5321 if authentication is None: 

5322 authentication = {} 

5323 

5324 # Skip validation for OAuth servers - we already validated via OAuth flow 

5325 # Use async with for both sse_client and ClientSession 

5326 try: 

5327 async with sse_client(url=server_url, headers=authentication) as streams: 

5328 async with ClientSession(*streams) as session: 

5329 # Initialize the session 

5330 response = await session.initialize() 

5331 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5332 logger.debug(f"Server capabilities: {capabilities}") 

5333 

5334 response = await session.list_tools() 

5335 tools = response.tools 

5336 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5337 

5338 tools, _ = self._validate_tools(tools, context="oauth") 

5339 if tools: 

5340 logger.info(f"Fetched {len(tools)} tools from gateway") 

5341 # Fetch resources if supported 

5342 

5343 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5344 resources = [] 

5345 if capabilities.get("resources"): 

5346 try: 

5347 response = await session.list_resources() 

5348 raw_resources = response.resources 

5349 for resource in raw_resources: 

5350 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5351 # Convert AnyUrl to string if present 

5352 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5353 resource_data["uri"] = str(resource_data["uri"]) 

5354 # Add default content if not present (will be fetched on demand) 

5355 if "content" not in resource_data: 

5356 resource_data["content"] = "" 

5357 try: 

5358 resources.append(ResourceCreate.model_validate(resource_data)) 

5359 except Exception: 

5360 # If validation fails, create minimal resource 

5361 resources.append( 

5362 ResourceCreate( 

5363 uri=str(resource_data.get("uri", "")), 

5364 name=resource_data.get("name", ""), 

5365 description=resource_data.get("description"), 

5366 mime_type=resource_data.get("mimeType"), 

5367 uri_template=resource_data.get("uriTemplate") or None, 

5368 content="", 

5369 ) 

5370 ) 

5371 logger.info(f"Fetched {len(resources)} resources from gateway") 

5372 except Exception as e: 

5373 logger.warning(f"Failed to fetch resources: {e}") 

5374 

5375 # resource template URI 

5376 try: 

5377 response_templates = await session.list_resource_templates() 

5378 raw_resources_templates = response_templates.resourceTemplates 

5379 resource_templates = [] 

5380 for resource_template in raw_resources_templates: 

5381 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5382 

5383 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 

5384 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5385 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5386 

5387 if "content" not in resource_template_data: 

5388 resource_template_data["content"] = "" 

5389 

5390 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5391 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5392 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") 

5393 except Exception as e: 

5394 logger.warning(f"Failed to fetch resource templates: {e}") 

5395 

5396 # Fetch prompts if supported 

5397 prompts = [] 

5398 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5399 if capabilities.get("prompts"): 

5400 try: 

5401 response = await session.list_prompts() 

5402 raw_prompts = response.prompts 

5403 for prompt in raw_prompts: 

5404 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5405 # Add default template if not present 

5406 if "template" not in prompt_data: 

5407 prompt_data["template"] = "" 

5408 try: 

5409 prompts.append(PromptCreate.model_validate(prompt_data)) 

5410 except Exception: 

5411 # If validation fails, create minimal prompt 

5412 prompts.append( 

5413 PromptCreate( 

5414 name=prompt_data.get("name", ""), 

5415 description=prompt_data.get("description"), 

5416 template=prompt_data.get("template", ""), 

5417 ) 

5418 ) 

5419 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5420 except Exception as e: 

5421 logger.warning(f"Failed to fetch prompts: {e}") 

5422 

5423 return capabilities, tools, resources, prompts 

5424 except Exception as e: 

5425 # Note: This function is for OAuth servers only, which don't use query param auth 

5426 # Still sanitize in case exception contains URL with static sensitive params 

5427 sanitized_url = sanitize_url_for_logging(server_url) 

5428 sanitized_error = sanitize_exception_message(str(e)) 

5429 logger.error(f"SSE connection error details: {type(e).__name__}: {sanitized_error}", exc_info=True) 

5430 raise GatewayConnectionError(f"Failed to connect to SSE server at {sanitized_url}: {sanitized_error}") 

5431 

5432 async def connect_to_sse_server( 

5433 self, 

5434 server_url: str, 

5435 authentication: Optional[Dict[str, str]] = None, 

5436 ca_certificate: Optional[bytes] = None, 

5437 include_prompts: bool = True, 

5438 include_resources: bool = True, 

5439 auth_query_params: Optional[Dict[str, str]] = None, 

5440 client_cert: Optional[str] = None, 

5441 client_key: Optional[str] = None, 

5442 ): 

5443 """Connect to an MCP server running with SSE transport. 

5444 

5445 Args: 

5446 server_url: The URL of the SSE MCP server to connect to. 

5447 authentication: Optional dictionary containing authentication headers. 

5448 ca_certificate: Optional CA certificate for SSL verification. 

5449 include_prompts: Whether to fetch prompts from the server. 

5450 include_resources: Whether to fetch resources from the server. 

5451 auth_query_params: Query param names for URL sanitization in error logs. 

5452 client_cert: Optional client certificate path or PEM for mTLS. 

5453 client_key: Optional client private key path or PEM for mTLS. 

5454 

5455 Returns: 

5456 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5457 """ 

5458 if authentication is None: 

5459 authentication = {} 

5460 

5461 def get_httpx_client_factory( 

5462 headers: dict[str, str] | None = None, 

5463 timeout: httpx.Timeout | None = None, 

5464 auth: httpx.Auth | None = None, 

5465 ) -> httpx.AsyncClient: 

5466 """Factory function to create httpx.AsyncClient with optional CA certificate. 

5467 

5468 Args: 

5469 headers: Optional headers for the client 

5470 timeout: Optional timeout for the client 

5471 auth: Optional auth for the client 

5472 

5473 Returns: 

5474 httpx.AsyncClient: Configured HTTPX async client 

5475 """ 

5476 if server_url and server_url.lower().startswith("http://"): 

5477 ctx = None 

5478 elif ca_certificate: 

5479 ctx = get_cached_ssl_context(ca_certificate, client_cert=client_cert, client_key=client_key) 

5480 else: 

5481 ctx = None 

5482 return httpx.AsyncClient( 

5483 verify=ctx if ctx else get_default_verify(), 

5484 follow_redirects=True, 

5485 headers=headers, 

5486 timeout=timeout if timeout else get_http_timeout(), 

5487 auth=auth, 

5488 limits=httpx.Limits( 

5489 max_connections=settings.httpx_max_connections, 

5490 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

5491 keepalive_expiry=settings.httpx_keepalive_expiry, 

5492 ), 

5493 ) 

5494 

5495 # Use async with for both sse_client and ClientSession 

5496 async with sse_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as streams: 

5497 async with ClientSession(*streams) as session: 

5498 # Initialize the session 

5499 response = await session.initialize() 

5500 

5501 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5502 logger.debug(f"Server capabilities: {capabilities}") 

5503 

5504 response = await session.list_tools() 

5505 tools = response.tools 

5506 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5507 

5508 tools, _ = self._validate_tools(tools) 

5509 if tools: 

5510 logger.info(f"Fetched {len(tools)} tools from gateway") 

5511 # Fetch resources if supported 

5512 resources = [] 

5513 if include_resources: 

5514 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5515 if capabilities.get("resources"): 

5516 try: 

5517 response = await session.list_resources() 

5518 raw_resources = response.resources 

5519 for resource in raw_resources: 

5520 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5521 # Convert AnyUrl to string if present 

5522 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5523 resource_data["uri"] = str(resource_data["uri"]) 

5524 # Add default content if not present (will be fetched on demand) 

5525 if "content" not in resource_data: 

5526 resource_data["content"] = "" 

5527 try: 

5528 resources.append(ResourceCreate.model_validate(resource_data)) 

5529 except Exception: 

5530 # If validation fails, create minimal resource 

5531 resources.append( 

5532 ResourceCreate( 

5533 uri=str(resource_data.get("uri", "")), 

5534 name=resource_data.get("name", ""), 

5535 description=resource_data.get("description"), 

5536 mime_type=resource_data.get("mimeType"), 

5537 uri_template=resource_data.get("uriTemplate") or None, 

5538 content="", 

5539 ) 

5540 ) 

5541 logger.info(f"Fetched {len(resources)} resources from gateway") 

5542 except Exception as e: 

5543 logger.warning(f"Failed to fetch resources: {e}") 

5544 

5545 # resource template URI 

5546 try: 

5547 response_templates = await session.list_resource_templates() 

5548 raw_resources_templates = response_templates.resourceTemplates 

5549 resource_templates = [] 

5550 for resource_template in raw_resources_templates: 

5551 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5552 

5553 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 

5554 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5555 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5556 

5557 if "content" not in resource_template_data: 

5558 resource_template_data["content"] = "" 

5559 

5560 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5561 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5562 logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway") 

5563 except Exception as ei: 

5564 logger.warning(f"Failed to fetch resource templates: {ei}") 

5565 

5566 # Fetch prompts if supported 

5567 prompts = [] 

5568 if include_prompts: 

5569 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5570 if capabilities.get("prompts"): 

5571 try: 

5572 response = await session.list_prompts() 

5573 raw_prompts = response.prompts 

5574 for prompt in raw_prompts: 

5575 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5576 # Add default template if not present 

5577 if "template" not in prompt_data: 

5578 prompt_data["template"] = "" 

5579 try: 

5580 prompts.append(PromptCreate.model_validate(prompt_data)) 

5581 except Exception: 

5582 # If validation fails, create minimal prompt 

5583 prompts.append( 

5584 PromptCreate( 

5585 name=prompt_data.get("name", ""), 

5586 description=prompt_data.get("description"), 

5587 template=prompt_data.get("template", ""), 

5588 ) 

5589 ) 

5590 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5591 except Exception as e: 

5592 logger.warning(f"Failed to fetch prompts: {e}") 

5593 

5594 return capabilities, tools, resources, prompts 

5595 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params) 

5596 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established") 

5597 

5598 async def connect_to_streamablehttp_server( 

5599 self, 

5600 server_url: str, 

5601 authentication: Optional[Dict[str, str]] = None, 

5602 ca_certificate: Optional[bytes] = None, 

5603 include_prompts: bool = True, 

5604 include_resources: bool = True, 

5605 auth_query_params: Optional[Dict[str, str]] = None, 

5606 client_cert: Optional[str] = None, 

5607 client_key: Optional[str] = None, 

5608 ): 

5609 """Connect to an MCP server running with Streamable HTTP transport. 

5610 

5611 Args: 

5612 server_url: The URL of the Streamable HTTP MCP server to connect to. 

5613 authentication: Optional dictionary containing authentication headers. 

5614 ca_certificate: Optional CA certificate for SSL verification. 

5615 include_prompts: Whether to fetch prompts from the server. 

5616 include_resources: Whether to fetch resources from the server. 

5617 auth_query_params: Query param names for URL sanitization in error logs. 

5618 client_cert: Optional client certificate path or PEM for mTLS. 

5619 client_key: Optional client private key path or PEM for mTLS. 

5620 

5621 Returns: 

5622 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5623 """ 

5624 if authentication is None: 

5625 authentication = {} 

5626 

5627 # Use authentication directly instead 

5628 def get_httpx_client_factory( 

5629 headers: dict[str, str] | None = None, 

5630 timeout: httpx.Timeout | None = None, 

5631 auth: httpx.Auth | None = None, 

5632 ) -> httpx.AsyncClient: 

5633 """Factory function to create httpx.AsyncClient with optional CA certificate. 

5634 

5635 Args: 

5636 headers: Optional headers for the client 

5637 timeout: Optional timeout for the client 

5638 auth: Optional auth for the client 

5639 

5640 Returns: 

5641 httpx.AsyncClient: Configured HTTPX async client 

5642 """ 

5643 if server_url and server_url.lower().startswith("http://"): 

5644 ctx = None 

5645 elif ca_certificate: 

5646 ctx = get_cached_ssl_context(ca_certificate, client_cert=client_cert, client_key=client_key) 

5647 else: 

5648 ctx = None 

5649 return httpx.AsyncClient( 

5650 verify=ctx if ctx else get_default_verify(), 

5651 follow_redirects=True, 

5652 headers=headers, 

5653 timeout=timeout if timeout else get_http_timeout(), 

5654 auth=auth, 

5655 limits=httpx.Limits( 

5656 max_connections=settings.httpx_max_connections, 

5657 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

5658 keepalive_expiry=settings.httpx_keepalive_expiry, 

5659 ), 

5660 ) 

5661 

5662 async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): 

5663 async with ClientSession(read_stream, write_stream) as session: 

5664 # Initialize the session 

5665 response = await session.initialize() 

5666 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5667 logger.debug(f"Server capabilities: {capabilities}") 

5668 

5669 response = await session.list_tools() 

5670 tools = response.tools 

5671 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5672 

5673 tools, _ = self._validate_tools(tools) 

5674 for tool in tools: 

5675 tool.request_type = "STREAMABLEHTTP" 

5676 if tools: 

5677 logger.info(f"Fetched {len(tools)} tools from gateway") 

5678 

5679 # Fetch resources if supported 

5680 resources = [] 

5681 if include_resources: 

5682 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5683 if capabilities.get("resources"): 

5684 try: 

5685 response = await session.list_resources() 

5686 raw_resources = response.resources 

5687 for resource in raw_resources: 

5688 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5689 # Convert AnyUrl to string if present 

5690 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5691 resource_data["uri"] = str(resource_data["uri"]) 

5692 # Add default content if not present 

5693 if "content" not in resource_data: 

5694 resource_data["content"] = "" 

5695 try: 

5696 resources.append(ResourceCreate.model_validate(resource_data)) 

5697 except Exception: 

5698 # If validation fails, create minimal resource 

5699 resources.append( 

5700 ResourceCreate( 

5701 uri=str(resource_data.get("uri", "")), 

5702 name=resource_data.get("name", ""), 

5703 description=resource_data.get("description"), 

5704 mime_type=resource_data.get("mimeType"), 

5705 uri_template=resource_data.get("uriTemplate") or None, 

5706 content="", 

5707 ) 

5708 ) 

5709 logger.info(f"Fetched {len(resources)} resources from gateway") 

5710 except Exception as e: 

5711 logger.warning(f"Failed to fetch resources: {e}") 

5712 

5713 # resource template URI 

5714 try: 

5715 response_templates = await session.list_resource_templates() 

5716 raw_resources_templates = response_templates.resourceTemplates 

5717 resource_templates = [] 

5718 for resource_template in raw_resources_templates: 

5719 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5720 

5721 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 

5722 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5723 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5724 

5725 if "content" not in resource_template_data: 

5726 resource_template_data["content"] = "" 

5727 

5728 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5729 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5730 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") 

5731 except Exception as e: 

5732 logger.warning(f"Failed to fetch resource templates: {e}") 

5733 

5734 # Fetch prompts if supported 

5735 prompts = [] 

5736 if include_prompts: 

5737 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5738 if capabilities.get("prompts"): 

5739 try: 

5740 response = await session.list_prompts() 

5741 raw_prompts = response.prompts 

5742 for prompt in raw_prompts: 

5743 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5744 # Add default template if not present 

5745 if "template" not in prompt_data: 

5746 prompt_data["template"] = "" 

5747 prompts.append(PromptCreate.model_validate(prompt_data)) 

5748 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5749 except Exception as e: 

5750 logger.warning(f"Failed to fetch prompts: {e}") 

5751 

5752 return capabilities, tools, resources, prompts 

5753 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params) 

5754 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established") 

5755 

5756 

5757# Lazy singleton - created on first access, not at module import time. 

5758# This avoids instantiation when only exception classes are imported. 

5759_gateway_service_instance = None # pylint: disable=invalid-name 

5760 

5761 

5762def __getattr__(name: str): 

5763 """Module-level __getattr__ for lazy singleton creation. 

5764 

5765 Args: 

5766 name: The attribute name being accessed. 

5767 

5768 Returns: 

5769 The gateway_service singleton instance if name is "gateway_service". 

5770 

5771 Raises: 

5772 AttributeError: If the attribute name is not "gateway_service". 

5773 """ 

5774 global _gateway_service_instance # pylint: disable=global-statement 

5775 if name == "gateway_service": 

5776 if _gateway_service_instance is None: 

5777 _gateway_service_instance = GatewayService() 

5778 return _gateway_service_instance 

5779 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")