Coverage for mcpgateway / services / gateway_service.py: 93%

2074 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2# pylint: disable=import-outside-toplevel,no-name-in-module 

3"""Location: ./mcpgateway/services/gateway_service.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Gateway Service Implementation. 

9This module implements gateway federation according to the MCP specification. 

10It handles: 

11- Gateway discovery and registration 

12- Capability aggregation 

13- Health monitoring 

14- Active/inactive gateway management 

15 

16Examples: 

17 >>> from mcpgateway.services.gateway_service import GatewayService, GatewayError 

18 >>> service = GatewayService() 

19 >>> isinstance(service, GatewayService) 

20 True 

21 >>> hasattr(service, '_active_gateways') 

22 True 

23 >>> isinstance(service._active_gateways, set) 

24 True 

25 

26 Test error classes: 

27 >>> error = GatewayError("Test error") 

28 >>> str(error) 

29 'Test error' 

30 >>> isinstance(error, Exception) 

31 True 

32 

33 >>> conflict_error = GatewayNameConflictError("test_gw") 

34 >>> "test_gw" in str(conflict_error) 

35 True 

36 >>> conflict_error.enabled 

37 True 

38 >>> 

39 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

40 >>> import asyncio 

41 >>> asyncio.run(service._http_client.aclose()) 

42""" 

43 

44# Standard 

45import asyncio 

46import binascii 

47from datetime import datetime, timezone 

48import logging 

49import mimetypes 

50import os 

51import ssl 

52import tempfile 

53import time 

54from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union 

55from urllib.parse import urlparse, urlunparse 

56import uuid 

57 

58# Third-Party 

59from filelock import FileLock, Timeout 

60import httpx 

61from mcp import ClientSession 

62from mcp.client.sse import sse_client 

63from mcp.client.streamable_http import streamablehttp_client 

64from pydantic import ValidationError 

65from sqlalchemy import and_, delete, desc, or_, select, update 

66from sqlalchemy.exc import IntegrityError 

67from sqlalchemy.orm import joinedload, selectinload, Session 

68 

69try: 

70 # Third-Party - check if redis is available 

71 # Third-Party 

72 import redis.asyncio as _aioredis # noqa: F401 # pylint: disable=unused-import 

73 

74 REDIS_AVAILABLE = True 

75 del _aioredis # Only needed for availability check 

76except ImportError: 

77 REDIS_AVAILABLE = False 

78 logging.info("Redis is not utilized in this environment.") 

79 

80# First-Party 

81from mcpgateway.config import settings 

82from mcpgateway.db import fresh_db_session 

83from mcpgateway.db import Gateway as DbGateway 

84from mcpgateway.db import get_for_update 

85from mcpgateway.db import Prompt as DbPrompt 

86from mcpgateway.db import PromptMetric 

87from mcpgateway.db import Resource as DbResource 

88from mcpgateway.db import ResourceMetric, ResourceSubscription, server_prompt_association, server_resource_association, server_tool_association, SessionLocal 

89from mcpgateway.db import Tool as DbTool 

90from mcpgateway.db import ToolMetric 

91from mcpgateway.observability import create_span 

92from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate 

93 

94# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks 

95from mcpgateway.services.audit_trail_service import get_audit_trail_service 

96from mcpgateway.services.base_service import BaseService 

97from mcpgateway.services.encryption_service import protect_oauth_config_for_storage 

98from mcpgateway.services.event_service import EventService 

99from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout, get_isolated_http_client 

100from mcpgateway.services.logging_service import LoggingService 

101from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, register_gateway_capabilities_for_notifications, TransportType 

102from mcpgateway.services.oauth_manager import OAuthManager 

103from mcpgateway.services.structured_logger import get_structured_logger 

104from mcpgateway.services.team_management_service import TeamManagementService 

105from mcpgateway.utils.create_slug import slugify 

106from mcpgateway.utils.display_name import generate_display_name 

107from mcpgateway.utils.pagination import unified_paginate 

108from mcpgateway.utils.passthrough_headers import get_passthrough_headers 

109from mcpgateway.utils.redis_client import get_redis_client 

110from mcpgateway.utils.retry_manager import ResilientHttpClient 

111from mcpgateway.utils.services_auth import decode_auth, encode_auth 

112from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr 

113from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context 

114from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging 

115from mcpgateway.utils.validate_signature import validate_signature 

116from mcpgateway.validation.tags import validate_tags_field 

117 

118# Cache import (lazy to avoid circular dependencies) 

119_REGISTRY_CACHE = None 

120_TOOL_LOOKUP_CACHE = None 

121 

122 

123def _get_registry_cache(): 

124 """Get registry cache singleton lazily. 

125 

126 Returns: 

127 RegistryCache instance. 

128 """ 

129 global _REGISTRY_CACHE # pylint: disable=global-statement 

130 if _REGISTRY_CACHE is None: 

131 # First-Party 

132 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel 

133 

134 _REGISTRY_CACHE = registry_cache 

135 return _REGISTRY_CACHE 

136 

137 

138def _get_tool_lookup_cache(): 

139 """Get tool lookup cache singleton lazily. 

140 

141 Returns: 

142 ToolLookupCache instance. 

143 """ 

144 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement 

145 if _TOOL_LOOKUP_CACHE is None: 

146 # First-Party 

147 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

148 

149 _TOOL_LOOKUP_CACHE = tool_lookup_cache 

150 return _TOOL_LOOKUP_CACHE 

151 

152 

153# Initialize logging service first 

154logging_service = LoggingService() 

155logger = logging_service.get_logger(__name__) 

156 

157# Initialize structured logger and audit trail for gateway operations 

158structured_logger = get_structured_logger("gateway_service") 

159audit_trail = get_audit_trail_service() 

160 

161 

162GW_FAILURE_THRESHOLD = settings.unhealthy_threshold 

163GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval 

164 

165 

166class GatewayError(Exception): 

167 """Base class for gateway-related errors. 

168 

169 Examples: 

170 >>> error = GatewayError("Test error") 

171 >>> str(error) 

172 'Test error' 

173 >>> isinstance(error, Exception) 

174 True 

175 """ 

176 

177 

178class GatewayNotFoundError(GatewayError): 

179 """Raised when a requested gateway is not found. 

180 

181 Examples: 

182 >>> error = GatewayNotFoundError("Gateway not found") 

183 >>> str(error) 

184 'Gateway not found' 

185 >>> isinstance(error, GatewayError) 

186 True 

187 """ 

188 

189 

190class GatewayNameConflictError(GatewayError): 

191 """Raised when a gateway name conflicts with existing (active or inactive) gateway. 

192 

193 Args: 

194 name: The conflicting gateway name 

195 enabled: Whether the existing gateway is enabled 

196 gateway_id: ID of the existing gateway if available 

197 visibility: The visibility of the gateway ("public" or "team"). 

198 

199 Examples: 

200 >>> error = GatewayNameConflictError("test_gateway") 

201 >>> str(error) 

202 'Public Gateway already exists with name: test_gateway' 

203 >>> error.name 

204 'test_gateway' 

205 >>> error.enabled 

206 True 

207 >>> error.gateway_id is None 

208 True 

209 

210 >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123) 

211 >>> str(error_inactive) 

212 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)' 

213 >>> error_inactive.enabled 

214 False 

215 >>> error_inactive.gateway_id 

216 123 

217 """ 

218 

219 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"): 

220 """Initialize the error with gateway information. 

221 

222 Args: 

223 name: The conflicting gateway name 

224 enabled: Whether the existing gateway is enabled 

225 gateway_id: ID of the existing gateway if available 

226 visibility: The visibility of the gateway ("public" or "team"). 

227 """ 

228 self.name = name 

229 self.enabled = enabled 

230 self.gateway_id = gateway_id 

231 if visibility == "team": 

232 vis_label = "Team-level" 

233 else: 

234 vis_label = "Public" 

235 message = f"{vis_label} Gateway already exists with name: {name}" 

236 if not enabled: 

237 message += f" (currently inactive, ID: {gateway_id})" 

238 super().__init__(message) 

239 

240 

241class GatewayDuplicateConflictError(GatewayError): 

242 """Raised when a gateway conflicts with an existing gateway (same URL + credentials). 

243 

244 This error is raised when attempting to register a gateway with a URL and 

245 authentication credentials that already exist within the same scope: 

246 - Public: Global uniqueness required across all public gateways. 

247 - Team: Uniqueness required within the same team. 

248 - Private: Uniqueness required for the same user, a user cannot have two private gateways with the same URL and credentials. 

249 

250 Args: 

251 duplicate_gateway: The existing conflicting gateway (DbGateway instance). 

252 

253 Examples: 

254 >>> # Public gateway conflict with the same URL and basic auth 

255 >>> existing_gw = DbGateway(url="https://api.example.com", id="abc-123", enabled=True, visibility="public", team_id=None, name="API Gateway", owner_email="alice@example.com") 

256 >>> error = GatewayDuplicateConflictError( 

257 ... duplicate_gateway=existing_gw 

258 ... ) 

259 >>> str(error) 

260 'The Server already exists in Public scope (Name: API Gateway, Status: active)' 

261 

262 >>> # Team gateway conflict with the same URL and OAuth credentials 

263 >>> team_gw = DbGateway(url="https://api.example.com", id="def-456", enabled=False, visibility="team", team_id="engineering-team", name="API Gateway", owner_email="bob@example.com") 

264 >>> error = GatewayDuplicateConflictError( 

265 ... duplicate_gateway=team_gw 

266 ... ) 

267 >>> str(error) 

268 'The Server already exists in your Team (Name: API Gateway, Status: inactive). You may want to re-enable the existing gateway instead.' 

269 

270 >>> # Private gateway conflict (same user cannot have two gateways with the same URL) 

271 >>> private_gw = DbGateway(url="https://api.example.com", id="ghi-789", enabled=True, visibility="private", team_id="none", name="API Gateway", owner_email="charlie@example.com") 

272 >>> error = GatewayDuplicateConflictError( 

273 ... duplicate_gateway=private_gw 

274 ... ) 

275 >>> str(error) 

276 'The Server already exists in "private" scope (Name: API Gateway, Status: active)' 

277 """ 

278 

279 def __init__( 

280 self, 

281 duplicate_gateway: "DbGateway", 

282 ): 

283 """Initialize the error with gateway information. 

284 

285 Args: 

286 duplicate_gateway: The existing conflicting gateway (DbGateway instance) 

287 """ 

288 self.duplicate_gateway = duplicate_gateway 

289 self.url = duplicate_gateway.url 

290 self.gateway_id = duplicate_gateway.id 

291 self.enabled = duplicate_gateway.enabled 

292 self.visibility = duplicate_gateway.visibility 

293 self.team_id = duplicate_gateway.team_id 

294 self.name = duplicate_gateway.name 

295 

296 # Build scope description 

297 if self.visibility == "public": 

298 scope_desc = "Public scope" 

299 elif self.visibility == "team" and self.team_id: 

300 scope_desc = "your Team" 

301 else: 

302 scope_desc = f'"{self.visibility}" scope' 

303 

304 # Build status description 

305 status = "active" if self.enabled else "inactive" 

306 

307 # Construct error message 

308 message = f"The Server already exists in {scope_desc} " f"(Name: {self.name}, Status: {status})" 

309 

310 # Add helpful hint for inactive gateways 

311 if not self.enabled: 

312 message += ". You may want to re-enable the existing gateway instead." 

313 

314 super().__init__(message) 

315 

316 

317class GatewayConnectionError(GatewayError): 

318 """Raised when gateway connection fails. 

319 

320 Examples: 

321 >>> error = GatewayConnectionError("Connection failed") 

322 >>> str(error) 

323 'Connection failed' 

324 >>> isinstance(error, GatewayError) 

325 True 

326 """ 

327 

328 

329class OAuthToolValidationError(GatewayConnectionError): 

330 """Raised when tool validation fails during OAuth-driven fetch.""" 

331 

332 

333class GatewayService(BaseService): # pylint: disable=too-many-instance-attributes 

334 """Service for managing federated gateways. 

335 

336 Handles: 

337 - Gateway registration and health checks 

338 - Capability negotiation 

339 - Federation events 

340 - Active/inactive status management 

341 """ 

342 

343 _visibility_model_cls = DbGateway 

344 

345 def __init__(self) -> None: 

346 """Initialize the gateway service. 

347 

348 Examples: 

349 >>> from mcpgateway.services.gateway_service import GatewayService 

350 >>> from mcpgateway.services.event_service import EventService 

351 >>> from mcpgateway.utils.retry_manager import ResilientHttpClient 

352 >>> from mcpgateway.services.tool_service import ToolService 

353 >>> service = GatewayService() 

354 >>> isinstance(service._event_service, EventService) 

355 True 

356 >>> isinstance(service._http_client, ResilientHttpClient) 

357 True 

358 >>> service._health_check_interval == GW_HEALTH_CHECK_INTERVAL 

359 True 

360 >>> service._health_check_task is None 

361 True 

362 >>> isinstance(service._active_gateways, set) 

363 True 

364 >>> len(service._active_gateways) 

365 0 

366 >>> service._stream_response is None 

367 True 

368 >>> isinstance(service._pending_responses, dict) 

369 True 

370 >>> len(service._pending_responses) 

371 0 

372 >>> isinstance(service.tool_service, ToolService) 

373 True 

374 >>> isinstance(service._gateway_failure_counts, dict) 

375 True 

376 >>> len(service._gateway_failure_counts) 

377 0 

378 >>> hasattr(service, 'redis_url') 

379 True 

380 >>> 

381 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

382 >>> import asyncio 

383 >>> asyncio.run(service._http_client.aclose()) 

384 """ 

385 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) 

386 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL 

387 self._health_check_task: Optional[asyncio.Task] = None 

388 self._active_gateways: Set[str] = set() # Track active gateway URLs 

389 self._stream_response = None 

390 self._pending_responses = {} 

391 # Prefer using the globally-initialized singletons from the service modules 

392 # so events propagate via their initialized EventService/Redis clients. 

393 # Import lazily and fall back to creating local instances when the module-level 

394 # __getattr__ singletons are not yet available (e.g. circular import during 

395 # Gunicorn --preload). 

396 # First-Party 

397 try: 

398 # First-Party 

399 from mcpgateway.services.prompt_service import prompt_service 

400 except ImportError: 

401 # First-Party 

402 from mcpgateway.services.prompt_service import PromptService 

403 

404 prompt_service = PromptService() 

405 try: 

406 # First-Party 

407 from mcpgateway.services.resource_service import resource_service 

408 except ImportError: 

409 # First-Party 

410 from mcpgateway.services.resource_service import ResourceService 

411 

412 resource_service = ResourceService() 

413 try: 

414 # First-Party 

415 from mcpgateway.services.tool_service import tool_service 

416 except ImportError: 

417 # First-Party 

418 from mcpgateway.services.tool_service import ToolService 

419 

420 tool_service = ToolService() 

421 

422 self.tool_service = tool_service 

423 self.prompt_service = prompt_service 

424 self.resource_service = resource_service 

425 self._gateway_failure_counts: dict[str, int] = {} 

426 self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3"))) 

427 self._event_service = EventService(channel_name="mcpgateway:gateway_events") 

428 

429 # Per-gateway refresh locks to prevent concurrent refreshes for the same gateway 

430 self._refresh_locks: Dict[str, asyncio.Lock] = {} 

431 

432 # For health checks, we determine the leader instance. 

433 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None 

434 

435 # Initialize optional Redis client holder (set in initialize()) 

436 self._redis_client: Optional[Any] = None 

437 

438 # Leader election settings from config 

439 if self.redis_url and REDIS_AVAILABLE: 

440 self._instance_id = str(uuid.uuid4()) # Unique ID for this process 

441 self._leader_key = settings.redis_leader_key 

442 self._leader_ttl = settings.redis_leader_ttl 

443 self._leader_heartbeat_interval = settings.redis_leader_heartbeat_interval 

444 self._leader_heartbeat_task: Optional[asyncio.Task] = None 

445 

446 # Always initialize file lock as fallback (used if Redis connection fails at runtime) 

447 if settings.cache_type != "none": 

448 temp_dir = tempfile.gettempdir() 

449 user_path = os.path.normpath(settings.filelock_name) 

450 if os.path.isabs(user_path): 

451 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep) 

452 full_path = os.path.join(temp_dir, user_path) 

453 self._lock_path = full_path.replace("\\", "/") 

454 self._file_lock = FileLock(self._lock_path) 

455 

456 @staticmethod 

457 def normalize_url(url: str) -> str: 

458 """ 

459 Normalize a URL by ensuring it's properly formatted. 

460 

461 Special handling for localhost to prevent duplicates: 

462 - Converts 127.0.0.1 to localhost for consistency 

463 - Preserves all other domain names as-is for CDN/load balancer support 

464 

465 Args: 

466 url (str): The URL to normalize. 

467 

468 Returns: 

469 str: The normalized URL. 

470 

471 Examples: 

472 >>> GatewayService.normalize_url('http://localhost:8080/path') 

473 'http://localhost:8080/path' 

474 >>> GatewayService.normalize_url('http://127.0.0.1:8080/path') 

475 'http://localhost:8080/path' 

476 >>> GatewayService.normalize_url('https://example.com/api') 

477 'https://example.com/api' 

478 """ 

479 parsed = urlparse(url) 

480 hostname = parsed.hostname 

481 

482 # Special case: normalize 127.0.0.1 to localhost to prevent duplicates 

483 # but preserve all other domains as-is for CDN/load balancer support 

484 if hostname == "127.0.0.1": 

485 netloc = "localhost" 

486 if parsed.port: 

487 netloc += f":{parsed.port}" 

488 normalized = parsed._replace(netloc=netloc) 

489 return str(urlunparse(normalized)) 

490 

491 # For all other URLs, preserve the domain name 

492 return url 

493 

494 def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext: 

495 """Create an SSL context with the provided CA certificate. 

496 

497 Uses caching to avoid repeated SSL context creation for the same certificate. 

498 

499 Args: 

500 ca_certificate: CA certificate in PEM format 

501 

502 Returns: 

503 ssl.SSLContext: Configured SSL context 

504 """ 

505 return get_cached_ssl_context(ca_certificate) 

506 

507 async def initialize(self) -> None: 

508 """Initialize the service and start health check if this instance is the leader. 

509 

510 Raises: 

511 ConnectionError: When redis ping fails 

512 """ 

513 logger.info("Initializing gateway service") 

514 

515 # Initialize event service with shared Redis client 

516 await self._event_service.initialize() 

517 

518 # NOTE: We intentionally do NOT create a long-lived DB session here. 

519 # Health checks use fresh_db_session() only when DB access is actually needed, 

520 # avoiding holding connections during HTTP calls to MCP servers. 

521 

522 user_email = settings.platform_admin_email 

523 

524 # Get shared Redis client from factory 

525 if self.redis_url and REDIS_AVAILABLE: 

526 self._redis_client = await get_redis_client() 

527 

528 if self._redis_client: 

529 # Check if Redis is available (ping already done by factory, but verify) 

530 try: 

531 await self._redis_client.ping() 

532 except Exception as e: 

533 raise ConnectionError(f"Redis ping failed: {e}") from e 

534 

535 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) 

536 if is_leader: 

537 logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.") 

538 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

539 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat()) 

540 else: 

541 # Always create the health check task in filelock mode; leader check is handled inside. 

542 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

543 

544 async def shutdown(self) -> None: 

545 """Shutdown the service. 

546 

547 Examples: 

548 >>> service = GatewayService() 

549 >>> # Mock internal components 

550 >>> from unittest.mock import AsyncMock 

551 >>> service._event_service = AsyncMock() 

552 >>> service._active_gateways = {'test_gw'} 

553 >>> import asyncio 

554 >>> asyncio.run(service.shutdown()) 

555 >>> # Verify event service shutdown was called 

556 >>> service._event_service.shutdown.assert_awaited_once() 

557 >>> len(service._active_gateways) 

558 0 

559 """ 

560 if self._health_check_task: 

561 self._health_check_task.cancel() 

562 try: 

563 await self._health_check_task 

564 except asyncio.CancelledError: 

565 pass 

566 

567 # Cancel leader heartbeat task if running 

568 if getattr(self, "_leader_heartbeat_task", None): 

569 self._leader_heartbeat_task.cancel() 

570 try: 

571 await self._leader_heartbeat_task 

572 except asyncio.CancelledError: 

573 pass 

574 

575 # Release Redis leadership atomically if we hold it 

576 if self._redis_client: 

577 try: 

578 # Lua script for atomic check-and-delete (only delete if we own the key) 

579 release_script = """ 

580 if redis.call("get", KEYS[1]) == ARGV[1] then 

581 return redis.call("del", KEYS[1]) 

582 else 

583 return 0 

584 end 

585 """ 

586 result = await self._redis_client.eval(release_script, 1, self._leader_key, self._instance_id) 

587 if result: 

588 logger.info("Released Redis leadership on shutdown") 

589 except Exception as e: 

590 logger.warning(f"Failed to release Redis leader key on shutdown: {e}") 

591 

592 await self._http_client.aclose() 

593 await self._event_service.shutdown() 

594 self._active_gateways.clear() 

595 logger.info("Gateway service shutdown complete") 

596 

597 def _check_gateway_uniqueness( 

598 self, 

599 db: Session, 

600 url: str, 

601 auth_value: Optional[Dict[str, str]], 

602 oauth_config: Optional[Dict[str, Any]], 

603 team_id: Optional[str], 

604 owner_email: str, 

605 visibility: str, 

606 gateway_id: Optional[str] = None, 

607 ) -> Optional[DbGateway]: 

608 """ 

609 Check if a gateway with the same URL and credentials already exists. 

610 

611 Args: 

612 db: Database session 

613 url: Gateway URL (normalized) 

614 auth_value: Decoded auth_value dict (not encrypted) 

615 oauth_config: OAuth configuration dict 

616 team_id: Team ID for team-scoped gateways 

617 owner_email: Email of the gateway owner 

618 visibility: Gateway visibility (public/team/private) 

619 gateway_id: Optional gateway ID to exclude from check (for updates) 

620 

621 Returns: 

622 DbGateway if duplicate found, None otherwise 

623 """ 

624 # Build base query based on visibility 

625 if visibility == "public": 

626 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "public") 

627 elif visibility == "team" and team_id: 

628 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "team", DbGateway.team_id == team_id) 

629 elif visibility == "private": 

630 # Check for duplicates within the same user's private gateways 

631 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "private", DbGateway.owner_email == owner_email) # Scoped to same user 

632 else: 

633 return None 

634 

635 # Exclude current gateway if updating 

636 if gateway_id: 

637 query = query.filter(DbGateway.id != gateway_id) 

638 

639 existing_gateways = query.all() 

640 

641 # Check each existing gateway 

642 for existing in existing_gateways: 

643 # Case 1: Both have OAuth config 

644 if oauth_config and existing.oauth_config: 

645 # Compare OAuth configs (exclude dynamic fields like tokens) 

646 existing_oauth = existing.oauth_config or {} 

647 new_oauth = oauth_config or {} 

648 

649 # Compare key OAuth fields 

650 oauth_keys = ["grant_type", "client_id", "authorization_url", "token_url", "scope"] 

651 if all(existing_oauth.get(k) == new_oauth.get(k) for k in oauth_keys): 

652 return existing # Duplicate OAuth config found 

653 

654 # Case 2: Both have auth_value (need to decrypt and compare) 

655 elif auth_value and existing.auth_value: 

656 

657 try: 

658 # Decrypt existing auth_value 

659 if isinstance(existing.auth_value, str): 

660 existing_decoded = decode_auth(existing.auth_value) 

661 

662 elif isinstance(existing.auth_value, dict): 

663 existing_decoded = existing.auth_value 

664 

665 else: 

666 continue 

667 

668 # Compare decoded auth values 

669 if auth_value == existing_decoded: 

670 return existing # Duplicate credentials found 

671 except Exception as e: 

672 logger.warning(f"Failed to decode auth_value for comparison: {e}") 

673 continue 

674 

675 # Case 3: Both have no auth (URL only, not allowed) 

676 elif not auth_value and not oauth_config and not existing.auth_value and not existing.oauth_config: 

677 return existing # Duplicate URL without credentials 

678 

679 return None # No duplicate found 

680 

681 async def register_gateway( 

682 self, 

683 db: Session, 

684 gateway: GatewayCreate, 

685 created_by: Optional[str] = None, 

686 created_from_ip: Optional[str] = None, 

687 created_via: Optional[str] = None, 

688 created_user_agent: Optional[str] = None, 

689 team_id: Optional[str] = None, 

690 owner_email: Optional[str] = None, 

691 visibility: Optional[str] = None, 

692 initialize_timeout: Optional[float] = None, 

693 ) -> GatewayRead: 

694 """Register a new gateway. 

695 

696 Args: 

697 db: Database session 

698 gateway: Gateway creation schema 

699 created_by: Username who created this gateway 

700 created_from_ip: IP address of creator 

701 created_via: Creation method (ui, api, federation) 

702 created_user_agent: User agent of creation request 

703 team_id (Optional[str]): Team ID to assign the gateway to. 

704 owner_email (Optional[str]): Email of the user who owns this gateway. 

705 visibility (Optional[str]): Gateway visibility level (private, team, public). 

706 initialize_timeout (Optional[float]): Timeout in seconds for gateway initialization. 

707 

708 Returns: 

709 Created gateway information 

710 

711 Raises: 

712 GatewayNameConflictError: If gateway name already exists 

713 GatewayConnectionError: If there was an error connecting to the gateway 

714 ValueError: If required values are missing 

715 RuntimeError: If there is an error during processing that is not covered by other exceptions 

716 IntegrityError: If there is a database integrity error 

717 BaseException: If an unexpected error occurs 

718 

719 Examples: 

720 >>> from mcpgateway.services.gateway_service import GatewayService 

721 >>> from unittest.mock import MagicMock 

722 >>> service = GatewayService() 

723 >>> db = MagicMock() 

724 >>> gateway = MagicMock() 

725 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

726 >>> db.add = MagicMock() 

727 >>> db.commit = MagicMock() 

728 >>> db.refresh = MagicMock() 

729 >>> service._notify_gateway_added = MagicMock() 

730 >>> import asyncio 

731 >>> try: 

732 ... asyncio.run(service.register_gateway(db, gateway)) 

733 ... except Exception: 

734 ... pass 

735 >>> 

736 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

737 >>> asyncio.run(service._http_client.aclose()) 

738 """ 

739 visibility = "public" if visibility not in ("private", "team", "public") else visibility 

740 try: 

741 # # Check for name conflicts (both active and inactive) 

742 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none() 

743 

744 # if existing_gateway: 

745 # raise GatewayNameConflictError( 

746 # gateway.name, 

747 # enabled=existing_gateway.enabled, 

748 # gateway_id=existing_gateway.id, 

749 # ) 

750 # Check for existing gateway with the same slug and visibility 

751 slug_name = slugify(gateway.name) 

752 if visibility.lower() == "public": 

753 # Check for existing public gateway with the same slug (row-locked) 

754 existing_gateway = get_for_update( 

755 db, 

756 DbGateway, 

757 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "public"), 

758 ) 

759 if existing_gateway: 

760 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility) 

761 elif visibility.lower() == "team" and team_id: 

762 # Check for existing team gateway with the same slug (row-locked) 

763 existing_gateway = get_for_update( 

764 db, 

765 DbGateway, 

766 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "team", DbGateway.team_id == team_id), 

767 ) 

768 if existing_gateway: 

769 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility) 

770 

771 # Normalize the gateway URL 

772 normalized_url = self.normalize_url(str(gateway.url)) 

773 

774 decoded_auth_value = None 

775 if gateway.auth_value: 

776 if isinstance(gateway.auth_value, str): 

777 try: 

778 decoded_auth_value = decode_auth(gateway.auth_value) 

779 except Exception as e: 

780 logger.warning(f"Failed to decode provided auth_value: {e}") 

781 decoded_auth_value = None 

782 elif isinstance(gateway.auth_value, dict): 

783 decoded_auth_value = gateway.auth_value 

784 

785 # Check for duplicate gateway 

786 if not gateway.one_time_auth: 

787 duplicate_gateway = self._check_gateway_uniqueness( 

788 db=db, url=normalized_url, auth_value=decoded_auth_value, oauth_config=gateway.oauth_config, team_id=team_id, owner_email=owner_email, visibility=visibility 

789 ) 

790 

791 if duplicate_gateway: 

792 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway) 

793 

794 # Prevent URL-only gateways (no auth at all) 

795 # if not decoded_auth_value and not gateway.oauth_config: 

796 # raise ValueError( 

797 # f"Gateway with URL '{normalized_url}' must have either auth_value or oauth_config. " 

798 # "URL-only gateways are not allowed." 

799 # ) 

800 

801 auth_type = getattr(gateway, "auth_type", None) 

802 # Support multiple custom headers 

803 auth_value = getattr(gateway, "auth_value", {}) 

804 authentication_headers: Optional[Dict[str, str]] = None 

805 

806 # Handle query_param auth - encrypt and prepare for storage 

807 auth_query_params_encrypted: Optional[Dict[str, str]] = None 

808 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

809 init_url = normalized_url # URL to use for initialization 

810 

811 if auth_type == "query_param": 

812 # Extract and encrypt query param auth 

813 param_key = getattr(gateway, "auth_query_param_key", None) 

814 param_value = getattr(gateway, "auth_query_param_value", None) 

815 if param_key and param_value: 

816 # Get the actual secret value 

817 if hasattr(param_value, "get_secret_value"): 

818 raw_value = param_value.get_secret_value() 

819 else: 

820 raw_value = str(param_value) 

821 # Encrypt for storage 

822 encrypted_value = encode_auth({param_key: raw_value}) 

823 auth_query_params_encrypted = {param_key: encrypted_value} 

824 auth_query_params_decrypted = {param_key: raw_value} 

825 # Append query params to URL for initialization 

826 init_url = apply_query_param_auth(normalized_url, auth_query_params_decrypted) 

827 # Query param auth doesn't use auth_value 

828 auth_value = None 

829 authentication_headers = None 

830 

831 elif hasattr(gateway, "auth_headers") and gateway.auth_headers: 

832 # Convert list of {key, value} to dict 

833 header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")} 

834 auth_value = header_dict # store plain dict, consistent with update path and DB column type 

835 authentication_headers = {str(k): str(v) for k, v in header_dict.items()} 

836 

837 elif isinstance(auth_value, str) and auth_value: 

838 # Decode persisted auth for initialization 

839 decoded = decode_auth(auth_value) 

840 authentication_headers = {str(k): str(v) for k, v in decoded.items()} 

841 else: 

842 authentication_headers = None 

843 

844 oauth_config = await protect_oauth_config_for_storage(getattr(gateway, "oauth_config", None)) 

845 ca_certificate = getattr(gateway, "ca_certificate", None) 

846 

847 # Check if gateway is in direct_proxy mode 

848 gateway_mode = getattr(gateway, "gateway_mode", "cache") 

849 

850 if gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled: 

851 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.") 

852 

853 if initialize_timeout is not None: 

854 try: 

855 capabilities, tools, resources, prompts = await asyncio.wait_for( 

856 self._initialize_gateway( 

857 init_url, # URL with query params if applicable 

858 authentication_headers, 

859 gateway.transport, 

860 auth_type, 

861 oauth_config, 

862 ca_certificate, 

863 auth_query_params=auth_query_params_decrypted, 

864 ), 

865 timeout=initialize_timeout, 

866 ) 

867 except asyncio.TimeoutError as exc: 

868 sanitized = sanitize_url_for_logging(init_url, auth_query_params_decrypted) 

869 raise GatewayConnectionError(f"Gateway initialization timed out after {initialize_timeout}s for {sanitized}") from exc 

870 else: 

871 capabilities, tools, resources, prompts = await self._initialize_gateway( 

872 init_url, # URL with query params if applicable 

873 authentication_headers, 

874 gateway.transport, 

875 auth_type, 

876 oauth_config, 

877 ca_certificate, 

878 auth_query_params=auth_query_params_decrypted, 

879 ) 

880 

881 if gateway.one_time_auth: 

882 # For one-time auth, clear auth_type and auth_value after initialization 

883 auth_type = "one_time_auth" 

884 auth_value = None 

885 oauth_config = None 

886 

887 # DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before 

888 # storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and 

889 # receives the plain dict directly (see assignment above). 

890 tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value 

891 

892 tools = [ 

893 DbTool( 

894 original_name=tool.name, 

895 custom_name=tool.name, 

896 custom_name_slug=slugify(tool.name), 

897 display_name=generate_display_name(tool.name), 

898 url=normalized_url, 

899 original_description=tool.description, 

900 description=tool.description, 

901 integration_type="MCP", # Gateway-discovered tools are MCP type 

902 request_type=tool.request_type, 

903 headers=tool.headers, 

904 input_schema=tool.input_schema, 

905 output_schema=tool.output_schema, 

906 annotations=tool.annotations, 

907 jsonpath_filter=tool.jsonpath_filter, 

908 auth_type=auth_type, 

909 auth_value=tool_auth_value, 

910 # Federation metadata 

911 created_by=created_by or "system", 

912 created_from_ip=created_from_ip, 

913 created_via="federation", # These are federated tools 

914 created_user_agent=created_user_agent, 

915 federation_source=gateway.name, 

916 version=1, 

917 # Inherit team assignment from gateway 

918 team_id=team_id, 

919 owner_email=owner_email, 

920 visibility=visibility, 

921 ) 

922 for tool in tools 

923 ] 

924 

925 # Create resource DB models with upsert logic for ORPHANED resources only 

926 # Query for existing ORPHANED resources (gateway_id IS NULL or points to non-existent gateway) 

927 # with same (team_id, owner_email, uri) to handle resources left behind from incomplete 

928 # gateway deletions (e.g., issue #2341 crash scenarios). 

929 # We only update orphaned resources - resources belonging to active gateways are not touched. 

930 resource_uris = [r.uri for r in resources] 

931 effective_owner = owner_email or created_by 

932 

933 # Build lookup map: (team_id, owner_email, uri) -> orphaned DbResource 

934 # We query all resources matching our URIs, then filter to orphaned ones in Python 

935 # to handle per-resource team/owner overrides correctly 

936 orphaned_resources_map: Dict[tuple, DbResource] = {} 

937 if resource_uris: 

938 try: 

939 # Get valid gateway IDs to identify orphaned resources 

940 valid_gateway_ids = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all()) 

941 candidate_resources = db.execute(select(DbResource).where(DbResource.uri.in_(resource_uris))).scalars().all() 

942 for res in candidate_resources: 

943 # Only consider orphaned resources (no gateway or gateway doesn't exist) 

944 is_orphaned = res.gateway_id is None or res.gateway_id not in valid_gateway_ids 

945 if is_orphaned: 

946 key = (res.team_id, res.owner_email, res.uri) 

947 orphaned_resources_map[key] = res 

948 if orphaned_resources_map: 

949 logger.info(f"Found {len(orphaned_resources_map)} orphaned resources to reassign for gateway {gateway.name}") 

950 except Exception as e: 

951 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new resources 

952 # This is conservative - we won't accidentally reassign resources from active gateways 

953 logger.debug(f"Orphan resource detection skipped: {e}") 

954 

955 db_resources = [] 

956 for r in resources: 

957 mime_type = mimetypes.guess_type(r.uri)[0] or ("text/plain" if isinstance(r.content, str) else "application/octet-stream") 

958 r_team_id = getattr(r, "team_id", None) or team_id 

959 r_owner_email = getattr(r, "owner_email", None) or effective_owner 

960 r_visibility = getattr(r, "visibility", None) or visibility 

961 

962 # Check if there's an orphaned resource with matching unique key 

963 lookup_key = (r_team_id, r_owner_email, r.uri) 

964 if lookup_key in orphaned_resources_map: 

965 # Update orphaned resource - reassign to new gateway 

966 existing = orphaned_resources_map[lookup_key] 

967 existing.name = r.name 

968 existing.description = r.description 

969 existing.mime_type = mime_type 

970 existing.uri_template = r.uri_template or None 

971 existing.text_content = r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None 

972 existing.binary_content = ( 

973 r.content.encode() if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else r.content if isinstance(r.content, bytes) else None 

974 ) 

975 existing.size = len(r.content) if r.content else 0 

976 existing.tags = getattr(r, "tags", []) or [] 

977 existing.federation_source = gateway.name 

978 existing.modified_by = created_by 

979 existing.modified_from_ip = created_from_ip 

980 existing.modified_via = "federation" 

981 existing.modified_user_agent = created_user_agent 

982 existing.updated_at = datetime.now(timezone.utc) 

983 existing.visibility = r_visibility 

984 # Note: gateway_id will be set when gateway is created (relationship) 

985 db_resources.append(existing) 

986 else: 

987 # Create new resource 

988 db_resources.append( 

989 DbResource( 

990 uri=r.uri, 

991 name=r.name, 

992 description=r.description, 

993 mime_type=mime_type, 

994 uri_template=r.uri_template or None, 

995 text_content=r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None, 

996 binary_content=( 

997 r.content.encode() 

998 if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) 

999 else r.content if isinstance(r.content, bytes) else None 

1000 ), 

1001 size=len(r.content) if r.content else 0, 

1002 tags=getattr(r, "tags", []) or [], 

1003 created_by=created_by or "system", 

1004 created_from_ip=created_from_ip, 

1005 created_via="federation", 

1006 created_user_agent=created_user_agent, 

1007 import_batch_id=None, 

1008 federation_source=gateway.name, 

1009 version=1, 

1010 team_id=r_team_id, 

1011 owner_email=r_owner_email, 

1012 visibility=r_visibility, 

1013 ) 

1014 ) 

1015 

1016 # Create prompt DB models with upsert logic for ORPHANED prompts only 

1017 # Query for existing ORPHANED prompts (gateway_id IS NULL or points to non-existent gateway) 

1018 # with same (team_id, owner_email, name) to handle prompts left behind from incomplete 

1019 # gateway deletions. We only update orphaned prompts - prompts belonging to active gateways are not touched. 

1020 prompt_names = [p.name for p in prompts] 

1021 

1022 # Build lookup map: (team_id, owner_email, name) -> orphaned DbPrompt 

1023 orphaned_prompts_map: Dict[tuple, DbPrompt] = {} 

1024 if prompt_names: 

1025 try: 

1026 # Get valid gateway IDs to identify orphaned prompts 

1027 valid_gateway_ids_for_prompts = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all()) 

1028 candidate_prompts = db.execute(select(DbPrompt).where(DbPrompt.name.in_(prompt_names))).scalars().all() 

1029 for pmt in candidate_prompts: 

1030 # Only consider orphaned prompts (no gateway or gateway doesn't exist) 

1031 is_orphaned = pmt.gateway_id is None or pmt.gateway_id not in valid_gateway_ids_for_prompts 

1032 if is_orphaned: 

1033 key = (pmt.team_id, pmt.owner_email, pmt.name) 

1034 orphaned_prompts_map[key] = pmt 

1035 if orphaned_prompts_map: 

1036 logger.info(f"Found {len(orphaned_prompts_map)} orphaned prompts to reassign for gateway {gateway.name}") 

1037 except Exception as e: 

1038 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new prompts 

1039 logger.debug(f"Orphan prompt detection skipped: {e}") 

1040 

1041 db_prompts = [] 

1042 for prompt in prompts: 

1043 # Prompts inherit team/owner from gateway (no per-prompt overrides) 

1044 p_team_id = team_id 

1045 p_owner_email = owner_email or effective_owner 

1046 

1047 # Check if there's an orphaned prompt with matching unique key 

1048 lookup_key = (p_team_id, p_owner_email, prompt.name) 

1049 if lookup_key in orphaned_prompts_map: 

1050 # Update orphaned prompt - reassign to new gateway 

1051 existing = orphaned_prompts_map[lookup_key] 

1052 existing.original_name = prompt.name 

1053 existing.custom_name = prompt.name 

1054 existing.display_name = prompt.name 

1055 existing.description = prompt.description 

1056 existing.template = prompt.template if hasattr(prompt, "template") else "" 

1057 existing.federation_source = gateway.name 

1058 existing.modified_by = created_by 

1059 existing.modified_from_ip = created_from_ip 

1060 existing.modified_via = "federation" 

1061 existing.modified_user_agent = created_user_agent 

1062 existing.updated_at = datetime.now(timezone.utc) 

1063 existing.visibility = visibility 

1064 # Note: gateway_id will be set when gateway is created (relationship) 

1065 db_prompts.append(existing) 

1066 else: 

1067 # Create new prompt 

1068 db_prompts.append( 

1069 DbPrompt( 

1070 name=prompt.name, 

1071 original_name=prompt.name, 

1072 custom_name=prompt.name, 

1073 display_name=prompt.name, 

1074 description=prompt.description, 

1075 template=prompt.template if hasattr(prompt, "template") else "", 

1076 argument_schema={}, # Use argument_schema instead of arguments 

1077 # Federation metadata 

1078 created_by=created_by or "system", 

1079 created_from_ip=created_from_ip, 

1080 created_via="federation", # These are federated prompts 

1081 created_user_agent=created_user_agent, 

1082 federation_source=gateway.name, 

1083 version=1, 

1084 # Inherit team assignment from gateway 

1085 team_id=team_id, 

1086 owner_email=owner_email, 

1087 visibility=visibility, 

1088 ) 

1089 ) 

1090 

1091 # Create DB model 

1092 db_gateway = DbGateway( 

1093 name=gateway.name, 

1094 slug=slug_name, 

1095 url=normalized_url, 

1096 description=gateway.description, 

1097 tags=gateway.tags or [], 

1098 transport=gateway.transport, 

1099 capabilities=capabilities, 

1100 last_seen=datetime.now(timezone.utc), 

1101 auth_type=auth_type, 

1102 auth_value=auth_value, 

1103 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth 

1104 oauth_config=oauth_config, 

1105 passthrough_headers=gateway.passthrough_headers, 

1106 tools=tools, 

1107 resources=db_resources, 

1108 prompts=db_prompts, 

1109 # Gateway metadata 

1110 created_by=created_by, 

1111 created_from_ip=created_from_ip, 

1112 created_via=created_via or "api", 

1113 created_user_agent=created_user_agent, 

1114 version=1, 

1115 # Team scoping fields 

1116 team_id=team_id, 

1117 owner_email=owner_email, 

1118 visibility=visibility, 

1119 ca_certificate=gateway.ca_certificate, 

1120 ca_certificate_sig=gateway.ca_certificate_sig, 

1121 signing_algorithm=gateway.signing_algorithm, 

1122 # Gateway mode configuration 

1123 gateway_mode=gateway_mode, 

1124 ) 

1125 

1126 # Add to DB 

1127 db.add(db_gateway) 

1128 db.flush() # Flush to get the ID without committing 

1129 db.refresh(db_gateway) 

1130 

1131 # Update tracking 

1132 self._active_gateways.add(db_gateway.url) 

1133 

1134 # Notify subscribers 

1135 await self._notify_gateway_added(db_gateway) 

1136 

1137 logger.info(f"Registered gateway: {gateway.name}") 

1138 

1139 # Structured logging: Audit trail for gateway creation 

1140 audit_trail.log_action( 

1141 user_id=created_by or "system", 

1142 action="create_gateway", 

1143 resource_type="gateway", 

1144 resource_id=str(db_gateway.id), 

1145 resource_name=db_gateway.name, 

1146 user_email=owner_email, 

1147 team_id=team_id, 

1148 client_ip=created_from_ip, 

1149 user_agent=created_user_agent, 

1150 new_values={ 

1151 "name": db_gateway.name, 

1152 "url": db_gateway.url, 

1153 "visibility": visibility, 

1154 "transport": db_gateway.transport, 

1155 "tools_count": len(tools), 

1156 "resources_count": len(db_resources), 

1157 "prompts_count": len(db_prompts), 

1158 }, 

1159 context={ 

1160 "created_via": created_via, 

1161 }, 

1162 db=db, 

1163 ) 

1164 

1165 # Structured logging: Log successful gateway creation 

1166 structured_logger.log( 

1167 level="INFO", 

1168 message="Gateway created successfully", 

1169 event_type="gateway_created", 

1170 component="gateway_service", 

1171 user_id=created_by, 

1172 user_email=owner_email, 

1173 team_id=team_id, 

1174 resource_type="gateway", 

1175 resource_id=str(db_gateway.id), 

1176 custom_fields={ 

1177 "gateway_name": db_gateway.name, 

1178 "gateway_url": normalized_url, 

1179 "visibility": visibility, 

1180 "transport": db_gateway.transport, 

1181 }, 

1182 ) 

1183 

1184 return GatewayRead.model_validate(self._prepare_gateway_for_read(db_gateway)).masked() 

1185 except* GatewayConnectionError as ge: # pragma: no mutate 

1186 if TYPE_CHECKING: 

1187 ge: ExceptionGroup[GatewayConnectionError] 

1188 logger.error(f"GatewayConnectionError in group: {ge.exceptions}") 

1189 db.rollback() 

1190 

1191 structured_logger.log( 

1192 level="ERROR", 

1193 message="Gateway creation failed due to connection error", 

1194 event_type="gateway_creation_failed", 

1195 component="gateway_service", 

1196 user_id=created_by, 

1197 user_email=owner_email, 

1198 error=ge.exceptions[0], 

1199 custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)}, 

1200 ) 

1201 raise ge.exceptions[0] 

1202 except* GatewayNameConflictError as gnce: # pragma: no mutate 

1203 if TYPE_CHECKING: 

1204 gnce: ExceptionGroup[GatewayNameConflictError] 

1205 logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") 

1206 db.rollback() 

1207 

1208 structured_logger.log( 

1209 level="WARNING", 

1210 message="Gateway creation failed due to name conflict", 

1211 event_type="gateway_name_conflict", 

1212 component="gateway_service", 

1213 user_id=created_by, 

1214 user_email=owner_email, 

1215 custom_fields={"gateway_name": gateway.name, "visibility": visibility}, 

1216 ) 

1217 raise gnce.exceptions[0] 

1218 except* GatewayDuplicateConflictError as guce: # pragma: no mutate 

1219 if TYPE_CHECKING: 

1220 guce: ExceptionGroup[GatewayDuplicateConflictError] 

1221 logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}") 

1222 db.rollback() 

1223 

1224 structured_logger.log( 

1225 level="WARNING", 

1226 message="Gateway creation failed due to duplicate", 

1227 event_type="gateway_duplicate_conflict", 

1228 component="gateway_service", 

1229 user_id=created_by, 

1230 user_email=owner_email, 

1231 custom_fields={"gateway_name": gateway.name}, 

1232 ) 

1233 raise guce.exceptions[0] 

1234 except* ValueError as ve: # pragma: no mutate 

1235 if TYPE_CHECKING: 

1236 ve: ExceptionGroup[ValueError] 

1237 logger.error(f"ValueErrors in group: {ve.exceptions}") 

1238 db.rollback() 

1239 

1240 structured_logger.log( 

1241 level="ERROR", 

1242 message="Gateway creation failed due to validation error", 

1243 event_type="gateway_creation_failed", 

1244 component="gateway_service", 

1245 user_id=created_by, 

1246 user_email=owner_email, 

1247 error=ve.exceptions[0], 

1248 custom_fields={"gateway_name": gateway.name}, 

1249 ) 

1250 raise ve.exceptions[0] 

1251 except* RuntimeError as re: # pragma: no mutate 

1252 if TYPE_CHECKING: 

1253 re: ExceptionGroup[RuntimeError] 

1254 logger.error(f"RuntimeErrors in group: {re.exceptions}") 

1255 db.rollback() 

1256 

1257 structured_logger.log( 

1258 level="ERROR", 

1259 message="Gateway creation failed due to runtime error", 

1260 event_type="gateway_creation_failed", 

1261 component="gateway_service", 

1262 user_id=created_by, 

1263 user_email=owner_email, 

1264 error=re.exceptions[0], 

1265 custom_fields={"gateway_name": gateway.name}, 

1266 ) 

1267 raise re.exceptions[0] 

1268 except* IntegrityError as ie: # pragma: no mutate 

1269 if TYPE_CHECKING: 

1270 ie: ExceptionGroup[IntegrityError] 

1271 logger.error(f"IntegrityErrors in group: {ie.exceptions}") 

1272 db.rollback() 

1273 

1274 structured_logger.log( 

1275 level="ERROR", 

1276 message="Gateway creation failed due to database integrity error", 

1277 event_type="gateway_creation_failed", 

1278 component="gateway_service", 

1279 user_id=created_by, 

1280 user_email=owner_email, 

1281 error=ie.exceptions[0], 

1282 custom_fields={"gateway_name": gateway.name}, 

1283 ) 

1284 raise ie.exceptions[0] 

1285 except* BaseException as other: # catches every other sub-exception # pragma: no mutate 

1286 if TYPE_CHECKING: 

1287 other: ExceptionGroup[Exception] 

1288 logger.error(f"Other grouped errors: {other.exceptions}") 

1289 db.rollback() 

1290 raise other.exceptions[0] 

1291 

1292 async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_email: str) -> Dict[str, Any]: 

1293 """Fetch tools from MCP server after OAuth completion for Authorization Code flow. 

1294 

1295 Args: 

1296 db: Database session 

1297 gateway_id: ID of the gateway to fetch tools for 

1298 app_user_email: ContextForge user email for token retrieval 

1299 

1300 Returns: 

1301 Dict containing capabilities, tools, resources, and prompts 

1302 

1303 Raises: 

1304 GatewayConnectionError: If connection or OAuth fails 

1305 """ 

1306 try: 

1307 # Get the gateway with eager loading for sync operations to avoid N+1 queries 

1308 gateway = db.execute( 

1309 select(DbGateway) 

1310 .options( 

1311 selectinload(DbGateway.tools), 

1312 selectinload(DbGateway.resources), 

1313 selectinload(DbGateway.prompts), 

1314 joinedload(DbGateway.email_team), 

1315 ) 

1316 .where(DbGateway.id == gateway_id) 

1317 ).scalar_one_or_none() 

1318 

1319 if not gateway: 

1320 raise ValueError(f"Gateway {gateway_id} not found") 

1321 

1322 if not gateway.oauth_config: 

1323 raise ValueError(f"Gateway {gateway_id} has no OAuth configuration") 

1324 

1325 grant_type = gateway.oauth_config.get("grant_type") 

1326 if grant_type != "authorization_code": 

1327 raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow") 

1328 

1329 # Get OAuth tokens for this gateway 

1330 # First-Party 

1331 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

1332 

1333 token_storage = TokenStorageService(db) 

1334 

1335 # Get user-specific OAuth token 

1336 if not app_user_email: 

1337 raise GatewayConnectionError(f"User authentication required for OAuth gateway {gateway.name}") 

1338 

1339 access_token = await token_storage.get_user_token(gateway.id, app_user_email) 

1340 

1341 if not access_token: 

1342 raise GatewayConnectionError( 

1343 f"No OAuth tokens found for user {app_user_email} on gateway {gateway.name}. Please complete the OAuth authorization flow first at /oauth/authorize/{gateway.id}" 

1344 ) 

1345 

1346 # Debug: Check if token was decrypted 

1347 if access_token.startswith("Z0FBQUFBQm"): # Encrypted tokens start with this 

1348 logger.error("OAuth token decryption may have failed before gateway initialization") 

1349 else: 

1350 logger.info("Using decrypted OAuth token for gateway %s", gateway.name) 

1351 

1352 # Now connect to MCP server with the access token 

1353 authentication = {"Authorization": f"Bearer {access_token}"} 

1354 

1355 # Use the existing connection logic 

1356 # Note: For OAuth servers, skip validation since we already validated via OAuth flow 

1357 if gateway.transport.upper() == "SSE": 

1358 capabilities, tools, resources, prompts = await self._connect_to_sse_server_without_validation(gateway.url, authentication) 

1359 elif gateway.transport.upper() == "STREAMABLEHTTP": 

1360 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication) 

1361 else: 

1362 raise ValueError(f"Unsupported transport type: {gateway.transport}") 

1363 

1364 # Handle tools, resources, and prompts using helper methods 

1365 tools_to_add = self._update_or_create_tools(db, tools, gateway, "oauth") 

1366 resources_to_add = self._update_or_create_resources(db, resources, gateway, "oauth") 

1367 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "oauth") 

1368 

1369 # Clean up items that are no longer available from the gateway 

1370 new_tool_names = [tool.name for tool in tools] 

1371 new_resource_uris = [resource.uri for resource in resources] 

1372 new_prompt_names = [prompt.name for prompt in prompts] 

1373 

1374 # Count items before cleanup for logging 

1375 

1376 # Bulk delete tools that are no longer available from the gateway 

1377 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

1378 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

1379 if stale_tool_ids: 

1380 # Delete child records first to avoid FK constraint violations 

1381 for i in range(0, len(stale_tool_ids), 500): 

1382 chunk = stale_tool_ids[i : i + 500] 

1383 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

1384 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

1385 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

1386 

1387 # Bulk delete resources that are no longer available from the gateway 

1388 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

1389 if stale_resource_ids: 

1390 # Delete child records first to avoid FK constraint violations 

1391 for i in range(0, len(stale_resource_ids), 500): 

1392 chunk = stale_resource_ids[i : i + 500] 

1393 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

1394 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

1395 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

1396 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

1397 

1398 # Bulk delete prompts that are no longer available from the gateway 

1399 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

1400 if stale_prompt_ids: 

1401 # Delete child records first to avoid FK constraint violations 

1402 for i in range(0, len(stale_prompt_ids), 500): 

1403 chunk = stale_prompt_ids[i : i + 500] 

1404 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

1405 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

1406 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

1407 

1408 # Expire gateway to clear cached relationships after bulk deletes 

1409 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

1410 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

1411 db.expire(gateway) 

1412 

1413 # Update gateway relationships to reflect deletions 

1414 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] 

1415 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] 

1416 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] 

1417 

1418 # Log cleanup results 

1419 tools_removed = len(stale_tool_ids) 

1420 resources_removed = len(stale_resource_ids) 

1421 prompts_removed = len(stale_prompt_ids) 

1422 

1423 if tools_removed > 0: 

1424 logger.info(f"Removed {tools_removed} tools no longer available from gateway") 

1425 if resources_removed > 0: 

1426 logger.info(f"Removed {resources_removed} resources no longer available from gateway") 

1427 if prompts_removed > 0: 

1428 logger.info(f"Removed {prompts_removed} prompts no longer available from gateway") 

1429 

1430 # Update gateway capabilities and last_seen 

1431 gateway.capabilities = capabilities 

1432 gateway.last_seen = datetime.now(timezone.utc) 

1433 

1434 # Register capabilities for notification-driven actions 

1435 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

1436 

1437 # Add new items to DB in chunks to prevent lock escalation 

1438 items_added = 0 

1439 chunk_size = 50 

1440 

1441 if tools_to_add: 

1442 for i in range(0, len(tools_to_add), chunk_size): 

1443 chunk = tools_to_add[i : i + chunk_size] 

1444 db.add_all(chunk) 

1445 db.flush() # Flush each chunk to avoid excessive memory usage 

1446 items_added += len(tools_to_add) 

1447 logger.info(f"Added {len(tools_to_add)} new tools to database") 

1448 

1449 if resources_to_add: 

1450 for i in range(0, len(resources_to_add), chunk_size): 

1451 chunk = resources_to_add[i : i + chunk_size] 

1452 db.add_all(chunk) 

1453 db.flush() 

1454 items_added += len(resources_to_add) 

1455 logger.info(f"Added {len(resources_to_add)} new resources to database") 

1456 

1457 if prompts_to_add: 

1458 for i in range(0, len(prompts_to_add), chunk_size): 

1459 chunk = prompts_to_add[i : i + chunk_size] 

1460 db.add_all(chunk) 

1461 db.flush() 

1462 items_added += len(prompts_to_add) 

1463 logger.info(f"Added {len(prompts_to_add)} new prompts to database") 

1464 

1465 if items_added > 0: 

1466 db.commit() 

1467 logger.info(f"Total {items_added} new items added to database") 

1468 else: 

1469 logger.info("No new items to add to database") 

1470 # Still commit to save any updates to existing items 

1471 db.commit() 

1472 

1473 cache = _get_registry_cache() 

1474 await cache.invalidate_tools() 

1475 await cache.invalidate_resources() 

1476 await cache.invalidate_prompts() 

1477 tool_lookup_cache = _get_tool_lookup_cache() 

1478 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

1479 # Also invalidate tags cache since tool/resource tags may have changed 

1480 # First-Party 

1481 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

1482 

1483 await admin_stats_cache.invalidate_tags() 

1484 

1485 return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts} 

1486 

1487 except GatewayConnectionError as gce: 

1488 db.rollback() 

1489 # Surface validation or depth-related failures directly to the user 

1490 logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}") 

1491 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}") 

1492 except Exception as e: 

1493 db.rollback() 

1494 logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}") 

1495 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}") 

1496 

1497 async def list_gateways( 

1498 self, 

1499 db: Session, 

1500 include_inactive: bool = False, 

1501 tags: Optional[List[str]] = None, 

1502 cursor: Optional[str] = None, 

1503 limit: Optional[int] = None, 

1504 page: Optional[int] = None, 

1505 per_page: Optional[int] = None, 

1506 user_email: Optional[str] = None, 

1507 team_id: Optional[str] = None, 

1508 visibility: Optional[str] = None, 

1509 token_teams: Optional[List[str]] = None, 

1510 ) -> Union[tuple[List[GatewayRead], Optional[str]], Dict[str, Any]]: 

1511 """List all registered gateways with cursor pagination and optional team filtering. 

1512 

1513 Args: 

1514 db: Database session 

1515 include_inactive: Whether to include inactive gateways 

1516 tags (Optional[List[str]]): Filter resources by tags. If provided, only resources with at least one matching tag will be returned. 

1517 cursor: Cursor for pagination (encoded last created_at and id). 

1518 limit: Maximum number of gateways to return. None for default, 0 for unlimited. 

1519 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor. 

1520 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size. 

1521 user_email: Email of user for team-based access control. None for no access control. 

1522 team_id: Optional team ID to filter by specific team (requires user_email). 

1523 visibility: Optional visibility filter (private, team, public) (requires user_email). 

1524 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only). 

1525 

1526 Returns: 

1527 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}} 

1528 If cursor is provided or neither: tuple of (list of GatewayRead objects, next_cursor). 

1529 

1530 Examples: 

1531 >>> from mcpgateway.services.gateway_service import GatewayService 

1532 >>> from unittest.mock import MagicMock, AsyncMock, patch 

1533 >>> from mcpgateway.schemas import GatewayRead 

1534 >>> import asyncio 

1535 >>> service = GatewayService() 

1536 >>> db = MagicMock() 

1537 >>> gateway_obj = MagicMock() 

1538 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj] 

1539 >>> gateway_read_obj = MagicMock(spec=GatewayRead) 

1540 >>> service.convert_gateway_to_read = MagicMock(return_value=gateway_read_obj) 

1541 >>> # Mock the cache to bypass caching logic 

1542 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory: 

1543 ... mock_cache = MagicMock() 

1544 ... mock_cache.get = AsyncMock(return_value=None) 

1545 ... mock_cache.set = AsyncMock(return_value=None) 

1546 ... mock_cache.hash_filters = MagicMock(return_value="hash") 

1547 ... mock_cache_factory.return_value = mock_cache 

1548 ... gateways, cursor = asyncio.run(service.list_gateways(db)) 

1549 ... gateways == [gateway_read_obj] and cursor is None 

1550 True 

1551 

1552 >>> # Test empty result 

1553 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

1554 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory: 

1555 ... mock_cache = MagicMock() 

1556 ... mock_cache.get = AsyncMock(return_value=None) 

1557 ... mock_cache.set = AsyncMock(return_value=None) 

1558 ... mock_cache.hash_filters = MagicMock(return_value="hash") 

1559 ... mock_cache_factory.return_value = mock_cache 

1560 ... empty_result, cursor = asyncio.run(service.list_gateways(db)) 

1561 ... empty_result == [] and cursor is None 

1562 True 

1563 >>> 

1564 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

1565 >>> asyncio.run(service._http_client.aclose()) 

1566 """ 

1567 # Check cache for first page only - only for public-only queries (no user/team filtering) 

1568 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped 

1569 cache = _get_registry_cache() 

1570 is_public_only = token_teams is not None and len(token_teams) == 0 

1571 use_cache = cursor is None and user_email is None and page is None and is_public_only 

1572 if use_cache: 

1573 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None) 

1574 cached = await cache.get("gateways", filters_hash) 

1575 if cached is not None: 

1576 # Reconstruct GatewayRead objects from cached dicts 

1577 # SECURITY: Always apply .masked() to ensure stale cache entries don't leak credentials 

1578 cached_gateways = [GatewayRead.model_validate(g).masked() for g in cached["gateways"]] 

1579 return (cached_gateways, cached.get("next_cursor")) 

1580 

1581 # Build base query with ordering 

1582 query = select(DbGateway).options(joinedload(DbGateway.email_team)).order_by(desc(DbGateway.created_at), desc(DbGateway.id)) 

1583 

1584 # Apply active/inactive filter 

1585 if not include_inactive: 

1586 query = query.where(DbGateway.enabled) 

1587 

1588 query = await self._apply_access_control(query, db, user_email, token_teams, team_id) 

1589 

1590 if visibility: 

1591 query = query.where(DbGateway.visibility == visibility) 

1592 

1593 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats) 

1594 if tags: 

1595 query = query.where(json_contains_tag_expr(db, DbGateway.tags, tags, match_any=True)) 

1596 # Use unified pagination helper - handles both page and cursor pagination 

1597 pag_result = await unified_paginate( 

1598 db=db, 

1599 query=query, 

1600 page=page, 

1601 per_page=per_page, 

1602 cursor=cursor, 

1603 limit=limit, 

1604 base_url="/admin/gateways", # Used for page-based links 

1605 query_params={"include_inactive": include_inactive} if include_inactive else {}, 

1606 ) 

1607 

1608 next_cursor = None 

1609 # Extract gateways based on pagination type 

1610 if page is not None: 

1611 # Page-based: pag_result is a dict 

1612 gateways_db = pag_result["data"] 

1613 else: 

1614 # Cursor-based: pag_result is a tuple 

1615 gateways_db, next_cursor = pag_result 

1616 

1617 db.commit() # Release transaction to avoid idle-in-transaction 

1618 

1619 # Convert to GatewayRead (common for both pagination types) 

1620 result = [] 

1621 for s in gateways_db: 

1622 try: 

1623 result.append(self.convert_gateway_to_read(s)) 

1624 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e: 

1625 logger.exception(f"Failed to convert gateway {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}") 

1626 # Continue with remaining gateways instead of failing completely 

1627 

1628 # Return appropriate format based on pagination type 

1629 if page is not None: 

1630 # Page-based format 

1631 return { 

1632 "data": result, 

1633 "pagination": pag_result["pagination"], 

1634 "links": pag_result["links"], 

1635 } 

1636 

1637 # Cursor-based format 

1638 

1639 # Cache first page results - only for public-only queries (no user/team filtering) 

1640 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped 

1641 if cursor is None and user_email is None and is_public_only: 

1642 try: 

1643 cache_data = {"gateways": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor} 

1644 await cache.set("gateways", cache_data, filters_hash) 

1645 except AttributeError: 

1646 pass # Skip caching if result objects don't support model_dump (e.g., in doctests) 

1647 

1648 return (result, next_cursor) 

1649 

1650 async def list_gateways_for_user( 

1651 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100 

1652 ) -> List[GatewayRead]: 

1653 """ 

1654 DEPRECATED: Use list_gateways() with user_email parameter instead. 

1655 

1656 This method is maintained for backward compatibility but is no longer used. 

1657 New code should call list_gateways() with user_email, team_id, and visibility parameters. 

1658 

1659 List gateways user has access to with team filtering. 

1660 

1661 Args: 

1662 db: Database session 

1663 user_email: Email of the user requesting gateways 

1664 team_id: Optional team ID to filter by specific team 

1665 visibility: Optional visibility filter (private, team, public) 

1666 include_inactive: Whether to include inactive gateways 

1667 skip: Number of gateways to skip for pagination 

1668 limit: Maximum number of gateways to return 

1669 

1670 Returns: 

1671 List[GatewayRead]: Gateways the user has access to 

1672 """ 

1673 # Build query following existing patterns from list_gateways() 

1674 team_service = TeamManagementService(db) 

1675 user_teams = await team_service.get_user_teams(user_email) 

1676 team_ids = [team.id for team in user_teams] 

1677 

1678 # Use joinedload to eager load email_team relationship (avoids N+1 queries) 

1679 query = select(DbGateway).options(joinedload(DbGateway.email_team)) 

1680 

1681 # Apply active/inactive filter 

1682 if not include_inactive: 

1683 query = query.where(DbGateway.enabled.is_(True)) 

1684 

1685 if team_id: 

1686 if team_id not in team_ids: 

1687 return [] # No access to team 

1688 

1689 access_conditions = [] 

1690 # Filter by specific team 

1691 

1692 # Team-owned gateways (team-scoped gateways) 

1693 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"]))) 

1694 

1695 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email)) 

1696 

1697 # Also include global public gateways (no team_id) so public gateways are visible regardless of selected team 

1698 access_conditions.append(DbGateway.visibility == "public") 

1699 

1700 query = query.where(or_(*access_conditions)) 

1701 else: 

1702 # Get user's accessible teams 

1703 # Build access conditions following existing patterns 

1704 access_conditions = [] 

1705 # 1. User's personal resources (owner_email matches) 

1706 access_conditions.append(DbGateway.owner_email == user_email) 

1707 # 2. Team resources where user is member 

1708 if team_ids: 

1709 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"]))) 

1710 # 3. Public resources (if visibility allows) 

1711 access_conditions.append(DbGateway.visibility == "public") 

1712 

1713 query = query.where(or_(*access_conditions)) 

1714 

1715 # Apply visibility filter if specified 

1716 if visibility: 

1717 query = query.where(DbGateway.visibility == visibility) 

1718 

1719 # Apply pagination following existing patterns 

1720 query = query.offset(skip).limit(limit) 

1721 

1722 gateways = db.execute(query).scalars().all() 

1723 

1724 db.commit() # Release transaction to avoid idle-in-transaction 

1725 

1726 # Team names are loaded via joinedload(DbGateway.email_team) 

1727 result = [] 

1728 for g in gateways: 

1729 logger.info(f"Gateway: {g.team_id}, Team: {g.team}") 

1730 result.append(GatewayRead.model_validate(self._prepare_gateway_for_read(g)).masked()) 

1731 return result 

1732 

1733 async def update_gateway( 

1734 self, 

1735 db: Session, 

1736 gateway_id: str, 

1737 gateway_update: GatewayUpdate, 

1738 modified_by: Optional[str] = None, 

1739 modified_from_ip: Optional[str] = None, 

1740 modified_via: Optional[str] = None, 

1741 modified_user_agent: Optional[str] = None, 

1742 include_inactive: bool = True, 

1743 user_email: Optional[str] = None, 

1744 ) -> Optional[GatewayRead]: 

1745 """Update a gateway. 

1746 

1747 Args: 

1748 db: Database session 

1749 gateway_id: Gateway ID to update 

1750 gateway_update: Updated gateway data 

1751 modified_by: Username of the person modifying the gateway 

1752 modified_from_ip: IP address where the modification request originated 

1753 modified_via: Source of modification (ui/api/import) 

1754 modified_user_agent: User agent string from the modification request 

1755 include_inactive: Whether to include inactive gateways 

1756 user_email: Email of user performing update (for ownership check) 

1757 

1758 Returns: 

1759 Updated gateway information 

1760 

1761 Raises: 

1762 GatewayNotFoundError: If gateway not found 

1763 PermissionError: If user doesn't own the gateway 

1764 GatewayError: For other update errors 

1765 GatewayNameConflictError: If gateway name conflict occurs 

1766 IntegrityError: If there is a database integrity error 

1767 ValidationError: If validation fails 

1768 """ 

1769 try: # pylint: disable=too-many-nested-blocks 

1770 # Acquire row lock and eager-load relationships while locked so 

1771 # concurrent updates are serialized on Postgres. 

1772 gateway = get_for_update( 

1773 db, 

1774 DbGateway, 

1775 gateway_id, 

1776 options=[ 

1777 selectinload(DbGateway.tools), 

1778 selectinload(DbGateway.resources), 

1779 selectinload(DbGateway.prompts), 

1780 selectinload(DbGateway.email_team), # Use selectinload to avoid locking email_teams 

1781 ], 

1782 ) 

1783 if not gateway: 

1784 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

1785 

1786 # Check ownership if user_email provided 

1787 if user_email: 

1788 # First-Party 

1789 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

1790 

1791 permission_service = PermissionService(db) 

1792 if not await permission_service.check_resource_ownership(user_email, gateway): 

1793 raise PermissionError("Only the owner can update this gateway") 

1794 

1795 if gateway.enabled or include_inactive: 

1796 # Check for name conflicts if name is being changed 

1797 if gateway_update.name is not None and gateway_update.name != gateway.name: 

1798 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none() 

1799 

1800 # if existing_gateway: 

1801 # raise GatewayNameConflictError( 

1802 # gateway_update.name, 

1803 # enabled=existing_gateway.enabled, 

1804 # gateway_id=existing_gateway.id, 

1805 # ) 

1806 # Check for existing gateway with the same slug and visibility 

1807 new_slug = slugify(gateway_update.name) 

1808 if gateway_update.visibility is not None: 

1809 vis = gateway_update.visibility 

1810 else: 

1811 vis = gateway.visibility 

1812 if vis == "public": 

1813 # Check for existing public gateway with the same slug (row-locked) 

1814 existing_gateway = get_for_update( 

1815 db, 

1816 DbGateway, 

1817 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "public", DbGateway.id != gateway_id), 

1818 ) 

1819 if existing_gateway: 

1820 raise GatewayNameConflictError( 

1821 new_slug, 

1822 enabled=existing_gateway.enabled, 

1823 gateway_id=existing_gateway.id, 

1824 visibility=existing_gateway.visibility, 

1825 ) 

1826 elif vis == "team" and gateway.team_id: 

1827 # Check for existing team gateway with the same slug (row-locked) 

1828 existing_gateway = get_for_update( 

1829 db, 

1830 DbGateway, 

1831 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "team", DbGateway.team_id == gateway.team_id, DbGateway.id != gateway_id), 

1832 ) 

1833 if existing_gateway: 

1834 raise GatewayNameConflictError( 

1835 new_slug, 

1836 enabled=existing_gateway.enabled, 

1837 gateway_id=existing_gateway.id, 

1838 visibility=existing_gateway.visibility, 

1839 ) 

1840 # Check for existing gateway with the same URL and visibility 

1841 normalized_url = "" 

1842 if gateway_update.url is not None: 

1843 normalized_url = self.normalize_url(str(gateway_update.url)) 

1844 else: 

1845 normalized_url = None 

1846 

1847 # Prepare decoded auth_value for uniqueness check 

1848 decoded_auth_value = None 

1849 if gateway_update.auth_value: 

1850 if isinstance(gateway_update.auth_value, str): 

1851 try: 

1852 decoded_auth_value = decode_auth(gateway_update.auth_value) 

1853 except Exception as e: 

1854 logger.warning(f"Failed to decode provided auth_value: {e}") 

1855 elif isinstance(gateway_update.auth_value, dict): 

1856 decoded_auth_value = gateway_update.auth_value 

1857 

1858 # Determine final values for uniqueness check 

1859 final_auth_value = decoded_auth_value if gateway_update.auth_value is not None else (decode_auth(gateway.auth_value) if isinstance(gateway.auth_value, str) else gateway.auth_value) 

1860 final_oauth_config = gateway_update.oauth_config if gateway_update.oauth_config is not None else gateway.oauth_config 

1861 final_visibility = gateway_update.visibility if gateway_update.visibility is not None else gateway.visibility 

1862 

1863 # Check for duplicates with updated credentials 

1864 if not gateway_update.one_time_auth: 

1865 duplicate_gateway = self._check_gateway_uniqueness( 

1866 db=db, 

1867 url=normalized_url, 

1868 auth_value=final_auth_value, 

1869 oauth_config=final_oauth_config, 

1870 team_id=gateway.team_id, 

1871 visibility=final_visibility, 

1872 gateway_id=gateway_id, # Exclude current gateway from check 

1873 owner_email=user_email, 

1874 ) 

1875 

1876 if duplicate_gateway: 

1877 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway) 

1878 

1879 # FIX for Issue #1025: Determine if URL actually changed before we update it 

1880 # We need this early because we update gateway.url below, and need to know 

1881 # if it actually changed to decide whether to re-fetch tools 

1882 # tools/resoures/prompts are need to be re-fetched not only if URL changed , in case any update like authentication and visibility changed 

1883 # url_changed = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != gateway.url 

1884 

1885 # Save original values BEFORE updating for change detection checks later 

1886 original_url = gateway.url 

1887 original_auth_type = gateway.auth_type 

1888 

1889 # Update fields if provided 

1890 if gateway_update.name is not None: 

1891 gateway.name = gateway_update.name 

1892 gateway.slug = slugify(gateway_update.name) 

1893 if gateway_update.url is not None: 

1894 # Normalize the updated URL 

1895 gateway.url = self.normalize_url(str(gateway_update.url)) 

1896 if gateway_update.description is not None: 

1897 gateway.description = gateway_update.description 

1898 if gateway_update.transport is not None: 

1899 gateway.transport = gateway_update.transport 

1900 if gateway_update.tags is not None: 

1901 gateway.tags = gateway_update.tags 

1902 if gateway_update.visibility is not None: 

1903 gateway.visibility = gateway_update.visibility 

1904 # Propagate visibility to all linked items immediately so it 

1905 # takes effect even when the upstream server is unreachable 

1906 # and _initialize_gateway fails. 

1907 for tool in gateway.tools: 

1908 tool.visibility = gateway.visibility 

1909 for resource in gateway.resources: 

1910 resource.visibility = gateway.visibility 

1911 for prompt in gateway.prompts: 

1912 prompt.visibility = gateway.visibility 

1913 if gateway_update.passthrough_headers is not None: 

1914 if isinstance(gateway_update.passthrough_headers, list): 

1915 gateway.passthrough_headers = gateway_update.passthrough_headers 

1916 else: 

1917 if isinstance(gateway_update.passthrough_headers, str): 

1918 parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()] 

1919 gateway.passthrough_headers = parsed 

1920 else: 

1921 raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string") 

1922 

1923 logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}") 

1924 

1925 # Only update auth_type if explicitly provided in the update 

1926 if gateway_update.auth_type is not None: 

1927 gateway.auth_type = gateway_update.auth_type 

1928 

1929 # If auth_type is empty, update the auth_value too 

1930 if gateway_update.auth_type == "": 

1931 gateway.auth_value = cast(Any, "") 

1932 

1933 # Clear auth_query_params when switching away from query_param auth 

1934 if original_auth_type == "query_param" and gateway_update.auth_type != "query_param": 

1935 gateway.auth_query_params = None 

1936 logger.debug(f"Cleared auth_query_params for gateway {gateway.id} (switched from query_param to {gateway_update.auth_type})") 

1937 

1938 # if auth_type is not None and only then check auth_value 

1939 # Handle OAuth configuration updates 

1940 if gateway_update.oauth_config is not None: 

1941 gateway.oauth_config = await protect_oauth_config_for_storage(gateway_update.oauth_config, existing_oauth_config=gateway.oauth_config) 

1942 

1943 # Handle auth_value updates (both existing and new auth values) 

1944 token = gateway_update.auth_token 

1945 password = gateway_update.auth_password 

1946 header_value = gateway_update.auth_header_value 

1947 

1948 # Support multiple custom headers on update 

1949 if hasattr(gateway_update, "auth_headers") and gateway_update.auth_headers: 

1950 existing_auth_raw = getattr(gateway, "auth_value", {}) or {} 

1951 if isinstance(existing_auth_raw, str): 

1952 try: 

1953 existing_auth = decode_auth(existing_auth_raw) 

1954 except Exception: 

1955 existing_auth = {} 

1956 elif isinstance(existing_auth_raw, dict): 

1957 existing_auth = existing_auth_raw 

1958 else: 

1959 existing_auth = {} 

1960 

1961 header_dict: Dict[str, str] = {} 

1962 for header in gateway_update.auth_headers: 

1963 key = header.get("key") 

1964 if not key: 

1965 continue 

1966 value = header.get("value", "") 

1967 if value == settings.masked_auth_value and key in existing_auth: 

1968 header_dict[key] = existing_auth[key] 

1969 else: 

1970 header_dict[key] = value 

1971 gateway.auth_value = header_dict # Store as dict for DB JSON field 

1972 elif settings.masked_auth_value not in (token, password, header_value): 

1973 # Check if values differ from existing ones or if setting for first time 

1974 decoded_auth = decode_auth(gateway_update.auth_value) if gateway_update.auth_value else {} 

1975 current_auth = getattr(gateway, "auth_value", {}) or {} 

1976 if current_auth != decoded_auth: 

1977 gateway.auth_value = decoded_auth 

1978 

1979 # Handle query_param auth updates with service-layer enforcement 

1980 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

1981 init_url = gateway.url 

1982 

1983 # Check if updating to query_param auth or updating existing query_param credentials 

1984 # Use original_auth_type since gateway.auth_type may have been updated already 

1985 is_switching_to_queryparam = gateway_update.auth_type == "query_param" and original_auth_type != "query_param" 

1986 is_updating_queryparam_creds = original_auth_type == "query_param" and (gateway_update.auth_query_param_key is not None or gateway_update.auth_query_param_value is not None) 

1987 is_url_changing = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != original_url 

1988 

1989 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"): 

1990 # Service-layer enforcement: Check feature flag 

1991 if not settings.insecure_allow_queryparam_auth: 

1992 # Grandfather clause: Allow updates to existing query_param gateways 

1993 # unless they're trying to change credentials 

1994 if is_switching_to_queryparam or is_updating_queryparam_creds: 

1995 raise ValueError("Query parameter authentication is disabled. " + "Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.") 

1996 

1997 # Service-layer enforcement: Check host allowlist 

1998 if settings.insecure_queryparam_auth_allowed_hosts: 

1999 check_url = str(gateway_update.url) if gateway_update.url else gateway.url 

2000 parsed = urlparse(check_url) 

2001 hostname = (parsed.hostname or "").lower() 

2002 if hostname not in settings.insecure_queryparam_auth_allowed_hosts: 

2003 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts) 

2004 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. Allowed: {allowed}") 

2005 

2006 # Process query_param auth credentials 

2007 param_key = getattr(gateway_update, "auth_query_param_key", None) or (next(iter(gateway.auth_query_params.keys()), None) if gateway.auth_query_params else None) 

2008 param_value = getattr(gateway_update, "auth_query_param_value", None) 

2009 

2010 # Get raw value from SecretStr if applicable 

2011 raw_value: Optional[str] = None 

2012 if param_value: 

2013 if hasattr(param_value, "get_secret_value"): 

2014 raw_value = param_value.get_secret_value() 

2015 else: 

2016 raw_value = str(param_value) 

2017 

2018 # Check if the value is the masked placeholder - if so, keep existing value 

2019 is_masked_placeholder = raw_value == settings.masked_auth_value 

2020 

2021 if param_key: 

2022 if raw_value and not is_masked_placeholder: 

2023 # New value provided - encrypt for storage 

2024 encrypted_value = encode_auth({param_key: raw_value}) 

2025 gateway.auth_query_params = {param_key: encrypted_value} 

2026 auth_query_params_decrypted = {param_key: raw_value} 

2027 elif gateway.auth_query_params: 

2028 # Use existing encrypted value 

2029 existing_encrypted = gateway.auth_query_params.get(param_key, "") 

2030 if existing_encrypted: 

2031 decrypted = decode_auth(existing_encrypted) 

2032 auth_query_params_decrypted = {param_key: decrypted.get(param_key, "")} 

2033 

2034 # Append query params to URL for initialization 

2035 if auth_query_params_decrypted: 

2036 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2037 

2038 # Update auth_type if switching 

2039 if is_switching_to_queryparam: 

2040 gateway.auth_type = "query_param" 

2041 gateway.auth_value = None # Query param auth doesn't use auth_value 

2042 

2043 elif gateway.auth_type == "query_param" and gateway.auth_query_params: 

2044 # Existing query_param gateway without credential changes - decrypt for init 

2045 first_key = next(iter(gateway.auth_query_params.keys()), None) 

2046 if first_key: 

2047 encrypted_value = gateway.auth_query_params.get(first_key, "") 

2048 if encrypted_value: 

2049 decrypted = decode_auth(encrypted_value) 

2050 auth_query_params_decrypted = {first_key: decrypted.get(first_key, "")} 

2051 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2052 

2053 # Try to reinitialize connection if URL actually changed 

2054 # if url_changed: 

2055 # Initialize empty lists in case initialization fails 

2056 tools_to_add = [] 

2057 resources_to_add = [] 

2058 prompts_to_add = [] 

2059 

2060 try: 

2061 ca_certificate = getattr(gateway, "ca_certificate", None) 

2062 capabilities, tools, resources, prompts = await self._initialize_gateway( 

2063 init_url, 

2064 gateway.auth_value, 

2065 gateway.transport, 

2066 gateway.auth_type, 

2067 gateway.oauth_config, 

2068 ca_certificate, 

2069 auth_query_params=auth_query_params_decrypted, 

2070 ) 

2071 new_tool_names = [tool.name for tool in tools] 

2072 new_resource_uris = [resource.uri for resource in resources] 

2073 new_prompt_names = [prompt.name for prompt in prompts] 

2074 

2075 if gateway_update.one_time_auth: 

2076 # For one-time auth, clear auth_type and auth_value after initialization 

2077 gateway.auth_type = "one_time_auth" 

2078 gateway.auth_value = None 

2079 gateway.oauth_config = None 

2080 

2081 # Update tools using helper method 

2082 tools_to_add = self._update_or_create_tools(db, tools, gateway, "update") 

2083 

2084 # Update resources using helper method 

2085 resources_to_add = self._update_or_create_resources(db, resources, gateway, "update") 

2086 

2087 # Update prompts using helper method 

2088 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "update") 

2089 

2090 # Log newly added items 

2091 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add) 

2092 if items_added > 0: 

2093 if tools_to_add: 

2094 logger.info(f"Added {len(tools_to_add)} new tools during gateway update") 

2095 if resources_to_add: 

2096 logger.info(f"Added {len(resources_to_add)} new resources during gateway update") 

2097 if prompts_to_add: 

2098 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway update") 

2099 logger.info(f"Total {items_added} new items added during gateway update") 

2100 

2101 # Count items before cleanup for logging 

2102 

2103 # Bulk delete tools that are no longer available from the gateway 

2104 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2105 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

2106 if stale_tool_ids: 

2107 # Delete child records first to avoid FK constraint violations 

2108 for i in range(0, len(stale_tool_ids), 500): 

2109 chunk = stale_tool_ids[i : i + 500] 

2110 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2111 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2112 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2113 

2114 # Bulk delete resources that are no longer available from the gateway 

2115 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

2116 if stale_resource_ids: 

2117 # Delete child records first to avoid FK constraint violations 

2118 for i in range(0, len(stale_resource_ids), 500): 

2119 chunk = stale_resource_ids[i : i + 500] 

2120 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2121 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2122 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2123 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2124 

2125 # Bulk delete prompts that are no longer available from the gateway 

2126 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

2127 if stale_prompt_ids: 

2128 # Delete child records first to avoid FK constraint violations 

2129 for i in range(0, len(stale_prompt_ids), 500): 

2130 chunk = stale_prompt_ids[i : i + 500] 

2131 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2132 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2133 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2134 

2135 # Expire gateway to clear cached relationships after bulk deletes 

2136 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

2137 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

2138 db.expire(gateway) 

2139 

2140 gateway.capabilities = capabilities 

2141 

2142 # Register capabilities for notification-driven actions 

2143 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

2144 

2145 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

2146 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows 

2147 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows 

2148 

2149 # Log cleanup results 

2150 tools_removed = len(stale_tool_ids) 

2151 resources_removed = len(stale_resource_ids) 

2152 prompts_removed = len(stale_prompt_ids) 

2153 

2154 if tools_removed > 0: 

2155 logger.info(f"Removed {tools_removed} tools no longer available during gateway update") 

2156 if resources_removed > 0: 

2157 logger.info(f"Removed {resources_removed} resources no longer available during gateway update") 

2158 if prompts_removed > 0: 

2159 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway update") 

2160 

2161 gateway.last_seen = datetime.now(timezone.utc) 

2162 

2163 # Add new items to database session in chunks to prevent lock escalation 

2164 chunk_size = 50 

2165 

2166 if tools_to_add: 

2167 for i in range(0, len(tools_to_add), chunk_size): 

2168 chunk = tools_to_add[i : i + chunk_size] 

2169 db.add_all(chunk) 

2170 db.flush() 

2171 if resources_to_add: 

2172 for i in range(0, len(resources_to_add), chunk_size): 

2173 chunk = resources_to_add[i : i + chunk_size] 

2174 db.add_all(chunk) 

2175 db.flush() 

2176 if prompts_to_add: 

2177 for i in range(0, len(prompts_to_add), chunk_size): 

2178 chunk = prompts_to_add[i : i + chunk_size] 

2179 db.add_all(chunk) 

2180 db.flush() 

2181 

2182 # Update tracking with new URL 

2183 self._active_gateways.discard(gateway.url) 

2184 self._active_gateways.add(gateway.url) 

2185 except Exception as e: 

2186 logger.warning(f"Failed to initialize updated gateway: {e}") 

2187 

2188 # Update tags if provided 

2189 if gateway_update.tags is not None: 

2190 gateway.tags = gateway_update.tags 

2191 

2192 # Update gateway_mode if provided 

2193 if hasattr(gateway_update, "gateway_mode") and gateway_update.gateway_mode is not None: 

2194 if gateway_update.gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled: 

2195 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.") 

2196 gateway.gateway_mode = gateway_update.gateway_mode 

2197 

2198 # Update metadata fields 

2199 gateway.updated_at = datetime.now(timezone.utc) 

2200 if modified_by: 

2201 gateway.modified_by = modified_by 

2202 if modified_from_ip: 

2203 gateway.modified_from_ip = modified_from_ip 

2204 if modified_via: 

2205 gateway.modified_via = modified_via 

2206 if modified_user_agent: 

2207 gateway.modified_user_agent = modified_user_agent 

2208 if hasattr(gateway, "version") and gateway.version is not None: 

2209 gateway.version = gateway.version + 1 

2210 else: 

2211 gateway.version = 1 

2212 

2213 db.commit() 

2214 db.refresh(gateway) 

2215 

2216 # Invalidate cache after successful update 

2217 cache = _get_registry_cache() 

2218 await cache.invalidate_gateways() 

2219 tool_lookup_cache = _get_tool_lookup_cache() 

2220 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

2221 # Also invalidate tags cache since gateway tags may have changed 

2222 # First-Party 

2223 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

2224 

2225 await admin_stats_cache.invalidate_tags() 

2226 

2227 # Notify subscribers 

2228 await self._notify_gateway_updated(gateway) 

2229 

2230 logger.info(f"Updated gateway: {gateway.name}") 

2231 

2232 # Structured logging: Audit trail for gateway update 

2233 audit_trail.log_action( 

2234 user_id=user_email or modified_by or "system", 

2235 action="update_gateway", 

2236 resource_type="gateway", 

2237 resource_id=str(gateway.id), 

2238 resource_name=gateway.name, 

2239 user_email=user_email, 

2240 team_id=gateway.team_id, 

2241 client_ip=modified_from_ip, 

2242 user_agent=modified_user_agent, 

2243 new_values={ 

2244 "name": gateway.name, 

2245 "url": gateway.url, 

2246 "version": gateway.version, 

2247 }, 

2248 context={ 

2249 "modified_via": modified_via, 

2250 }, 

2251 db=db, 

2252 ) 

2253 

2254 # Structured logging: Log successful gateway update 

2255 structured_logger.log( 

2256 level="INFO", 

2257 message="Gateway updated successfully", 

2258 event_type="gateway_updated", 

2259 component="gateway_service", 

2260 user_id=modified_by, 

2261 user_email=user_email, 

2262 team_id=gateway.team_id, 

2263 resource_type="gateway", 

2264 resource_id=str(gateway.id), 

2265 custom_fields={ 

2266 "gateway_name": gateway.name, 

2267 "version": gateway.version, 

2268 }, 

2269 ) 

2270 

2271 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() 

2272 # Gateway is inactive and include_inactive is False → skip update, return None 

2273 return None 

2274 except GatewayNameConflictError as ge: 

2275 logger.error(f"GatewayNameConflictError in group: {ge}") 

2276 db.rollback() 

2277 

2278 structured_logger.log( 

2279 level="WARNING", 

2280 message="Gateway update failed due to name conflict", 

2281 event_type="gateway_name_conflict", 

2282 component="gateway_service", 

2283 user_email=user_email, 

2284 resource_type="gateway", 

2285 resource_id=gateway_id, 

2286 error=ge, 

2287 ) 

2288 raise ge 

2289 except GatewayNotFoundError as gnfe: 

2290 logger.error(f"GatewayNotFoundError: {gnfe}") 

2291 db.rollback() 

2292 

2293 structured_logger.log( 

2294 level="ERROR", 

2295 message="Gateway update failed - gateway not found", 

2296 event_type="gateway_not_found", 

2297 component="gateway_service", 

2298 user_email=user_email, 

2299 resource_type="gateway", 

2300 resource_id=gateway_id, 

2301 error=gnfe, 

2302 ) 

2303 raise gnfe 

2304 except IntegrityError as ie: 

2305 logger.error(f"IntegrityErrors in group: {ie}") 

2306 db.rollback() 

2307 

2308 structured_logger.log( 

2309 level="ERROR", 

2310 message="Gateway update failed due to database integrity error", 

2311 event_type="gateway_update_failed", 

2312 component="gateway_service", 

2313 user_email=user_email, 

2314 resource_type="gateway", 

2315 resource_id=gateway_id, 

2316 error=ie, 

2317 ) 

2318 raise ie 

2319 except PermissionError as pe: 

2320 db.rollback() 

2321 

2322 structured_logger.log( 

2323 level="WARNING", 

2324 message="Gateway update failed due to permission error", 

2325 event_type="gateway_update_permission_denied", 

2326 component="gateway_service", 

2327 user_email=user_email, 

2328 resource_type="gateway", 

2329 resource_id=gateway_id, 

2330 error=pe, 

2331 ) 

2332 raise 

2333 except Exception as e: 

2334 db.rollback() 

2335 

2336 structured_logger.log( 

2337 level="ERROR", 

2338 message="Gateway update failed", 

2339 event_type="gateway_update_failed", 

2340 component="gateway_service", 

2341 user_email=user_email, 

2342 resource_type="gateway", 

2343 resource_id=gateway_id, 

2344 error=e, 

2345 ) 

2346 raise GatewayError(f"Failed to update gateway: {str(e)}") 

2347 

2348 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead: 

2349 """Get a gateway by its ID. 

2350 

2351 Args: 

2352 db: Database session 

2353 gateway_id: Gateway ID 

2354 include_inactive: Whether to include inactive gateways 

2355 

2356 Returns: 

2357 GatewayRead object 

2358 

2359 Raises: 

2360 GatewayNotFoundError: If the gateway is not found 

2361 

2362 Examples: 

2363 >>> from unittest.mock import MagicMock 

2364 >>> from mcpgateway.schemas import GatewayRead 

2365 >>> service = GatewayService() 

2366 >>> db = MagicMock() 

2367 >>> gateway_mock = MagicMock() 

2368 >>> gateway_mock.enabled = True 

2369 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock 

2370 >>> mocked_gateway_read = MagicMock() 

2371 >>> mocked_gateway_read.masked.return_value = 'gateway_read' 

2372 >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read) 

2373 >>> import asyncio 

2374 >>> result = asyncio.run(service.get_gateway(db, 'gateway_id')) 

2375 >>> result == 'gateway_read' 

2376 True 

2377 

2378 >>> # Test with inactive gateway but include_inactive=True 

2379 >>> gateway_mock.enabled = False 

2380 >>> result_inactive = asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=True)) 

2381 >>> result_inactive == 'gateway_read' 

2382 True 

2383 

2384 >>> # Test gateway not found 

2385 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

2386 >>> try: 

2387 ... asyncio.run(service.get_gateway(db, 'missing_id')) 

2388 ... except GatewayNotFoundError as e: 

2389 ... 'Gateway not found: missing_id' in str(e) 

2390 True 

2391 

2392 >>> # Test inactive gateway with include_inactive=False 

2393 >>> gateway_mock.enabled = False 

2394 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock 

2395 >>> try: 

2396 ... asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=False)) 

2397 ... except GatewayNotFoundError as e: 

2398 ... 'Gateway not found: gateway_id' in str(e) 

2399 True 

2400 >>> 

2401 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

2402 >>> asyncio.run(service._http_client.aclose()) 

2403 """ 

2404 # Use eager loading to avoid N+1 queries for relationships and team name 

2405 gateway = db.execute( 

2406 select(DbGateway) 

2407 .options( 

2408 selectinload(DbGateway.tools), 

2409 selectinload(DbGateway.resources), 

2410 selectinload(DbGateway.prompts), 

2411 joinedload(DbGateway.email_team), 

2412 ) 

2413 .where(DbGateway.id == gateway_id) 

2414 ).scalar_one_or_none() 

2415 

2416 if not gateway: 

2417 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2418 

2419 if gateway.enabled or include_inactive: 

2420 # Structured logging: Log gateway view 

2421 structured_logger.log( 

2422 level="INFO", 

2423 message="Gateway retrieved successfully", 

2424 event_type="gateway_viewed", 

2425 component="gateway_service", 

2426 team_id=getattr(gateway, "team_id", None), 

2427 resource_type="gateway", 

2428 resource_id=str(gateway.id), 

2429 custom_fields={ 

2430 "gateway_name": gateway.name, 

2431 "gateway_url": gateway.url, 

2432 "include_inactive": include_inactive, 

2433 }, 

2434 ) 

2435 

2436 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() 

2437 

2438 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2439 

2440 async def set_gateway_state(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False, user_email: Optional[str] = None) -> GatewayRead: 

2441 """ 

2442 Set the activation status of a gateway. 

2443 

2444 Args: 

2445 db: Database session 

2446 gateway_id: Gateway ID 

2447 activate: True to activate, False to deactivate 

2448 reachable: Whether the gateway is reachable 

2449 only_update_reachable: Only update reachable status 

2450 user_email: Optional[str] The email of the user to check if the user has permission to modify. 

2451 

2452 Returns: 

2453 The updated GatewayRead object 

2454 

2455 Raises: 

2456 GatewayNotFoundError: If the gateway is not found 

2457 GatewayError: For other errors 

2458 PermissionError: If user doesn't own the agent. 

2459 """ 

2460 try: 

2461 # Eager-load collections for the gateway. Note: we don't use FOR UPDATE 

2462 # here because _initialize_gateway does network I/O, and holding a row 

2463 # lock during network calls would block other operations and risk timeouts. 

2464 gateway = db.execute( 

2465 select(DbGateway) 

2466 .options( 

2467 selectinload(DbGateway.tools), 

2468 selectinload(DbGateway.resources), 

2469 selectinload(DbGateway.prompts), 

2470 joinedload(DbGateway.email_team), 

2471 ) 

2472 .where(DbGateway.id == gateway_id) 

2473 ).scalar_one_or_none() 

2474 if not gateway: 

2475 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2476 

2477 if user_email: 

2478 # First-Party 

2479 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2480 

2481 permission_service = PermissionService(db) 

2482 if not await permission_service.check_resource_ownership(user_email, gateway): 

2483 raise PermissionError("Only the owner can activate the gateway" if activate else "Only the owner can deactivate the gateway") 

2484 

2485 # Update status if it's different 

2486 if (gateway.enabled != activate) or (gateway.reachable != reachable): 

2487 gateway.enabled = activate 

2488 gateway.reachable = reachable 

2489 gateway.updated_at = datetime.now(timezone.utc) 

2490 # Update tracking 

2491 if activate and reachable: 

2492 self._active_gateways.add(gateway.url) 

2493 

2494 # Initialize empty lists in case initialization fails 

2495 tools_to_add = [] 

2496 resources_to_add = [] 

2497 prompts_to_add = [] 

2498 

2499 # Try to initialize if activating 

2500 try: 

2501 # Handle query_param auth - decrypt and apply to URL 

2502 init_url = gateway.url 

2503 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

2504 if gateway.auth_type == "query_param" and gateway.auth_query_params: 

2505 auth_query_params_decrypted = {} 

2506 for param_key, encrypted_value in gateway.auth_query_params.items(): 

2507 if encrypted_value: 

2508 try: 

2509 decrypted = decode_auth(encrypted_value) 

2510 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

2511 except Exception: 

2512 logger.debug(f"Failed to decrypt query param '{param_key}' for gateway activation") 

2513 if auth_query_params_decrypted: 

2514 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2515 

2516 capabilities, tools, resources, prompts = await self._initialize_gateway( 

2517 init_url, gateway.auth_value, gateway.transport, gateway.auth_type, gateway.oauth_config, auth_query_params=auth_query_params_decrypted, oauth_auto_fetch_tool_flag=True 

2518 ) 

2519 new_tool_names = [tool.name for tool in tools] 

2520 new_resource_uris = [resource.uri for resource in resources] 

2521 new_prompt_names = [prompt.name for prompt in prompts] 

2522 

2523 # Update tools, resources, and prompts using helper methods 

2524 tools_to_add = self._update_or_create_tools(db, tools, gateway, "rediscovery") 

2525 resources_to_add = self._update_or_create_resources(db, resources, gateway, "rediscovery") 

2526 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "rediscovery") 

2527 

2528 # Log newly added items 

2529 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add) 

2530 if items_added > 0: 

2531 if tools_to_add: 

2532 logger.info(f"Added {len(tools_to_add)} new tools during gateway reactivation") 

2533 if resources_to_add: 

2534 logger.info(f"Added {len(resources_to_add)} new resources during gateway reactivation") 

2535 if prompts_to_add: 

2536 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway reactivation") 

2537 logger.info(f"Total {items_added} new items added during gateway reactivation") 

2538 

2539 # Count items before cleanup for logging 

2540 

2541 # Bulk delete tools that are no longer available from the gateway 

2542 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2543 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

2544 if stale_tool_ids: 

2545 # Delete child records first to avoid FK constraint violations 

2546 for i in range(0, len(stale_tool_ids), 500): 

2547 chunk = stale_tool_ids[i : i + 500] 

2548 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2549 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2550 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2551 

2552 # Bulk delete resources that are no longer available from the gateway 

2553 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

2554 if stale_resource_ids: 

2555 # Delete child records first to avoid FK constraint violations 

2556 for i in range(0, len(stale_resource_ids), 500): 

2557 chunk = stale_resource_ids[i : i + 500] 

2558 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2559 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2560 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2561 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2562 

2563 # Bulk delete prompts that are no longer available from the gateway 

2564 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

2565 if stale_prompt_ids: 

2566 # Delete child records first to avoid FK constraint violations 

2567 for i in range(0, len(stale_prompt_ids), 500): 

2568 chunk = stale_prompt_ids[i : i + 500] 

2569 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2570 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2571 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2572 

2573 # Expire gateway to clear cached relationships after bulk deletes 

2574 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

2575 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

2576 db.expire(gateway) 

2577 

2578 gateway.capabilities = capabilities 

2579 

2580 # Register capabilities for notification-driven actions 

2581 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

2582 

2583 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

2584 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows 

2585 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows 

2586 

2587 # Log cleanup results 

2588 tools_removed = len(stale_tool_ids) 

2589 resources_removed = len(stale_resource_ids) 

2590 prompts_removed = len(stale_prompt_ids) 

2591 

2592 if tools_removed > 0: 

2593 logger.info(f"Removed {tools_removed} tools no longer available during gateway reactivation") 

2594 if resources_removed > 0: 

2595 logger.info(f"Removed {resources_removed} resources no longer available during gateway reactivation") 

2596 if prompts_removed > 0: 

2597 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway reactivation") 

2598 

2599 gateway.last_seen = datetime.now(timezone.utc) 

2600 

2601 # Add new items to database session in chunks to prevent lock escalation 

2602 chunk_size = 50 

2603 

2604 if tools_to_add: 

2605 for i in range(0, len(tools_to_add), chunk_size): 

2606 chunk = tools_to_add[i : i + chunk_size] 

2607 db.add_all(chunk) 

2608 db.flush() 

2609 if resources_to_add: 

2610 for i in range(0, len(resources_to_add), chunk_size): 

2611 chunk = resources_to_add[i : i + chunk_size] 

2612 db.add_all(chunk) 

2613 db.flush() 

2614 if prompts_to_add: 

2615 for i in range(0, len(prompts_to_add), chunk_size): 

2616 chunk = prompts_to_add[i : i + chunk_size] 

2617 db.add_all(chunk) 

2618 db.flush() 

2619 except Exception as e: 

2620 logger.warning(f"Failed to initialize reactivated gateway: {e}") 

2621 else: 

2622 self._active_gateways.discard(gateway.url) 

2623 

2624 db.commit() 

2625 db.refresh(gateway) 

2626 

2627 # Invalidate cache after status change 

2628 cache = _get_registry_cache() 

2629 await cache.invalidate_gateways() 

2630 

2631 # Notify Subscribers 

2632 if not gateway.enabled: 

2633 # Inactive 

2634 await self._notify_gateway_deactivated(gateway) 

2635 elif gateway.enabled and not gateway.reachable: 

2636 # Offline (Enabled but Unreachable) 

2637 await self._notify_gateway_offline(gateway) 

2638 else: 

2639 # Active (Enabled and Reachable) 

2640 await self._notify_gateway_activated(gateway) 

2641 

2642 # Bulk update tools - single UPDATE statement instead of N FOR UPDATE locks 

2643 # This prevents lock contention under high concurrent load 

2644 now = datetime.now(timezone.utc) 

2645 if only_update_reachable: 

2646 # Only update reachable status, keep enabled as-is 

2647 tools_result = db.execute(update(DbTool).where(DbTool.gateway_id == gateway_id).where(DbTool.reachable != reachable).values(reachable=reachable, updated_at=now)) 

2648 else: 

2649 # Update both enabled and reachable 

2650 tools_result = db.execute( 

2651 update(DbTool) 

2652 .where(DbTool.gateway_id == gateway_id) 

2653 .where(or_(DbTool.enabled != activate, DbTool.reachable != reachable)) 

2654 .values(enabled=activate, reachable=reachable, updated_at=now) 

2655 ) 

2656 tools_updated = tools_result.rowcount 

2657 

2658 # Commit tool updates 

2659 if tools_updated > 0: 

2660 db.commit() 

2661 

2662 # Invalidate tools cache once after bulk update 

2663 if tools_updated > 0: 

2664 await cache.invalidate_tools() 

2665 tool_lookup_cache = _get_tool_lookup_cache() 

2666 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

2667 

2668 # Bulk update prompts when gateway is deactivated/activated (skip for reachability-only updates) 

2669 prompts_updated = 0 

2670 if not only_update_reachable: 

2671 prompts_result = db.execute(update(DbPrompt).where(DbPrompt.gateway_id == gateway_id).where(DbPrompt.enabled != activate).values(enabled=activate, updated_at=now)) 

2672 prompts_updated = prompts_result.rowcount 

2673 if prompts_updated > 0: 

2674 db.commit() 

2675 await cache.invalidate_prompts() 

2676 

2677 # Bulk update resources when gateway is deactivated/activated (skip for reachability-only updates) 

2678 resources_updated = 0 

2679 if not only_update_reachable: 

2680 resources_result = db.execute(update(DbResource).where(DbResource.gateway_id == gateway_id).where(DbResource.enabled != activate).values(enabled=activate, updated_at=now)) 

2681 resources_updated = resources_result.rowcount 

2682 if resources_updated > 0: 

2683 db.commit() 

2684 await cache.invalidate_resources() 

2685 

2686 logger.debug(f"Gateway {gateway.name} bulk state update: {tools_updated} tools, {prompts_updated} prompts, {resources_updated} resources") 

2687 

2688 logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") 

2689 

2690 # Structured logging: Audit trail for gateway state change 

2691 audit_trail.log_action( 

2692 user_id=user_email or "system", 

2693 action="set_gateway_state", 

2694 resource_type="gateway", 

2695 resource_id=str(gateway.id), 

2696 resource_name=gateway.name, 

2697 user_email=user_email, 

2698 team_id=gateway.team_id, 

2699 new_values={ 

2700 "enabled": gateway.enabled, 

2701 "reachable": gateway.reachable, 

2702 }, 

2703 context={ 

2704 "action": "activate" if activate else "deactivate", 

2705 "only_update_reachable": only_update_reachable, 

2706 }, 

2707 db=db, 

2708 ) 

2709 

2710 # Structured logging: Log successful gateway state change 

2711 structured_logger.log( 

2712 level="INFO", 

2713 message=f"Gateway {'activated' if activate else 'deactivated'} successfully", 

2714 event_type="gateway_state_changed", 

2715 component="gateway_service", 

2716 user_email=user_email, 

2717 team_id=gateway.team_id, 

2718 resource_type="gateway", 

2719 resource_id=str(gateway.id), 

2720 custom_fields={ 

2721 "gateway_name": gateway.name, 

2722 "enabled": gateway.enabled, 

2723 "reachable": gateway.reachable, 

2724 }, 

2725 ) 

2726 

2727 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() 

2728 

2729 except PermissionError as e: 

2730 db.rollback() 

2731 

2732 # Structured logging: Log permission error 

2733 structured_logger.log( 

2734 level="WARNING", 

2735 message="Gateway state change failed due to permission error", 

2736 event_type="gateway_state_change_permission_denied", 

2737 component="gateway_service", 

2738 user_email=user_email, 

2739 resource_type="gateway", 

2740 resource_id=gateway_id, 

2741 error=e, 

2742 ) 

2743 raise e 

2744 except Exception as e: 

2745 db.rollback() 

2746 

2747 # Structured logging: Log generic gateway state change failure 

2748 structured_logger.log( 

2749 level="ERROR", 

2750 message="Gateway state change failed", 

2751 event_type="gateway_state_change_failed", 

2752 component="gateway_service", 

2753 user_email=user_email, 

2754 resource_type="gateway", 

2755 resource_id=gateway_id, 

2756 error=e, 

2757 ) 

2758 raise GatewayError(f"Failed to set gateway state: {str(e)}") 

2759 

2760 async def _notify_gateway_updated(self, gateway: DbGateway) -> None: 

2761 """ 

2762 Notify subscribers of gateway update. 

2763 

2764 Args: 

2765 gateway: Gateway to update 

2766 """ 

2767 event = { 

2768 "type": "gateway_updated", 

2769 "data": { 

2770 "id": gateway.id, 

2771 "name": gateway.name, 

2772 "url": gateway.url, 

2773 "description": gateway.description, 

2774 "enabled": gateway.enabled, 

2775 }, 

2776 "timestamp": datetime.now(timezone.utc).isoformat(), 

2777 } 

2778 await self._publish_event(event) 

2779 

2780 async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optional[str] = None) -> None: 

2781 """ 

2782 Delete a gateway by its ID. 

2783 

2784 Args: 

2785 db: Database session 

2786 gateway_id: Gateway ID 

2787 user_email: Email of user performing deletion (for ownership check) 

2788 

2789 Raises: 

2790 GatewayNotFoundError: If the gateway is not found 

2791 PermissionError: If user doesn't own the gateway 

2792 GatewayError: For other deletion errors 

2793 

2794 Examples: 

2795 >>> from mcpgateway.services.gateway_service import GatewayService 

2796 >>> from unittest.mock import MagicMock 

2797 >>> service = GatewayService() 

2798 >>> db = MagicMock() 

2799 >>> gateway = MagicMock() 

2800 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway 

2801 >>> db.delete = MagicMock() 

2802 >>> db.commit = MagicMock() 

2803 >>> service._notify_gateway_deleted = MagicMock() 

2804 >>> import asyncio 

2805 >>> try: 

2806 ... asyncio.run(service.delete_gateway(db, 'gateway_id', 'user@example.com')) 

2807 ... except Exception: 

2808 ... pass 

2809 >>> 

2810 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

2811 >>> asyncio.run(service._http_client.aclose()) 

2812 """ 

2813 try: 

2814 # Find gateway with eager loading for deletion to avoid N+1 queries 

2815 gateway = db.execute( 

2816 select(DbGateway) 

2817 .options( 

2818 selectinload(DbGateway.tools), 

2819 selectinload(DbGateway.resources), 

2820 selectinload(DbGateway.prompts), 

2821 joinedload(DbGateway.email_team), 

2822 ) 

2823 .where(DbGateway.id == gateway_id) 

2824 ).scalar_one_or_none() 

2825 

2826 if not gateway: 

2827 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2828 

2829 # Check ownership if user_email provided 

2830 if user_email: 

2831 # First-Party 

2832 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2833 

2834 permission_service = PermissionService(db) 

2835 if not await permission_service.check_resource_ownership(user_email, gateway): 

2836 raise PermissionError("Only the owner can delete this gateway") 

2837 

2838 # Store gateway info for notification before deletion 

2839 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} 

2840 gateway_name = gateway.name 

2841 gateway_team_id = gateway.team_id 

2842 gateway_url = gateway.url # Store URL before expiring the object 

2843 

2844 # Manually delete children first to avoid FK constraint violations 

2845 # (passive_deletes=True means ORM won't auto-cascade, we must do it explicitly) 

2846 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2847 tool_ids = [t.id for t in gateway.tools] 

2848 resource_ids = [r.id for r in gateway.resources] 

2849 prompt_ids = [p.id for p in gateway.prompts] 

2850 

2851 # Delete tool children and tools 

2852 if tool_ids: 

2853 for i in range(0, len(tool_ids), 500): 

2854 chunk = tool_ids[i : i + 500] 

2855 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2856 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2857 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2858 

2859 # Delete resource children and resources 

2860 if resource_ids: 

2861 for i in range(0, len(resource_ids), 500): 

2862 chunk = resource_ids[i : i + 500] 

2863 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2864 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2865 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2866 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2867 

2868 # Delete prompt children and prompts 

2869 if prompt_ids: 

2870 for i in range(0, len(prompt_ids), 500): 

2871 chunk = prompt_ids[i : i + 500] 

2872 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2873 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2874 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2875 

2876 # Expire gateway to clear cached relationships after bulk deletes 

2877 db.expire(gateway) 

2878 

2879 # Use DELETE with rowcount check for database-agnostic atomic delete 

2880 # (RETURNING is not supported on MySQL/MariaDB) 

2881 stmt = delete(DbGateway).where(DbGateway.id == gateway_id) 

2882 result = db.execute(stmt) 

2883 if result.rowcount == 0: 

2884 # Gateway was already deleted by another concurrent request 

2885 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2886 

2887 db.commit() 

2888 

2889 # Invalidate cache after successful deletion 

2890 cache = _get_registry_cache() 

2891 await cache.invalidate_gateways() 

2892 tool_lookup_cache = _get_tool_lookup_cache() 

2893 await tool_lookup_cache.invalidate_gateway(str(gateway_id)) 

2894 # Also invalidate tags cache since gateway tags may have changed 

2895 # First-Party 

2896 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

2897 

2898 await admin_stats_cache.invalidate_tags() 

2899 

2900 # Update tracking 

2901 self._active_gateways.discard(gateway_url) 

2902 

2903 # Notify subscribers 

2904 await self._notify_gateway_deleted(gateway_info) 

2905 

2906 logger.info(f"Permanently deleted gateway: {gateway_name}") 

2907 

2908 # Structured logging: Audit trail for gateway deletion 

2909 audit_trail.log_action( 

2910 user_id=user_email or "system", 

2911 action="delete_gateway", 

2912 resource_type="gateway", 

2913 resource_id=str(gateway_info["id"]), 

2914 resource_name=gateway_name, 

2915 user_email=user_email, 

2916 team_id=gateway_team_id, 

2917 old_values={ 

2918 "name": gateway_name, 

2919 "url": gateway_info["url"], 

2920 }, 

2921 db=db, 

2922 ) 

2923 

2924 # Structured logging: Log successful gateway deletion 

2925 structured_logger.log( 

2926 level="INFO", 

2927 message="Gateway deleted successfully", 

2928 event_type="gateway_deleted", 

2929 component="gateway_service", 

2930 user_email=user_email, 

2931 team_id=gateway_team_id, 

2932 resource_type="gateway", 

2933 resource_id=str(gateway_info["id"]), 

2934 custom_fields={ 

2935 "gateway_name": gateway_name, 

2936 "gateway_url": gateway_info["url"], 

2937 }, 

2938 ) 

2939 

2940 except PermissionError as pe: 

2941 db.rollback() 

2942 

2943 # Structured logging: Log permission error 

2944 structured_logger.log( 

2945 level="WARNING", 

2946 message="Gateway deletion failed due to permission error", 

2947 event_type="gateway_delete_permission_denied", 

2948 component="gateway_service", 

2949 user_email=user_email, 

2950 resource_type="gateway", 

2951 resource_id=gateway_id, 

2952 error=pe, 

2953 ) 

2954 raise 

2955 except Exception as e: 

2956 db.rollback() 

2957 

2958 # Structured logging: Log generic gateway deletion failure 

2959 structured_logger.log( 

2960 level="ERROR", 

2961 message="Gateway deletion failed", 

2962 event_type="gateway_deletion_failed", 

2963 component="gateway_service", 

2964 user_email=user_email, 

2965 resource_type="gateway", 

2966 resource_id=gateway_id, 

2967 error=e, 

2968 ) 

2969 raise GatewayError(f"Failed to delete gateway: {str(e)}") 

2970 

2971 async def _handle_gateway_failure(self, gateway: DbGateway) -> None: 

2972 """Tracks and handles gateway failures during health checks. 

2973 If the failure count exceeds the threshold, the gateway is deactivated. 

2974 

2975 Args: 

2976 gateway: The gateway object that failed its health check. 

2977 

2978 Returns: 

2979 None 

2980 

2981 Examples: 

2982 >>> from mcpgateway.services.gateway_service import GatewayService 

2983 >>> service = GatewayService() 

2984 >>> gateway = type('Gateway', (), { 

2985 ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True 

2986 ... })() 

2987 >>> service._gateway_failure_counts = {} 

2988 >>> import asyncio 

2989 >>> # Test failure counting 

2990 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS 

2991 >>> service._gateway_failure_counts['gw1'] >= 1 

2992 True 

2993 

2994 >>> # Test disabled gateway (no action) 

2995 >>> gateway.enabled = False 

2996 >>> old_count = service._gateway_failure_counts.get('gw1', 0) 

2997 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS 

2998 >>> service._gateway_failure_counts.get('gw1', 0) == old_count 

2999 True 

3000 """ 

3001 if GW_FAILURE_THRESHOLD == -1: 

3002 return # Gateway failure action disabled 

3003 

3004 if not gateway.enabled: 

3005 return # No action needed for inactive gateways 

3006 

3007 if not gateway.reachable: 

3008 return # No action needed for unreachable gateways 

3009 

3010 count = self._gateway_failure_counts.get(gateway.id, 0) + 1 

3011 self._gateway_failure_counts[gateway.id] = count 

3012 

3013 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).") 

3014 

3015 if count >= GW_FAILURE_THRESHOLD: 

3016 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...") 

3017 with cast(Any, SessionLocal)() as db: 

3018 await self.set_gateway_state(db, gateway.id, activate=True, reachable=False, only_update_reachable=True) 

3019 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation 

3020 

3021 async def check_health_of_gateways(self, gateways: List[DbGateway], user_email: Optional[str] = None) -> bool: 

3022 """Check health of a batch of gateways. 

3023 

3024 Performs an asynchronous health-check for each gateway in `gateways` using 

3025 an Async HTTP client. The function handles different authentication 

3026 modes (OAuth client_credentials and authorization_code, and non-OAuth 

3027 auth headers). When a gateway uses the authorization_code flow, the 

3028 optional `user_email` is used to look up stored user tokens with 

3029 fresh_db_session(). On individual failures the service will record the 

3030 failure and call internal failure handling which may mark a gateway 

3031 unreachable or deactivate it after repeated failures. If a previously 

3032 unreachable gateway becomes healthy again the service will attempt to 

3033 update its reachable status. 

3034 

3035 NOTE: This method intentionally does NOT take a db parameter. 

3036 DB access uses fresh_db_session() only when needed, avoiding holding 

3037 connections during HTTP calls to MCP servers. 

3038 

3039 Args: 

3040 gateways: List of DbGateway objects to check. 

3041 user_email: Optional MCP gateway user email used to retrieve 

3042 stored OAuth tokens for gateways using the 

3043 "authorization_code" grant type. If not provided, authorization 

3044 code flows that require a user token will be treated as failed. 

3045 

3046 Returns: 

3047 bool: True when the health-check batch completes. This return 

3048 value indicates completion of the checks, not that every gateway 

3049 was healthy. Individual gateway failures are handled internally 

3050 (via _handle_gateway_failure and status updates). 

3051 

3052 Examples: 

3053 >>> from mcpgateway.services.gateway_service import GatewayService 

3054 >>> from unittest.mock import MagicMock 

3055 >>> service = GatewayService() 

3056 >>> gateways = [MagicMock()] 

3057 >>> gateways[0].ca_certificate = None 

3058 >>> import asyncio 

3059 >>> result = asyncio.run(service.check_health_of_gateways(gateways)) 

3060 >>> isinstance(result, bool) 

3061 True 

3062 

3063 >>> # Test empty gateway list 

3064 >>> empty_result = asyncio.run(service.check_health_of_gateways([])) 

3065 >>> empty_result 

3066 True 

3067 

3068 >>> # Test multiple gateways (basic smoke) 

3069 >>> multiple_gateways = [MagicMock(), MagicMock(), MagicMock()] 

3070 >>> for i, gw in enumerate(multiple_gateways): 

3071 ... gw.name = f"gateway_{i}" 

3072 ... gw.url = f"http://gateway{i}.example.com" 

3073 ... gw.transport = "SSE" 

3074 ... gw.enabled = True 

3075 ... gw.reachable = True 

3076 ... gw.auth_value = {} 

3077 ... gw.ca_certificate = None 

3078 >>> multi_result = asyncio.run(service.check_health_of_gateways(multiple_gateways)) 

3079 >>> isinstance(multi_result, bool) 

3080 True 

3081 """ 

3082 start_time = time.monotonic() 

3083 concurrency_limit = min(settings.max_concurrent_health_checks, max(10, os.cpu_count() * 5)) # adaptive concurrency 

3084 semaphore = asyncio.Semaphore(concurrency_limit) 

3085 

3086 async def limited_check(gateway: DbGateway): 

3087 """ 

3088 Checks the health of a single gateway while respecting a concurrency limit. 

3089 

3090 This function checks the health of the given database gateway, ensuring that 

3091 the number of concurrent checks does not exceed a predefined limit. The check 

3092 is performed asynchronously and uses a semaphore to manage concurrency. 

3093 

3094 Args: 

3095 gateway (DbGateway): The database gateway whose health is to be checked. 

3096 

3097 Raises: 

3098 Any exceptions raised during the health check will be propagated to the caller. 

3099 """ 

3100 async with semaphore: 

3101 try: 

3102 await asyncio.wait_for( 

3103 self._check_single_gateway_health(gateway, user_email), 

3104 timeout=settings.gateway_health_check_timeout, 

3105 ) 

3106 except asyncio.TimeoutError: 

3107 logger.warning(f"Gateway {getattr(gateway, 'name', 'unknown')} health check timed out after {settings.gateway_health_check_timeout}s") 

3108 # Treat timeout as a failed health check 

3109 await self._handle_gateway_failure(gateway) 

3110 

3111 # Create trace span for health check batch 

3112 with create_span("gateway.health_check_batch", {"gateway.count": len(gateways), "check.type": "health"}) as batch_span: 

3113 # Chunk processing to avoid overload 

3114 if not gateways: 

3115 return True 

3116 chunk_size = concurrency_limit 

3117 for i in range(0, len(gateways), chunk_size): 

3118 # batch will be a sublist of gateways from index i to i + chunk_size 

3119 batch = gateways[i : i + chunk_size] 

3120 

3121 # Each task is a health check for a gateway in the batch, excluding those with auth_type == "one_time_auth" 

3122 tasks = [limited_check(gw) for gw in batch if gw.auth_type != "one_time_auth"] 

3123 

3124 # Execute all health checks concurrently 

3125 await asyncio.gather(*tasks, return_exceptions=True) 

3126 await asyncio.sleep(0.05) # small pause prevents network saturation 

3127 

3128 elapsed = time.monotonic() - start_time 

3129 

3130 if batch_span: 

3131 batch_span.set_attribute("check.duration_ms", int(elapsed * 1000)) 

3132 batch_span.set_attribute("check.completed", True) 

3133 

3134 logger.debug(f"Health check batch completed for {len(gateways)} gateways in {elapsed:.2f}s") 

3135 

3136 return True 

3137 

3138 async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Optional[str] = None) -> None: 

3139 """Check health of a single gateway. 

3140 

3141 NOTE: This method intentionally does NOT take a db parameter. 

3142 DB access uses fresh_db_session() only when needed, avoiding holding 

3143 connections during HTTP calls to MCP servers. 

3144 

3145 Args: 

3146 gateway: Gateway to check (may be detached from session) 

3147 user_email: Optional user email for OAuth token lookup 

3148 """ 

3149 # Extract gateway data upfront (gateway may be detached from session) 

3150 gateway_id = gateway.id 

3151 gateway_name = gateway.name 

3152 gateway_url = gateway.url 

3153 gateway_transport = gateway.transport 

3154 gateway_enabled = gateway.enabled 

3155 gateway_reachable = gateway.reachable 

3156 gateway_ca_certificate = gateway.ca_certificate 

3157 gateway_ca_certificate_sig = gateway.ca_certificate_sig 

3158 gateway_auth_type = gateway.auth_type 

3159 gateway_oauth_config = gateway.oauth_config 

3160 gateway_auth_value = gateway.auth_value 

3161 gateway_auth_query_params = gateway.auth_query_params 

3162 

3163 # Handle query_param auth - decrypt and apply to URL for health check 

3164 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

3165 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

3166 auth_query_params_decrypted = {} 

3167 for param_key, encrypted_value in gateway_auth_query_params.items(): 

3168 if encrypted_value: 

3169 try: 

3170 decrypted = decode_auth(encrypted_value) 

3171 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

3172 except Exception: 

3173 logger.debug(f"Failed to decrypt query param '{param_key}' for health check") 

3174 if auth_query_params_decrypted: 

3175 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) 

3176 

3177 # Sanitize URL for logging/telemetry (redacts sensitive query params) 

3178 gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted) 

3179 

3180 # Create span for individual gateway health check 

3181 with create_span( 

3182 "gateway.health_check", 

3183 { 

3184 "gateway.name": gateway_name, 

3185 "gateway.id": str(gateway_id), 

3186 "gateway.url": gateway_url_sanitized, 

3187 "gateway.transport": gateway_transport, 

3188 "gateway.enabled": gateway_enabled, 

3189 "http.method": "GET", 

3190 "http.url": gateway_url_sanitized, 

3191 }, 

3192 ) as span: 

3193 valid = False 

3194 if gateway_ca_certificate: 

3195 if settings.enable_ed25519_signing: 

3196 public_key_pem = settings.ed25519_public_key 

3197 valid = validate_signature(gateway_ca_certificate.encode(), gateway_ca_certificate_sig, public_key_pem) 

3198 else: 

3199 valid = True 

3200 if valid: 

3201 ssl_context = self.create_ssl_context(gateway_ca_certificate) 

3202 else: 

3203 ssl_context = None 

3204 

3205 def get_httpx_client_factory( 

3206 headers: dict[str, str] | None = None, 

3207 timeout: httpx.Timeout | None = None, 

3208 auth: httpx.Auth | None = None, 

3209 ) -> httpx.AsyncClient: 

3210 """Factory function to create httpx.AsyncClient with optional CA certificate. 

3211 

3212 Args: 

3213 headers: Optional headers for the client 

3214 timeout: Optional timeout for the client 

3215 auth: Optional auth for the client 

3216 

3217 Returns: 

3218 httpx.AsyncClient: Configured HTTPX async client 

3219 """ 

3220 return httpx.AsyncClient( 

3221 verify=ssl_context if ssl_context else get_default_verify(), 

3222 follow_redirects=True, 

3223 headers=headers, 

3224 timeout=timeout if timeout else get_http_timeout(), 

3225 auth=auth, 

3226 limits=httpx.Limits( 

3227 max_connections=settings.httpx_max_connections, 

3228 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

3229 keepalive_expiry=settings.httpx_keepalive_expiry, 

3230 ), 

3231 ) 

3232 

3233 # Use isolated client for gateway health checks (each gateway may have custom CA cert) 

3234 # Use admin timeout for health checks (fail fast, don't wait 120s for slow upstreams) 

3235 # Pass ssl_context if present, otherwise let get_isolated_http_client use skip_ssl_verify setting 

3236 async with get_isolated_http_client(timeout=settings.httpx_admin_read_timeout, verify=ssl_context) as client: 

3237 logger.debug(f"Checking health of gateway: {gateway_name} ({gateway_url_sanitized})") 

3238 try: 

3239 # Handle different authentication types 

3240 headers = {} 

3241 

3242 if gateway_auth_type == "oauth" and gateway_oauth_config: 

3243 grant_type = gateway_oauth_config.get("grant_type", "client_credentials") 

3244 

3245 if grant_type == "authorization_code": 

3246 # For Authorization Code flow, try to get stored tokens 

3247 try: 

3248 # First-Party 

3249 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

3250 

3251 # Use fresh session for OAuth token lookup 

3252 with fresh_db_session() as token_db: 

3253 token_storage = TokenStorageService(token_db) 

3254 

3255 # Get user-specific OAuth token 

3256 if not user_email: 

3257 if span: 

3258 span.set_attribute("health.status", "unhealthy") 

3259 span.set_attribute("error.message", "User email required for OAuth token") 

3260 await self._handle_gateway_failure(gateway) 

3261 return 

3262 

3263 access_token = await token_storage.get_user_token(gateway_id, user_email) 

3264 

3265 if access_token: 

3266 headers["Authorization"] = f"Bearer {access_token}" 

3267 else: 

3268 if span: 

3269 span.set_attribute("health.status", "unhealthy") 

3270 span.set_attribute("error.message", "No valid OAuth token for user") 

3271 await self._handle_gateway_failure(gateway) 

3272 return 

3273 except Exception as e: 

3274 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}") 

3275 if span: 

3276 span.set_attribute("health.status", "unhealthy") 

3277 span.set_attribute("error.message", "Failed to obtain stored OAuth token") 

3278 await self._handle_gateway_failure(gateway) 

3279 return 

3280 else: 

3281 # For Client Credentials flow, get token directly 

3282 try: 

3283 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config) 

3284 headers["Authorization"] = f"Bearer {access_token}" 

3285 except Exception as e: 

3286 if span: 

3287 span.set_attribute("health.status", "unhealthy") 

3288 span.set_attribute("error.message", str(e)) 

3289 await self._handle_gateway_failure(gateway) 

3290 return 

3291 else: 

3292 # Handle non-OAuth authentication (existing logic) 

3293 auth_data = gateway_auth_value or {} 

3294 if isinstance(auth_data, str): 

3295 headers = decode_auth(auth_data) 

3296 elif isinstance(auth_data, dict): 

3297 headers = {str(k): str(v) for k, v in auth_data.items()} 

3298 else: 

3299 headers = {} 

3300 

3301 # Perform the GET and raise on 4xx/5xx 

3302 if (gateway_transport).lower() == "sse": 

3303 timeout = httpx.Timeout(settings.health_check_timeout) 

3304 async with client.stream("GET", gateway_url, headers=headers, timeout=timeout) as response: 

3305 # This will raise immediately if status is 4xx/5xx 

3306 response.raise_for_status() 

3307 if span: 

3308 span.set_attribute("http.status_code", response.status_code) 

3309 elif (gateway_transport).lower() == "streamablehttp": 

3310 # Use session pool if enabled for faster health checks 

3311 use_pool = False 

3312 pool = None 

3313 if settings.mcp_session_pool_enabled: 

3314 try: 

3315 pool = get_mcp_session_pool() 

3316 use_pool = True 

3317 except RuntimeError: 

3318 # Pool not initialized (e.g., in tests), fall back to per-call sessions 

3319 pass 

3320 

3321 if use_pool and pool is not None: 

3322 # Health checks are system operations, not user-driven. 

3323 # Use system identity to isolate from user sessions. 

3324 async with pool.session( 

3325 url=gateway_url, 

3326 headers=headers, 

3327 transport_type=TransportType.STREAMABLE_HTTP, 

3328 httpx_client_factory=get_httpx_client_factory, 

3329 user_identity="_system_health_check", 

3330 gateway_id=gateway_id, 

3331 ) as pooled: 

3332 # Optional explicit RPC verification (off by default for performance). 

3333 # Pool's internal staleness check handles health via _validate_session. 

3334 if settings.mcp_session_pool_explicit_health_rpc: 

3335 await asyncio.wait_for( 

3336 pooled.session.list_tools(), 

3337 timeout=settings.health_check_timeout, 

3338 ) 

3339 else: 

3340 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout, httpx_client_factory=get_httpx_client_factory) as ( 

3341 read_stream, 

3342 write_stream, 

3343 _get_session_id, 

3344 ): 

3345 async with ClientSession(read_stream, write_stream) as session: 

3346 # Initialize the session 

3347 response = await session.initialize() 

3348 

3349 # Reactivate gateway if it was previously inactive and health check passed now 

3350 if gateway_enabled and not gateway_reachable: 

3351 logger.info(f"Reactivating gateway: {gateway_name}, as it is healthy now") 

3352 with cast(Any, SessionLocal)() as status_db: 

3353 await self.set_gateway_state(status_db, gateway_id, activate=True, reachable=True, only_update_reachable=True) 

3354 

3355 # Update last_seen with fresh session (gateway object is detached) 

3356 try: 

3357 with fresh_db_session() as update_db: 

3358 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

3359 if db_gateway: 

3360 db_gateway.last_seen = datetime.now(timezone.utc) 

3361 update_db.commit() 

3362 except Exception as update_error: 

3363 logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}") 

3364 

3365 # Auto-refresh tools/resources/prompts if enabled 

3366 if settings.auto_refresh_servers: 

3367 try: 

3368 # Throttling: Check if refresh is needed based on last_refresh_at 

3369 refresh_needed = True 

3370 if gateway.last_refresh_at: 

3371 # Default to config value if configured interval is missing 

3372 

3373 last_refresh = gateway.last_refresh_at 

3374 if last_refresh.tzinfo is None: 

3375 last_refresh = last_refresh.replace(tzinfo=timezone.utc) 

3376 

3377 # Use per-gateway interval if set, otherwise fall back to global default 

3378 refresh_interval = getattr(settings, "gateway_auto_refresh_interval", 300) 

3379 if gateway.refresh_interval_seconds is not None: 

3380 refresh_interval = gateway.refresh_interval_seconds 

3381 

3382 time_since_refresh = (datetime.now(timezone.utc) - last_refresh).total_seconds() 

3383 

3384 if time_since_refresh < refresh_interval: 

3385 refresh_needed = False 

3386 logger.debug(f"Skipping auto-refresh for {gateway_name}: last refreshed {int(time_since_refresh)}s ago") 

3387 

3388 if refresh_needed: 

3389 # Locking: Try to acquire lock to avoid conflict with manual refresh 

3390 lock = self._get_refresh_lock(gateway_id) 

3391 if not lock.locked(): 

3392 # Acquire lock to prevent concurrent manual refresh 

3393 async with lock: 

3394 await self._refresh_gateway_tools_resources_prompts( 

3395 gateway_id=gateway_id, 

3396 _user_email=user_email, 

3397 created_via="health_check", 

3398 pre_auth_headers=headers if headers else None, 

3399 gateway=gateway, 

3400 ) 

3401 else: 

3402 logger.debug(f"Skipping auto-refresh for {gateway_name}: lock held (likely manual refresh in progress)") 

3403 except Exception as refresh_error: 

3404 logger.warning(f"Failed to refresh tools for gateway {gateway_name}: {refresh_error}") 

3405 

3406 if span: 

3407 span.set_attribute("health.status", "healthy") 

3408 span.set_attribute("success", True) 

3409 

3410 except Exception as e: 

3411 if span: 

3412 span.set_attribute("health.status", "unhealthy") 

3413 span.set_attribute("error.message", str(e)) 

3414 

3415 # Set the logger as debug as this check happens for each interval 

3416 logger.debug(f"Health check failed for gateway {gateway_name}: {e}") 

3417 await self._handle_gateway_failure(gateway) 

3418 

3419 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]: 

3420 """ 

3421 Aggregate capabilities across all gateways. 

3422 

3423 Args: 

3424 db: Database session 

3425 

3426 Returns: 

3427 Dictionary of aggregated capabilities 

3428 

3429 Examples: 

3430 >>> from mcpgateway.services.gateway_service import GatewayService 

3431 >>> from unittest.mock import MagicMock 

3432 >>> service = GatewayService() 

3433 >>> db = MagicMock() 

3434 >>> gateway_mock = MagicMock() 

3435 >>> gateway_mock.capabilities = {"tools": {"listChanged": True}, "custom": {"feature": True}} 

3436 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_mock] 

3437 >>> import asyncio 

3438 >>> result = asyncio.run(service.aggregate_capabilities(db)) 

3439 >>> isinstance(result, dict) 

3440 True 

3441 >>> 'prompts' in result 

3442 True 

3443 >>> 'resources' in result 

3444 True 

3445 >>> 'tools' in result 

3446 True 

3447 >>> 'logging' in result 

3448 True 

3449 >>> result['prompts']['listChanged'] 

3450 True 

3451 >>> result['resources']['subscribe'] 

3452 True 

3453 >>> result['resources']['listChanged'] 

3454 True 

3455 >>> result['tools']['listChanged'] 

3456 True 

3457 >>> isinstance(result['logging'], dict) 

3458 True 

3459 

3460 >>> # Test with no gateways 

3461 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

3462 >>> empty_result = asyncio.run(service.aggregate_capabilities(db)) 

3463 >>> isinstance(empty_result, dict) 

3464 True 

3465 >>> 'tools' in empty_result 

3466 True 

3467 

3468 >>> # Test capability merging 

3469 >>> gateway1 = MagicMock() 

3470 >>> gateway1.capabilities = {"tools": {"feature1": True}} 

3471 >>> gateway2 = MagicMock() 

3472 >>> gateway2.capabilities = {"tools": {"feature2": True}} 

3473 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway1, gateway2] 

3474 >>> merged_result = asyncio.run(service.aggregate_capabilities(db)) 

3475 >>> merged_result['tools']['listChanged'] # Default capability 

3476 True 

3477 """ 

3478 capabilities = { 

3479 "prompts": {"listChanged": True}, 

3480 "resources": {"subscribe": True, "listChanged": True}, 

3481 "tools": {"listChanged": True}, 

3482 "logging": {}, 

3483 } 

3484 

3485 # Get all active gateways 

3486 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

3487 

3488 # Combine capabilities 

3489 for gateway in gateways: 

3490 if gateway.capabilities: 

3491 for key, value in gateway.capabilities.items(): 

3492 if key not in capabilities: 

3493 capabilities[key] = value 

3494 elif isinstance(value, dict): 

3495 capabilities[key].update(value) 

3496 

3497 return capabilities 

3498 

3499 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

3500 """Subscribe to gateway events. 

3501 

3502 Creates a new event queue and subscribes to gateway events. Events are 

3503 yielded as they are published. The subscription is automatically cleaned 

3504 up when the generator is closed or goes out of scope. 

3505 

3506 Yields: 

3507 Dict[str, Any]: Gateway event messages with 'type', 'data', and 'timestamp' fields 

3508 

3509 Examples: 

3510 >>> service = GatewayService() 

3511 >>> import asyncio 

3512 >>> from unittest.mock import MagicMock 

3513 >>> # Create a mock async generator for the event service 

3514 >>> async def mock_event_gen(): 

3515 ... yield {"type": "test_event", "data": "payload"} 

3516 >>> 

3517 >>> # Mock the event service to return our generator 

3518 >>> service._event_service = MagicMock() 

3519 >>> service._event_service.subscribe_events.return_value = mock_event_gen() 

3520 >>> 

3521 >>> # Test the subscription 

3522 >>> async def test_sub(): 

3523 ... async for event in service.subscribe_events(): 

3524 ... return event 

3525 >>> 

3526 >>> result = asyncio.run(test_sub()) 

3527 >>> result 

3528 {'type': 'test_event', 'data': 'payload'} 

3529 """ 

3530 async for event in self._event_service.subscribe_events(): 

3531 yield event 

3532 

3533 async def _initialize_gateway( 

3534 self, 

3535 url: str, 

3536 authentication: Optional[Dict[str, str]] = None, 

3537 transport: str = "SSE", 

3538 auth_type: Optional[str] = None, 

3539 oauth_config: Optional[Dict[str, Any]] = None, 

3540 ca_certificate: Optional[bytes] = None, 

3541 pre_auth_headers: Optional[Dict[str, str]] = None, 

3542 include_resources: bool = True, 

3543 include_prompts: bool = True, 

3544 auth_query_params: Optional[Dict[str, str]] = None, 

3545 oauth_auto_fetch_tool_flag: Optional[bool] = False, 

3546 ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: 

3547 """Initialize connection to a gateway and retrieve its capabilities. 

3548 

3549 Connects to an MCP gateway using the specified transport protocol, 

3550 performs the MCP handshake, and retrieves capabilities, tools, 

3551 resources, and prompts from the gateway. 

3552 

3553 Args: 

3554 url: Gateway URL to connect to 

3555 authentication: Optional authentication headers for the connection 

3556 transport: Transport protocol - "SSE" or "StreamableHTTP" 

3557 auth_type: Authentication type - "basic", "bearer", "authheaders", "oauth", "query_param" or None 

3558 oauth_config: OAuth configuration if auth_type is "oauth" 

3559 ca_certificate: CA certificate for SSL verification 

3560 pre_auth_headers: Pre-authenticated headers to skip OAuth token fetch (for reuse) 

3561 include_resources: Whether to include resources in the fetch 

3562 include_prompts: Whether to include prompts in the fetch 

3563 auth_query_params: Query param names for URL sanitization in error logs (decrypted values) 

3564 oauth_auto_fetch_tool_flag: Whether to skip the early return for OAuth Authorization Code flow. 

3565 When False (default), auth_code gateways return empty lists immediately (for health checks). 

3566 When True, attempts to connect even for auth_code gateways (for activation after user authorization). 

3567 

3568 Returns: 

3569 tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: 

3570 Capabilities dictionary, list of ToolCreate objects, list of ResourceCreate objects, and list of PromptCreate objects 

3571 

3572 Raises: 

3573 GatewayConnectionError: If connection or initialization fails 

3574 

3575 Examples: 

3576 >>> service = GatewayService() 

3577 >>> # Test parameter validation 

3578 >>> import asyncio 

3579 >>> from unittest.mock import AsyncMock 

3580 >>> # Avoid opening a real SSE connection in doctests (it can leak anyio streams on failure paths) 

3581 >>> service.connect_to_sse_server = AsyncMock(side_effect=GatewayConnectionError("boom")) 

3582 >>> async def test_params(): 

3583 ... try: 

3584 ... await service._initialize_gateway("hello//") 

3585 ... except Exception as e: 

3586 ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e) 

3587 

3588 >>> asyncio.run(test_params()) 

3589 True 

3590 

3591 >>> # Test default parameters 

3592 >>> hasattr(service, '_initialize_gateway') 

3593 True 

3594 >>> import inspect 

3595 >>> sig = inspect.signature(service._initialize_gateway) 

3596 >>> sig.parameters['transport'].default 

3597 'SSE' 

3598 >>> sig.parameters['authentication'].default is None 

3599 True 

3600 >>> 

3601 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

3602 >>> asyncio.run(service._http_client.aclose()) 

3603 """ 

3604 try: 

3605 if authentication is None: 

3606 authentication = {} 

3607 

3608 # Use pre-authenticated headers if provided (avoids duplicate OAuth token fetch) 

3609 if pre_auth_headers: 

3610 authentication = pre_auth_headers 

3611 # Handle OAuth authentication 

3612 elif auth_type == "oauth" and oauth_config: 

3613 grant_type = oauth_config.get("grant_type", "client_credentials") 

3614 

3615 if grant_type == "authorization_code": 

3616 if not oauth_auto_fetch_tool_flag: 

3617 # For Authorization Code flow during health checks, we can't initialize immediately 

3618 # because we need user consent. Just store the configuration 

3619 # and let the user complete the OAuth flow later. 

3620 logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""") 

3621 # Don't try to get access token here - it will be obtained during tool invocation 

3622 authentication = {} 

3623 

3624 # Skip MCP server connection for Authorization Code flow 

3625 # Tools will be fetched after OAuth completion 

3626 return {}, [], [], [] 

3627 # When flag is True (activation), skip token fetch but try to connect 

3628 # This allows activation to proceed - actual auth happens during tool invocation 

3629 logger.debug("OAuth Authorization Code gateway activation - skipping token fetch") 

3630 elif grant_type == "client_credentials": 

3631 # For Client Credentials flow, we can get the token immediately 

3632 try: 

3633 logger.debug("Obtaining OAuth access token for Client Credentials flow") 

3634 access_token = await self.oauth_manager.get_access_token(oauth_config) 

3635 authentication = {"Authorization": f"Bearer {access_token}"} 

3636 except Exception as e: 

3637 logger.error(f"Failed to obtain OAuth access token: {e}") 

3638 raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}") 

3639 

3640 capabilities = {} 

3641 tools = [] 

3642 resources = [] 

3643 prompts = [] 

3644 if auth_type in ("basic", "bearer", "authheaders") and isinstance(authentication, str): 

3645 authentication = decode_auth(authentication) 

3646 if transport.lower() == "sse": 

3647 capabilities, tools, resources, prompts = await self.connect_to_sse_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params) 

3648 elif transport.lower() == "streamablehttp": 

3649 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params) 

3650 

3651 return capabilities, tools, resources, prompts 

3652 except Exception as e: 

3653 

3654 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup 

3655 root_cause = e 

3656 if isinstance(e, BaseExceptionGroup): 

3657 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions: 

3658 root_cause = root_cause.exceptions[0] 

3659 sanitized_url = sanitize_url_for_logging(url, auth_query_params) 

3660 sanitized_error = sanitize_exception_message(str(root_cause), auth_query_params) 

3661 logger.error(f"Gateway initialization failed for {sanitized_url}: {sanitized_error}", exc_info=True) 

3662 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: {sanitized_error}") 

3663 

3664 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]: 

3665 """Sync function for database operations (runs in thread). 

3666 

3667 Args: 

3668 include_inactive: Whether to include inactive gateways 

3669 

3670 Returns: 

3671 List[DbGateway]: List of active gateways 

3672 

3673 Examples: 

3674 >>> from unittest.mock import patch, MagicMock 

3675 >>> service = GatewayService() 

3676 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session: 

3677 ... mock_db = MagicMock() 

3678 ... mock_session.return_value.__enter__.return_value = mock_db 

3679 ... mock_db.execute.return_value.scalars.return_value.all.return_value = [] 

3680 ... result = service._get_gateways() 

3681 ... isinstance(result, list) 

3682 True 

3683 

3684 >>> # Test include_inactive parameter handling 

3685 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session: 

3686 ... mock_db = MagicMock() 

3687 ... mock_session.return_value.__enter__.return_value = mock_db 

3688 ... mock_db.execute.return_value.scalars.return_value.all.return_value = [] 

3689 ... result_active_only = service._get_gateways(include_inactive=False) 

3690 ... isinstance(result_active_only, list) 

3691 True 

3692 """ 

3693 with cast(Any, SessionLocal)() as db: 

3694 if include_inactive: 

3695 return db.execute(select(DbGateway)).scalars().all() 

3696 # Only return active gateways 

3697 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

3698 

3699 def get_first_gateway_by_url(self, db: Session, url: str, team_id: Optional[str] = None, include_inactive: bool = False) -> Optional[GatewayRead]: 

3700 """Return the first DbGateway matching the given URL and optional team_id. 

3701 

3702 This is a synchronous helper intended for use from request handlers where 

3703 a simple DB lookup is needed. It normalizes the provided URL similar to 

3704 how gateways are stored and matches by the `url` column. If team_id is 

3705 provided, it restricts the search to that team. 

3706 

3707 Args: 

3708 db: Database session to use for the query 

3709 url: Gateway base URL to match (will be normalized) 

3710 team_id: Optional team id to restrict search 

3711 include_inactive: Whether to include inactive gateways 

3712 

3713 Returns: 

3714 Optional[DbGateway]: First matching gateway or None 

3715 """ 

3716 query = select(DbGateway).where(DbGateway.url == url) 

3717 if not include_inactive: 

3718 query = query.where(DbGateway.enabled) 

3719 if team_id: 

3720 query = query.where(DbGateway.team_id == team_id) 

3721 result = db.execute(query).scalars().first() 

3722 # Wrap the DB object in the GatewayRead schema for consistency with 

3723 # other service methods. Return None if no match found. 

3724 if result is None: 

3725 return None 

3726 return GatewayRead.model_validate(self._prepare_gateway_for_read(result)).masked() 

3727 

3728 async def _run_leader_heartbeat(self) -> None: 

3729 """Run leader heartbeat loop to keep leader key alive. 

3730 

3731 This runs independently from health checks to ensure the leader key 

3732 is refreshed frequently enough (every redis_leader_heartbeat_interval seconds) 

3733 to prevent expiration during long-running health check operations. 

3734 

3735 The loop exits if this instance loses leadership. 

3736 """ 

3737 while True: 

3738 try: 

3739 await asyncio.sleep(self._leader_heartbeat_interval) 

3740 

3741 if not self._redis_client: 

3742 return 

3743 

3744 # Check if we're still the leader 

3745 current_leader = await self._redis_client.get(self._leader_key) 

3746 if current_leader != self._instance_id: 

3747 logger.info("Lost Redis leadership, stopping heartbeat") 

3748 return 

3749 

3750 # Refresh the leader key TTL 

3751 await self._redis_client.expire(self._leader_key, self._leader_ttl) 

3752 logger.debug(f"Leader heartbeat: refreshed TTL to {self._leader_ttl}s") 

3753 

3754 except Exception as e: 

3755 logger.warning(f"Leader heartbeat error: {e}") 

3756 # Continue trying - the main health check loop will handle leadership loss 

3757 

3758 async def _run_health_checks(self, user_email: str) -> None: 

3759 """Run health checks periodically, 

3760 Uses Redis or FileLock - for multiple workers. 

3761 Uses simple health check for single worker mode. 

3762 

3763 NOTE: This method intentionally does NOT take a db parameter. 

3764 Health checks use fresh_db_session() only when DB access is needed, 

3765 avoiding holding connections during HTTP calls to MCP servers. 

3766 

3767 Args: 

3768 user_email: Email of the user for OAuth token lookup 

3769 

3770 Examples: 

3771 >>> service = GatewayService() 

3772 >>> service._health_check_interval = 0.1 # Short interval for testing 

3773 >>> service._redis_client = None 

3774 >>> import asyncio 

3775 >>> # Test that method exists and is callable 

3776 >>> callable(service._run_health_checks) 

3777 True 

3778 >>> # Test setup without actual execution (would run forever) 

3779 >>> hasattr(service, '_health_check_interval') 

3780 True 

3781 >>> service._health_check_interval == 0.1 

3782 True 

3783 """ 

3784 

3785 while True: 

3786 try: 

3787 if self._redis_client and settings.cache_type == "redis": 

3788 # Redis-based leader check (async, decode_responses=True returns strings) 

3789 # Note: Leader key TTL refresh is handled by _run_leader_heartbeat task 

3790 current_leader = await self._redis_client.get(self._leader_key) 

3791 if current_leader != self._instance_id: 

3792 return 

3793 

3794 # Run health checks 

3795 gateways = await asyncio.to_thread(self._get_gateways) 

3796 if gateways: 

3797 await self.check_health_of_gateways(gateways, user_email) 

3798 

3799 await asyncio.sleep(self._health_check_interval) 

3800 

3801 elif settings.cache_type == "none": 

3802 try: 

3803 # For single worker mode, run health checks directly 

3804 gateways = await asyncio.to_thread(self._get_gateways) 

3805 if gateways: 

3806 await self.check_health_of_gateways(gateways, user_email) 

3807 except Exception as e: 

3808 logger.error(f"Health check run failed: {str(e)}") 

3809 

3810 await asyncio.sleep(self._health_check_interval) 

3811 

3812 else: 

3813 # FileLock-based leader fallback 

3814 try: 

3815 self._file_lock.acquire(timeout=0) 

3816 logger.info("File lock acquired. Running health checks.") 

3817 

3818 while True: 

3819 gateways = await asyncio.to_thread(self._get_gateways) 

3820 if gateways: 

3821 await self.check_health_of_gateways(gateways, user_email) 

3822 await asyncio.sleep(self._health_check_interval) 

3823 

3824 except Timeout: 

3825 logger.debug("File lock already held. Retrying later.") 

3826 await asyncio.sleep(self._health_check_interval) 

3827 

3828 except Exception as e: 

3829 logger.error(f"FileLock health check failed: {str(e)}") 

3830 

3831 finally: 

3832 if self._file_lock.is_locked: 

3833 try: 

3834 self._file_lock.release() 

3835 logger.info("Released file lock.") 

3836 except Exception as e: 

3837 logger.warning(f"Failed to release file lock: {str(e)}") 

3838 

3839 except Exception as e: 

3840 logger.error(f"Unexpected error in health check loop: {str(e)}") 

3841 await asyncio.sleep(self._health_check_interval) 

3842 

3843 def _get_auth_headers(self) -> Dict[str, str]: 

3844 """Get default headers for gateway requests (no authentication). 

3845 

3846 SECURITY: This method intentionally does NOT include authentication credentials. 

3847 Each gateway should have its own auth_value configured. Never send this gateway's 

3848 admin credentials to remote servers. 

3849 

3850 Returns: 

3851 dict: Default headers without authentication 

3852 

3853 Examples: 

3854 >>> service = GatewayService() 

3855 >>> headers = service._get_auth_headers() 

3856 >>> isinstance(headers, dict) 

3857 True 

3858 >>> 'Content-Type' in headers 

3859 True 

3860 >>> headers['Content-Type'] 

3861 'application/json' 

3862 >>> 'Authorization' not in headers # No credentials leaked 

3863 True 

3864 """ 

3865 return {"Content-Type": "application/json"} 

3866 

3867 async def _notify_gateway_added(self, gateway: DbGateway) -> None: 

3868 """Notify subscribers of gateway addition. 

3869 

3870 Args: 

3871 gateway: Gateway to add 

3872 """ 

3873 event = { 

3874 "type": "gateway_added", 

3875 "data": { 

3876 "id": gateway.id, 

3877 "name": gateway.name, 

3878 "url": gateway.url, 

3879 "description": gateway.description, 

3880 "enabled": gateway.enabled, 

3881 }, 

3882 "timestamp": datetime.now(timezone.utc).isoformat(), 

3883 } 

3884 await self._publish_event(event) 

3885 

3886 async def _notify_gateway_activated(self, gateway: DbGateway) -> None: 

3887 """Notify subscribers of gateway activation. 

3888 

3889 Args: 

3890 gateway: Gateway to activate 

3891 """ 

3892 event = { 

3893 "type": "gateway_activated", 

3894 "data": { 

3895 "id": gateway.id, 

3896 "name": gateway.name, 

3897 "url": gateway.url, 

3898 "enabled": gateway.enabled, 

3899 "reachable": gateway.reachable, 

3900 }, 

3901 "timestamp": datetime.now(timezone.utc).isoformat(), 

3902 } 

3903 await self._publish_event(event) 

3904 

3905 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None: 

3906 """Notify subscribers of gateway deactivation. 

3907 

3908 Args: 

3909 gateway: Gateway database object 

3910 """ 

3911 event = { 

3912 "type": "gateway_deactivated", 

3913 "data": { 

3914 "id": gateway.id, 

3915 "name": gateway.name, 

3916 "url": gateway.url, 

3917 "enabled": gateway.enabled, 

3918 "reachable": gateway.reachable, 

3919 }, 

3920 "timestamp": datetime.now(timezone.utc).isoformat(), 

3921 } 

3922 await self._publish_event(event) 

3923 

3924 async def _notify_gateway_offline(self, gateway: DbGateway) -> None: 

3925 """ 

3926 Notify subscribers that gateway is offline (Enabled but Unreachable). 

3927 

3928 Args: 

3929 gateway: Gateway database object 

3930 """ 

3931 event = { 

3932 "type": "gateway_offline", 

3933 "data": { 

3934 "id": gateway.id, 

3935 "name": gateway.name, 

3936 "url": gateway.url, 

3937 "enabled": True, 

3938 "reachable": False, 

3939 }, 

3940 "timestamp": datetime.now(timezone.utc).isoformat(), 

3941 } 

3942 await self._publish_event(event) 

3943 

3944 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None: 

3945 """Notify subscribers of gateway deletion. 

3946 

3947 Args: 

3948 gateway_info: Dict containing information about gateway to delete 

3949 """ 

3950 event = { 

3951 "type": "gateway_deleted", 

3952 "data": gateway_info, 

3953 "timestamp": datetime.now(timezone.utc).isoformat(), 

3954 } 

3955 await self._publish_event(event) 

3956 

3957 async def _notify_gateway_removed(self, gateway: DbGateway) -> None: 

3958 """Notify subscribers of gateway removal (deactivation). 

3959 

3960 Args: 

3961 gateway: Gateway to remove 

3962 """ 

3963 event = { 

3964 "type": "gateway_removed", 

3965 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled}, 

3966 "timestamp": datetime.now(timezone.utc).isoformat(), 

3967 } 

3968 await self._publish_event(event) 

3969 

3970 def convert_gateway_to_read(self, gateway: DbGateway) -> GatewayRead: 

3971 """Convert a DbGateway instance to a GatewayRead Pydantic model. 

3972 

3973 Args: 

3974 gateway: Gateway database object 

3975 

3976 Returns: 

3977 GatewayRead: Pydantic model instance 

3978 """ 

3979 gateway_dict = gateway.__dict__.copy() 

3980 gateway_dict.pop("_sa_instance_state", None) 

3981 

3982 # Ensure auth_value is properly encoded 

3983 if isinstance(gateway.auth_value, dict): 

3984 gateway_dict["auth_value"] = encode_auth(gateway.auth_value) 

3985 

3986 if gateway.tags: 

3987 # Check tags are list of strings or list of Dict[str, str] 

3988 if isinstance(gateway.tags[0], str): 

3989 # Convert tags from List[str] to List[Dict[str, str]] for GatewayRead 

3990 gateway_dict["tags"] = validate_tags_field(gateway.tags) 

3991 else: 

3992 gateway_dict["tags"] = gateway.tags 

3993 else: 

3994 gateway_dict["tags"] = [] 

3995 

3996 # Include metadata fields 

3997 gateway_dict["created_by"] = getattr(gateway, "created_by", None) 

3998 gateway_dict["modified_by"] = getattr(gateway, "modified_by", None) 

3999 gateway_dict["created_at"] = getattr(gateway, "created_at", None) 

4000 gateway_dict["updated_at"] = getattr(gateway, "updated_at", None) 

4001 gateway_dict["version"] = getattr(gateway, "version", None) 

4002 gateway_dict["team"] = getattr(gateway, "team", None) 

4003 

4004 return GatewayRead.model_validate(gateway_dict).masked() 

4005 

4006 def _prepare_gateway_for_read(self, gateway: DbGateway) -> DbGateway: 

4007 """DEPRECATED: Use convert_gateway_to_read instead. 

4008 

4009 Prepare a gateway object for GatewayRead validation. 

4010 

4011 Ensures auth_value is in the correct format (encoded string) for the schema. 

4012 Converts legacy List[str] tags to List[Dict[str, str]] format for GatewayRead schema. 

4013 

4014 Args: 

4015 gateway: Gateway database object 

4016 

4017 Returns: 

4018 Gateway object with properly formatted auth_value and tags 

4019 """ 

4020 # If auth_value is a dict, encode it to string for GatewayRead schema 

4021 if isinstance(gateway.auth_value, dict): 

4022 gateway.auth_value = encode_auth(gateway.auth_value) 

4023 

4024 # Handle legacy List[str] tags - convert to List[Dict[str, str]] for GatewayRead schema 

4025 if gateway.tags: 

4026 if isinstance(gateway.tags[0], str): 

4027 # Legacy format: convert to dict format 

4028 gateway.tags = validate_tags_field(gateway.tags) 

4029 

4030 return gateway 

4031 

4032 def _create_db_tool( 

4033 self, 

4034 tool: ToolCreate, 

4035 gateway: DbGateway, 

4036 created_by: Optional[str] = None, 

4037 created_from_ip: Optional[str] = None, 

4038 created_via: Optional[str] = None, 

4039 created_user_agent: Optional[str] = None, 

4040 ) -> DbTool: 

4041 """Create a DbTool with consistent federation metadata across all scenarios. 

4042 

4043 Args: 

4044 tool: Tool creation schema 

4045 gateway: Gateway database object 

4046 created_by: Username who created/updated this tool 

4047 created_from_ip: IP address of creator 

4048 created_via: Creation method (ui, api, federation, rediscovery) 

4049 created_user_agent: User agent of creation request 

4050 

4051 Returns: 

4052 DbTool: Consistently configured database tool object 

4053 """ 

4054 return DbTool( 

4055 original_name=tool.name, 

4056 custom_name=tool.name, 

4057 custom_name_slug=slugify(tool.name), 

4058 display_name=generate_display_name(tool.name), 

4059 url=gateway.url, 

4060 original_description=tool.description, 

4061 description=tool.description, 

4062 integration_type="MCP", # Gateway-discovered tools are MCP type 

4063 request_type=tool.request_type, 

4064 headers=tool.headers, 

4065 input_schema=tool.input_schema, 

4066 annotations=tool.annotations, 

4067 jsonpath_filter=tool.jsonpath_filter, 

4068 auth_type=gateway.auth_type, 

4069 auth_value=encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value, 

4070 # Federation metadata - consistent across all scenarios 

4071 created_by=created_by or "system", 

4072 created_from_ip=created_from_ip, 

4073 created_via=created_via or "federation", 

4074 created_user_agent=created_user_agent, 

4075 federation_source=gateway.name, 

4076 version=1, 

4077 # Inherit team assignment and visibility from gateway 

4078 team_id=gateway.team_id, 

4079 owner_email=gateway.owner_email, 

4080 visibility="public", # Federated tools should be public for discovery 

4081 ) 

4082 

4083 def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGateway, created_via: str) -> List[DbTool]: 

4084 """Helper to handle update-or-create logic for tools from MCP server. 

4085 

4086 Args: 

4087 db: Database session 

4088 tools: List of tools from MCP server 

4089 gateway: Gateway object 

4090 created_via: String indicating creation source ("oauth", "update", etc.) 

4091 

4092 Returns: 

4093 List of new tools to be added to the database 

4094 """ 

4095 if not tools: 

4096 return [] 

4097 

4098 tools_to_add = [] 

4099 

4100 # Batch fetch all existing tools for this gateway 

4101 tool_names = [tool.name for tool in tools if tool is not None] 

4102 if not tool_names: 

4103 return [] 

4104 

4105 existing_tools_query = select(DbTool).where(DbTool.gateway_id == gateway.id, DbTool.original_name.in_(tool_names)) 

4106 existing_tools = db.execute(existing_tools_query).scalars().all() 

4107 existing_tools_map = {tool.original_name: tool for tool in existing_tools} 

4108 

4109 for tool in tools: 

4110 if tool is None: 

4111 logger.warning("Skipping None tool in tools list") 

4112 continue 

4113 

4114 try: 

4115 # Check if tool already exists for this gateway from the tools_map 

4116 existing_tool = existing_tools_map.get(tool.name) 

4117 if existing_tool: 

4118 # Update existing tool if there are changes 

4119 fields_to_update = False 

4120 

4121 # Check basic field changes 

4122 # Compare against original_description (upstream value) rather than description 

4123 # (which may have been customized by the user) 

4124 basic_fields_changed = ( 

4125 existing_tool.url != gateway.url 

4126 or existing_tool.original_description != tool.description 

4127 or existing_tool.integration_type != "MCP" 

4128 or existing_tool.request_type != tool.request_type 

4129 ) 

4130 

4131 # Check schema and configuration changes 

4132 schema_fields_changed = ( 

4133 existing_tool.headers != tool.headers 

4134 or existing_tool.input_schema != tool.input_schema 

4135 or existing_tool.output_schema != tool.output_schema 

4136 or existing_tool.jsonpath_filter != tool.jsonpath_filter 

4137 ) 

4138 

4139 # Check authentication and visibility changes. 

4140 # DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict). 

4141 # encode_auth() uses a random nonce, so comparing ciphertext would always 

4142 # differ even when the plaintext hasn't changed. Compare on decoded 

4143 # (plaintext) values instead, and only encode on the write path. 

4144 # If decoding fails (legacy/corrupt data), fall back to direct comparison. 

4145 try: 

4146 gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {}) 

4147 existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {} 

4148 auth_value_changed = existing_tool_auth_plain != gateway_auth_plain 

4149 except Exception: 

4150 gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value 

4151 auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value 

4152 auth_fields_changed = existing_tool.auth_type != gateway.auth_type or auth_value_changed or existing_tool.visibility != gateway.visibility 

4153 

4154 if basic_fields_changed or schema_fields_changed or auth_fields_changed: 

4155 fields_to_update = True 

4156 if fields_to_update: 

4157 existing_tool.url = gateway.url 

4158 # Only overwrite user-facing description if it hasn't been customized 

4159 # (mirrors original_name/custom_name pattern) 

4160 if existing_tool.description == existing_tool.original_description: 

4161 existing_tool.description = tool.description 

4162 existing_tool.original_description = tool.description 

4163 existing_tool.integration_type = "MCP" 

4164 existing_tool.request_type = tool.request_type 

4165 existing_tool.headers = tool.headers 

4166 existing_tool.input_schema = tool.input_schema 

4167 existing_tool.output_schema = tool.output_schema 

4168 existing_tool.jsonpath_filter = tool.jsonpath_filter 

4169 existing_tool.auth_type = gateway.auth_type 

4170 existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value 

4171 existing_tool.visibility = gateway.visibility 

4172 logger.debug(f"Updated existing tool: {tool.name}") 

4173 else: 

4174 # Create new tool if it doesn't exist 

4175 db_tool = self._create_db_tool( 

4176 tool=tool, 

4177 gateway=gateway, 

4178 created_by="system", 

4179 created_via=created_via, 

4180 ) 

4181 # Attach relationship to avoid NoneType during flush 

4182 db_tool.gateway = gateway 

4183 tools_to_add.append(db_tool) 

4184 logger.debug(f"Created new tool: {tool.name}") 

4185 except Exception as e: 

4186 logger.warning(f"Failed to process tool {getattr(tool, 'name', 'unknown')}: {e}") 

4187 continue 

4188 

4189 return tools_to_add 

4190 

4191 def _update_or_create_resources(self, db: Session, resources: List[Any], gateway: DbGateway, created_via: str) -> List[DbResource]: 

4192 """Helper to handle update-or-create logic for resources from MCP server. 

4193 

4194 Args: 

4195 db: Database session 

4196 resources: List of resources from MCP server 

4197 gateway: Gateway object 

4198 created_via: String indicating creation source ("oauth", "update", etc.) 

4199 

4200 Returns: 

4201 List of new resources to be added to the database 

4202 """ 

4203 if not resources: 

4204 return [] 

4205 

4206 resources_to_add = [] 

4207 

4208 # Batch fetch all existing resources for this gateway 

4209 resource_uris = [resource.uri for resource in resources if resource is not None] 

4210 if not resource_uris: 

4211 return [] 

4212 

4213 existing_resources_query = select(DbResource).where(DbResource.gateway_id == gateway.id, DbResource.uri.in_(resource_uris)) 

4214 existing_resources = db.execute(existing_resources_query).scalars().all() 

4215 existing_resources_map = {resource.uri: resource for resource in existing_resources} 

4216 

4217 for resource in resources: 

4218 if resource is None: 

4219 logger.warning("Skipping None resource in resources list") 

4220 continue 

4221 

4222 try: 

4223 # Check if resource already exists for this gateway from the resources_map 

4224 existing_resource = existing_resources_map.get(resource.uri) 

4225 

4226 if existing_resource: 

4227 # Update existing resource if there are changes 

4228 fields_to_update = False 

4229 

4230 if ( 

4231 existing_resource.name != resource.name 

4232 or existing_resource.description != resource.description 

4233 or existing_resource.mime_type != resource.mime_type 

4234 or existing_resource.uri_template != resource.uri_template 

4235 or existing_resource.visibility != gateway.visibility 

4236 ): 

4237 fields_to_update = True 

4238 

4239 if fields_to_update: 

4240 existing_resource.name = resource.name 

4241 existing_resource.description = resource.description 

4242 existing_resource.mime_type = resource.mime_type 

4243 existing_resource.uri_template = resource.uri_template 

4244 existing_resource.visibility = gateway.visibility 

4245 logger.debug(f"Updated existing resource: {resource.uri}") 

4246 else: 

4247 # Create new resource if it doesn't exist 

4248 db_resource = DbResource( 

4249 uri=resource.uri, 

4250 name=resource.name, 

4251 description=resource.description, 

4252 mime_type=resource.mime_type, 

4253 uri_template=resource.uri_template, 

4254 gateway_id=gateway.id, 

4255 created_by="system", 

4256 created_via=created_via, 

4257 visibility=gateway.visibility, 

4258 ) 

4259 resources_to_add.append(db_resource) 

4260 logger.debug(f"Created new resource: {resource.uri}") 

4261 except Exception as e: 

4262 logger.warning(f"Failed to process resource {getattr(resource, 'uri', 'unknown')}: {e}") 

4263 continue 

4264 

4265 return resources_to_add 

4266 

4267 def _update_or_create_prompts(self, db: Session, prompts: List[Any], gateway: DbGateway, created_via: str) -> List[DbPrompt]: 

4268 """Helper to handle update-or-create logic for prompts from MCP server. 

4269 

4270 Args: 

4271 db: Database session 

4272 prompts: List of prompts from MCP server 

4273 gateway: Gateway object 

4274 created_via: String indicating creation source ("oauth", "update", etc.) 

4275 

4276 Returns: 

4277 List of new prompts to be added to the database 

4278 """ 

4279 if not prompts: 

4280 return [] 

4281 

4282 prompts_to_add = [] 

4283 

4284 # Batch fetch all existing prompts for this gateway 

4285 prompt_names = [prompt.name for prompt in prompts if prompt is not None] 

4286 if not prompt_names: 

4287 return [] 

4288 

4289 existing_prompts_query = select(DbPrompt).where(DbPrompt.gateway_id == gateway.id, DbPrompt.original_name.in_(prompt_names)) 

4290 existing_prompts = db.execute(existing_prompts_query).scalars().all() 

4291 existing_prompts_map = {prompt.original_name: prompt for prompt in existing_prompts} 

4292 

4293 for prompt in prompts: 

4294 if prompt is None: 

4295 logger.warning("Skipping None prompt in prompts list") 

4296 continue 

4297 

4298 try: 

4299 # Check if resource already exists for this gateway from the prompts_map 

4300 existing_prompt = existing_prompts_map.get(prompt.name) 

4301 

4302 if existing_prompt: 

4303 # Update existing prompt if there are changes 

4304 fields_to_update = False 

4305 

4306 if ( 

4307 existing_prompt.description != prompt.description 

4308 or existing_prompt.template != (prompt.template if hasattr(prompt, "template") else "") 

4309 or existing_prompt.visibility != gateway.visibility 

4310 ): 

4311 fields_to_update = True 

4312 

4313 if fields_to_update: 

4314 existing_prompt.description = prompt.description 

4315 existing_prompt.template = prompt.template if hasattr(prompt, "template") else "" 

4316 existing_prompt.visibility = gateway.visibility 

4317 logger.debug(f"Updated existing prompt: {prompt.name}") 

4318 else: 

4319 # Create new prompt if it doesn't exist 

4320 db_prompt = DbPrompt( 

4321 name=prompt.name, 

4322 original_name=prompt.name, 

4323 custom_name=prompt.name, 

4324 display_name=prompt.name, 

4325 description=prompt.description, 

4326 template=prompt.template if hasattr(prompt, "template") else "", 

4327 argument_schema={}, # Use argument_schema instead of arguments 

4328 gateway_id=gateway.id, 

4329 created_by="system", 

4330 created_via=created_via, 

4331 visibility=gateway.visibility, 

4332 ) 

4333 db_prompt.gateway = gateway 

4334 prompts_to_add.append(db_prompt) 

4335 logger.debug(f"Created new prompt: {prompt.name}") 

4336 except Exception as e: 

4337 logger.warning(f"Failed to process prompt {getattr(prompt, 'name', 'unknown')}: {e}") 

4338 continue 

4339 

4340 return prompts_to_add 

4341 

4342 async def _refresh_gateway_tools_resources_prompts( 

4343 self, 

4344 gateway_id: str, 

4345 _user_email: Optional[str] = None, 

4346 created_via: str = "health_check", 

4347 pre_auth_headers: Optional[Dict[str, str]] = None, 

4348 gateway: Optional[DbGateway] = None, 

4349 include_resources: bool = True, 

4350 include_prompts: bool = True, 

4351 ) -> Dict[str, int]: 

4352 """Refresh tools, resources, and prompts for a gateway during health checks. 

4353 

4354 Fetches the latest tools/resources/prompts from the MCP server and syncs 

4355 with the database (add new, update changed, remove stale). Only performs 

4356 DB operations if actual changes are detected. 

4357 

4358 This method uses fresh_db_session() internally to avoid holding 

4359 connections during HTTP calls to MCP servers. 

4360 

4361 Args: 

4362 gateway_id: ID of the gateway to refresh 

4363 _user_email: Optional user email for OAuth token lookup (unused currently) 

4364 created_via: String indicating creation source (default: "health_check") 

4365 pre_auth_headers: Pre-authenticated headers from health check to avoid duplicate OAuth token fetch 

4366 gateway: Optional DbGateway object to avoid redundant DB lookup 

4367 include_resources: Whether to include resources in the refresh 

4368 include_prompts: Whether to include prompts in the refresh 

4369 

4370 Returns: 

4371 Dict with counts: {tools_added, tools_removed, resources_added, 

4372 resources_removed, prompts_added, prompts_removed} 

4373 

4374 Examples: 

4375 >>> from mcpgateway.services.gateway_service import GatewayService 

4376 >>> from unittest.mock import patch, MagicMock, AsyncMock 

4377 >>> import asyncio 

4378 

4379 >>> # Test gateway not found returns empty result 

4380 >>> service = GatewayService() 

4381 >>> mock_session = MagicMock() 

4382 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = None 

4383 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4384 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4385 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4386 >>> result['tools_added'] == 0 and result['tools_removed'] == 0 

4387 True 

4388 >>> result['resources_added'] == 0 and result['resources_removed'] == 0 

4389 True 

4390 >>> result['success'] is True and result['error'] is None 

4391 True 

4392 

4393 >>> # Test disabled gateway returns empty result 

4394 >>> mock_gw = MagicMock() 

4395 >>> mock_gw.enabled = False 

4396 >>> mock_gw.reachable = True 

4397 >>> mock_gw.name = 'test_gw' 

4398 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = mock_gw 

4399 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4400 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4401 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4402 >>> result['tools_added'] 

4403 0 

4404 

4405 >>> # Test unreachable gateway returns empty result 

4406 >>> mock_gw.enabled = True 

4407 >>> mock_gw.reachable = False 

4408 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4409 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4410 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4411 >>> result['tools_added'] 

4412 0 

4413 

4414 >>> # Test method is async and callable 

4415 >>> import inspect 

4416 >>> inspect.iscoroutinefunction(service._refresh_gateway_tools_resources_prompts) 

4417 True 

4418 >>> 

4419 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

4420 >>> asyncio.run(service._http_client.aclose()) 

4421 """ 

4422 result = { 

4423 "tools_added": 0, 

4424 "tools_removed": 0, 

4425 "resources_added": 0, 

4426 "resources_removed": 0, 

4427 "prompts_added": 0, 

4428 "prompts_removed": 0, 

4429 "tools_updated": 0, 

4430 "resources_updated": 0, 

4431 "prompts_updated": 0, 

4432 "success": True, 

4433 "error": None, 

4434 "validation_errors": [], 

4435 } 

4436 

4437 # Fetch gateway metadata only (no relationships needed for MCP call) 

4438 # Use provided gateway object if available to save a DB call 

4439 gateway_name = None 

4440 gateway_url = None 

4441 gateway_transport = None 

4442 gateway_auth_type = None 

4443 gateway_auth_value = None 

4444 gateway_oauth_config = None 

4445 gateway_ca_certificate = None 

4446 gateway_auth_query_params = None 

4447 

4448 if gateway: 

4449 if not gateway.enabled or not gateway.reachable: 

4450 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway.name}") 

4451 return result 

4452 

4453 gateway_name = gateway.name 

4454 gateway_url = gateway.url 

4455 gateway_transport = gateway.transport 

4456 gateway_auth_type = gateway.auth_type 

4457 gateway_auth_value = gateway.auth_value 

4458 gateway_oauth_config = gateway.oauth_config 

4459 gateway_ca_certificate = gateway.ca_certificate 

4460 gateway_auth_query_params = gateway.auth_query_params 

4461 else: 

4462 with fresh_db_session() as db: 

4463 gateway_obj = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

4464 

4465 if not gateway_obj: 

4466 logger.warning(f"Gateway {gateway_id} not found for tool refresh") 

4467 return result 

4468 

4469 if not gateway_obj.enabled or not gateway_obj.reachable: 

4470 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway_obj.name}") 

4471 return result 

4472 

4473 # Extract metadata before session closes 

4474 gateway_name = gateway_obj.name 

4475 gateway_url = gateway_obj.url 

4476 gateway_transport = gateway_obj.transport 

4477 gateway_auth_type = gateway_obj.auth_type 

4478 gateway_auth_value = gateway_obj.auth_value 

4479 gateway_oauth_config = gateway_obj.oauth_config 

4480 gateway_ca_certificate = gateway_obj.ca_certificate 

4481 gateway_auth_query_params = gateway_obj.auth_query_params 

4482 

4483 # Handle query_param auth - decrypt and apply to URL for refresh 

4484 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

4485 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

4486 auth_query_params_decrypted = {} 

4487 for param_key, encrypted_value in gateway_auth_query_params.items(): 

4488 if encrypted_value: 

4489 try: 

4490 decrypted = decode_auth(encrypted_value) 

4491 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

4492 except Exception: 

4493 logger.debug(f"Failed to decrypt query param '{param_key}' for tool refresh") 

4494 if auth_query_params_decrypted: 

4495 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) 

4496 

4497 # Fetch tools/resources/prompts from MCP server (no DB connection held) 

4498 try: 

4499 _capabilities, tools, resources, prompts = await self._initialize_gateway( 

4500 url=gateway_url, 

4501 authentication=gateway_auth_value, 

4502 transport=gateway_transport, 

4503 auth_type=gateway_auth_type, 

4504 oauth_config=gateway_oauth_config, 

4505 ca_certificate=gateway_ca_certificate.encode() if gateway_ca_certificate else None, 

4506 pre_auth_headers=pre_auth_headers, 

4507 include_resources=include_resources, 

4508 include_prompts=include_prompts, 

4509 auth_query_params=auth_query_params_decrypted, 

4510 ) 

4511 except Exception as e: 

4512 logger.warning(f"Failed to fetch tools from gateway {gateway_name}: {e}") 

4513 result["success"] = False 

4514 result["error"] = str(e) 

4515 return result 

4516 

4517 # For authorization_code OAuth gateways, empty responses may indicate incomplete auth flow 

4518 # Skip only if it's an auth_code gateway with no data (user may not have completed authorization) 

4519 is_auth_code_gateway = gateway_oauth_config and isinstance(gateway_oauth_config, dict) and gateway_oauth_config.get("grant_type") == "authorization_code" 

4520 if not tools and not resources and not prompts and is_auth_code_gateway: 

4521 logger.debug(f"No tools/resources/prompts returned from auth_code gateway {gateway_name} (user may not have authorized)") 

4522 return result 

4523 

4524 # For non-auth_code gateways, empty responses are legitimate and will clear stale items 

4525 

4526 # Update database with fresh session 

4527 with fresh_db_session() as db: 

4528 # Fetch gateway with relationships for update/comparison 

4529 gateway = db.execute( 

4530 select(DbGateway) 

4531 .options( 

4532 selectinload(DbGateway.tools), 

4533 selectinload(DbGateway.resources), 

4534 selectinload(DbGateway.prompts), 

4535 ) 

4536 .where(DbGateway.id == gateway_id) 

4537 ).scalar_one_or_none() 

4538 

4539 if not gateway: 

4540 result["success"] = False 

4541 result["error"] = f"Gateway {gateway_id} not found during refresh" 

4542 return result 

4543 

4544 new_tool_names = [tool.name for tool in tools] 

4545 new_resource_uris = [resource.uri for resource in resources] if include_resources else None 

4546 new_prompt_names = [prompt.name for prompt in prompts] if include_prompts else None 

4547 

4548 # Track dirty objects before update operations to count per-type updates 

4549 pending_tools_before = {obj for obj in db.dirty if isinstance(obj, DbTool)} 

4550 pending_resources_before = {obj for obj in db.dirty if isinstance(obj, DbResource)} 

4551 pending_prompts_before = {obj for obj in db.dirty if isinstance(obj, DbPrompt)} 

4552 

4553 # Update/create tools, resources, and prompts 

4554 tools_to_add = self._update_or_create_tools(db, tools, gateway, created_via) 

4555 resources_to_add = self._update_or_create_resources(db, resources, gateway, created_via) if include_resources else [] 

4556 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, created_via) if include_prompts else [] 

4557 

4558 # Count per-type updates 

4559 result["tools_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbTool)} - pending_tools_before) 

4560 result["resources_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbResource)} - pending_resources_before) 

4561 result["prompts_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbPrompt)} - pending_prompts_before) 

4562 

4563 # Only delete MCP-discovered items (not user-created entries) 

4564 # Excludes "api", "ui", None (legacy/user-created) to preserve user entries 

4565 mcp_created_via_values = {"MCP", "federation", "health_check", "manual_refresh", "oauth", "update"} 

4566 

4567 # Find and remove stale tools (only MCP-discovered ones) 

4568 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names and tool.created_via in mcp_created_via_values] 

4569 if stale_tool_ids: 

4570 for i in range(0, len(stale_tool_ids), 500): 

4571 chunk = stale_tool_ids[i : i + 500] 

4572 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

4573 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

4574 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

4575 result["tools_removed"] = len(stale_tool_ids) 

4576 

4577 # Find and remove stale resources (only MCP-discovered ones, only if resources were fetched) 

4578 stale_resource_ids = [] 

4579 if new_resource_uris is not None: 

4580 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris and resource.created_via in mcp_created_via_values] 

4581 if stale_resource_ids: 

4582 for i in range(0, len(stale_resource_ids), 500): 

4583 chunk = stale_resource_ids[i : i + 500] 

4584 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

4585 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

4586 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

4587 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

4588 result["resources_removed"] = len(stale_resource_ids) 

4589 

4590 # Find and remove stale prompts (only MCP-discovered ones, only if prompts were fetched) 

4591 stale_prompt_ids = [] 

4592 if new_prompt_names is not None: 

4593 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names and prompt.created_via in mcp_created_via_values] 

4594 if stale_prompt_ids: 

4595 for i in range(0, len(stale_prompt_ids), 500): 

4596 chunk = stale_prompt_ids[i : i + 500] 

4597 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

4598 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

4599 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

4600 result["prompts_removed"] = len(stale_prompt_ids) 

4601 

4602 # Expire gateway if stale items were deleted 

4603 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

4604 db.expire(gateway) 

4605 

4606 # Add new items in chunks 

4607 chunk_size = 50 

4608 if tools_to_add: 

4609 for i in range(0, len(tools_to_add), chunk_size): 

4610 chunk = tools_to_add[i : i + chunk_size] 

4611 db.add_all(chunk) 

4612 db.flush() 

4613 result["tools_added"] = len(tools_to_add) 

4614 

4615 if resources_to_add: 

4616 for i in range(0, len(resources_to_add), chunk_size): 

4617 chunk = resources_to_add[i : i + chunk_size] 

4618 db.add_all(chunk) 

4619 db.flush() 

4620 result["resources_added"] = len(resources_to_add) 

4621 

4622 if prompts_to_add: 

4623 for i in range(0, len(prompts_to_add), chunk_size): 

4624 chunk = prompts_to_add[i : i + chunk_size] 

4625 db.add_all(chunk) 

4626 db.flush() 

4627 result["prompts_added"] = len(prompts_to_add) 

4628 

4629 gateway.last_refresh_at = datetime.now(timezone.utc) 

4630 

4631 total_changes = ( 

4632 result["tools_added"] 

4633 + result["tools_removed"] 

4634 + result["tools_updated"] 

4635 + result["resources_added"] 

4636 + result["resources_removed"] 

4637 + result["resources_updated"] 

4638 + result["prompts_added"] 

4639 + result["prompts_removed"] 

4640 + result["prompts_updated"] 

4641 ) 

4642 

4643 has_changes = total_changes > 0 

4644 

4645 if has_changes: 

4646 db.commit() 

4647 logger.info( 

4648 f"Refreshed gateway {gateway_name}: " 

4649 f"tools(+{result['tools_added']}/-{result['tools_removed']}/~{result['tools_updated']}), " 

4650 f"resources(+{result['resources_added']}/-{result['resources_removed']}/~{result['resources_updated']}), " 

4651 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}/~{result['prompts_updated']})" 

4652 ) 

4653 

4654 # Invalidate caches per-type based on actual changes 

4655 cache = _get_registry_cache() 

4656 if result["tools_added"] > 0 or result["tools_removed"] > 0 or result["tools_updated"] > 0: 

4657 await cache.invalidate_tools() 

4658 if result["resources_added"] > 0 or result["resources_removed"] > 0 or result["resources_updated"] > 0: 

4659 await cache.invalidate_resources() 

4660 if result["prompts_added"] > 0 or result["prompts_removed"] > 0 or result["prompts_updated"] > 0: 

4661 await cache.invalidate_prompts() 

4662 

4663 # Invalidate tool lookup cache for this gateway 

4664 tool_lookup_cache = _get_tool_lookup_cache() 

4665 await tool_lookup_cache.invalidate_gateway(str(gateway_id)) 

4666 else: 

4667 db.commit() 

4668 logger.debug(f"No changes detected during refresh of gateway {gateway_name}") 

4669 

4670 return result 

4671 

4672 def _get_refresh_lock(self, gateway_id: str) -> asyncio.Lock: 

4673 """Get or create a per-gateway refresh lock. 

4674 

4675 This ensures only one refresh operation can run for a given gateway at a time. 

4676 

4677 Args: 

4678 gateway_id: ID of the gateway to get the lock for 

4679 

4680 Returns: 

4681 asyncio.Lock: The lock for the specified gateway 

4682 

4683 Examples: 

4684 >>> from mcpgateway.services.gateway_service import GatewayService 

4685 >>> service = GatewayService() 

4686 >>> lock1 = service._get_refresh_lock('gw-123') 

4687 >>> lock2 = service._get_refresh_lock('gw-123') 

4688 >>> lock1 is lock2 

4689 True 

4690 >>> lock3 = service._get_refresh_lock('gw-456') 

4691 >>> lock1 is lock3 

4692 False 

4693 """ 

4694 if gateway_id not in self._refresh_locks: 

4695 self._refresh_locks[gateway_id] = asyncio.Lock() 

4696 return self._refresh_locks[gateway_id] 

4697 

4698 async def refresh_gateway_manually( 

4699 self, 

4700 gateway_id: str, 

4701 include_resources: bool = True, 

4702 include_prompts: bool = True, 

4703 user_email: Optional[str] = None, 

4704 request_headers: Optional[Dict[str, str]] = None, 

4705 ) -> Dict[str, Any]: 

4706 """Manually trigger a refresh of tools/resources/prompts for a gateway. 

4707 

4708 This method provides a public API for triggering an immediate refresh 

4709 of a gateway's tools, resources, and prompts from its MCP server. 

4710 It includes concurrency control via per-gateway locking. 

4711 

4712 Args: 

4713 gateway_id: Gateway ID to refresh 

4714 include_resources: Whether to include resources in the refresh 

4715 include_prompts: Whether to include prompts in the refresh 

4716 user_email: Email of the user triggering the refresh 

4717 request_headers: Optional request headers for passthrough authentication 

4718 

4719 Returns: 

4720 Dict with counts: {tools_added, tools_updated, tools_removed, 

4721 resources_added, resources_updated, resources_removed, 

4722 prompts_added, prompts_updated, prompts_removed, 

4723 validation_errors, duration_ms, refreshed_at} 

4724 

4725 Raises: 

4726 GatewayNotFoundError: If the gateway does not exist 

4727 GatewayError: If another refresh is already in progress for this gateway 

4728 

4729 Examples: 

4730 >>> from mcpgateway.services.gateway_service import GatewayService 

4731 >>> from unittest.mock import patch, MagicMock, AsyncMock 

4732 >>> import asyncio 

4733 

4734 >>> # Test method is async 

4735 >>> service = GatewayService() 

4736 >>> import inspect 

4737 >>> inspect.iscoroutinefunction(service.refresh_gateway_manually) 

4738 True 

4739 """ 

4740 start_time = time.monotonic() 

4741 

4742 pre_auth_headers = {} 

4743 

4744 # Check if gateway exists before acquiring lock 

4745 with fresh_db_session() as db: 

4746 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

4747 if not gateway: 

4748 raise GatewayNotFoundError(f"Gateway with ID '{gateway_id}' not found") 

4749 gateway_name = gateway.name 

4750 

4751 # Get passthrough headers if request headers provided 

4752 if request_headers: 

4753 pre_auth_headers = get_passthrough_headers(request_headers, {}, db, gateway) 

4754 

4755 lock = self._get_refresh_lock(gateway_id) 

4756 

4757 # Check if lock is already held (concurrent refresh in progress) 

4758 if lock.locked(): 

4759 raise GatewayError(f"Refresh already in progress for gateway {gateway_name}") 

4760 

4761 async with lock: 

4762 logger.info(f"Starting manual refresh for gateway {gateway_name} (ID: {gateway_id})") 

4763 

4764 result = await self._refresh_gateway_tools_resources_prompts( 

4765 gateway_id=gateway_id, 

4766 _user_email=user_email, 

4767 created_via="manual_refresh", 

4768 pre_auth_headers=pre_auth_headers, 

4769 gateway=gateway, 

4770 include_resources=include_resources, 

4771 include_prompts=include_prompts, 

4772 ) 

4773 # Note: last_refresh_at is updated inside _refresh_gateway_tools_resources_prompts on success 

4774 

4775 result["duration_ms"] = (time.monotonic() - start_time) * 1000 

4776 result["refreshed_at"] = datetime.now(timezone.utc) 

4777 

4778 log_level = logging.INFO if result.get("success", True) else logging.WARNING 

4779 status_msg = "succeeded" if result.get("success", True) else f"failed: {result.get('error')}" 

4780 

4781 logger.log( 

4782 log_level, 

4783 f"Manual refresh for gateway {gateway_id} {status_msg}. Stats: " 

4784 f"tools(+{result['tools_added']}/-{result['tools_removed']}), " 

4785 f"resources(+{result['resources_added']}/-{result['resources_removed']}), " 

4786 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}) " 

4787 f"in {result['duration_ms']:.2f}ms", 

4788 ) 

4789 

4790 return result 

4791 

4792 async def _publish_event(self, event: Dict[str, Any]) -> None: 

4793 """Publish event to all subscribers. 

4794 

4795 Args: 

4796 event: event dictionary 

4797 

4798 Examples: 

4799 >>> import asyncio 

4800 >>> from unittest.mock import AsyncMock 

4801 >>> service = GatewayService() 

4802 >>> # Mock the underlying event service 

4803 >>> service._event_service = AsyncMock() 

4804 >>> test_event = {"type": "test", "data": {}} 

4805 >>> 

4806 >>> asyncio.run(service._publish_event(test_event)) 

4807 >>> 

4808 >>> # Verify the event was passed to the event service 

4809 >>> service._event_service.publish_event.assert_awaited_with(test_event) 

4810 """ 

4811 await self._event_service.publish_event(event) 

4812 

4813 def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> tuple[list[ToolCreate], list[str]]: 

4814 """Validate tools individually with richer logging and error aggregation. 

4815 

4816 Args: 

4817 tools: list of tool dicts 

4818 context: caller context, e.g. "oauth" to tailor errors/messages 

4819 

4820 Returns: 

4821 tuple[list[ToolCreate], list[str]]: Tuple of (valid tools, validation errors) 

4822 

4823 Raises: 

4824 OAuthToolValidationError: If all tools fail validation in OAuth context 

4825 GatewayConnectionError: If all tools fail validation in default context 

4826 """ 

4827 valid_tools: list[ToolCreate] = [] 

4828 validation_errors: list[str] = [] 

4829 

4830 for i, tool_dict in enumerate(tools): 

4831 tool_name = tool_dict.get("name", f"unknown_tool_{i}") 

4832 try: 

4833 logger.debug(f"Validating tool: {tool_name}") 

4834 validated_tool = ToolCreate.model_validate(tool_dict) 

4835 valid_tools.append(validated_tool) 

4836 logger.debug(f"Tool '{tool_name}' validated successfully") 

4837 except ValidationError as e: 

4838 error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}" 

4839 logger.error(error_msg) 

4840 logger.debug(f"Failed tool schema: {tool_dict}") 

4841 validation_errors.append(error_msg) 

4842 except ValueError as e: 

4843 if "JSON structure exceeds maximum depth" in str(e): 

4844 error_msg = f"Tool '{tool_name}' schema too deeply nested. " f"Current depth limit: {settings.validation_max_json_depth}" 

4845 logger.error(error_msg) 

4846 logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable") 

4847 else: 

4848 error_msg = f"ValueError for tool '{tool_name}': {str(e)}" 

4849 logger.error(error_msg) 

4850 validation_errors.append(error_msg) 

4851 except Exception as e: # pragma: no cover - defensive 

4852 error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}" 

4853 logger.error(error_msg, exc_info=True) 

4854 validation_errors.append(error_msg) 

4855 

4856 if validation_errors: 

4857 logger.warning(f"Tool validation completed with {len(validation_errors)} error(s). " f"Successfully validated {len(valid_tools)} tool(s).") 

4858 for err in validation_errors[:3]: 

4859 logger.debug(f"Validation error: {err}") 

4860 

4861 if not valid_tools and validation_errors: 

4862 if context == "oauth": 

4863 raise OAuthToolValidationError(f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}") 

4864 raise GatewayConnectionError(f"Failed to fetch tools: All {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}") 

4865 

4866 return valid_tools, validation_errors 

4867 

4868 async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None): 

4869 """Connect to an MCP server running with SSE transport, skipping URL validation. 

4870 

4871 This is used for OAuth-protected servers where we've already validated the token works. 

4872 

4873 Args: 

4874 server_url: The URL of the SSE MCP server to connect to. 

4875 authentication: Optional dictionary containing authentication headers. 

4876 

4877 Returns: 

4878 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

4879 """ 

4880 if authentication is None: 

4881 authentication = {} 

4882 

4883 # Skip validation for OAuth servers - we already validated via OAuth flow 

4884 # Use async with for both sse_client and ClientSession 

4885 try: 

4886 async with sse_client(url=server_url, headers=authentication) as streams: 

4887 async with ClientSession(*streams) as session: 

4888 # Initialize the session 

4889 response = await session.initialize() 

4890 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

4891 logger.debug(f"Server capabilities: {capabilities}") 

4892 

4893 response = await session.list_tools() 

4894 tools = response.tools 

4895 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

4896 

4897 tools, _ = self._validate_tools(tools, context="oauth") 

4898 if tools: 

4899 logger.info(f"Fetched {len(tools)} tools from gateway") 

4900 # Fetch resources if supported 

4901 

4902 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

4903 resources = [] 

4904 if capabilities.get("resources"): 

4905 try: 

4906 response = await session.list_resources() 

4907 raw_resources = response.resources 

4908 for resource in raw_resources: 

4909 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

4910 # Convert AnyUrl to string if present 

4911 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

4912 resource_data["uri"] = str(resource_data["uri"]) 

4913 # Add default content if not present (will be fetched on demand) 

4914 if "content" not in resource_data: 

4915 resource_data["content"] = "" 

4916 try: 

4917 resources.append(ResourceCreate.model_validate(resource_data)) 

4918 except Exception: 

4919 # If validation fails, create minimal resource 

4920 resources.append( 

4921 ResourceCreate( 

4922 uri=str(resource_data.get("uri", "")), 

4923 name=resource_data.get("name", ""), 

4924 description=resource_data.get("description"), 

4925 mime_type=resource_data.get("mimeType"), 

4926 uri_template=resource_data.get("uriTemplate") or None, 

4927 content="", 

4928 ) 

4929 ) 

4930 logger.info(f"Fetched {len(resources)} resources from gateway") 

4931 except Exception as e: 

4932 logger.warning(f"Failed to fetch resources: {e}") 

4933 

4934 # resource template URI 

4935 try: 

4936 response_templates = await session.list_resource_templates() 

4937 raw_resources_templates = response_templates.resourceTemplates 

4938 resource_templates = [] 

4939 for resource_template in raw_resources_templates: 

4940 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

4941 

4942 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 

4943 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

4944 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

4945 

4946 if "content" not in resource_template_data: 

4947 resource_template_data["content"] = "" 

4948 

4949 resources.append(ResourceCreate.model_validate(resource_template_data)) 

4950 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

4951 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") 

4952 except Exception as e: 

4953 logger.warning(f"Failed to fetch resource templates: {e}") 

4954 

4955 # Fetch prompts if supported 

4956 prompts = [] 

4957 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

4958 if capabilities.get("prompts"): 

4959 try: 

4960 response = await session.list_prompts() 

4961 raw_prompts = response.prompts 

4962 for prompt in raw_prompts: 

4963 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

4964 # Add default template if not present 

4965 if "template" not in prompt_data: 

4966 prompt_data["template"] = "" 

4967 try: 

4968 prompts.append(PromptCreate.model_validate(prompt_data)) 

4969 except Exception: 

4970 # If validation fails, create minimal prompt 

4971 prompts.append( 

4972 PromptCreate( 

4973 name=prompt_data.get("name", ""), 

4974 description=prompt_data.get("description"), 

4975 template=prompt_data.get("template", ""), 

4976 ) 

4977 ) 

4978 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

4979 except Exception as e: 

4980 logger.warning(f"Failed to fetch prompts: {e}") 

4981 

4982 return capabilities, tools, resources, prompts 

4983 except Exception as e: 

4984 # Note: This function is for OAuth servers only, which don't use query param auth 

4985 # Still sanitize in case exception contains URL with static sensitive params 

4986 sanitized_url = sanitize_url_for_logging(server_url) 

4987 sanitized_error = sanitize_exception_message(str(e)) 

4988 logger.error(f"SSE connection error details: {type(e).__name__}: {sanitized_error}", exc_info=True) 

4989 raise GatewayConnectionError(f"Failed to connect to SSE server at {sanitized_url}: {sanitized_error}") 

4990 

4991 async def connect_to_sse_server( 

4992 self, 

4993 server_url: str, 

4994 authentication: Optional[Dict[str, str]] = None, 

4995 ca_certificate: Optional[bytes] = None, 

4996 include_prompts: bool = True, 

4997 include_resources: bool = True, 

4998 auth_query_params: Optional[Dict[str, str]] = None, 

4999 ): 

5000 """Connect to an MCP server running with SSE transport. 

5001 

5002 Args: 

5003 server_url: The URL of the SSE MCP server to connect to. 

5004 authentication: Optional dictionary containing authentication headers. 

5005 ca_certificate: Optional CA certificate for SSL verification. 

5006 include_prompts: Whether to fetch prompts from the server. 

5007 include_resources: Whether to fetch resources from the server. 

5008 auth_query_params: Query param names for URL sanitization in error logs. 

5009 

5010 Returns: 

5011 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5012 """ 

5013 if authentication is None: 

5014 authentication = {} 

5015 

5016 def get_httpx_client_factory( 

5017 headers: dict[str, str] | None = None, 

5018 timeout: httpx.Timeout | None = None, 

5019 auth: httpx.Auth | None = None, 

5020 ) -> httpx.AsyncClient: 

5021 """Factory function to create httpx.AsyncClient with optional CA certificate. 

5022 

5023 Args: 

5024 headers: Optional headers for the client 

5025 timeout: Optional timeout for the client 

5026 auth: Optional auth for the client 

5027 

5028 Returns: 

5029 httpx.AsyncClient: Configured HTTPX async client 

5030 """ 

5031 if ca_certificate: 

5032 ctx = self.create_ssl_context(ca_certificate) 

5033 else: 

5034 ctx = None 

5035 return httpx.AsyncClient( 

5036 verify=ctx if ctx else get_default_verify(), 

5037 follow_redirects=True, 

5038 headers=headers, 

5039 timeout=timeout if timeout else get_http_timeout(), 

5040 auth=auth, 

5041 limits=httpx.Limits( 

5042 max_connections=settings.httpx_max_connections, 

5043 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

5044 keepalive_expiry=settings.httpx_keepalive_expiry, 

5045 ), 

5046 ) 

5047 

5048 # Use async with for both sse_client and ClientSession 

5049 async with sse_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as streams: 

5050 async with ClientSession(*streams) as session: 

5051 # Initialize the session 

5052 response = await session.initialize() 

5053 

5054 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5055 logger.debug(f"Server capabilities: {capabilities}") 

5056 

5057 response = await session.list_tools() 

5058 tools = response.tools 

5059 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5060 

5061 tools, _ = self._validate_tools(tools) 

5062 if tools: 

5063 logger.info(f"Fetched {len(tools)} tools from gateway") 

5064 # Fetch resources if supported 

5065 resources = [] 

5066 if include_resources: 

5067 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5068 if capabilities.get("resources"): 

5069 try: 

5070 response = await session.list_resources() 

5071 raw_resources = response.resources 

5072 for resource in raw_resources: 

5073 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5074 # Convert AnyUrl to string if present 

5075 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5076 resource_data["uri"] = str(resource_data["uri"]) 

5077 # Add default content if not present (will be fetched on demand) 

5078 if "content" not in resource_data: 

5079 resource_data["content"] = "" 

5080 try: 

5081 resources.append(ResourceCreate.model_validate(resource_data)) 

5082 except Exception: 

5083 # If validation fails, create minimal resource 

5084 resources.append( 

5085 ResourceCreate( 

5086 uri=str(resource_data.get("uri", "")), 

5087 name=resource_data.get("name", ""), 

5088 description=resource_data.get("description"), 

5089 mime_type=resource_data.get("mimeType"), 

5090 uri_template=resource_data.get("uriTemplate") or None, 

5091 content="", 

5092 ) 

5093 ) 

5094 logger.info(f"Fetched {len(resources)} resources from gateway") 

5095 except Exception as e: 

5096 logger.warning(f"Failed to fetch resources: {e}") 

5097 

5098 # resource template URI 

5099 try: 

5100 response_templates = await session.list_resource_templates() 

5101 raw_resources_templates = response_templates.resourceTemplates 

5102 resource_templates = [] 

5103 for resource_template in raw_resources_templates: 

5104 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5105 

5106 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 

5107 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5108 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5109 

5110 if "content" not in resource_template_data: 

5111 resource_template_data["content"] = "" 

5112 

5113 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5114 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5115 logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway") 

5116 except Exception as ei: 

5117 logger.warning(f"Failed to fetch resource templates: {ei}") 

5118 

5119 # Fetch prompts if supported 

5120 prompts = [] 

5121 if include_prompts: 

5122 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5123 if capabilities.get("prompts"): 

5124 try: 

5125 response = await session.list_prompts() 

5126 raw_prompts = response.prompts 

5127 for prompt in raw_prompts: 

5128 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5129 # Add default template if not present 

5130 if "template" not in prompt_data: 

5131 prompt_data["template"] = "" 

5132 try: 

5133 prompts.append(PromptCreate.model_validate(prompt_data)) 

5134 except Exception: 

5135 # If validation fails, create minimal prompt 

5136 prompts.append( 

5137 PromptCreate( 

5138 name=prompt_data.get("name", ""), 

5139 description=prompt_data.get("description"), 

5140 template=prompt_data.get("template", ""), 

5141 ) 

5142 ) 

5143 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5144 except Exception as e: 

5145 logger.warning(f"Failed to fetch prompts: {e}") 

5146 

5147 return capabilities, tools, resources, prompts 

5148 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params) 

5149 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established") 

5150 

5151 async def connect_to_streamablehttp_server( 

5152 self, 

5153 server_url: str, 

5154 authentication: Optional[Dict[str, str]] = None, 

5155 ca_certificate: Optional[bytes] = None, 

5156 include_prompts: bool = True, 

5157 include_resources: bool = True, 

5158 auth_query_params: Optional[Dict[str, str]] = None, 

5159 ): 

5160 """Connect to an MCP server running with Streamable HTTP transport. 

5161 

5162 Args: 

5163 server_url: The URL of the Streamable HTTP MCP server to connect to. 

5164 authentication: Optional dictionary containing authentication headers. 

5165 ca_certificate: Optional CA certificate for SSL verification. 

5166 include_prompts: Whether to fetch prompts from the server. 

5167 include_resources: Whether to fetch resources from the server. 

5168 auth_query_params: Query param names for URL sanitization in error logs. 

5169 

5170 Returns: 

5171 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5172 """ 

5173 if authentication is None: 

5174 authentication = {} 

5175 

5176 # Use authentication directly instead 

5177 def get_httpx_client_factory( 

5178 headers: dict[str, str] | None = None, 

5179 timeout: httpx.Timeout | None = None, 

5180 auth: httpx.Auth | None = None, 

5181 ) -> httpx.AsyncClient: 

5182 """Factory function to create httpx.AsyncClient with optional CA certificate. 

5183 

5184 Args: 

5185 headers: Optional headers for the client 

5186 timeout: Optional timeout for the client 

5187 auth: Optional auth for the client 

5188 

5189 Returns: 

5190 httpx.AsyncClient: Configured HTTPX async client 

5191 """ 

5192 if ca_certificate: 

5193 ctx = self.create_ssl_context(ca_certificate) 

5194 else: 

5195 ctx = None 

5196 return httpx.AsyncClient( 

5197 verify=ctx if ctx else get_default_verify(), 

5198 follow_redirects=True, 

5199 headers=headers, 

5200 timeout=timeout if timeout else get_http_timeout(), 

5201 auth=auth, 

5202 limits=httpx.Limits( 

5203 max_connections=settings.httpx_max_connections, 

5204 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

5205 keepalive_expiry=settings.httpx_keepalive_expiry, 

5206 ), 

5207 ) 

5208 

5209 async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): 

5210 async with ClientSession(read_stream, write_stream) as session: 

5211 # Initialize the session 

5212 response = await session.initialize() 

5213 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5214 logger.debug(f"Server capabilities: {capabilities}") 

5215 

5216 response = await session.list_tools() 

5217 tools = response.tools 

5218 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5219 

5220 tools, _ = self._validate_tools(tools) 

5221 for tool in tools: 

5222 tool.request_type = "STREAMABLEHTTP" 

5223 if tools: 

5224 logger.info(f"Fetched {len(tools)} tools from gateway") 

5225 

5226 # Fetch resources if supported 

5227 resources = [] 

5228 if include_resources: 

5229 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5230 if capabilities.get("resources"): 

5231 try: 

5232 response = await session.list_resources() 

5233 raw_resources = response.resources 

5234 for resource in raw_resources: 

5235 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5236 # Convert AnyUrl to string if present 

5237 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5238 resource_data["uri"] = str(resource_data["uri"]) 

5239 # Add default content if not present 

5240 if "content" not in resource_data: 

5241 resource_data["content"] = "" 

5242 try: 

5243 resources.append(ResourceCreate.model_validate(resource_data)) 

5244 except Exception: 

5245 # If validation fails, create minimal resource 

5246 resources.append( 

5247 ResourceCreate( 

5248 uri=str(resource_data.get("uri", "")), 

5249 name=resource_data.get("name", ""), 

5250 description=resource_data.get("description"), 

5251 mime_type=resource_data.get("mimeType"), 

5252 uri_template=resource_data.get("uriTemplate") or None, 

5253 content="", 

5254 ) 

5255 ) 

5256 logger.info(f"Fetched {len(resources)} resources from gateway") 

5257 except Exception as e: 

5258 logger.warning(f"Failed to fetch resources: {e}") 

5259 

5260 # resource template URI 

5261 try: 

5262 response_templates = await session.list_resource_templates() 

5263 raw_resources_templates = response_templates.resourceTemplates 

5264 resource_templates = [] 

5265 for resource_template in raw_resources_templates: 

5266 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5267 

5268 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 

5269 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5270 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5271 

5272 if "content" not in resource_template_data: 

5273 resource_template_data["content"] = "" 

5274 

5275 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5276 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5277 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") 

5278 except Exception as e: 

5279 logger.warning(f"Failed to fetch resource templates: {e}") 

5280 

5281 # Fetch prompts if supported 

5282 prompts = [] 

5283 if include_prompts: 

5284 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5285 if capabilities.get("prompts"): 

5286 try: 

5287 response = await session.list_prompts() 

5288 raw_prompts = response.prompts 

5289 for prompt in raw_prompts: 

5290 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5291 # Add default template if not present 

5292 if "template" not in prompt_data: 

5293 prompt_data["template"] = "" 

5294 prompts.append(PromptCreate.model_validate(prompt_data)) 

5295 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5296 except Exception as e: 

5297 logger.warning(f"Failed to fetch prompts: {e}") 

5298 

5299 return capabilities, tools, resources, prompts 

5300 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params) 

5301 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established") 

5302 

5303 

5304# Lazy singleton - created on first access, not at module import time. 

5305# This avoids instantiation when only exception classes are imported. 

5306_gateway_service_instance = None # pylint: disable=invalid-name 

5307 

5308 

5309def __getattr__(name: str): 

5310 """Module-level __getattr__ for lazy singleton creation. 

5311 

5312 Args: 

5313 name: The attribute name being accessed. 

5314 

5315 Returns: 

5316 The gateway_service singleton instance if name is "gateway_service". 

5317 

5318 Raises: 

5319 AttributeError: If the attribute name is not "gateway_service". 

5320 """ 

5321 global _gateway_service_instance # pylint: disable=global-statement 

5322 if name == "gateway_service": 

5323 if _gateway_service_instance is None: 

5324 _gateway_service_instance = GatewayService() 

5325 return _gateway_service_instance 

5326 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")