Coverage for mcpgateway / services / gateway_service.py: 90%

2210 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2# pylint: disable=import-outside-toplevel,no-name-in-module 

3"""Location: ./mcpgateway/services/gateway_service.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Gateway Service Implementation. 

9This module implements gateway federation according to the MCP specification. 

10It handles: 

11- Gateway discovery and registration 

12- Request forwarding 

13- Capability aggregation 

14- Health monitoring 

15- Active/inactive gateway management 

16 

17Examples: 

18 >>> from mcpgateway.services.gateway_service import GatewayService, GatewayError 

19 >>> service = GatewayService() 

20 >>> isinstance(service, GatewayService) 

21 True 

22 >>> hasattr(service, '_active_gateways') 

23 True 

24 >>> isinstance(service._active_gateways, set) 

25 True 

26 

27 Test error classes: 

28 >>> error = GatewayError("Test error") 

29 >>> str(error) 

30 'Test error' 

31 >>> isinstance(error, Exception) 

32 True 

33 

34 >>> conflict_error = GatewayNameConflictError("test_gw") 

35 >>> "test_gw" in str(conflict_error) 

36 True 

37 >>> conflict_error.enabled 

38 True 

39 >>> 

40 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

41 >>> import asyncio 

42 >>> asyncio.run(service._http_client.aclose()) 

43""" 

44 

45# Standard 

46import asyncio 

47import binascii 

48from datetime import datetime, timezone 

49import logging 

50import mimetypes 

51import os 

52import ssl 

53import tempfile 

54import time 

55from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union 

56from urllib.parse import urljoin, urlparse, urlunparse 

57import uuid 

58 

59# Third-Party 

60from filelock import FileLock, Timeout 

61import httpx 

62from mcp import ClientSession 

63from mcp.client.sse import sse_client 

64from mcp.client.streamable_http import streamablehttp_client 

65from pydantic import ValidationError 

66from sqlalchemy import and_, delete, desc, or_, select, update 

67from sqlalchemy.exc import IntegrityError 

68from sqlalchemy.orm import joinedload, selectinload, Session 

69 

70try: 

71 # Third-Party - check if redis is available 

72 # Third-Party 

73 import redis.asyncio as _aioredis # noqa: F401 # pylint: disable=unused-import 

74 

75 REDIS_AVAILABLE = True 

76 del _aioredis # Only needed for availability check 

77except ImportError: 

78 REDIS_AVAILABLE = False 

79 logging.info("Redis is not utilized in this environment.") 

80 

81# First-Party 

82from mcpgateway.config import settings 

83from mcpgateway.db import fresh_db_session 

84from mcpgateway.db import Gateway as DbGateway 

85from mcpgateway.db import get_db, get_for_update 

86from mcpgateway.db import Prompt as DbPrompt 

87from mcpgateway.db import PromptMetric 

88from mcpgateway.db import Resource as DbResource 

89from mcpgateway.db import ResourceMetric, ResourceSubscription, server_prompt_association, server_resource_association, server_tool_association, SessionLocal 

90from mcpgateway.db import Tool as DbTool 

91from mcpgateway.db import ToolMetric 

92from mcpgateway.observability import create_span 

93from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate 

94 

95# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks 

96from mcpgateway.services.audit_trail_service import get_audit_trail_service 

97from mcpgateway.services.event_service import EventService 

98from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout, get_isolated_http_client 

99from mcpgateway.services.logging_service import LoggingService 

100from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, register_gateway_capabilities_for_notifications, TransportType 

101from mcpgateway.services.oauth_manager import OAuthManager 

102from mcpgateway.services.structured_logger import get_structured_logger 

103from mcpgateway.services.team_management_service import TeamManagementService 

104from mcpgateway.utils.create_slug import slugify 

105from mcpgateway.utils.display_name import generate_display_name 

106from mcpgateway.utils.pagination import unified_paginate 

107from mcpgateway.utils.passthrough_headers import get_passthrough_headers 

108from mcpgateway.utils.redis_client import get_redis_client 

109from mcpgateway.utils.retry_manager import ResilientHttpClient 

110from mcpgateway.utils.services_auth import decode_auth, encode_auth 

111from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr 

112from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context 

113from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging 

114from mcpgateway.utils.validate_signature import validate_signature 

115from mcpgateway.validation.tags import validate_tags_field 

116 

117# Cache import (lazy to avoid circular dependencies) 

118_REGISTRY_CACHE = None 

119_TOOL_LOOKUP_CACHE = None 

120 

121 

122def _get_registry_cache(): 

123 """Get registry cache singleton lazily. 

124 

125 Returns: 

126 RegistryCache instance. 

127 """ 

128 global _REGISTRY_CACHE # pylint: disable=global-statement 

129 if _REGISTRY_CACHE is None: 

130 # First-Party 

131 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel 

132 

133 _REGISTRY_CACHE = registry_cache 

134 return _REGISTRY_CACHE 

135 

136 

137def _get_tool_lookup_cache(): 

138 """Get tool lookup cache singleton lazily. 

139 

140 Returns: 

141 ToolLookupCache instance. 

142 """ 

143 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement 

144 if _TOOL_LOOKUP_CACHE is None: 

145 # First-Party 

146 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

147 

148 _TOOL_LOOKUP_CACHE = tool_lookup_cache 

149 return _TOOL_LOOKUP_CACHE 

150 

151 

152# Initialize logging service first 

153logging_service = LoggingService() 

154logger = logging_service.get_logger(__name__) 

155 

156# Initialize structured logger and audit trail for gateway operations 

157structured_logger = get_structured_logger("gateway_service") 

158audit_trail = get_audit_trail_service() 

159 

160 

161GW_FAILURE_THRESHOLD = settings.unhealthy_threshold 

162GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval 

163 

164 

165class GatewayError(Exception): 

166 """Base class for gateway-related errors. 

167 

168 Examples: 

169 >>> error = GatewayError("Test error") 

170 >>> str(error) 

171 'Test error' 

172 >>> isinstance(error, Exception) 

173 True 

174 """ 

175 

176 

177class GatewayNotFoundError(GatewayError): 

178 """Raised when a requested gateway is not found. 

179 

180 Examples: 

181 >>> error = GatewayNotFoundError("Gateway not found") 

182 >>> str(error) 

183 'Gateway not found' 

184 >>> isinstance(error, GatewayError) 

185 True 

186 """ 

187 

188 

189class GatewayNameConflictError(GatewayError): 

190 """Raised when a gateway name conflicts with existing (active or inactive) gateway. 

191 

192 Args: 

193 name: The conflicting gateway name 

194 enabled: Whether the existing gateway is enabled 

195 gateway_id: ID of the existing gateway if available 

196 visibility: The visibility of the gateway ("public" or "team"). 

197 

198 Examples: 

199 >>> error = GatewayNameConflictError("test_gateway") 

200 >>> str(error) 

201 'Public Gateway already exists with name: test_gateway' 

202 >>> error.name 

203 'test_gateway' 

204 >>> error.enabled 

205 True 

206 >>> error.gateway_id is None 

207 True 

208 

209 >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123) 

210 >>> str(error_inactive) 

211 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)' 

212 >>> error_inactive.enabled 

213 False 

214 >>> error_inactive.gateway_id 

215 123 

216 """ 

217 

218 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"): 

219 """Initialize the error with gateway information. 

220 

221 Args: 

222 name: The conflicting gateway name 

223 enabled: Whether the existing gateway is enabled 

224 gateway_id: ID of the existing gateway if available 

225 visibility: The visibility of the gateway ("public" or "team"). 

226 """ 

227 self.name = name 

228 self.enabled = enabled 

229 self.gateway_id = gateway_id 

230 if visibility == "team": 

231 vis_label = "Team-level" 

232 else: 

233 vis_label = "Public" 

234 message = f"{vis_label} Gateway already exists with name: {name}" 

235 if not enabled: 

236 message += f" (currently inactive, ID: {gateway_id})" 

237 super().__init__(message) 

238 

239 

240class GatewayDuplicateConflictError(GatewayError): 

241 """Raised when a gateway conflicts with an existing gateway (same URL + credentials). 

242 

243 This error is raised when attempting to register a gateway with a URL and 

244 authentication credentials that already exist within the same scope: 

245 - Public: Global uniqueness required across all public gateways. 

246 - Team: Uniqueness required within the same team. 

247 - Private: Uniqueness required for the same user, a user cannot have two private gateways with the same URL and credentials. 

248 

249 Args: 

250 duplicate_gateway: The existing conflicting gateway (DbGateway instance). 

251 

252 Examples: 

253 >>> # Public gateway conflict with the same URL and basic auth 

254 >>> existing_gw = DbGateway(url="https://api.example.com", id="abc-123", enabled=True, visibility="public", team_id=None, name="API Gateway", owner_email="alice@example.com") 

255 >>> error = GatewayDuplicateConflictError( 

256 ... duplicate_gateway=existing_gw 

257 ... ) 

258 >>> str(error) 

259 'The Server already exists in Public scope (Name: API Gateway, Status: active)' 

260 

261 >>> # Team gateway conflict with the same URL and OAuth credentials 

262 >>> team_gw = DbGateway(url="https://api.example.com", id="def-456", enabled=False, visibility="team", team_id="engineering-team", name="API Gateway", owner_email="bob@example.com") 

263 >>> error = GatewayDuplicateConflictError( 

264 ... duplicate_gateway=team_gw 

265 ... ) 

266 >>> str(error) 

267 'The Server already exists in your Team (Name: API Gateway, Status: inactive). You may want to re-enable the existing gateway instead.' 

268 

269 >>> # Private gateway conflict (same user cannot have two gateways with the same URL) 

270 >>> private_gw = DbGateway(url="https://api.example.com", id="ghi-789", enabled=True, visibility="private", team_id="none", name="API Gateway", owner_email="charlie@example.com") 

271 >>> error = GatewayDuplicateConflictError( 

272 ... duplicate_gateway=private_gw 

273 ... ) 

274 >>> str(error) 

275 'The Server already exists in "private" scope (Name: API Gateway, Status: active)' 

276 """ 

277 

278 def __init__( 

279 self, 

280 duplicate_gateway: "DbGateway", 

281 ): 

282 """Initialize the error with gateway information. 

283 

284 Args: 

285 duplicate_gateway: The existing conflicting gateway (DbGateway instance) 

286 """ 

287 self.duplicate_gateway = duplicate_gateway 

288 self.url = duplicate_gateway.url 

289 self.gateway_id = duplicate_gateway.id 

290 self.enabled = duplicate_gateway.enabled 

291 self.visibility = duplicate_gateway.visibility 

292 self.team_id = duplicate_gateway.team_id 

293 self.name = duplicate_gateway.name 

294 

295 # Build scope description 

296 if self.visibility == "public": 

297 scope_desc = "Public scope" 

298 elif self.visibility == "team" and self.team_id: 

299 scope_desc = "your Team" 

300 else: 

301 scope_desc = f'"{self.visibility}" scope' 

302 

303 # Build status description 

304 status = "active" if self.enabled else "inactive" 

305 

306 # Construct error message 

307 message = f"The Server already exists in {scope_desc} " f"(Name: {self.name}, Status: {status})" 

308 

309 # Add helpful hint for inactive gateways 

310 if not self.enabled: 

311 message += ". You may want to re-enable the existing gateway instead." 

312 

313 super().__init__(message) 

314 

315 

316class GatewayConnectionError(GatewayError): 

317 """Raised when gateway connection fails. 

318 

319 Examples: 

320 >>> error = GatewayConnectionError("Connection failed") 

321 >>> str(error) 

322 'Connection failed' 

323 >>> isinstance(error, GatewayError) 

324 True 

325 """ 

326 

327 

328class OAuthToolValidationError(GatewayConnectionError): 

329 """Raised when tool validation fails during OAuth-driven fetch.""" 

330 

331 

332class GatewayService: # pylint: disable=too-many-instance-attributes 

333 """Service for managing federated gateways. 

334 

335 Handles: 

336 - Gateway registration and health checks 

337 - Request forwarding 

338 - Capability negotiation 

339 - Federation events 

340 - Active/inactive status management 

341 """ 

342 

343 def __init__(self) -> None: 

344 """Initialize the gateway service. 

345 

346 Examples: 

347 >>> from mcpgateway.services.gateway_service import GatewayService 

348 >>> from mcpgateway.services.event_service import EventService 

349 >>> from mcpgateway.utils.retry_manager import ResilientHttpClient 

350 >>> from mcpgateway.services.tool_service import ToolService 

351 >>> service = GatewayService() 

352 >>> isinstance(service._event_service, EventService) 

353 True 

354 >>> isinstance(service._http_client, ResilientHttpClient) 

355 True 

356 >>> service._health_check_interval == GW_HEALTH_CHECK_INTERVAL 

357 True 

358 >>> service._health_check_task is None 

359 True 

360 >>> isinstance(service._active_gateways, set) 

361 True 

362 >>> len(service._active_gateways) 

363 0 

364 >>> service._stream_response is None 

365 True 

366 >>> isinstance(service._pending_responses, dict) 

367 True 

368 >>> len(service._pending_responses) 

369 0 

370 >>> isinstance(service.tool_service, ToolService) 

371 True 

372 >>> isinstance(service._gateway_failure_counts, dict) 

373 True 

374 >>> len(service._gateway_failure_counts) 

375 0 

376 >>> hasattr(service, 'redis_url') 

377 True 

378 >>> 

379 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

380 >>> import asyncio 

381 >>> asyncio.run(service._http_client.aclose()) 

382 """ 

383 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) 

384 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL 

385 self._health_check_task: Optional[asyncio.Task] = None 

386 self._active_gateways: Set[str] = set() # Track active gateway URLs 

387 self._stream_response = None 

388 self._pending_responses = {} 

389 # Prefer using the globally-initialized singletons from mcpgateway.main 

390 # (created at application startup). Import lazily to avoid circular 

391 # import issues during module import time. Fall back to creating 

392 # local instances if the singletons are not available. 

393 # Use the globally-exported singletons from the service modules so 

394 # events propagate via their initialized EventService/Redis clients. 

395 # First-Party 

396 from mcpgateway.services.prompt_service import prompt_service 

397 from mcpgateway.services.resource_service import resource_service 

398 from mcpgateway.services.tool_service import tool_service 

399 

400 self.tool_service = tool_service 

401 self.prompt_service = prompt_service 

402 self.resource_service = resource_service 

403 self._gateway_failure_counts: dict[str, int] = {} 

404 self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3"))) 

405 self._event_service = EventService(channel_name="mcpgateway:gateway_events") 

406 

407 # Per-gateway refresh locks to prevent concurrent refreshes for the same gateway 

408 self._refresh_locks: Dict[str, asyncio.Lock] = {} 

409 

410 # For health checks, we determine the leader instance. 

411 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None 

412 

413 # Initialize optional Redis client holder (set in initialize()) 

414 self._redis_client: Optional[Any] = None 

415 

416 # Leader election settings from config 

417 if self.redis_url and REDIS_AVAILABLE: 

418 self._instance_id = str(uuid.uuid4()) # Unique ID for this process 

419 self._leader_key = settings.redis_leader_key 

420 self._leader_ttl = settings.redis_leader_ttl 

421 self._leader_heartbeat_interval = settings.redis_leader_heartbeat_interval 

422 self._leader_heartbeat_task: Optional[asyncio.Task] = None 

423 

424 # Always initialize file lock as fallback (used if Redis connection fails at runtime) 

425 if settings.cache_type != "none": 

426 temp_dir = tempfile.gettempdir() 

427 user_path = os.path.normpath(settings.filelock_name) 

428 if os.path.isabs(user_path): 

429 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep) 

430 full_path = os.path.join(temp_dir, user_path) 

431 self._lock_path = full_path.replace("\\", "/") 

432 self._file_lock = FileLock(self._lock_path) 

433 

434 @staticmethod 

435 def normalize_url(url: str) -> str: 

436 """ 

437 Normalize a URL by ensuring it's properly formatted. 

438 

439 Special handling for localhost to prevent duplicates: 

440 - Converts 127.0.0.1 to localhost for consistency 

441 - Preserves all other domain names as-is for CDN/load balancer support 

442 

443 Args: 

444 url (str): The URL to normalize. 

445 

446 Returns: 

447 str: The normalized URL. 

448 

449 Examples: 

450 >>> GatewayService.normalize_url('http://localhost:8080/path') 

451 'http://localhost:8080/path' 

452 >>> GatewayService.normalize_url('http://127.0.0.1:8080/path') 

453 'http://localhost:8080/path' 

454 >>> GatewayService.normalize_url('https://example.com/api') 

455 'https://example.com/api' 

456 """ 

457 parsed = urlparse(url) 

458 hostname = parsed.hostname 

459 

460 # Special case: normalize 127.0.0.1 to localhost to prevent duplicates 

461 # but preserve all other domains as-is for CDN/load balancer support 

462 if hostname == "127.0.0.1": 

463 netloc = "localhost" 

464 if parsed.port: 

465 netloc += f":{parsed.port}" 

466 normalized = parsed._replace(netloc=netloc) 

467 return str(urlunparse(normalized)) 

468 

469 # For all other URLs, preserve the domain name 

470 return url 

471 

472 def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext: 

473 """Create an SSL context with the provided CA certificate. 

474 

475 Uses caching to avoid repeated SSL context creation for the same certificate. 

476 

477 Args: 

478 ca_certificate: CA certificate in PEM format 

479 

480 Returns: 

481 ssl.SSLContext: Configured SSL context 

482 """ 

483 return get_cached_ssl_context(ca_certificate) 

484 

485 async def initialize(self) -> None: 

486 """Initialize the service and start health check if this instance is the leader. 

487 

488 Raises: 

489 ConnectionError: When redis ping fails 

490 """ 

491 logger.info("Initializing gateway service") 

492 

493 # Initialize event service with shared Redis client 

494 await self._event_service.initialize() 

495 

496 # NOTE: We intentionally do NOT create a long-lived DB session here. 

497 # Health checks use fresh_db_session() only when DB access is actually needed, 

498 # avoiding holding connections during HTTP calls to MCP servers. 

499 

500 user_email = settings.platform_admin_email 

501 

502 # Get shared Redis client from factory 

503 if self.redis_url and REDIS_AVAILABLE: 

504 self._redis_client = await get_redis_client() 

505 

506 if self._redis_client: 

507 # Check if Redis is available (ping already done by factory, but verify) 

508 try: 

509 await self._redis_client.ping() 

510 except Exception as e: 

511 raise ConnectionError(f"Redis ping failed: {e}") from e 

512 

513 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) 

514 if is_leader: 514 ↛ exitline 514 didn't return from function 'initialize' because the condition on line 514 was always true

515 logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.") 

516 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

517 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat()) 

518 else: 

519 # Always create the health check task in filelock mode; leader check is handled inside. 

520 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) 

521 

522 async def shutdown(self) -> None: 

523 """Shutdown the service. 

524 

525 Examples: 

526 >>> service = GatewayService() 

527 >>> # Mock internal components 

528 >>> from unittest.mock import AsyncMock 

529 >>> service._event_service = AsyncMock() 

530 >>> service._active_gateways = {'test_gw'} 

531 >>> import asyncio 

532 >>> asyncio.run(service.shutdown()) 

533 >>> # Verify event service shutdown was called 

534 >>> service._event_service.shutdown.assert_awaited_once() 

535 >>> len(service._active_gateways) 

536 0 

537 """ 

538 if self._health_check_task: 

539 self._health_check_task.cancel() 

540 try: 

541 await self._health_check_task 

542 except asyncio.CancelledError: 

543 pass 

544 

545 # Cancel leader heartbeat task if running 

546 if getattr(self, "_leader_heartbeat_task", None): 

547 self._leader_heartbeat_task.cancel() 

548 try: 

549 await self._leader_heartbeat_task 

550 except asyncio.CancelledError: 

551 pass 

552 

553 # Release Redis leadership atomically if we hold it 

554 if self._redis_client: 

555 try: 

556 # Lua script for atomic check-and-delete (only delete if we own the key) 

557 release_script = """ 

558 if redis.call("get", KEYS[1]) == ARGV[1] then 

559 return redis.call("del", KEYS[1]) 

560 else 

561 return 0 

562 end 

563 """ 

564 result = await self._redis_client.eval(release_script, 1, self._leader_key, self._instance_id) 

565 if result: 565 ↛ 570line 565 didn't jump to line 570 because the condition on line 565 was always true

566 logger.info("Released Redis leadership on shutdown") 

567 except Exception as e: 

568 logger.warning(f"Failed to release Redis leader key on shutdown: {e}") 

569 

570 await self._http_client.aclose() 

571 await self._event_service.shutdown() 

572 self._active_gateways.clear() 

573 logger.info("Gateway service shutdown complete") 

574 

575 def _check_gateway_uniqueness( 

576 self, 

577 db: Session, 

578 url: str, 

579 auth_value: Optional[Dict[str, str]], 

580 oauth_config: Optional[Dict[str, Any]], 

581 team_id: Optional[str], 

582 owner_email: str, 

583 visibility: str, 

584 gateway_id: Optional[str] = None, 

585 ) -> Optional[DbGateway]: 

586 """ 

587 Check if a gateway with the same URL and credentials already exists. 

588 

589 Args: 

590 db: Database session 

591 url: Gateway URL (normalized) 

592 auth_value: Decoded auth_value dict (not encrypted) 

593 oauth_config: OAuth configuration dict 

594 team_id: Team ID for team-scoped gateways 

595 owner_email: Email of the gateway owner 

596 visibility: Gateway visibility (public/team/private) 

597 gateway_id: Optional gateway ID to exclude from check (for updates) 

598 

599 Returns: 

600 DbGateway if duplicate found, None otherwise 

601 """ 

602 # Build base query based on visibility 

603 if visibility == "public": 

604 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "public") 

605 elif visibility == "team" and team_id: 

606 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "team", DbGateway.team_id == team_id) 

607 elif visibility == "private": 

608 # Check for duplicates within the same user's private gateways 

609 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "private", DbGateway.owner_email == owner_email) # Scoped to same user 

610 else: 

611 return None 

612 

613 # Exclude current gateway if updating 

614 if gateway_id: 

615 query = query.filter(DbGateway.id != gateway_id) 

616 

617 existing_gateways = query.all() 

618 

619 # Check each existing gateway 

620 for existing in existing_gateways: 

621 # Case 1: Both have OAuth config 

622 if oauth_config and existing.oauth_config: 

623 # Compare OAuth configs (exclude dynamic fields like tokens) 

624 existing_oauth = existing.oauth_config or {} 

625 new_oauth = oauth_config or {} 

626 

627 # Compare key OAuth fields 

628 oauth_keys = ["grant_type", "client_id", "authorization_url", "token_url", "scope"] 

629 if all(existing_oauth.get(k) == new_oauth.get(k) for k in oauth_keys): 629 ↛ 620line 629 didn't jump to line 620 because the condition on line 629 was always true

630 return existing # Duplicate OAuth config found 

631 

632 # Case 2: Both have auth_value (need to decrypt and compare) 

633 elif auth_value and existing.auth_value: 

634 

635 try: 

636 # Decrypt existing auth_value 

637 if isinstance(existing.auth_value, str): 

638 existing_decoded = decode_auth(existing.auth_value) 

639 

640 elif isinstance(existing.auth_value, dict): 

641 existing_decoded = existing.auth_value 

642 

643 else: 

644 continue 

645 

646 # Compare decoded auth values 

647 if auth_value == existing_decoded: 647 ↛ 620line 647 didn't jump to line 620 because the condition on line 647 was always true

648 return existing # Duplicate credentials found 

649 except Exception as e: 

650 logger.warning(f"Failed to decode auth_value for comparison: {e}") 

651 continue 

652 

653 # Case 3: Both have no auth (URL only, not allowed) 

654 elif not auth_value and not oauth_config and not existing.auth_value and not existing.oauth_config: 654 ↛ 620line 654 didn't jump to line 620 because the condition on line 654 was always true

655 return existing # Duplicate URL without credentials 

656 

657 return None # No duplicate found 

658 

659 async def register_gateway( 

660 self, 

661 db: Session, 

662 gateway: GatewayCreate, 

663 created_by: Optional[str] = None, 

664 created_from_ip: Optional[str] = None, 

665 created_via: Optional[str] = None, 

666 created_user_agent: Optional[str] = None, 

667 team_id: Optional[str] = None, 

668 owner_email: Optional[str] = None, 

669 visibility: Optional[str] = None, 

670 initialize_timeout: Optional[float] = None, 

671 ) -> GatewayRead: 

672 """Register a new gateway. 

673 

674 Args: 

675 db: Database session 

676 gateway: Gateway creation schema 

677 created_by: Username who created this gateway 

678 created_from_ip: IP address of creator 

679 created_via: Creation method (ui, api, federation) 

680 created_user_agent: User agent of creation request 

681 team_id (Optional[str]): Team ID to assign the gateway to. 

682 owner_email (Optional[str]): Email of the user who owns this gateway. 

683 visibility (Optional[str]): Gateway visibility level (private, team, public). 

684 initialize_timeout (Optional[float]): Timeout in seconds for gateway initialization. 

685 

686 Returns: 

687 Created gateway information 

688 

689 Raises: 

690 GatewayNameConflictError: If gateway name already exists 

691 GatewayConnectionError: If there was an error connecting to the gateway 

692 ValueError: If required values are missing 

693 RuntimeError: If there is an error during processing that is not covered by other exceptions 

694 IntegrityError: If there is a database integrity error 

695 BaseException: If an unexpected error occurs 

696 

697 Examples: 

698 >>> from mcpgateway.services.gateway_service import GatewayService 

699 >>> from unittest.mock import MagicMock 

700 >>> service = GatewayService() 

701 >>> db = MagicMock() 

702 >>> gateway = MagicMock() 

703 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

704 >>> db.add = MagicMock() 

705 >>> db.commit = MagicMock() 

706 >>> db.refresh = MagicMock() 

707 >>> service._notify_gateway_added = MagicMock() 

708 >>> import asyncio 

709 >>> try: 

710 ... asyncio.run(service.register_gateway(db, gateway)) 

711 ... except Exception: 

712 ... pass 

713 >>> 

714 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

715 >>> asyncio.run(service._http_client.aclose()) 

716 """ 

717 visibility = "public" if visibility not in ("private", "team", "public") else visibility 

718 try: 

719 # # Check for name conflicts (both active and inactive) 

720 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none() 

721 

722 # if existing_gateway: 

723 # raise GatewayNameConflictError( 

724 # gateway.name, 

725 # enabled=existing_gateway.enabled, 

726 # gateway_id=existing_gateway.id, 

727 # ) 

728 # Check for existing gateway with the same slug and visibility 

729 slug_name = slugify(gateway.name) 

730 if visibility.lower() == "public": 

731 # Check for existing public gateway with the same slug (row-locked) 

732 existing_gateway = get_for_update( 

733 db, 

734 DbGateway, 

735 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "public"), 

736 ) 

737 if existing_gateway: 

738 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility) 

739 elif visibility.lower() == "team" and team_id: 739 ↛ 750line 739 didn't jump to line 750 because the condition on line 739 was always true

740 # Check for existing team gateway with the same slug (row-locked) 

741 existing_gateway = get_for_update( 

742 db, 

743 DbGateway, 

744 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "team", DbGateway.team_id == team_id), 

745 ) 

746 if existing_gateway: 746 ↛ 750line 746 didn't jump to line 750 because the condition on line 746 was always true

747 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility) 

748 

749 # Normalize the gateway URL 

750 normalized_url = self.normalize_url(str(gateway.url)) 

751 

752 decoded_auth_value = None 

753 if gateway.auth_value: 

754 if isinstance(gateway.auth_value, str): 

755 try: 

756 decoded_auth_value = decode_auth(gateway.auth_value) 

757 except Exception as e: 

758 logger.warning(f"Failed to decode provided auth_value: {e}") 

759 decoded_auth_value = None 

760 elif isinstance(gateway.auth_value, dict): 760 ↛ 764line 760 didn't jump to line 764 because the condition on line 760 was always true

761 decoded_auth_value = gateway.auth_value 

762 

763 # Check for duplicate gateway 

764 if not gateway.one_time_auth: 

765 duplicate_gateway = self._check_gateway_uniqueness( 

766 db=db, url=normalized_url, auth_value=decoded_auth_value, oauth_config=gateway.oauth_config, team_id=team_id, owner_email=owner_email, visibility=visibility 

767 ) 

768 

769 if duplicate_gateway: 

770 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway) 

771 

772 # Prevent URL-only gateways (no auth at all) 

773 # if not decoded_auth_value and not gateway.oauth_config: 

774 # raise ValueError( 

775 # f"Gateway with URL '{normalized_url}' must have either auth_value or oauth_config. " 

776 # "URL-only gateways are not allowed." 

777 # ) 

778 

779 auth_type = getattr(gateway, "auth_type", None) 

780 # Support multiple custom headers 

781 auth_value = getattr(gateway, "auth_value", {}) 

782 authentication_headers: Optional[Dict[str, str]] = None 

783 

784 # Handle query_param auth - encrypt and prepare for storage 

785 auth_query_params_encrypted: Optional[Dict[str, str]] = None 

786 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

787 init_url = normalized_url # URL to use for initialization 

788 

789 if auth_type == "query_param": 

790 # Extract and encrypt query param auth 

791 param_key = getattr(gateway, "auth_query_param_key", None) 

792 param_value = getattr(gateway, "auth_query_param_value", None) 

793 if param_key and param_value: 793 ↛ 823line 793 didn't jump to line 823 because the condition on line 793 was always true

794 # Get the actual secret value 

795 if hasattr(param_value, "get_secret_value"): 

796 raw_value = param_value.get_secret_value() 

797 else: 

798 raw_value = str(param_value) 

799 # Encrypt for storage 

800 encrypted_value = encode_auth({param_key: raw_value}) 

801 auth_query_params_encrypted = {param_key: encrypted_value} 

802 auth_query_params_decrypted = {param_key: raw_value} 

803 # Append query params to URL for initialization 

804 init_url = apply_query_param_auth(normalized_url, auth_query_params_decrypted) 

805 # Query param auth doesn't use auth_value 

806 auth_value = None 

807 authentication_headers = None 

808 

809 elif hasattr(gateway, "auth_headers") and gateway.auth_headers: 

810 # Convert list of {key, value} to dict 

811 header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")} 

812 # Keep encoded form for persistence, but pass raw headers for initialization 

813 auth_value = encode_auth(header_dict) # Encode the dict for consistency 

814 authentication_headers = {str(k): str(v) for k, v in header_dict.items()} 

815 

816 elif isinstance(auth_value, str) and auth_value: 

817 # Decode persisted auth for initialization 

818 decoded = decode_auth(auth_value) 

819 authentication_headers = {str(k): str(v) for k, v in decoded.items()} 

820 else: 

821 authentication_headers = None 

822 

823 oauth_config = getattr(gateway, "oauth_config", None) 

824 ca_certificate = getattr(gateway, "ca_certificate", None) 

825 if initialize_timeout is not None: 

826 try: 

827 capabilities, tools, resources, prompts = await asyncio.wait_for( 

828 self._initialize_gateway( 

829 init_url, # URL with query params if applicable 

830 authentication_headers, 

831 gateway.transport, 

832 auth_type, 

833 oauth_config, 

834 ca_certificate, 

835 auth_query_params=auth_query_params_decrypted, 

836 ), 

837 timeout=initialize_timeout, 

838 ) 

839 except asyncio.TimeoutError as exc: 

840 sanitized = sanitize_url_for_logging(init_url, auth_query_params_decrypted) 

841 raise GatewayConnectionError(f"Gateway initialization timed out after {initialize_timeout}s for {sanitized}") from exc 

842 else: 

843 capabilities, tools, resources, prompts = await self._initialize_gateway( 

844 init_url, # URL with query params if applicable 

845 authentication_headers, 

846 gateway.transport, 

847 auth_type, 

848 oauth_config, 

849 ca_certificate, 

850 auth_query_params=auth_query_params_decrypted, 

851 ) 

852 

853 if gateway.one_time_auth: 

854 # For one-time auth, clear auth_type and auth_value after initialization 

855 auth_type = "one_time_auth" 

856 auth_value = None 

857 oauth_config = None 

858 

859 tools = [ 

860 DbTool( 

861 original_name=tool.name, 

862 custom_name=tool.name, 

863 custom_name_slug=slugify(tool.name), 

864 display_name=generate_display_name(tool.name), 

865 url=normalized_url, 

866 description=tool.description, 

867 integration_type="MCP", # Gateway-discovered tools are MCP type 

868 request_type=tool.request_type, 

869 headers=tool.headers, 

870 input_schema=tool.input_schema, 

871 output_schema=tool.output_schema, 

872 annotations=tool.annotations, 

873 jsonpath_filter=tool.jsonpath_filter, 

874 auth_type=auth_type, 

875 auth_value=auth_value, 

876 # Federation metadata 

877 created_by=created_by or "system", 

878 created_from_ip=created_from_ip, 

879 created_via="federation", # These are federated tools 

880 created_user_agent=created_user_agent, 

881 federation_source=gateway.name, 

882 version=1, 

883 # Inherit team assignment from gateway 

884 team_id=team_id, 

885 owner_email=owner_email, 

886 visibility=visibility, 

887 ) 

888 for tool in tools 

889 ] 

890 

891 # Create resource DB models with upsert logic for ORPHANED resources only 

892 # Query for existing ORPHANED resources (gateway_id IS NULL or points to non-existent gateway) 

893 # with same (team_id, owner_email, uri) to handle resources left behind from incomplete 

894 # gateway deletions (e.g., issue #2341 crash scenarios). 

895 # We only update orphaned resources - resources belonging to active gateways are not touched. 

896 resource_uris = [r.uri for r in resources] 

897 effective_owner = owner_email or created_by 

898 

899 # Build lookup map: (team_id, owner_email, uri) -> orphaned DbResource 

900 # We query all resources matching our URIs, then filter to orphaned ones in Python 

901 # to handle per-resource team/owner overrides correctly 

902 orphaned_resources_map: Dict[tuple, DbResource] = {} 

903 if resource_uris: 

904 try: 

905 # Get valid gateway IDs to identify orphaned resources 

906 valid_gateway_ids = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all()) 

907 candidate_resources = db.execute(select(DbResource).where(DbResource.uri.in_(resource_uris))).scalars().all() 

908 for res in candidate_resources: 

909 # Only consider orphaned resources (no gateway or gateway doesn't exist) 

910 is_orphaned = res.gateway_id is None or res.gateway_id not in valid_gateway_ids 

911 if is_orphaned: 911 ↛ 908line 911 didn't jump to line 908 because the condition on line 911 was always true

912 key = (res.team_id, res.owner_email, res.uri) 

913 orphaned_resources_map[key] = res 

914 if orphaned_resources_map: 

915 logger.info(f"Found {len(orphaned_resources_map)} orphaned resources to reassign for gateway {gateway.name}") 

916 except Exception as e: 

917 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new resources 

918 # This is conservative - we won't accidentally reassign resources from active gateways 

919 logger.debug(f"Orphan resource detection skipped: {e}") 

920 

921 db_resources = [] 

922 for r in resources: 

923 mime_type = mimetypes.guess_type(r.uri)[0] or ("text/plain" if isinstance(r.content, str) else "application/octet-stream") 

924 r_team_id = getattr(r, "team_id", None) or team_id 

925 r_owner_email = getattr(r, "owner_email", None) or effective_owner 

926 r_visibility = getattr(r, "visibility", None) or visibility 

927 

928 # Check if there's an orphaned resource with matching unique key 

929 lookup_key = (r_team_id, r_owner_email, r.uri) 

930 if lookup_key in orphaned_resources_map: 

931 # Update orphaned resource - reassign to new gateway 

932 existing = orphaned_resources_map[lookup_key] 

933 existing.name = r.name 

934 existing.description = r.description 

935 existing.mime_type = mime_type 

936 existing.uri_template = r.uri_template or None 

937 existing.text_content = r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None 

938 existing.binary_content = ( 

939 r.content.encode() if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else r.content if isinstance(r.content, bytes) else None 

940 ) 

941 existing.size = len(r.content) if r.content else 0 

942 existing.tags = getattr(r, "tags", []) or [] 

943 existing.federation_source = gateway.name 

944 existing.modified_by = created_by 

945 existing.modified_from_ip = created_from_ip 

946 existing.modified_via = "federation" 

947 existing.modified_user_agent = created_user_agent 

948 existing.updated_at = datetime.now(timezone.utc) 

949 existing.visibility = r_visibility 

950 # Note: gateway_id will be set when gateway is created (relationship) 

951 db_resources.append(existing) 

952 else: 

953 # Create new resource 

954 db_resources.append( 

955 DbResource( 

956 uri=r.uri, 

957 name=r.name, 

958 description=r.description, 

959 mime_type=mime_type, 

960 uri_template=r.uri_template or None, 

961 text_content=r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None, 

962 binary_content=( 

963 r.content.encode() 

964 if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) 

965 else r.content if isinstance(r.content, bytes) else None 

966 ), 

967 size=len(r.content) if r.content else 0, 

968 tags=getattr(r, "tags", []) or [], 

969 created_by=created_by or "system", 

970 created_from_ip=created_from_ip, 

971 created_via="federation", 

972 created_user_agent=created_user_agent, 

973 import_batch_id=None, 

974 federation_source=gateway.name, 

975 version=1, 

976 team_id=r_team_id, 

977 owner_email=r_owner_email, 

978 visibility=r_visibility, 

979 ) 

980 ) 

981 

982 # Create prompt DB models with upsert logic for ORPHANED prompts only 

983 # Query for existing ORPHANED prompts (gateway_id IS NULL or points to non-existent gateway) 

984 # with same (team_id, owner_email, name) to handle prompts left behind from incomplete 

985 # gateway deletions. We only update orphaned prompts - prompts belonging to active gateways are not touched. 

986 prompt_names = [p.name for p in prompts] 

987 

988 # Build lookup map: (team_id, owner_email, name) -> orphaned DbPrompt 

989 orphaned_prompts_map: Dict[tuple, DbPrompt] = {} 

990 if prompt_names: 

991 try: 

992 # Get valid gateway IDs to identify orphaned prompts 

993 valid_gateway_ids_for_prompts = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all()) 

994 candidate_prompts = db.execute(select(DbPrompt).where(DbPrompt.name.in_(prompt_names))).scalars().all() 

995 for pmt in candidate_prompts: 

996 # Only consider orphaned prompts (no gateway or gateway doesn't exist) 

997 is_orphaned = pmt.gateway_id is None or pmt.gateway_id not in valid_gateway_ids_for_prompts 

998 if is_orphaned: 998 ↛ 995line 998 didn't jump to line 995 because the condition on line 998 was always true

999 key = (pmt.team_id, pmt.owner_email, pmt.name) 

1000 orphaned_prompts_map[key] = pmt 

1001 if orphaned_prompts_map: 

1002 logger.info(f"Found {len(orphaned_prompts_map)} orphaned prompts to reassign for gateway {gateway.name}") 

1003 except Exception as e: 

1004 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new prompts 

1005 logger.debug(f"Orphan prompt detection skipped: {e}") 

1006 

1007 db_prompts = [] 

1008 for prompt in prompts: 

1009 # Prompts inherit team/owner from gateway (no per-prompt overrides) 

1010 p_team_id = team_id 

1011 p_owner_email = owner_email or effective_owner 

1012 

1013 # Check if there's an orphaned prompt with matching unique key 

1014 lookup_key = (p_team_id, p_owner_email, prompt.name) 

1015 if lookup_key in orphaned_prompts_map: 

1016 # Update orphaned prompt - reassign to new gateway 

1017 existing = orphaned_prompts_map[lookup_key] 

1018 existing.original_name = prompt.name 

1019 existing.custom_name = prompt.name 

1020 existing.display_name = prompt.name 

1021 existing.description = prompt.description 

1022 existing.template = prompt.template if hasattr(prompt, "template") else "" 

1023 existing.federation_source = gateway.name 

1024 existing.modified_by = created_by 

1025 existing.modified_from_ip = created_from_ip 

1026 existing.modified_via = "federation" 

1027 existing.modified_user_agent = created_user_agent 

1028 existing.updated_at = datetime.now(timezone.utc) 

1029 existing.visibility = visibility 

1030 # Note: gateway_id will be set when gateway is created (relationship) 

1031 db_prompts.append(existing) 

1032 else: 

1033 # Create new prompt 

1034 db_prompts.append( 

1035 DbPrompt( 

1036 name=prompt.name, 

1037 original_name=prompt.name, 

1038 custom_name=prompt.name, 

1039 display_name=prompt.name, 

1040 description=prompt.description, 

1041 template=prompt.template if hasattr(prompt, "template") else "", 

1042 argument_schema={}, # Use argument_schema instead of arguments 

1043 # Federation metadata 

1044 created_by=created_by or "system", 

1045 created_from_ip=created_from_ip, 

1046 created_via="federation", # These are federated prompts 

1047 created_user_agent=created_user_agent, 

1048 federation_source=gateway.name, 

1049 version=1, 

1050 # Inherit team assignment from gateway 

1051 team_id=team_id, 

1052 owner_email=owner_email, 

1053 visibility=visibility, 

1054 ) 

1055 ) 

1056 

1057 # Create DB model 

1058 db_gateway = DbGateway( 

1059 name=gateway.name, 

1060 slug=slug_name, 

1061 url=normalized_url, 

1062 description=gateway.description, 

1063 tags=gateway.tags or [], 

1064 transport=gateway.transport, 

1065 capabilities=capabilities, 

1066 last_seen=datetime.now(timezone.utc), 

1067 auth_type=auth_type, 

1068 auth_value=auth_value, 

1069 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth 

1070 oauth_config=oauth_config, 

1071 passthrough_headers=gateway.passthrough_headers, 

1072 tools=tools, 

1073 resources=db_resources, 

1074 prompts=db_prompts, 

1075 # Gateway metadata 

1076 created_by=created_by, 

1077 created_from_ip=created_from_ip, 

1078 created_via=created_via or "api", 

1079 created_user_agent=created_user_agent, 

1080 version=1, 

1081 # Team scoping fields 

1082 team_id=team_id, 

1083 owner_email=owner_email, 

1084 visibility=visibility, 

1085 ca_certificate=gateway.ca_certificate, 

1086 ca_certificate_sig=gateway.ca_certificate_sig, 

1087 signing_algorithm=gateway.signing_algorithm, 

1088 ) 

1089 

1090 # Add to DB 

1091 db.add(db_gateway) 

1092 db.flush() # Flush to get the ID without committing 

1093 db.refresh(db_gateway) 

1094 

1095 # Update tracking 

1096 self._active_gateways.add(db_gateway.url) 

1097 

1098 # Notify subscribers 

1099 await self._notify_gateway_added(db_gateway) 

1100 

1101 logger.info(f"Registered gateway: {gateway.name}") 

1102 

1103 # Structured logging: Audit trail for gateway creation 

1104 audit_trail.log_action( 

1105 user_id=created_by or "system", 

1106 action="create_gateway", 

1107 resource_type="gateway", 

1108 resource_id=str(db_gateway.id), 

1109 resource_name=db_gateway.name, 

1110 user_email=owner_email, 

1111 team_id=team_id, 

1112 client_ip=created_from_ip, 

1113 user_agent=created_user_agent, 

1114 new_values={ 

1115 "name": db_gateway.name, 

1116 "url": db_gateway.url, 

1117 "visibility": visibility, 

1118 "transport": db_gateway.transport, 

1119 "tools_count": len(tools), 

1120 "resources_count": len(db_resources), 

1121 "prompts_count": len(db_prompts), 

1122 }, 

1123 context={ 

1124 "created_via": created_via, 

1125 }, 

1126 db=db, 

1127 ) 

1128 

1129 # Structured logging: Log successful gateway creation 

1130 structured_logger.log( 

1131 level="INFO", 

1132 message="Gateway created successfully", 

1133 event_type="gateway_created", 

1134 component="gateway_service", 

1135 user_id=created_by, 

1136 user_email=owner_email, 

1137 team_id=team_id, 

1138 resource_type="gateway", 

1139 resource_id=str(db_gateway.id), 

1140 custom_fields={ 

1141 "gateway_name": db_gateway.name, 

1142 "gateway_url": normalized_url, 

1143 "visibility": visibility, 

1144 "transport": db_gateway.transport, 

1145 }, 

1146 db=db, 

1147 ) 

1148 

1149 return GatewayRead.model_validate(self._prepare_gateway_for_read(db_gateway)).masked() 

1150 except* GatewayConnectionError as ge: # pragma: no mutate 

1151 if TYPE_CHECKING: 

1152 ge: ExceptionGroup[GatewayConnectionError] 

1153 logger.error(f"GatewayConnectionError in group: {ge.exceptions}") 

1154 

1155 structured_logger.log( 

1156 level="ERROR", 

1157 message="Gateway creation failed due to connection error", 

1158 event_type="gateway_creation_failed", 

1159 component="gateway_service", 

1160 user_id=created_by, 

1161 user_email=owner_email, 

1162 error=ge.exceptions[0], 

1163 custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)}, 

1164 db=db, 

1165 ) 

1166 raise ge.exceptions[0] 

1167 except* GatewayNameConflictError as gnce: # pragma: no mutate 

1168 if TYPE_CHECKING: 

1169 gnce: ExceptionGroup[GatewayNameConflictError] 

1170 logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") 

1171 

1172 structured_logger.log( 

1173 level="WARNING", 

1174 message="Gateway creation failed due to name conflict", 

1175 event_type="gateway_name_conflict", 

1176 component="gateway_service", 

1177 user_id=created_by, 

1178 user_email=owner_email, 

1179 custom_fields={"gateway_name": gateway.name, "visibility": visibility}, 

1180 db=db, 

1181 ) 

1182 raise gnce.exceptions[0] 

1183 except* GatewayDuplicateConflictError as guce: # pragma: no mutate 

1184 if TYPE_CHECKING: 

1185 guce: ExceptionGroup[GatewayDuplicateConflictError] 

1186 logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}") 

1187 

1188 structured_logger.log( 

1189 level="WARNING", 

1190 message="Gateway creation failed due to duplicate", 

1191 event_type="gateway_duplicate_conflict", 

1192 component="gateway_service", 

1193 user_id=created_by, 

1194 user_email=owner_email, 

1195 custom_fields={"gateway_name": gateway.name}, 

1196 db=db, 

1197 ) 

1198 raise guce.exceptions[0] 

1199 except* ValueError as ve: # pragma: no mutate 

1200 if TYPE_CHECKING: 

1201 ve: ExceptionGroup[ValueError] 

1202 logger.error(f"ValueErrors in group: {ve.exceptions}") 

1203 

1204 structured_logger.log( 

1205 level="ERROR", 

1206 message="Gateway creation failed due to validation error", 

1207 event_type="gateway_creation_failed", 

1208 component="gateway_service", 

1209 user_id=created_by, 

1210 user_email=owner_email, 

1211 error=ve.exceptions[0], 

1212 custom_fields={"gateway_name": gateway.name}, 

1213 db=db, 

1214 ) 

1215 raise ve.exceptions[0] 

1216 except* RuntimeError as re: # pragma: no mutate 

1217 if TYPE_CHECKING: 

1218 re: ExceptionGroup[RuntimeError] 

1219 logger.error(f"RuntimeErrors in group: {re.exceptions}") 

1220 

1221 structured_logger.log( 

1222 level="ERROR", 

1223 message="Gateway creation failed due to runtime error", 

1224 event_type="gateway_creation_failed", 

1225 component="gateway_service", 

1226 user_id=created_by, 

1227 user_email=owner_email, 

1228 error=re.exceptions[0], 

1229 custom_fields={"gateway_name": gateway.name}, 

1230 db=db, 

1231 ) 

1232 raise re.exceptions[0] 

1233 except* IntegrityError as ie: # pragma: no mutate 

1234 if TYPE_CHECKING: 

1235 ie: ExceptionGroup[IntegrityError] 

1236 logger.error(f"IntegrityErrors in group: {ie.exceptions}") 

1237 

1238 structured_logger.log( 

1239 level="ERROR", 

1240 message="Gateway creation failed due to database integrity error", 

1241 event_type="gateway_creation_failed", 

1242 component="gateway_service", 

1243 user_id=created_by, 

1244 user_email=owner_email, 

1245 error=ie.exceptions[0], 

1246 custom_fields={"gateway_name": gateway.name}, 

1247 db=db, 

1248 ) 

1249 raise ie.exceptions[0] 

1250 except* BaseException as other: # catches every other sub-exception # pragma: no mutate 

1251 if TYPE_CHECKING: 

1252 other: ExceptionGroup[Exception] 

1253 logger.error(f"Other grouped errors: {other.exceptions}") 

1254 raise other.exceptions[0] 

1255 

1256 async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_email: str) -> Dict[str, Any]: 

1257 """Fetch tools from MCP server after OAuth completion for Authorization Code flow. 

1258 

1259 Args: 

1260 db: Database session 

1261 gateway_id: ID of the gateway to fetch tools for 

1262 app_user_email: MCP Gateway user email for token retrieval 

1263 

1264 Returns: 

1265 Dict containing capabilities, tools, resources, and prompts 

1266 

1267 Raises: 

1268 GatewayConnectionError: If connection or OAuth fails 

1269 """ 

1270 try: 

1271 # Get the gateway with eager loading for sync operations to avoid N+1 queries 

1272 gateway = db.execute( 

1273 select(DbGateway) 

1274 .options( 

1275 selectinload(DbGateway.tools), 

1276 selectinload(DbGateway.resources), 

1277 selectinload(DbGateway.prompts), 

1278 joinedload(DbGateway.email_team), 

1279 ) 

1280 .where(DbGateway.id == gateway_id) 

1281 ).scalar_one_or_none() 

1282 

1283 if not gateway: 

1284 raise ValueError(f"Gateway {gateway_id} not found") 

1285 

1286 if not gateway.oauth_config: 

1287 raise ValueError(f"Gateway {gateway_id} has no OAuth configuration") 

1288 

1289 grant_type = gateway.oauth_config.get("grant_type") 

1290 if grant_type != "authorization_code": 

1291 raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow") 

1292 

1293 # Get OAuth tokens for this gateway 

1294 # First-Party 

1295 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

1296 

1297 token_storage = TokenStorageService(db) 

1298 

1299 # Get user-specific OAuth token 

1300 if not app_user_email: 

1301 raise GatewayConnectionError(f"User authentication required for OAuth gateway {gateway.name}") 

1302 

1303 access_token = await token_storage.get_user_token(gateway.id, app_user_email) 

1304 

1305 if not access_token: 

1306 raise GatewayConnectionError( 

1307 f"No OAuth tokens found for user {app_user_email} on gateway {gateway.name}. Please complete the OAuth authorization flow first at /oauth/authorize/{gateway.id}" 

1308 ) 

1309 

1310 # Debug: Check if token was decrypted 

1311 if access_token.startswith("Z0FBQUFBQm"): # Encrypted tokens start with this 

1312 logger.error(f"Token appears to be encrypted! Encryption service may have failed. Token length: {len(access_token)}") 

1313 else: 

1314 logger.info(f"Using decrypted OAuth token for {gateway.name} (length: {len(access_token)})") 

1315 

1316 # Now connect to MCP server with the access token 

1317 authentication = {"Authorization": f"Bearer {access_token}"} 

1318 

1319 # Use the existing connection logic 

1320 # Note: For OAuth servers, skip validation since we already validated via OAuth flow 

1321 if gateway.transport.upper() == "SSE": 

1322 capabilities, tools, resources, prompts = await self._connect_to_sse_server_without_validation(gateway.url, authentication) 

1323 elif gateway.transport.upper() == "STREAMABLEHTTP": 

1324 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication) 

1325 else: 

1326 raise ValueError(f"Unsupported transport type: {gateway.transport}") 

1327 

1328 # Handle tools, resources, and prompts using helper methods 

1329 tools_to_add = self._update_or_create_tools(db, tools, gateway, "oauth") 

1330 resources_to_add = self._update_or_create_resources(db, resources, gateway, "oauth") 

1331 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "oauth") 

1332 

1333 # Clean up items that are no longer available from the gateway 

1334 new_tool_names = [tool.name for tool in tools] 

1335 new_resource_uris = [resource.uri for resource in resources] 

1336 new_prompt_names = [prompt.name for prompt in prompts] 

1337 

1338 # Count items before cleanup for logging 

1339 

1340 # Bulk delete tools that are no longer available from the gateway 

1341 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

1342 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

1343 if stale_tool_ids: 

1344 # Delete child records first to avoid FK constraint violations 

1345 for i in range(0, len(stale_tool_ids), 500): 

1346 chunk = stale_tool_ids[i : i + 500] 

1347 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

1348 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

1349 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

1350 

1351 # Bulk delete resources that are no longer available from the gateway 

1352 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

1353 if stale_resource_ids: 

1354 # Delete child records first to avoid FK constraint violations 

1355 for i in range(0, len(stale_resource_ids), 500): 

1356 chunk = stale_resource_ids[i : i + 500] 

1357 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

1358 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

1359 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

1360 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

1361 

1362 # Bulk delete prompts that are no longer available from the gateway 

1363 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

1364 if stale_prompt_ids: 

1365 # Delete child records first to avoid FK constraint violations 

1366 for i in range(0, len(stale_prompt_ids), 500): 

1367 chunk = stale_prompt_ids[i : i + 500] 

1368 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

1369 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

1370 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

1371 

1372 # Expire gateway to clear cached relationships after bulk deletes 

1373 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

1374 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

1375 db.expire(gateway) 

1376 

1377 # Update gateway relationships to reflect deletions 

1378 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] 

1379 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] 

1380 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] 

1381 

1382 # Log cleanup results 

1383 tools_removed = len(stale_tool_ids) 

1384 resources_removed = len(stale_resource_ids) 

1385 prompts_removed = len(stale_prompt_ids) 

1386 

1387 if tools_removed > 0: 

1388 logger.info(f"Removed {tools_removed} tools no longer available from gateway") 

1389 if resources_removed > 0: 

1390 logger.info(f"Removed {resources_removed} resources no longer available from gateway") 

1391 if prompts_removed > 0: 

1392 logger.info(f"Removed {prompts_removed} prompts no longer available from gateway") 

1393 

1394 # Update gateway capabilities and last_seen 

1395 gateway.capabilities = capabilities 

1396 gateway.last_seen = datetime.now(timezone.utc) 

1397 

1398 # Register capabilities for notification-driven actions 

1399 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

1400 

1401 # Add new items to DB in chunks to prevent lock escalation 

1402 items_added = 0 

1403 chunk_size = 50 

1404 

1405 if tools_to_add: 

1406 for i in range(0, len(tools_to_add), chunk_size): 

1407 chunk = tools_to_add[i : i + chunk_size] 

1408 db.add_all(chunk) 

1409 db.flush() # Flush each chunk to avoid excessive memory usage 

1410 items_added += len(tools_to_add) 

1411 logger.info(f"Added {len(tools_to_add)} new tools to database") 

1412 

1413 if resources_to_add: 

1414 for i in range(0, len(resources_to_add), chunk_size): 

1415 chunk = resources_to_add[i : i + chunk_size] 

1416 db.add_all(chunk) 

1417 db.flush() 

1418 items_added += len(resources_to_add) 

1419 logger.info(f"Added {len(resources_to_add)} new resources to database") 

1420 

1421 if prompts_to_add: 

1422 for i in range(0, len(prompts_to_add), chunk_size): 

1423 chunk = prompts_to_add[i : i + chunk_size] 

1424 db.add_all(chunk) 

1425 db.flush() 

1426 items_added += len(prompts_to_add) 

1427 logger.info(f"Added {len(prompts_to_add)} new prompts to database") 

1428 

1429 if items_added > 0: 

1430 db.commit() 

1431 logger.info(f"Total {items_added} new items added to database") 

1432 else: 

1433 logger.info("No new items to add to database") 

1434 # Still commit to save any updates to existing items 

1435 db.commit() 

1436 

1437 cache = _get_registry_cache() 

1438 await cache.invalidate_tools() 

1439 await cache.invalidate_resources() 

1440 await cache.invalidate_prompts() 

1441 tool_lookup_cache = _get_tool_lookup_cache() 

1442 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

1443 # Also invalidate tags cache since tool/resource tags may have changed 

1444 # First-Party 

1445 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

1446 

1447 await admin_stats_cache.invalidate_tags() 

1448 

1449 return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts} 

1450 

1451 except GatewayConnectionError as gce: 

1452 # Surface validation or depth-related failures directly to the user 

1453 logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}") 

1454 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}") 

1455 except Exception as e: 

1456 logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}") 

1457 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}") 

1458 

1459 async def list_gateways( 

1460 self, 

1461 db: Session, 

1462 include_inactive: bool = False, 

1463 tags: Optional[List[str]] = None, 

1464 cursor: Optional[str] = None, 

1465 limit: Optional[int] = None, 

1466 page: Optional[int] = None, 

1467 per_page: Optional[int] = None, 

1468 user_email: Optional[str] = None, 

1469 team_id: Optional[str] = None, 

1470 visibility: Optional[str] = None, 

1471 token_teams: Optional[List[str]] = None, 

1472 ) -> Union[tuple[List[GatewayRead], Optional[str]], Dict[str, Any]]: 

1473 """List all registered gateways with cursor pagination and optional team filtering. 

1474 

1475 Args: 

1476 db: Database session 

1477 include_inactive: Whether to include inactive gateways 

1478 tags (Optional[List[str]]): Filter resources by tags. If provided, only resources with at least one matching tag will be returned. 

1479 cursor: Cursor for pagination (encoded last created_at and id). 

1480 limit: Maximum number of gateways to return. None for default, 0 for unlimited. 

1481 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor. 

1482 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size. 

1483 user_email: Email of user for team-based access control. None for no access control. 

1484 team_id: Optional team ID to filter by specific team (requires user_email). 

1485 visibility: Optional visibility filter (private, team, public) (requires user_email). 

1486 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only). 

1487 

1488 Returns: 

1489 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}} 

1490 If cursor is provided or neither: tuple of (list of GatewayRead objects, next_cursor). 

1491 

1492 Examples: 

1493 >>> from mcpgateway.services.gateway_service import GatewayService 

1494 >>> from unittest.mock import MagicMock, AsyncMock, patch 

1495 >>> from mcpgateway.schemas import GatewayRead 

1496 >>> import asyncio 

1497 >>> service = GatewayService() 

1498 >>> db = MagicMock() 

1499 >>> gateway_obj = MagicMock() 

1500 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj] 

1501 >>> gateway_read_obj = MagicMock(spec=GatewayRead) 

1502 >>> service.convert_gateway_to_read = MagicMock(return_value=gateway_read_obj) 

1503 >>> # Mock the cache to bypass caching logic 

1504 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory: 

1505 ... mock_cache = MagicMock() 

1506 ... mock_cache.get = AsyncMock(return_value=None) 

1507 ... mock_cache.set = AsyncMock(return_value=None) 

1508 ... mock_cache.hash_filters = MagicMock(return_value="hash") 

1509 ... mock_cache_factory.return_value = mock_cache 

1510 ... gateways, cursor = asyncio.run(service.list_gateways(db)) 

1511 ... gateways == [gateway_read_obj] and cursor is None 

1512 True 

1513 

1514 >>> # Test empty result 

1515 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

1516 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory: 

1517 ... mock_cache = MagicMock() 

1518 ... mock_cache.get = AsyncMock(return_value=None) 

1519 ... mock_cache.set = AsyncMock(return_value=None) 

1520 ... mock_cache.hash_filters = MagicMock(return_value="hash") 

1521 ... mock_cache_factory.return_value = mock_cache 

1522 ... empty_result, cursor = asyncio.run(service.list_gateways(db)) 

1523 ... empty_result == [] and cursor is None 

1524 True 

1525 >>> 

1526 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

1527 >>> asyncio.run(service._http_client.aclose()) 

1528 """ 

1529 # Check cache for first page only - only for public-only queries (no user/team filtering) 

1530 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped 

1531 cache = _get_registry_cache() 

1532 is_public_only = token_teams is not None and len(token_teams) == 0 

1533 use_cache = cursor is None and user_email is None and page is None and is_public_only 

1534 if use_cache: 

1535 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None) 

1536 cached = await cache.get("gateways", filters_hash) 

1537 if cached is not None: 

1538 # Reconstruct GatewayRead objects from cached dicts 

1539 # SECURITY: Always apply .masked() to ensure stale cache entries don't leak credentials 

1540 cached_gateways = [GatewayRead.model_validate(g).masked() for g in cached["gateways"]] 

1541 return (cached_gateways, cached.get("next_cursor")) 

1542 

1543 # Build base query with ordering 

1544 query = select(DbGateway).options(joinedload(DbGateway.email_team)).order_by(desc(DbGateway.created_at), desc(DbGateway.id)) 

1545 

1546 # Apply active/inactive filter 

1547 if not include_inactive: 

1548 query = query.where(DbGateway.enabled) 

1549 

1550 # SECURITY: Apply token-based access control based on normalized token_teams 

1551 # - token_teams is None: admin bypass (is_admin=true with explicit null teams) - sees all 

1552 # - token_teams is []: public-only access (missing teams or explicit empty) 

1553 # - token_teams is [...]: access to specified teams + public + user's own 

1554 if token_teams is not None: 

1555 if len(token_teams) == 0: 

1556 # Public-only token: only access public gateways 

1557 query = query.where(DbGateway.visibility == "public") 

1558 else: 

1559 # Team-scoped token: public gateways + gateways in allowed teams + user's own 

1560 access_conditions = [ 

1561 DbGateway.visibility == "public", 

1562 and_(DbGateway.team_id.in_(token_teams), DbGateway.visibility.in_(["team", "public"])), 

1563 ] 

1564 if user_email: 

1565 access_conditions.append(and_(DbGateway.owner_email == user_email, DbGateway.visibility == "private")) 

1566 query = query.where(or_(*access_conditions)) 

1567 

1568 if visibility: 

1569 query = query.where(DbGateway.visibility == visibility) 

1570 

1571 # Apply team-based access control if user_email is provided (and no token_teams filtering) 

1572 elif user_email: 

1573 team_service = TeamManagementService(db) 

1574 user_teams = await team_service.get_user_teams(user_email) 

1575 team_ids = [team.id for team in user_teams] 

1576 

1577 if team_id: 

1578 # User requesting specific team - verify access 

1579 if team_id not in team_ids: 

1580 return ([], None) 

1581 access_conditions = [ 

1582 and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"])), 

1583 and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email), 

1584 ] 

1585 query = query.where(or_(*access_conditions)) 

1586 else: 

1587 # General access: user's gateways + public gateways + team gateways 

1588 access_conditions = [ 

1589 DbGateway.owner_email == user_email, 

1590 DbGateway.visibility == "public", 

1591 ] 

1592 if team_ids: 

1593 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"]))) 

1594 query = query.where(or_(*access_conditions)) 

1595 

1596 if visibility: 

1597 query = query.where(DbGateway.visibility == visibility) 

1598 

1599 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats) 

1600 if tags: 

1601 query = query.where(json_contains_tag_expr(db, DbGateway.tags, tags, match_any=True)) 

1602 # Use unified pagination helper - handles both page and cursor pagination 

1603 pag_result = await unified_paginate( 

1604 db=db, 

1605 query=query, 

1606 page=page, 

1607 per_page=per_page, 

1608 cursor=cursor, 

1609 limit=limit, 

1610 base_url="/admin/gateways", # Used for page-based links 

1611 query_params={"include_inactive": include_inactive} if include_inactive else {}, 

1612 ) 

1613 

1614 next_cursor = None 

1615 # Extract gateways based on pagination type 

1616 if page is not None: 

1617 # Page-based: pag_result is a dict 

1618 gateways_db = pag_result["data"] 

1619 else: 

1620 # Cursor-based: pag_result is a tuple 

1621 gateways_db, next_cursor = pag_result 

1622 

1623 db.commit() # Release transaction to avoid idle-in-transaction 

1624 

1625 # Convert to GatewayRead (common for both pagination types) 

1626 result = [] 

1627 for s in gateways_db: 

1628 try: 

1629 result.append(self.convert_gateway_to_read(s)) 

1630 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e: 

1631 logger.exception(f"Failed to convert gateway {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}") 

1632 # Continue with remaining gateways instead of failing completely 

1633 

1634 # Return appropriate format based on pagination type 

1635 if page is not None: 

1636 # Page-based format 

1637 return { 

1638 "data": result, 

1639 "pagination": pag_result["pagination"], 

1640 "links": pag_result["links"], 

1641 } 

1642 

1643 # Cursor-based format 

1644 

1645 # Cache first page results - only for public-only queries (no user/team filtering) 

1646 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped 

1647 if cursor is None and user_email is None and is_public_only: 

1648 try: 

1649 cache_data = {"gateways": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor} 

1650 await cache.set("gateways", cache_data, filters_hash) 

1651 except AttributeError: 

1652 pass # Skip caching if result objects don't support model_dump (e.g., in doctests) 

1653 

1654 return (result, next_cursor) 

1655 

1656 async def list_gateways_for_user( 

1657 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100 

1658 ) -> List[GatewayRead]: 

1659 """ 

1660 DEPRECATED: Use list_gateways() with user_email parameter instead. 

1661 

1662 This method is maintained for backward compatibility but is no longer used. 

1663 New code should call list_gateways() with user_email, team_id, and visibility parameters. 

1664 

1665 List gateways user has access to with team filtering. 

1666 

1667 Args: 

1668 db: Database session 

1669 user_email: Email of the user requesting gateways 

1670 team_id: Optional team ID to filter by specific team 

1671 visibility: Optional visibility filter (private, team, public) 

1672 include_inactive: Whether to include inactive gateways 

1673 skip: Number of gateways to skip for pagination 

1674 limit: Maximum number of gateways to return 

1675 

1676 Returns: 

1677 List[GatewayRead]: Gateways the user has access to 

1678 """ 

1679 # Build query following existing patterns from list_gateways() 

1680 team_service = TeamManagementService(db) 

1681 user_teams = await team_service.get_user_teams(user_email) 

1682 team_ids = [team.id for team in user_teams] 

1683 

1684 # Use joinedload to eager load email_team relationship (avoids N+1 queries) 

1685 query = select(DbGateway).options(joinedload(DbGateway.email_team)) 

1686 

1687 # Apply active/inactive filter 

1688 if not include_inactive: 

1689 query = query.where(DbGateway.enabled.is_(True)) 

1690 

1691 if team_id: 

1692 if team_id not in team_ids: 

1693 return [] # No access to team 

1694 

1695 access_conditions = [] 

1696 # Filter by specific team 

1697 

1698 # Team-owned gateways (team-scoped gateways) 

1699 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"]))) 

1700 

1701 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email)) 

1702 

1703 # Also include global public gateways (no team_id) so public gateways are visible regardless of selected team 

1704 access_conditions.append(DbGateway.visibility == "public") 

1705 

1706 query = query.where(or_(*access_conditions)) 

1707 else: 

1708 # Get user's accessible teams 

1709 # Build access conditions following existing patterns 

1710 access_conditions = [] 

1711 # 1. User's personal resources (owner_email matches) 

1712 access_conditions.append(DbGateway.owner_email == user_email) 

1713 # 2. Team resources where user is member 

1714 if team_ids: 1714 ↛ 1715line 1714 didn't jump to line 1715 because the condition on line 1714 was never true

1715 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"]))) 

1716 # 3. Public resources (if visibility allows) 

1717 access_conditions.append(DbGateway.visibility == "public") 

1718 

1719 query = query.where(or_(*access_conditions)) 

1720 

1721 # Apply visibility filter if specified 

1722 if visibility: 

1723 query = query.where(DbGateway.visibility == visibility) 

1724 

1725 # Apply pagination following existing patterns 

1726 query = query.offset(skip).limit(limit) 

1727 

1728 gateways = db.execute(query).scalars().all() 

1729 

1730 db.commit() # Release transaction to avoid idle-in-transaction 

1731 

1732 # Team names are loaded via joinedload(DbGateway.email_team) 

1733 result = [] 

1734 for g in gateways: 

1735 logger.info(f"Gateway: {g.team_id}, Team: {g.team}") 

1736 result.append(GatewayRead.model_validate(self._prepare_gateway_for_read(g)).masked()) 

1737 return result 

1738 

1739 async def update_gateway( 

1740 self, 

1741 db: Session, 

1742 gateway_id: str, 

1743 gateway_update: GatewayUpdate, 

1744 modified_by: Optional[str] = None, 

1745 modified_from_ip: Optional[str] = None, 

1746 modified_via: Optional[str] = None, 

1747 modified_user_agent: Optional[str] = None, 

1748 include_inactive: bool = True, 

1749 user_email: Optional[str] = None, 

1750 ) -> GatewayRead: 

1751 """Update a gateway. 

1752 

1753 Args: 

1754 db: Database session 

1755 gateway_id: Gateway ID to update 

1756 gateway_update: Updated gateway data 

1757 modified_by: Username of the person modifying the gateway 

1758 modified_from_ip: IP address where the modification request originated 

1759 modified_via: Source of modification (ui/api/import) 

1760 modified_user_agent: User agent string from the modification request 

1761 include_inactive: Whether to include inactive gateways 

1762 user_email: Email of user performing update (for ownership check) 

1763 

1764 Returns: 

1765 Updated gateway information 

1766 

1767 Raises: 

1768 GatewayNotFoundError: If gateway not found 

1769 PermissionError: If user doesn't own the gateway 

1770 GatewayError: For other update errors 

1771 GatewayNameConflictError: If gateway name conflict occurs 

1772 IntegrityError: If there is a database integrity error 

1773 ValidationError: If validation fails 

1774 """ 

1775 try: # pylint: disable=too-many-nested-blocks 

1776 # Acquire row lock and eager-load relationships while locked so 

1777 # concurrent updates are serialized on Postgres. 

1778 gateway = get_for_update( 

1779 db, 

1780 DbGateway, 

1781 gateway_id, 

1782 options=[ 

1783 selectinload(DbGateway.tools), 

1784 selectinload(DbGateway.resources), 

1785 selectinload(DbGateway.prompts), 

1786 selectinload(DbGateway.email_team), # Use selectinload to avoid locking email_teams 

1787 ], 

1788 ) 

1789 if not gateway: 

1790 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

1791 

1792 # Check ownership if user_email provided 

1793 if user_email: 

1794 # First-Party 

1795 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

1796 

1797 permission_service = PermissionService(db) 

1798 if not await permission_service.check_resource_ownership(user_email, gateway): 1798 ↛ 1801line 1798 didn't jump to line 1801 because the condition on line 1798 was always true

1799 raise PermissionError("Only the owner can update this gateway") 

1800 

1801 if gateway.enabled or include_inactive: 

1802 # Check for name conflicts if name is being changed 

1803 if gateway_update.name is not None and gateway_update.name != gateway.name: 

1804 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none() 

1805 

1806 # if existing_gateway: 

1807 # raise GatewayNameConflictError( 

1808 # gateway_update.name, 

1809 # enabled=existing_gateway.enabled, 

1810 # gateway_id=existing_gateway.id, 

1811 # ) 

1812 # Check for existing gateway with the same slug and visibility 

1813 new_slug = slugify(gateway_update.name) 

1814 if gateway_update.visibility is not None: 

1815 vis = gateway_update.visibility 

1816 else: 

1817 vis = gateway.visibility 

1818 if vis == "public": 

1819 # Check for existing public gateway with the same slug (row-locked) 

1820 existing_gateway = get_for_update( 

1821 db, 

1822 DbGateway, 

1823 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "public", DbGateway.id != gateway_id), 

1824 ) 

1825 if existing_gateway: 

1826 raise GatewayNameConflictError( 

1827 new_slug, 

1828 enabled=existing_gateway.enabled, 

1829 gateway_id=existing_gateway.id, 

1830 visibility=existing_gateway.visibility, 

1831 ) 

1832 elif vis == "team" and gateway.team_id: 1832 ↛ 1834line 1832 didn't jump to line 1834 because the condition on line 1832 was never true

1833 # Check for existing team gateway with the same slug (row-locked) 

1834 existing_gateway = get_for_update( 

1835 db, 

1836 DbGateway, 

1837 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "team", DbGateway.team_id == gateway.team_id, DbGateway.id != gateway_id), 

1838 ) 

1839 if existing_gateway: 

1840 raise GatewayNameConflictError( 

1841 new_slug, 

1842 enabled=existing_gateway.enabled, 

1843 gateway_id=existing_gateway.id, 

1844 visibility=existing_gateway.visibility, 

1845 ) 

1846 # Check for existing gateway with the same URL and visibility 

1847 normalized_url = "" 

1848 if gateway_update.url is not None: 

1849 normalized_url = self.normalize_url(str(gateway_update.url)) 

1850 else: 

1851 normalized_url = None 

1852 

1853 # Prepare decoded auth_value for uniqueness check 

1854 decoded_auth_value = None 

1855 if gateway_update.auth_value: 

1856 if isinstance(gateway_update.auth_value, str): 1856 ↛ 1861line 1856 didn't jump to line 1861 because the condition on line 1856 was always true

1857 try: 

1858 decoded_auth_value = decode_auth(gateway_update.auth_value) 

1859 except Exception as e: 

1860 logger.warning(f"Failed to decode provided auth_value: {e}") 

1861 elif isinstance(gateway_update.auth_value, dict): 

1862 decoded_auth_value = gateway_update.auth_value 

1863 

1864 # Determine final values for uniqueness check 

1865 final_auth_value = decoded_auth_value if gateway_update.auth_value is not None else (decode_auth(gateway.auth_value) if isinstance(gateway.auth_value, str) else gateway.auth_value) 

1866 final_oauth_config = gateway_update.oauth_config if gateway_update.oauth_config is not None else gateway.oauth_config 

1867 final_visibility = gateway_update.visibility if gateway_update.visibility is not None else gateway.visibility 

1868 

1869 # Check for duplicates with updated credentials 

1870 if not gateway_update.one_time_auth: 1870 ↛ 1892line 1870 didn't jump to line 1892 because the condition on line 1870 was always true

1871 duplicate_gateway = self._check_gateway_uniqueness( 

1872 db=db, 

1873 url=normalized_url, 

1874 auth_value=final_auth_value, 

1875 oauth_config=final_oauth_config, 

1876 team_id=gateway.team_id, 

1877 visibility=final_visibility, 

1878 gateway_id=gateway_id, # Exclude current gateway from check 

1879 owner_email=user_email, 

1880 ) 

1881 

1882 if duplicate_gateway: 1882 ↛ 1883line 1882 didn't jump to line 1883 because the condition on line 1882 was never true

1883 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway) 

1884 

1885 # FIX for Issue #1025: Determine if URL actually changed before we update it 

1886 # We need this early because we update gateway.url below, and need to know 

1887 # if it actually changed to decide whether to re-fetch tools 

1888 # tools/resoures/prompts are need to be re-fetched not only if URL changed , in case any update like authentication and visibility changed 

1889 # url_changed = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != gateway.url 

1890 

1891 # Save original values BEFORE updating for change detection checks later 

1892 original_url = gateway.url 

1893 original_auth_type = gateway.auth_type 

1894 

1895 # Update fields if provided 

1896 if gateway_update.name is not None: 

1897 gateway.name = gateway_update.name 

1898 gateway.slug = slugify(gateway_update.name) 

1899 if gateway_update.url is not None: 

1900 # Normalize the updated URL 

1901 gateway.url = self.normalize_url(str(gateway_update.url)) 

1902 if gateway_update.description is not None: 

1903 gateway.description = gateway_update.description 

1904 if gateway_update.transport is not None: 

1905 gateway.transport = gateway_update.transport 

1906 if gateway_update.tags is not None: 

1907 gateway.tags = gateway_update.tags 

1908 if gateway_update.visibility is not None: 

1909 gateway.visibility = gateway_update.visibility 

1910 if gateway_update.visibility is not None: 

1911 gateway.visibility = gateway_update.visibility 

1912 if gateway_update.passthrough_headers is not None: 

1913 if isinstance(gateway_update.passthrough_headers, list): 

1914 gateway.passthrough_headers = gateway_update.passthrough_headers 

1915 else: 

1916 if isinstance(gateway_update.passthrough_headers, str): 

1917 parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()] 

1918 gateway.passthrough_headers = parsed 

1919 else: 

1920 raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string") 

1921 

1922 logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}") 

1923 

1924 # Only update auth_type if explicitly provided in the update 

1925 if gateway_update.auth_type is not None: 

1926 gateway.auth_type = gateway_update.auth_type 

1927 

1928 # If auth_type is empty, update the auth_value too 

1929 if gateway_update.auth_type == "": 

1930 gateway.auth_value = cast(Any, "") 

1931 

1932 # Clear auth_query_params when switching away from query_param auth 

1933 if original_auth_type == "query_param" and gateway_update.auth_type != "query_param": 

1934 gateway.auth_query_params = None 

1935 logger.debug(f"Cleared auth_query_params for gateway {gateway.id} (switched from query_param to {gateway_update.auth_type})") 

1936 

1937 # if auth_type is not None and only then check auth_value 

1938 # Handle OAuth configuration updates 

1939 if gateway_update.oauth_config is not None: 

1940 gateway.oauth_config = gateway_update.oauth_config 

1941 

1942 # Handle auth_value updates (both existing and new auth values) 

1943 token = gateway_update.auth_token 

1944 password = gateway_update.auth_password 

1945 header_value = gateway_update.auth_header_value 

1946 

1947 # Support multiple custom headers on update 

1948 if hasattr(gateway_update, "auth_headers") and gateway_update.auth_headers: 

1949 existing_auth_raw = getattr(gateway, "auth_value", {}) or {} 

1950 if isinstance(existing_auth_raw, str): 

1951 try: 

1952 existing_auth = decode_auth(existing_auth_raw) 

1953 except Exception: 

1954 existing_auth = {} 

1955 elif isinstance(existing_auth_raw, dict): 1955 ↛ 1958line 1955 didn't jump to line 1958 because the condition on line 1955 was always true

1956 existing_auth = existing_auth_raw 

1957 else: 

1958 existing_auth = {} 

1959 

1960 header_dict: Dict[str, str] = {} 

1961 for header in gateway_update.auth_headers: 

1962 key = header.get("key") 

1963 if not key: 1963 ↛ 1964line 1963 didn't jump to line 1964 because the condition on line 1963 was never true

1964 continue 

1965 value = header.get("value", "") 

1966 if value == settings.masked_auth_value and key in existing_auth: 

1967 header_dict[key] = existing_auth[key] 

1968 else: 

1969 header_dict[key] = value 

1970 gateway.auth_value = header_dict # Store as dict for DB JSON field 

1971 elif settings.masked_auth_value not in (token, password, header_value): 

1972 # Check if values differ from existing ones or if setting for first time 

1973 decoded_auth = decode_auth(gateway_update.auth_value) if gateway_update.auth_value else {} 

1974 current_auth = getattr(gateway, "auth_value", {}) or {} 

1975 if current_auth != decoded_auth: 

1976 gateway.auth_value = decoded_auth 

1977 

1978 # Handle query_param auth updates with service-layer enforcement 

1979 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

1980 init_url = gateway.url 

1981 

1982 # Check if updating to query_param auth or updating existing query_param credentials 

1983 # Use original_auth_type since gateway.auth_type may have been updated already 

1984 is_switching_to_queryparam = gateway_update.auth_type == "query_param" and original_auth_type != "query_param" 

1985 is_updating_queryparam_creds = original_auth_type == "query_param" and (gateway_update.auth_query_param_key is not None or gateway_update.auth_query_param_value is not None) 

1986 is_url_changing = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != original_url 

1987 

1988 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"): 

1989 # Service-layer enforcement: Check feature flag 

1990 if not settings.insecure_allow_queryparam_auth: 

1991 # Grandfather clause: Allow updates to existing query_param gateways 

1992 # unless they're trying to change credentials 

1993 if is_switching_to_queryparam or is_updating_queryparam_creds: 

1994 raise ValueError("Query parameter authentication is disabled. " + "Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.") 

1995 

1996 # Service-layer enforcement: Check host allowlist 

1997 if settings.insecure_queryparam_auth_allowed_hosts: 

1998 check_url = str(gateway_update.url) if gateway_update.url else gateway.url 

1999 parsed = urlparse(check_url) 

2000 hostname = (parsed.hostname or "").lower() 

2001 if hostname not in settings.insecure_queryparam_auth_allowed_hosts: 

2002 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts) 

2003 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. Allowed: {allowed}") 

2004 

2005 # Process query_param auth credentials 

2006 param_key = getattr(gateway_update, "auth_query_param_key", None) or (next(iter(gateway.auth_query_params.keys()), None) if gateway.auth_query_params else None) 

2007 param_value = getattr(gateway_update, "auth_query_param_value", None) 

2008 

2009 # Get raw value from SecretStr if applicable 

2010 raw_value: Optional[str] = None 

2011 if param_value: 

2012 if hasattr(param_value, "get_secret_value"): 

2013 raw_value = param_value.get_secret_value() 

2014 else: 

2015 raw_value = str(param_value) 

2016 

2017 # Check if the value is the masked placeholder - if so, keep existing value 

2018 is_masked_placeholder = raw_value == settings.masked_auth_value 

2019 

2020 if param_key: 2020 ↛ 2038line 2020 didn't jump to line 2038 because the condition on line 2020 was always true

2021 if raw_value and not is_masked_placeholder: 

2022 # New value provided - encrypt for storage 

2023 encrypted_value = encode_auth({param_key: raw_value}) 

2024 gateway.auth_query_params = {param_key: encrypted_value} 

2025 auth_query_params_decrypted = {param_key: raw_value} 

2026 elif gateway.auth_query_params: 2026 ↛ 2034line 2026 didn't jump to line 2034 because the condition on line 2026 was always true

2027 # Use existing encrypted value 

2028 existing_encrypted = gateway.auth_query_params.get(param_key, "") 

2029 if existing_encrypted: 2029 ↛ 2034line 2029 didn't jump to line 2034 because the condition on line 2029 was always true

2030 decrypted = decode_auth(existing_encrypted) 

2031 auth_query_params_decrypted = {param_key: decrypted.get(param_key, "")} 

2032 

2033 # Append query params to URL for initialization 

2034 if auth_query_params_decrypted: 2034 ↛ 2038line 2034 didn't jump to line 2038 because the condition on line 2034 was always true

2035 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2036 

2037 # Update auth_type if switching 

2038 if is_switching_to_queryparam: 

2039 gateway.auth_type = "query_param" 

2040 gateway.auth_value = None # Query param auth doesn't use auth_value 

2041 

2042 elif gateway.auth_type == "query_param" and gateway.auth_query_params: 2042 ↛ 2044line 2042 didn't jump to line 2044 because the condition on line 2042 was never true

2043 # Existing query_param gateway without credential changes - decrypt for init 

2044 first_key = next(iter(gateway.auth_query_params.keys()), None) 

2045 if first_key: 

2046 encrypted_value = gateway.auth_query_params.get(first_key, "") 

2047 if encrypted_value: 

2048 decrypted = decode_auth(encrypted_value) 

2049 auth_query_params_decrypted = {first_key: decrypted.get(first_key, "")} 

2050 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2051 

2052 # Try to reinitialize connection if URL actually changed 

2053 # if url_changed: 

2054 # Initialize empty lists in case initialization fails 

2055 tools_to_add = [] 

2056 resources_to_add = [] 

2057 prompts_to_add = [] 

2058 

2059 try: 

2060 ca_certificate = getattr(gateway, "ca_certificate", None) 

2061 capabilities, tools, resources, prompts = await self._initialize_gateway( 

2062 init_url, 

2063 gateway.auth_value, 

2064 gateway.transport, 

2065 gateway.auth_type, 

2066 gateway.oauth_config, 

2067 ca_certificate, 

2068 auth_query_params=auth_query_params_decrypted, 

2069 ) 

2070 new_tool_names = [tool.name for tool in tools] 

2071 new_resource_uris = [resource.uri for resource in resources] 

2072 new_prompt_names = [prompt.name for prompt in prompts] 

2073 

2074 if gateway_update.one_time_auth: 2074 ↛ 2076line 2074 didn't jump to line 2076 because the condition on line 2074 was never true

2075 # For one-time auth, clear auth_type and auth_value after initialization 

2076 gateway.auth_type = "one_time_auth" 

2077 gateway.auth_value = None 

2078 gateway.oauth_config = None 

2079 

2080 # Update tools using helper method 

2081 tools_to_add = self._update_or_create_tools(db, tools, gateway, "update") 

2082 

2083 # Update resources using helper method 

2084 resources_to_add = self._update_or_create_resources(db, resources, gateway, "update") 

2085 

2086 # Update prompts using helper method 

2087 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "update") 

2088 

2089 # Log newly added items 

2090 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add) 

2091 if items_added > 0: 

2092 if tools_to_add: 2092 ↛ 2094line 2092 didn't jump to line 2094 because the condition on line 2092 was always true

2093 logger.info(f"Added {len(tools_to_add)} new tools during gateway update") 

2094 if resources_to_add: 2094 ↛ 2095line 2094 didn't jump to line 2095 because the condition on line 2094 was never true

2095 logger.info(f"Added {len(resources_to_add)} new resources during gateway update") 

2096 if prompts_to_add: 2096 ↛ 2097line 2096 didn't jump to line 2097 because the condition on line 2096 was never true

2097 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway update") 

2098 logger.info(f"Total {items_added} new items added during gateway update") 

2099 

2100 # Count items before cleanup for logging 

2101 

2102 # Bulk delete tools that are no longer available from the gateway 

2103 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2104 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

2105 if stale_tool_ids: 

2106 # Delete child records first to avoid FK constraint violations 

2107 for i in range(0, len(stale_tool_ids), 500): 

2108 chunk = stale_tool_ids[i : i + 500] 

2109 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2110 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2111 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2112 

2113 # Bulk delete resources that are no longer available from the gateway 

2114 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

2115 if stale_resource_ids: 2115 ↛ 2117line 2115 didn't jump to line 2117 because the condition on line 2115 was never true

2116 # Delete child records first to avoid FK constraint violations 

2117 for i in range(0, len(stale_resource_ids), 500): 

2118 chunk = stale_resource_ids[i : i + 500] 

2119 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2120 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2121 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2122 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2123 

2124 # Bulk delete prompts that are no longer available from the gateway 

2125 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

2126 if stale_prompt_ids: 2126 ↛ 2128line 2126 didn't jump to line 2128 because the condition on line 2126 was never true

2127 # Delete child records first to avoid FK constraint violations 

2128 for i in range(0, len(stale_prompt_ids), 500): 

2129 chunk = stale_prompt_ids[i : i + 500] 

2130 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2131 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2132 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2133 

2134 # Expire gateway to clear cached relationships after bulk deletes 

2135 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

2136 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

2137 db.expire(gateway) 

2138 

2139 gateway.capabilities = capabilities 

2140 

2141 # Register capabilities for notification-driven actions 

2142 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

2143 

2144 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

2145 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows 

2146 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows 

2147 

2148 # Log cleanup results 

2149 tools_removed = len(stale_tool_ids) 

2150 resources_removed = len(stale_resource_ids) 

2151 prompts_removed = len(stale_prompt_ids) 

2152 

2153 if tools_removed > 0: 

2154 logger.info(f"Removed {tools_removed} tools no longer available during gateway update") 

2155 if resources_removed > 0: 2155 ↛ 2156line 2155 didn't jump to line 2156 because the condition on line 2155 was never true

2156 logger.info(f"Removed {resources_removed} resources no longer available during gateway update") 

2157 if prompts_removed > 0: 2157 ↛ 2158line 2157 didn't jump to line 2158 because the condition on line 2157 was never true

2158 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway update") 

2159 

2160 gateway.last_seen = datetime.now(timezone.utc) 

2161 

2162 # Add new items to database session in chunks to prevent lock escalation 

2163 chunk_size = 50 

2164 

2165 if tools_to_add: 

2166 for i in range(0, len(tools_to_add), chunk_size): 

2167 chunk = tools_to_add[i : i + chunk_size] 

2168 db.add_all(chunk) 

2169 db.flush() 

2170 if resources_to_add: 2170 ↛ 2171line 2170 didn't jump to line 2171 because the condition on line 2170 was never true

2171 for i in range(0, len(resources_to_add), chunk_size): 

2172 chunk = resources_to_add[i : i + chunk_size] 

2173 db.add_all(chunk) 

2174 db.flush() 

2175 if prompts_to_add: 2175 ↛ 2176line 2175 didn't jump to line 2176 because the condition on line 2175 was never true

2176 for i in range(0, len(prompts_to_add), chunk_size): 

2177 chunk = prompts_to_add[i : i + chunk_size] 

2178 db.add_all(chunk) 

2179 db.flush() 

2180 

2181 # Update tracking with new URL 

2182 self._active_gateways.discard(gateway.url) 

2183 self._active_gateways.add(gateway.url) 

2184 except Exception as e: 

2185 logger.warning(f"Failed to initialize updated gateway: {e}") 

2186 

2187 # Update tags if provided 

2188 if gateway_update.tags is not None: 

2189 gateway.tags = gateway_update.tags 

2190 

2191 # Update metadata fields 

2192 gateway.updated_at = datetime.now(timezone.utc) 

2193 if modified_by: 

2194 gateway.modified_by = modified_by 

2195 if modified_from_ip: 

2196 gateway.modified_from_ip = modified_from_ip 

2197 if modified_via: 

2198 gateway.modified_via = modified_via 

2199 if modified_user_agent: 

2200 gateway.modified_user_agent = modified_user_agent 

2201 if hasattr(gateway, "version") and gateway.version is not None: 

2202 gateway.version = gateway.version + 1 

2203 else: 

2204 gateway.version = 1 

2205 

2206 db.commit() 

2207 db.refresh(gateway) 

2208 

2209 # Invalidate cache after successful update 

2210 cache = _get_registry_cache() 

2211 await cache.invalidate_gateways() 

2212 tool_lookup_cache = _get_tool_lookup_cache() 

2213 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

2214 # Also invalidate tags cache since gateway tags may have changed 

2215 # First-Party 

2216 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

2217 

2218 await admin_stats_cache.invalidate_tags() 

2219 

2220 # Notify subscribers 

2221 await self._notify_gateway_updated(gateway) 

2222 

2223 logger.info(f"Updated gateway: {gateway.name}") 

2224 

2225 # Structured logging: Audit trail for gateway update 

2226 audit_trail.log_action( 

2227 user_id=user_email or modified_by or "system", 

2228 action="update_gateway", 

2229 resource_type="gateway", 

2230 resource_id=str(gateway.id), 

2231 resource_name=gateway.name, 

2232 user_email=user_email, 

2233 team_id=gateway.team_id, 

2234 client_ip=modified_from_ip, 

2235 user_agent=modified_user_agent, 

2236 new_values={ 

2237 "name": gateway.name, 

2238 "url": gateway.url, 

2239 "version": gateway.version, 

2240 }, 

2241 context={ 

2242 "modified_via": modified_via, 

2243 }, 

2244 db=db, 

2245 ) 

2246 

2247 # Structured logging: Log successful gateway update 

2248 structured_logger.log( 

2249 level="INFO", 

2250 message="Gateway updated successfully", 

2251 event_type="gateway_updated", 

2252 component="gateway_service", 

2253 user_id=modified_by, 

2254 user_email=user_email, 

2255 team_id=gateway.team_id, 

2256 resource_type="gateway", 

2257 resource_id=str(gateway.id), 

2258 custom_fields={ 

2259 "gateway_name": gateway.name, 

2260 "version": gateway.version, 

2261 }, 

2262 db=db, 

2263 ) 

2264 

2265 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() 

2266 # Gateway is inactive and include_inactive is False → skip update, return None 

2267 return None 

2268 except GatewayNameConflictError as ge: 

2269 logger.error(f"GatewayNameConflictError in group: {ge}") 

2270 

2271 structured_logger.log( 

2272 level="WARNING", 

2273 message="Gateway update failed due to name conflict", 

2274 event_type="gateway_name_conflict", 

2275 component="gateway_service", 

2276 user_email=user_email, 

2277 resource_type="gateway", 

2278 resource_id=gateway_id, 

2279 error=ge, 

2280 db=db, 

2281 ) 

2282 raise ge 

2283 except GatewayNotFoundError as gnfe: 

2284 logger.error(f"GatewayNotFoundError: {gnfe}") 

2285 

2286 structured_logger.log( 

2287 level="ERROR", 

2288 message="Gateway update failed - gateway not found", 

2289 event_type="gateway_not_found", 

2290 component="gateway_service", 

2291 user_email=user_email, 

2292 resource_type="gateway", 

2293 resource_id=gateway_id, 

2294 error=gnfe, 

2295 db=db, 

2296 ) 

2297 raise gnfe 

2298 except IntegrityError as ie: 

2299 logger.error(f"IntegrityErrors in group: {ie}") 

2300 

2301 structured_logger.log( 

2302 level="ERROR", 

2303 message="Gateway update failed due to database integrity error", 

2304 event_type="gateway_update_failed", 

2305 component="gateway_service", 

2306 user_email=user_email, 

2307 resource_type="gateway", 

2308 resource_id=gateway_id, 

2309 error=ie, 

2310 db=db, 

2311 ) 

2312 raise ie 

2313 except PermissionError as pe: 

2314 db.rollback() 

2315 

2316 structured_logger.log( 

2317 level="WARNING", 

2318 message="Gateway update failed due to permission error", 

2319 event_type="gateway_update_permission_denied", 

2320 component="gateway_service", 

2321 user_email=user_email, 

2322 resource_type="gateway", 

2323 resource_id=gateway_id, 

2324 error=pe, 

2325 db=db, 

2326 ) 

2327 raise 

2328 except Exception as e: 

2329 db.rollback() 

2330 

2331 structured_logger.log( 

2332 level="ERROR", 

2333 message="Gateway update failed", 

2334 event_type="gateway_update_failed", 

2335 component="gateway_service", 

2336 user_email=user_email, 

2337 resource_type="gateway", 

2338 resource_id=gateway_id, 

2339 error=e, 

2340 db=db, 

2341 ) 

2342 raise GatewayError(f"Failed to update gateway: {str(e)}") 

2343 

2344 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead: 

2345 """Get a gateway by its ID. 

2346 

2347 Args: 

2348 db: Database session 

2349 gateway_id: Gateway ID 

2350 include_inactive: Whether to include inactive gateways 

2351 

2352 Returns: 

2353 GatewayRead object 

2354 

2355 Raises: 

2356 GatewayNotFoundError: If the gateway is not found 

2357 

2358 Examples: 

2359 >>> from unittest.mock import MagicMock 

2360 >>> from mcpgateway.schemas import GatewayRead 

2361 >>> service = GatewayService() 

2362 >>> db = MagicMock() 

2363 >>> gateway_mock = MagicMock() 

2364 >>> gateway_mock.enabled = True 

2365 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock 

2366 >>> mocked_gateway_read = MagicMock() 

2367 >>> mocked_gateway_read.masked.return_value = 'gateway_read' 

2368 >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read) 

2369 >>> import asyncio 

2370 >>> result = asyncio.run(service.get_gateway(db, 'gateway_id')) 

2371 >>> result == 'gateway_read' 

2372 True 

2373 

2374 >>> # Test with inactive gateway but include_inactive=True 

2375 >>> gateway_mock.enabled = False 

2376 >>> result_inactive = asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=True)) 

2377 >>> result_inactive == 'gateway_read' 

2378 True 

2379 

2380 >>> # Test gateway not found 

2381 >>> db.execute.return_value.scalar_one_or_none.return_value = None 

2382 >>> try: 

2383 ... asyncio.run(service.get_gateway(db, 'missing_id')) 

2384 ... except GatewayNotFoundError as e: 

2385 ... 'Gateway not found: missing_id' in str(e) 

2386 True 

2387 

2388 >>> # Test inactive gateway with include_inactive=False 

2389 >>> gateway_mock.enabled = False 

2390 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock 

2391 >>> try: 

2392 ... asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=False)) 

2393 ... except GatewayNotFoundError as e: 

2394 ... 'Gateway not found: gateway_id' in str(e) 

2395 True 

2396 >>> 

2397 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

2398 >>> asyncio.run(service._http_client.aclose()) 

2399 """ 

2400 # Use eager loading to avoid N+1 queries for relationships and team name 

2401 gateway = db.execute( 

2402 select(DbGateway) 

2403 .options( 

2404 selectinload(DbGateway.tools), 

2405 selectinload(DbGateway.resources), 

2406 selectinload(DbGateway.prompts), 

2407 joinedload(DbGateway.email_team), 

2408 ) 

2409 .where(DbGateway.id == gateway_id) 

2410 ).scalar_one_or_none() 

2411 

2412 if not gateway: 

2413 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2414 

2415 if gateway.enabled or include_inactive: 

2416 # Structured logging: Log gateway view 

2417 structured_logger.log( 

2418 level="INFO", 

2419 message="Gateway retrieved successfully", 

2420 event_type="gateway_viewed", 

2421 component="gateway_service", 

2422 team_id=getattr(gateway, "team_id", None), 

2423 resource_type="gateway", 

2424 resource_id=str(gateway.id), 

2425 custom_fields={ 

2426 "gateway_name": gateway.name, 

2427 "gateway_url": gateway.url, 

2428 "include_inactive": include_inactive, 

2429 }, 

2430 db=db, 

2431 ) 

2432 

2433 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() 

2434 

2435 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2436 

2437 async def set_gateway_state(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False, user_email: Optional[str] = None) -> GatewayRead: 

2438 """ 

2439 Set the activation status of a gateway. 

2440 

2441 Args: 

2442 db: Database session 

2443 gateway_id: Gateway ID 

2444 activate: True to activate, False to deactivate 

2445 reachable: Whether the gateway is reachable 

2446 only_update_reachable: Only update reachable status 

2447 user_email: Optional[str] The email of the user to check if the user has permission to modify. 

2448 

2449 Returns: 

2450 The updated GatewayRead object 

2451 

2452 Raises: 

2453 GatewayNotFoundError: If the gateway is not found 

2454 GatewayError: For other errors 

2455 PermissionError: If user doesn't own the agent. 

2456 """ 

2457 try: 

2458 # Eager-load collections for the gateway. Note: we don't use FOR UPDATE 

2459 # here because _initialize_gateway does network I/O, and holding a row 

2460 # lock during network calls would block other operations and risk timeouts. 

2461 gateway = db.execute( 

2462 select(DbGateway) 

2463 .options( 

2464 selectinload(DbGateway.tools), 

2465 selectinload(DbGateway.resources), 

2466 selectinload(DbGateway.prompts), 

2467 joinedload(DbGateway.email_team), 

2468 ) 

2469 .where(DbGateway.id == gateway_id) 

2470 ).scalar_one_or_none() 

2471 if not gateway: 

2472 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2473 

2474 if user_email: 

2475 # First-Party 

2476 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2477 

2478 permission_service = PermissionService(db) 

2479 if not await permission_service.check_resource_ownership(user_email, gateway): 2479 ↛ 2483line 2479 didn't jump to line 2483 because the condition on line 2479 was always true

2480 raise PermissionError("Only the owner can activate the gateway" if activate else "Only the owner can deactivate the gateway") 

2481 

2482 # Update status if it's different 

2483 if (gateway.enabled != activate) or (gateway.reachable != reachable): 

2484 gateway.enabled = activate 

2485 gateway.reachable = reachable 

2486 gateway.updated_at = datetime.now(timezone.utc) 

2487 # Update tracking 

2488 if activate and reachable: 

2489 self._active_gateways.add(gateway.url) 

2490 

2491 # Initialize empty lists in case initialization fails 

2492 tools_to_add = [] 

2493 resources_to_add = [] 

2494 prompts_to_add = [] 

2495 

2496 # Try to initialize if activating 

2497 try: 

2498 # Handle query_param auth - decrypt and apply to URL 

2499 init_url = gateway.url 

2500 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

2501 if gateway.auth_type == "query_param" and gateway.auth_query_params: 

2502 auth_query_params_decrypted = {} 

2503 for param_key, encrypted_value in gateway.auth_query_params.items(): 

2504 if encrypted_value: 2504 ↛ 2503line 2504 didn't jump to line 2503 because the condition on line 2504 was always true

2505 try: 

2506 decrypted = decode_auth(encrypted_value) 

2507 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

2508 except Exception: 

2509 logger.debug(f"Failed to decrypt query param '{param_key}' for gateway activation") 

2510 if auth_query_params_decrypted: 

2511 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted) 

2512 

2513 capabilities, tools, resources, prompts = await self._initialize_gateway( 

2514 init_url, gateway.auth_value, gateway.transport, gateway.auth_type, gateway.oauth_config, auth_query_params=auth_query_params_decrypted, oauth_auto_fetch_tool_flag=True 

2515 ) 

2516 new_tool_names = [tool.name for tool in tools] 

2517 new_resource_uris = [resource.uri for resource in resources] 

2518 new_prompt_names = [prompt.name for prompt in prompts] 

2519 

2520 # Update tools, resources, and prompts using helper methods 

2521 tools_to_add = self._update_or_create_tools(db, tools, gateway, "rediscovery") 

2522 resources_to_add = self._update_or_create_resources(db, resources, gateway, "rediscovery") 

2523 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "rediscovery") 

2524 

2525 # Log newly added items 

2526 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add) 

2527 if items_added > 0: 

2528 if tools_to_add: 2528 ↛ 2530line 2528 didn't jump to line 2530 because the condition on line 2528 was always true

2529 logger.info(f"Added {len(tools_to_add)} new tools during gateway reactivation") 

2530 if resources_to_add: 2530 ↛ 2531line 2530 didn't jump to line 2531 because the condition on line 2530 was never true

2531 logger.info(f"Added {len(resources_to_add)} new resources during gateway reactivation") 

2532 if prompts_to_add: 2532 ↛ 2533line 2532 didn't jump to line 2533 because the condition on line 2532 was never true

2533 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway reactivation") 

2534 logger.info(f"Total {items_added} new items added during gateway reactivation") 

2535 

2536 # Count items before cleanup for logging 

2537 

2538 # Bulk delete tools that are no longer available from the gateway 

2539 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2540 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] 

2541 if stale_tool_ids: 

2542 # Delete child records first to avoid FK constraint violations 

2543 for i in range(0, len(stale_tool_ids), 500): 

2544 chunk = stale_tool_ids[i : i + 500] 

2545 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2546 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2547 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2548 

2549 # Bulk delete resources that are no longer available from the gateway 

2550 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] 

2551 if stale_resource_ids: 2551 ↛ 2553line 2551 didn't jump to line 2553 because the condition on line 2551 was never true

2552 # Delete child records first to avoid FK constraint violations 

2553 for i in range(0, len(stale_resource_ids), 500): 

2554 chunk = stale_resource_ids[i : i + 500] 

2555 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2556 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2557 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2558 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2559 

2560 # Bulk delete prompts that are no longer available from the gateway 

2561 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] 

2562 if stale_prompt_ids: 2562 ↛ 2564line 2562 didn't jump to line 2564 because the condition on line 2562 was never true

2563 # Delete child records first to avoid FK constraint violations 

2564 for i in range(0, len(stale_prompt_ids), 500): 

2565 chunk = stale_prompt_ids[i : i + 500] 

2566 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2567 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2568 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2569 

2570 # Expire gateway to clear cached relationships after bulk deletes 

2571 # This prevents SQLAlchemy from trying to re-delete already-deleted items 

2572 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

2573 db.expire(gateway) 

2574 

2575 gateway.capabilities = capabilities 

2576 

2577 # Register capabilities for notification-driven actions 

2578 register_gateway_capabilities_for_notifications(gateway.id, capabilities) 

2579 

2580 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

2581 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows 

2582 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows 

2583 

2584 # Log cleanup results 

2585 tools_removed = len(stale_tool_ids) 

2586 resources_removed = len(stale_resource_ids) 

2587 prompts_removed = len(stale_prompt_ids) 

2588 

2589 if tools_removed > 0: 

2590 logger.info(f"Removed {tools_removed} tools no longer available during gateway reactivation") 

2591 if resources_removed > 0: 2591 ↛ 2592line 2591 didn't jump to line 2592 because the condition on line 2591 was never true

2592 logger.info(f"Removed {resources_removed} resources no longer available during gateway reactivation") 

2593 if prompts_removed > 0: 2593 ↛ 2594line 2593 didn't jump to line 2594 because the condition on line 2593 was never true

2594 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway reactivation") 

2595 

2596 gateway.last_seen = datetime.now(timezone.utc) 

2597 

2598 # Add new items to database session in chunks to prevent lock escalation 

2599 chunk_size = 50 

2600 

2601 if tools_to_add: 

2602 for i in range(0, len(tools_to_add), chunk_size): 

2603 chunk = tools_to_add[i : i + chunk_size] 

2604 db.add_all(chunk) 

2605 db.flush() 

2606 if resources_to_add: 2606 ↛ 2607line 2606 didn't jump to line 2607 because the condition on line 2606 was never true

2607 for i in range(0, len(resources_to_add), chunk_size): 

2608 chunk = resources_to_add[i : i + chunk_size] 

2609 db.add_all(chunk) 

2610 db.flush() 

2611 if prompts_to_add: 2611 ↛ 2612line 2611 didn't jump to line 2612 because the condition on line 2611 was never true

2612 for i in range(0, len(prompts_to_add), chunk_size): 

2613 chunk = prompts_to_add[i : i + chunk_size] 

2614 db.add_all(chunk) 

2615 db.flush() 

2616 except Exception as e: 

2617 logger.warning(f"Failed to initialize reactivated gateway: {e}") 

2618 else: 

2619 self._active_gateways.discard(gateway.url) 

2620 

2621 db.commit() 

2622 db.refresh(gateway) 

2623 

2624 # Invalidate cache after status change 

2625 cache = _get_registry_cache() 

2626 await cache.invalidate_gateways() 

2627 

2628 # Notify Subscribers 

2629 if not gateway.enabled: 

2630 # Inactive 

2631 await self._notify_gateway_deactivated(gateway) 

2632 elif gateway.enabled and not gateway.reachable: 

2633 # Offline (Enabled but Unreachable) 

2634 await self._notify_gateway_offline(gateway) 

2635 else: 

2636 # Active (Enabled and Reachable) 

2637 await self._notify_gateway_activated(gateway) 

2638 

2639 # Bulk update tools - single UPDATE statement instead of N FOR UPDATE locks 

2640 # This prevents lock contention under high concurrent load 

2641 now = datetime.now(timezone.utc) 

2642 if only_update_reachable: 

2643 # Only update reachable status, keep enabled as-is 

2644 tools_result = db.execute(update(DbTool).where(DbTool.gateway_id == gateway_id).where(DbTool.reachable != reachable).values(reachable=reachable, updated_at=now)) 

2645 else: 

2646 # Update both enabled and reachable 

2647 tools_result = db.execute( 

2648 update(DbTool) 

2649 .where(DbTool.gateway_id == gateway_id) 

2650 .where(or_(DbTool.enabled != activate, DbTool.reachable != reachable)) 

2651 .values(enabled=activate, reachable=reachable, updated_at=now) 

2652 ) 

2653 tools_updated = tools_result.rowcount 

2654 

2655 # Commit tool updates 

2656 if tools_updated > 0: 

2657 db.commit() 

2658 

2659 # Invalidate tools cache once after bulk update 

2660 if tools_updated > 0: 

2661 await cache.invalidate_tools() 

2662 tool_lookup_cache = _get_tool_lookup_cache() 

2663 await tool_lookup_cache.invalidate_gateway(str(gateway.id)) 

2664 

2665 # Bulk update prompts when gateway is deactivated/activated (skip for reachability-only updates) 

2666 prompts_updated = 0 

2667 if not only_update_reachable: 

2668 prompts_result = db.execute(update(DbPrompt).where(DbPrompt.gateway_id == gateway_id).where(DbPrompt.enabled != activate).values(enabled=activate, updated_at=now)) 

2669 prompts_updated = prompts_result.rowcount 

2670 if prompts_updated > 0: 

2671 db.commit() 

2672 await cache.invalidate_prompts() 

2673 

2674 # Bulk update resources when gateway is deactivated/activated (skip for reachability-only updates) 

2675 resources_updated = 0 

2676 if not only_update_reachable: 

2677 resources_result = db.execute(update(DbResource).where(DbResource.gateway_id == gateway_id).where(DbResource.enabled != activate).values(enabled=activate, updated_at=now)) 

2678 resources_updated = resources_result.rowcount 

2679 if resources_updated > 0: 

2680 db.commit() 

2681 await cache.invalidate_resources() 

2682 

2683 logger.debug(f"Gateway {gateway.name} bulk state update: {tools_updated} tools, {prompts_updated} prompts, {resources_updated} resources") 

2684 

2685 logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") 

2686 

2687 # Structured logging: Audit trail for gateway state change 

2688 audit_trail.log_action( 

2689 user_id=user_email or "system", 

2690 action="set_gateway_state", 

2691 resource_type="gateway", 

2692 resource_id=str(gateway.id), 

2693 resource_name=gateway.name, 

2694 user_email=user_email, 

2695 team_id=gateway.team_id, 

2696 new_values={ 

2697 "enabled": gateway.enabled, 

2698 "reachable": gateway.reachable, 

2699 }, 

2700 context={ 

2701 "action": "activate" if activate else "deactivate", 

2702 "only_update_reachable": only_update_reachable, 

2703 }, 

2704 db=db, 

2705 ) 

2706 

2707 # Structured logging: Log successful gateway state change 

2708 structured_logger.log( 

2709 level="INFO", 

2710 message=f"Gateway {'activated' if activate else 'deactivated'} successfully", 

2711 event_type="gateway_state_changed", 

2712 component="gateway_service", 

2713 user_email=user_email, 

2714 team_id=gateway.team_id, 

2715 resource_type="gateway", 

2716 resource_id=str(gateway.id), 

2717 custom_fields={ 

2718 "gateway_name": gateway.name, 

2719 "enabled": gateway.enabled, 

2720 "reachable": gateway.reachable, 

2721 }, 

2722 db=db, 

2723 ) 

2724 

2725 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() 

2726 

2727 except PermissionError as e: 

2728 # Structured logging: Log permission error 

2729 structured_logger.log( 

2730 level="WARNING", 

2731 message="Gateway state change failed due to permission error", 

2732 event_type="gateway_state_change_permission_denied", 

2733 component="gateway_service", 

2734 user_email=user_email, 

2735 resource_type="gateway", 

2736 resource_id=gateway_id, 

2737 error=e, 

2738 db=db, 

2739 ) 

2740 raise e 

2741 except Exception as e: 

2742 db.rollback() 

2743 

2744 # Structured logging: Log generic gateway state change failure 

2745 structured_logger.log( 

2746 level="ERROR", 

2747 message="Gateway state change failed", 

2748 event_type="gateway_state_change_failed", 

2749 component="gateway_service", 

2750 user_email=user_email, 

2751 resource_type="gateway", 

2752 resource_id=gateway_id, 

2753 error=e, 

2754 db=db, 

2755 ) 

2756 raise GatewayError(f"Failed to set gateway state: {str(e)}") 

2757 

2758 async def _notify_gateway_updated(self, gateway: DbGateway) -> None: 

2759 """ 

2760 Notify subscribers of gateway update. 

2761 

2762 Args: 

2763 gateway: Gateway to update 

2764 """ 

2765 event = { 

2766 "type": "gateway_updated", 

2767 "data": { 

2768 "id": gateway.id, 

2769 "name": gateway.name, 

2770 "url": gateway.url, 

2771 "description": gateway.description, 

2772 "enabled": gateway.enabled, 

2773 }, 

2774 "timestamp": datetime.now(timezone.utc).isoformat(), 

2775 } 

2776 await self._publish_event(event) 

2777 

2778 async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optional[str] = None) -> None: 

2779 """ 

2780 Delete a gateway by its ID. 

2781 

2782 Args: 

2783 db: Database session 

2784 gateway_id: Gateway ID 

2785 user_email: Email of user performing deletion (for ownership check) 

2786 

2787 Raises: 

2788 GatewayNotFoundError: If the gateway is not found 

2789 PermissionError: If user doesn't own the gateway 

2790 GatewayError: For other deletion errors 

2791 

2792 Examples: 

2793 >>> from mcpgateway.services.gateway_service import GatewayService 

2794 >>> from unittest.mock import MagicMock 

2795 >>> service = GatewayService() 

2796 >>> db = MagicMock() 

2797 >>> gateway = MagicMock() 

2798 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway 

2799 >>> db.delete = MagicMock() 

2800 >>> db.commit = MagicMock() 

2801 >>> service._notify_gateway_deleted = MagicMock() 

2802 >>> import asyncio 

2803 >>> try: 

2804 ... asyncio.run(service.delete_gateway(db, 'gateway_id', 'user@example.com')) 

2805 ... except Exception: 

2806 ... pass 

2807 >>> 

2808 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

2809 >>> asyncio.run(service._http_client.aclose()) 

2810 """ 

2811 try: 

2812 # Find gateway with eager loading for deletion to avoid N+1 queries 

2813 gateway = db.execute( 

2814 select(DbGateway) 

2815 .options( 

2816 selectinload(DbGateway.tools), 

2817 selectinload(DbGateway.resources), 

2818 selectinload(DbGateway.prompts), 

2819 joinedload(DbGateway.email_team), 

2820 ) 

2821 .where(DbGateway.id == gateway_id) 

2822 ).scalar_one_or_none() 

2823 

2824 if not gateway: 

2825 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2826 

2827 # Check ownership if user_email provided 

2828 if user_email: 

2829 # First-Party 

2830 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel 

2831 

2832 permission_service = PermissionService(db) 

2833 if not await permission_service.check_resource_ownership(user_email, gateway): 

2834 raise PermissionError("Only the owner can delete this gateway") 

2835 

2836 # Store gateway info for notification before deletion 

2837 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} 

2838 gateway_name = gateway.name 

2839 gateway_team_id = gateway.team_id 

2840 gateway_url = gateway.url # Store URL before expiring the object 

2841 

2842 # Manually delete children first to avoid FK constraint violations 

2843 # (passive_deletes=True means ORM won't auto-cascade, we must do it explicitly) 

2844 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses 

2845 tool_ids = [t.id for t in gateway.tools] 

2846 resource_ids = [r.id for r in gateway.resources] 

2847 prompt_ids = [p.id for p in gateway.prompts] 

2848 

2849 # Delete tool children and tools 

2850 if tool_ids: 

2851 for i in range(0, len(tool_ids), 500): 

2852 chunk = tool_ids[i : i + 500] 

2853 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

2854 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

2855 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

2856 

2857 # Delete resource children and resources 

2858 if resource_ids: 

2859 for i in range(0, len(resource_ids), 500): 

2860 chunk = resource_ids[i : i + 500] 

2861 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

2862 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

2863 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

2864 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

2865 

2866 # Delete prompt children and prompts 

2867 if prompt_ids: 

2868 for i in range(0, len(prompt_ids), 500): 

2869 chunk = prompt_ids[i : i + 500] 

2870 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

2871 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

2872 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

2873 

2874 # Expire gateway to clear cached relationships after bulk deletes 

2875 db.expire(gateway) 

2876 

2877 # Use DELETE with rowcount check for database-agnostic atomic delete 

2878 # (RETURNING is not supported on MySQL/MariaDB) 

2879 stmt = delete(DbGateway).where(DbGateway.id == gateway_id) 

2880 result = db.execute(stmt) 

2881 if result.rowcount == 0: 2881 ↛ 2883line 2881 didn't jump to line 2883 because the condition on line 2881 was never true

2882 # Gateway was already deleted by another concurrent request 

2883 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

2884 

2885 db.commit() 

2886 

2887 # Invalidate cache after successful deletion 

2888 cache = _get_registry_cache() 

2889 await cache.invalidate_gateways() 

2890 tool_lookup_cache = _get_tool_lookup_cache() 

2891 await tool_lookup_cache.invalidate_gateway(str(gateway_id)) 

2892 # Also invalidate tags cache since gateway tags may have changed 

2893 # First-Party 

2894 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

2895 

2896 await admin_stats_cache.invalidate_tags() 

2897 

2898 # Update tracking 

2899 self._active_gateways.discard(gateway_url) 

2900 

2901 # Notify subscribers 

2902 await self._notify_gateway_deleted(gateway_info) 

2903 

2904 logger.info(f"Permanently deleted gateway: {gateway_name}") 

2905 

2906 # Structured logging: Audit trail for gateway deletion 

2907 audit_trail.log_action( 

2908 user_id=user_email or "system", 

2909 action="delete_gateway", 

2910 resource_type="gateway", 

2911 resource_id=str(gateway_info["id"]), 

2912 resource_name=gateway_name, 

2913 user_email=user_email, 

2914 team_id=gateway_team_id, 

2915 old_values={ 

2916 "name": gateway_name, 

2917 "url": gateway_info["url"], 

2918 }, 

2919 db=db, 

2920 ) 

2921 

2922 # Structured logging: Log successful gateway deletion 

2923 structured_logger.log( 

2924 level="INFO", 

2925 message="Gateway deleted successfully", 

2926 event_type="gateway_deleted", 

2927 component="gateway_service", 

2928 user_email=user_email, 

2929 team_id=gateway_team_id, 

2930 resource_type="gateway", 

2931 resource_id=str(gateway_info["id"]), 

2932 custom_fields={ 

2933 "gateway_name": gateway_name, 

2934 "gateway_url": gateway_info["url"], 

2935 }, 

2936 db=db, 

2937 ) 

2938 

2939 except PermissionError as pe: 

2940 db.rollback() 

2941 

2942 # Structured logging: Log permission error 

2943 structured_logger.log( 

2944 level="WARNING", 

2945 message="Gateway deletion failed due to permission error", 

2946 event_type="gateway_delete_permission_denied", 

2947 component="gateway_service", 

2948 user_email=user_email, 

2949 resource_type="gateway", 

2950 resource_id=gateway_id, 

2951 error=pe, 

2952 db=db, 

2953 ) 

2954 raise 

2955 except Exception as e: 

2956 db.rollback() 

2957 

2958 # Structured logging: Log generic gateway deletion failure 

2959 structured_logger.log( 

2960 level="ERROR", 

2961 message="Gateway deletion failed", 

2962 event_type="gateway_deletion_failed", 

2963 component="gateway_service", 

2964 user_email=user_email, 

2965 resource_type="gateway", 

2966 resource_id=gateway_id, 

2967 error=e, 

2968 db=db, 

2969 ) 

2970 raise GatewayError(f"Failed to delete gateway: {str(e)}") 

2971 

2972 async def forward_request( 

2973 self, 

2974 gateway_or_db, 

2975 method: str, 

2976 params: Optional[Dict[str, Any]] = None, 

2977 app_user_email: Optional[str] = None, 

2978 user_email: Optional[str] = None, 

2979 token_teams: Optional[List[str]] = None, 

2980 ) -> Any: # noqa: F811 # pylint: disable=function-redefined 

2981 """ 

2982 Forward a request to a gateway or multiple gateways. 

2983 

2984 This method handles two calling patterns: 

2985 1. forward_request(gateway, method, params) - Forward to a specific gateway 

2986 2. forward_request(db, method, params) - Forward to active gateways in the database 

2987 

2988 Args: 

2989 gateway_or_db: Either a DbGateway object or database Session 

2990 method: RPC method name 

2991 params: Optional method parameters 

2992 app_user_email: Optional app user email for OAuth token selection 

2993 user_email: Optional user email for team-based access control 

2994 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only) 

2995 

2996 Returns: 

2997 Gateway response 

2998 

2999 Raises: 

3000 GatewayConnectionError: If forwarding fails 

3001 GatewayError: If gateway gave an error 

3002 """ 

3003 # Dispatch based on first parameter type 

3004 if hasattr(gateway_or_db, "execute"): 

3005 # This is a database session - forward to all active gateways 

3006 return await self._forward_request_to_all(gateway_or_db, method, params, app_user_email, user_email, token_teams) 

3007 # This is a gateway object - forward to specific gateway 

3008 return await self._forward_request_to_gateway(gateway_or_db, method, params, app_user_email) 

3009 

3010 async def _forward_request_to_gateway(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None, app_user_email: Optional[str] = None) -> Any: 

3011 """ 

3012 Forward a request to a specific gateway. 

3013 

3014 Args: 

3015 gateway: Gateway to forward to 

3016 method: RPC method name 

3017 params: Optional method parameters 

3018 app_user_email: Optional app user email for OAuth token selection 

3019 

3020 Returns: 

3021 Gateway response 

3022 

3023 Raises: 

3024 GatewayConnectionError: If forwarding fails 

3025 GatewayError: If gateway gave an error 

3026 """ 

3027 start_time = time.monotonic() 

3028 

3029 # Create trace span for gateway federation 

3030 with create_span( 

3031 "gateway.forward_request", 

3032 { 

3033 "gateway.name": gateway.name, 

3034 "gateway.id": str(gateway.id), 

3035 "gateway.url": gateway.url, 

3036 "rpc.method": method, 

3037 "rpc.service": "mcp-gateway", 

3038 "http.method": "POST", 

3039 "http.url": urljoin(gateway.url, "/rpc"), 

3040 "peer.service": gateway.name, 

3041 }, 

3042 ) as span: 

3043 if not gateway.enabled: 

3044 raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}") 

3045 

3046 response = None # Initialize response to avoid UnboundLocalError 

3047 try: 

3048 # Build RPC request 

3049 request: Dict[str, Any] = {"jsonrpc": "2.0", "id": 1, "method": method} 

3050 if params: 

3051 request["params"] = params 

3052 if span: 3052 ↛ 3053line 3052 didn't jump to line 3053 because the condition on line 3052 was never true

3053 span.set_attribute("rpc.params_count", len(params)) 

3054 

3055 # Handle OAuth authentication for the specific gateway 

3056 headers: Dict[str, str] = {} 

3057 

3058 if getattr(gateway, "auth_type", None) == "oauth" and gateway.oauth_config: 

3059 try: 

3060 grant_type = gateway.oauth_config.get("grant_type", "client_credentials") 

3061 

3062 if grant_type == "client_credentials": 

3063 # Use OAuth manager to get access token for Client Credentials flow 

3064 access_token = await self.oauth_manager.get_access_token(gateway.oauth_config) 

3065 headers = {"Authorization": f"Bearer {access_token}"} 

3066 elif grant_type == "authorization_code": 3066 ↛ 3108line 3066 didn't jump to line 3108 because the condition on line 3066 was always true

3067 # For Authorization Code flow, try to get a stored token 

3068 if not app_user_email: 

3069 logger.warning(f"Skipping OAuth authorization code gateway {gateway.name} - user-specific tokens required but no user email provided") 

3070 raise GatewayConnectionError(f"OAuth authorization code gateway {gateway.name} requires user context") 

3071 

3072 # First-Party 

3073 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

3074 

3075 # Get database session (this is a bit hacky but necessary for now) 

3076 db = next(get_db()) 

3077 try: 

3078 token_storage = TokenStorageService(db) 

3079 access_token = await token_storage.get_user_token(str(gateway.id), app_user_email) 

3080 if access_token: 

3081 headers = {"Authorization": f"Bearer {access_token}"} 

3082 else: 

3083 raise GatewayConnectionError(f"No valid OAuth token for user {app_user_email} and gateway {gateway.name}") 

3084 finally: 

3085 # Ensure close() always runs even if commit() fails 

3086 # Without this nested try/finally, a commit() failure (e.g., PgBouncer timeout) 

3087 # would skip close(), leaving the connection in "idle in transaction" state 

3088 try: 

3089 db.commit() # End read-only transaction cleanly before returning to pool 

3090 finally: 

3091 db.close() 

3092 except Exception as oauth_error: 

3093 raise GatewayConnectionError(f"Failed to obtain OAuth token for gateway {gateway.name}: {oauth_error}") 

3094 else: 

3095 # Handle non-OAuth authentication 

3096 auth_data = gateway.auth_value or {} 

3097 if isinstance(auth_data, str) and auth_data: 

3098 headers = decode_auth(auth_data) 

3099 elif isinstance(auth_data, dict) and auth_data: 

3100 headers = {str(k): str(v) for k, v in auth_data.items()} 

3101 else: 

3102 # No auth configured - send request without authentication 

3103 # SECURITY: Never send gateway admin credentials to remote servers 

3104 logger.warning(f"Gateway {gateway.name} has no authentication configured - sending unauthenticated request") 

3105 headers = {"Content-Type": "application/json"} 

3106 

3107 # Directly use the persistent HTTP client (no async with) 

3108 response = await self._http_client.post(urljoin(gateway.url, "/rpc"), json=request, headers=headers) 

3109 response.raise_for_status() 

3110 result = response.json() 

3111 

3112 # Update last seen timestamp using fresh DB session 

3113 # (gateway object may be detached from original session) 

3114 try: 

3115 with fresh_db_session() as update_db: 

3116 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway.id)).scalar_one_or_none() 

3117 if db_gateway: 

3118 db_gateway.last_seen = datetime.now(timezone.utc) 

3119 update_db.commit() 

3120 except Exception as update_error: 

3121 logger.warning(f"Failed to update last_seen for gateway {gateway.name}: {update_error}") 

3122 

3123 # Record success metrics 

3124 if span: 3124 ↛ 3125line 3124 didn't jump to line 3125 because the condition on line 3124 was never true

3125 span.set_attribute("http.status_code", response.status_code) 

3126 span.set_attribute("success", True) 

3127 span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) 

3128 

3129 except Exception: 

3130 if span: 3130 ↛ 3131line 3130 didn't jump to line 3131 because the condition on line 3130 was never true

3131 span.set_attribute("http.status_code", getattr(response, "status_code", 0)) 

3132 raise GatewayConnectionError(f"Failed to forward request to {gateway.name}") 

3133 

3134 if "error" in result: 

3135 if span: 3135 ↛ 3136line 3135 didn't jump to line 3136 because the condition on line 3135 was never true

3136 span.set_attribute("rpc.error", True) 

3137 span.set_attribute("rpc.error.message", result["error"].get("message", "Unknown error")) 

3138 raise GatewayError(f"Gateway error: {result['error'].get('message')}") 

3139 

3140 return result.get("result") 

3141 

3142 async def _forward_request_to_all( 

3143 self, 

3144 db: Session, 

3145 method: str, 

3146 params: Optional[Dict[str, Any]] = None, 

3147 app_user_email: Optional[str] = None, 

3148 user_email: Optional[str] = None, 

3149 token_teams: Optional[List[str]] = None, 

3150 ) -> Any: 

3151 """ 

3152 Forward a request to all active gateways that can handle the method. 

3153 

3154 Args: 

3155 db: Database session 

3156 method: RPC method name 

3157 params: Optional method parameters 

3158 app_user_email: Optional app user email for OAuth token selection 

3159 user_email: Optional user email for team-based access control 

3160 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only) 

3161 

3162 Returns: 

3163 Gateway response from the first successful gateway 

3164 

3165 Raises: 

3166 GatewayConnectionError: If no gateways can handle the request 

3167 """ 

3168 # ═══════════════════════════════════════════════════════════════════════════ 

3169 # PHASE 1: Fetch all required data before HTTP calls 

3170 # ═══════════════════════════════════════════════════════════════════════════ 

3171 

3172 # SECURITY: Apply team-based access control to gateway selection 

3173 # - token_teams is None: admin bypass (is_admin=true with explicit null teams) - sees all 

3174 # - token_teams is []: public-only access (missing teams or explicit empty) 

3175 # - token_teams is [...]: access to public + specified teams 

3176 query = select(DbGateway).where(DbGateway.enabled.is_(True)) 

3177 

3178 if token_teams is not None: 

3179 if len(token_teams) == 0: 

3180 # Public-only token: only access public gateways 

3181 query = query.where(DbGateway.visibility == "public") 

3182 else: 

3183 # Team-scoped token: public gateways + gateways in allowed teams 

3184 access_conditions = [ 

3185 DbGateway.visibility == "public", 

3186 and_(DbGateway.team_id.in_(token_teams), DbGateway.visibility.in_(["team", "public"])), 

3187 ] 

3188 # Also include private gateways owned by this user 

3189 if user_email: 3189 ↛ 3191line 3189 didn't jump to line 3191 because the condition on line 3189 was always true

3190 access_conditions.append(and_(DbGateway.owner_email == user_email, DbGateway.visibility == "private")) 

3191 query = query.where(or_(*access_conditions)) 

3192 

3193 active_gateways = db.execute(query).scalars().all() 

3194 

3195 if not active_gateways: 

3196 raise GatewayConnectionError("No active gateways available to forward request") 

3197 

3198 # Extract all gateway data to local variables before releasing DB connection 

3199 gateway_data_list: List[Dict[str, Any]] = [] 

3200 for gateway in active_gateways: 

3201 gw_data = { 

3202 "id": gateway.id, 

3203 "name": gateway.name, 

3204 "url": gateway.url, 

3205 "auth_type": getattr(gateway, "auth_type", None), 

3206 "auth_value": gateway.auth_value, 

3207 "oauth_config": gateway.oauth_config if hasattr(gateway, "oauth_config") else None, 

3208 } 

3209 gateway_data_list.append(gw_data) 

3210 

3211 # For OAuth authorization_code flow, we need to fetch tokens while session is open 

3212 # First-Party 

3213 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

3214 

3215 for gw_data in gateway_data_list: 

3216 if gw_data["auth_type"] == "oauth" and gw_data["oauth_config"]: 

3217 grant_type = gw_data["oauth_config"].get("grant_type", "client_credentials") 

3218 if grant_type == "authorization_code" and app_user_email: 

3219 try: 

3220 token_storage = TokenStorageService(db) 

3221 access_token = await token_storage.get_user_token(str(gw_data["id"]), app_user_email) 

3222 gw_data["_oauth_token"] = access_token 

3223 except Exception as e: 

3224 logger.warning(f"Failed to get OAuth token for gateway {gw_data['name']}: {e}") 

3225 gw_data["_oauth_token"] = None 

3226 

3227 # ═══════════════════════════════════════════════════════════════════════════ 

3228 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls 

3229 # This prevents connection pool exhaustion during slow upstream requests. 

3230 # ═══════════════════════════════════════════════════════════════════════════ 

3231 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats) 

3232 db.close() 

3233 

3234 errors: List[str] = [] 

3235 

3236 # ═══════════════════════════════════════════════════════════════════════════ 

3237 # PHASE 2: Make HTTP calls (no DB connection held) 

3238 # ═══════════════════════════════════════════════════════════════════════════ 

3239 for gw_data in gateway_data_list: 

3240 try: 

3241 # Handle OAuth authentication for the specific gateway 

3242 headers: Dict[str, str] = {} 

3243 

3244 if gw_data["auth_type"] == "oauth" and gw_data["oauth_config"]: 

3245 try: 

3246 grant_type = gw_data["oauth_config"].get("grant_type", "client_credentials") 

3247 

3248 if grant_type == "client_credentials": 

3249 # Use OAuth manager to get access token for Client Credentials flow 

3250 access_token = await self.oauth_manager.get_access_token(gw_data["oauth_config"]) 

3251 headers = {"Authorization": f"Bearer {access_token}"} 

3252 elif grant_type == "authorization_code": 3252 ↛ 3279line 3252 didn't jump to line 3279 because the condition on line 3252 was always true

3253 # For Authorization Code flow, use pre-fetched token 

3254 if not app_user_email: 

3255 logger.warning(f"Skipping OAuth authorization code gateway {gw_data['name']} - user-specific tokens required but no user email provided") 

3256 continue 

3257 

3258 access_token = gw_data.get("_oauth_token") 

3259 if access_token: 3259 ↛ 3262line 3259 didn't jump to line 3262 because the condition on line 3259 was always true

3260 headers = {"Authorization": f"Bearer {access_token}"} 

3261 else: 

3262 logger.warning(f"No valid OAuth token for user {app_user_email} and gateway {gw_data['name']}") 

3263 continue 

3264 except Exception as oauth_error: 

3265 logger.warning(f"Failed to obtain OAuth token for gateway {gw_data['name']}: {oauth_error}") 

3266 errors.append(f"Gateway {gw_data['name']}: OAuth error - {str(oauth_error)}") 

3267 continue 

3268 else: 

3269 # Handle non-OAuth authentication 

3270 auth_data = gw_data["auth_value"] or {} 

3271 if isinstance(auth_data, str): 

3272 headers = decode_auth(auth_data) 

3273 elif isinstance(auth_data, dict): 3273 ↛ 3276line 3273 didn't jump to line 3276 because the condition on line 3273 was always true

3274 headers = {str(k): str(v) for k, v in auth_data.items()} 

3275 else: 

3276 headers = {} 

3277 

3278 # Build RPC request 

3279 request: Dict[str, Any] = {"jsonrpc": "2.0", "id": 1, "method": method} 

3280 if params: 3280 ↛ 3281line 3280 didn't jump to line 3281 because the condition on line 3280 was never true

3281 request["params"] = params 

3282 

3283 # Forward request with proper authentication headers 

3284 response = await self._http_client.post(urljoin(gw_data["url"], "/rpc"), json=request, headers=headers) 

3285 response.raise_for_status() 

3286 result = response.json() 

3287 

3288 # Check for RPC errors 

3289 if "error" in result: 

3290 errors.append(f"Gateway {gw_data['name']}: {result['error'].get('message', 'Unknown RPC error')}") 

3291 continue 

3292 

3293 # ═══════════════════════════════════════════════════════════════════════════ 

3294 # PHASE 3: Update last_seen using fresh DB session 

3295 # ═══════════════════════════════════════════════════════════════════════════ 

3296 try: 

3297 with fresh_db_session() as update_db: 

3298 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gw_data["id"])).scalar_one_or_none() 

3299 if db_gateway: 3299 ↛ 3306line 3299 didn't jump to line 3306

3300 db_gateway.last_seen = datetime.now(timezone.utc) 

3301 update_db.commit() 

3302 except Exception as update_error: 

3303 logger.warning(f"Failed to update last_seen for gateway {gw_data['name']}: {update_error}") 

3304 

3305 # Success - return the result 

3306 logger.info(f"Successfully forwarded request to gateway {gw_data['name']}") 

3307 return result.get("result") 

3308 

3309 except Exception as e: 

3310 error_msg = f"Gateway {gw_data['name']}: {str(e)}" 

3311 errors.append(error_msg) 

3312 logger.warning(f"Failed to forward request to gateway {gw_data['name']}: {e}") 

3313 continue 

3314 

3315 # If we get here, all gateways failed 

3316 error_summary = "; ".join(errors) 

3317 raise GatewayConnectionError(f"All gateways failed to handle request '{method}': {error_summary}") 

3318 

3319 async def _handle_gateway_failure(self, gateway: DbGateway) -> None: 

3320 """Tracks and handles gateway failures during health checks. 

3321 If the failure count exceeds the threshold, the gateway is deactivated. 

3322 

3323 Args: 

3324 gateway: The gateway object that failed its health check. 

3325 

3326 Returns: 

3327 None 

3328 

3329 Examples: 

3330 >>> from mcpgateway.services.gateway_service import GatewayService 

3331 >>> service = GatewayService() 

3332 >>> gateway = type('Gateway', (), { 

3333 ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True 

3334 ... })() 

3335 >>> service._gateway_failure_counts = {} 

3336 >>> import asyncio 

3337 >>> # Test failure counting 

3338 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS 

3339 >>> service._gateway_failure_counts['gw1'] >= 1 

3340 True 

3341 

3342 >>> # Test disabled gateway (no action) 

3343 >>> gateway.enabled = False 

3344 >>> old_count = service._gateway_failure_counts.get('gw1', 0) 

3345 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS 

3346 >>> service._gateway_failure_counts.get('gw1', 0) == old_count 

3347 True 

3348 """ 

3349 if GW_FAILURE_THRESHOLD == -1: 

3350 return # Gateway failure action disabled 

3351 

3352 if not gateway.enabled: 

3353 return # No action needed for inactive gateways 

3354 

3355 if not gateway.reachable: 

3356 return # No action needed for unreachable gateways 

3357 

3358 count = self._gateway_failure_counts.get(gateway.id, 0) + 1 

3359 self._gateway_failure_counts[gateway.id] = count 

3360 

3361 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).") 

3362 

3363 if count >= GW_FAILURE_THRESHOLD: 3363 ↛ 3364line 3363 didn't jump to line 3364 because the condition on line 3363 was never true

3364 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...") 

3365 with cast(Any, SessionLocal)() as db: 

3366 await self.set_gateway_state(db, gateway.id, activate=True, reachable=False, only_update_reachable=True) 

3367 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation 

3368 

3369 async def check_health_of_gateways(self, gateways: List[DbGateway], user_email: Optional[str] = None) -> bool: 

3370 """Check health of a batch of gateways. 

3371 

3372 Performs an asynchronous health-check for each gateway in `gateways` using 

3373 an Async HTTP client. The function handles different authentication 

3374 modes (OAuth client_credentials and authorization_code, and non-OAuth 

3375 auth headers). When a gateway uses the authorization_code flow, the 

3376 optional `user_email` is used to look up stored user tokens with 

3377 fresh_db_session(). On individual failures the service will record the 

3378 failure and call internal failure handling which may mark a gateway 

3379 unreachable or deactivate it after repeated failures. If a previously 

3380 unreachable gateway becomes healthy again the service will attempt to 

3381 update its reachable status. 

3382 

3383 NOTE: This method intentionally does NOT take a db parameter. 

3384 DB access uses fresh_db_session() only when needed, avoiding holding 

3385 connections during HTTP calls to MCP servers. 

3386 

3387 Args: 

3388 gateways: List of DbGateway objects to check. 

3389 user_email: Optional MCP gateway user email used to retrieve 

3390 stored OAuth tokens for gateways using the 

3391 "authorization_code" grant type. If not provided, authorization 

3392 code flows that require a user token will be treated as failed. 

3393 

3394 Returns: 

3395 bool: True when the health-check batch completes. This return 

3396 value indicates completion of the checks, not that every gateway 

3397 was healthy. Individual gateway failures are handled internally 

3398 (via _handle_gateway_failure and status updates). 

3399 

3400 Examples: 

3401 >>> from mcpgateway.services.gateway_service import GatewayService 

3402 >>> from unittest.mock import MagicMock 

3403 >>> service = GatewayService() 

3404 >>> gateways = [MagicMock()] 

3405 >>> gateways[0].ca_certificate = None 

3406 >>> import asyncio 

3407 >>> result = asyncio.run(service.check_health_of_gateways(gateways)) 

3408 >>> isinstance(result, bool) 

3409 True 

3410 

3411 >>> # Test empty gateway list 

3412 >>> empty_result = asyncio.run(service.check_health_of_gateways([])) 

3413 >>> empty_result 

3414 True 

3415 

3416 >>> # Test multiple gateways (basic smoke) 

3417 >>> multiple_gateways = [MagicMock(), MagicMock(), MagicMock()] 

3418 >>> for i, gw in enumerate(multiple_gateways): 

3419 ... gw.name = f"gateway_{i}" 

3420 ... gw.url = f"http://gateway{i}.example.com" 

3421 ... gw.transport = "SSE" 

3422 ... gw.enabled = True 

3423 ... gw.reachable = True 

3424 ... gw.auth_value = {} 

3425 ... gw.ca_certificate = None 

3426 >>> multi_result = asyncio.run(service.check_health_of_gateways(multiple_gateways)) 

3427 >>> isinstance(multi_result, bool) 

3428 True 

3429 """ 

3430 start_time = time.monotonic() 

3431 concurrency_limit = min(settings.max_concurrent_health_checks, max(10, os.cpu_count() * 5)) # adaptive concurrency 

3432 semaphore = asyncio.Semaphore(concurrency_limit) 

3433 

3434 async def limited_check(gateway: DbGateway): 

3435 """ 

3436 Checks the health of a single gateway while respecting a concurrency limit. 

3437 

3438 This function checks the health of the given database gateway, ensuring that 

3439 the number of concurrent checks does not exceed a predefined limit. The check 

3440 is performed asynchronously and uses a semaphore to manage concurrency. 

3441 

3442 Args: 

3443 gateway (DbGateway): The database gateway whose health is to be checked. 

3444 

3445 Raises: 

3446 Any exceptions raised during the health check will be propagated to the caller. 

3447 """ 

3448 async with semaphore: 

3449 try: 

3450 await asyncio.wait_for( 

3451 self._check_single_gateway_health(gateway, user_email), 

3452 timeout=settings.gateway_health_check_timeout, 

3453 ) 

3454 except asyncio.TimeoutError: 

3455 logger.warning(f"Gateway {getattr(gateway, 'name', 'unknown')} health check timed out after {settings.gateway_health_check_timeout}s") 

3456 # Treat timeout as a failed health check 

3457 await self._handle_gateway_failure(gateway) 

3458 

3459 # Create trace span for health check batch 

3460 with create_span("gateway.health_check_batch", {"gateway.count": len(gateways), "check.type": "health"}) as batch_span: 

3461 # Chunk processing to avoid overload 

3462 if not gateways: 

3463 return True 

3464 chunk_size = concurrency_limit 

3465 for i in range(0, len(gateways), chunk_size): 

3466 # batch will be a sublist of gateways from index i to i + chunk_size 

3467 batch = gateways[i : i + chunk_size] 

3468 

3469 # Each task is a health check for a gateway in the batch, excluding those with auth_type == "one_time_auth" 

3470 tasks = [limited_check(gw) for gw in batch if gw.auth_type != "one_time_auth"] 

3471 

3472 # Execute all health checks concurrently 

3473 await asyncio.gather(*tasks, return_exceptions=True) 

3474 await asyncio.sleep(0.05) # small pause prevents network saturation 

3475 

3476 elapsed = time.monotonic() - start_time 

3477 

3478 if batch_span: 

3479 batch_span.set_attribute("check.duration_ms", int(elapsed * 1000)) 

3480 batch_span.set_attribute("check.completed", True) 

3481 

3482 logger.debug(f"Health check batch completed for {len(gateways)} gateways in {elapsed:.2f}s") 

3483 

3484 return True 

3485 

3486 async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Optional[str] = None) -> None: 

3487 """Check health of a single gateway. 

3488 

3489 NOTE: This method intentionally does NOT take a db parameter. 

3490 DB access uses fresh_db_session() only when needed, avoiding holding 

3491 connections during HTTP calls to MCP servers. 

3492 

3493 Args: 

3494 gateway: Gateway to check (may be detached from session) 

3495 user_email: Optional user email for OAuth token lookup 

3496 """ 

3497 # Extract gateway data upfront (gateway may be detached from session) 

3498 gateway_id = gateway.id 

3499 gateway_name = gateway.name 

3500 gateway_url = gateway.url 

3501 gateway_transport = gateway.transport 

3502 gateway_enabled = gateway.enabled 

3503 gateway_reachable = gateway.reachable 

3504 gateway_ca_certificate = gateway.ca_certificate 

3505 gateway_ca_certificate_sig = gateway.ca_certificate_sig 

3506 gateway_auth_type = gateway.auth_type 

3507 gateway_oauth_config = gateway.oauth_config 

3508 gateway_auth_value = gateway.auth_value 

3509 gateway_auth_query_params = gateway.auth_query_params 

3510 

3511 # Handle query_param auth - decrypt and apply to URL for health check 

3512 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

3513 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

3514 auth_query_params_decrypted = {} 

3515 for param_key, encrypted_value in gateway_auth_query_params.items(): 

3516 if encrypted_value: 3516 ↛ 3515line 3516 didn't jump to line 3515 because the condition on line 3516 was always true

3517 try: 

3518 decrypted = decode_auth(encrypted_value) 

3519 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

3520 except Exception: 

3521 logger.debug(f"Failed to decrypt query param '{param_key}' for health check") 

3522 if auth_query_params_decrypted: 3522 ↛ 3526line 3522 didn't jump to line 3526 because the condition on line 3522 was always true

3523 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) 

3524 

3525 # Sanitize URL for logging/telemetry (redacts sensitive query params) 

3526 gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted) 

3527 

3528 # Create span for individual gateway health check 

3529 with create_span( 

3530 "gateway.health_check", 

3531 { 

3532 "gateway.name": gateway_name, 

3533 "gateway.id": str(gateway_id), 

3534 "gateway.url": gateway_url_sanitized, 

3535 "gateway.transport": gateway_transport, 

3536 "gateway.enabled": gateway_enabled, 

3537 "http.method": "GET", 

3538 "http.url": gateway_url_sanitized, 

3539 }, 

3540 ) as span: 

3541 valid = False 

3542 if gateway_ca_certificate: 

3543 if settings.enable_ed25519_signing: 3543 ↛ 3547line 3543 didn't jump to line 3547 because the condition on line 3543 was always true

3544 public_key_pem = settings.ed25519_public_key 

3545 valid = validate_signature(gateway_ca_certificate.encode(), gateway_ca_certificate_sig, public_key_pem) 

3546 else: 

3547 valid = True 

3548 if valid: 

3549 ssl_context = self.create_ssl_context(gateway_ca_certificate) 

3550 else: 

3551 ssl_context = None 

3552 

3553 def get_httpx_client_factory( 

3554 headers: dict[str, str] | None = None, 

3555 timeout: httpx.Timeout | None = None, 

3556 auth: httpx.Auth | None = None, 

3557 ) -> httpx.AsyncClient: 

3558 """Factory function to create httpx.AsyncClient with optional CA certificate. 

3559 

3560 Args: 

3561 headers: Optional headers for the client 

3562 timeout: Optional timeout for the client 

3563 auth: Optional auth for the client 

3564 

3565 Returns: 

3566 httpx.AsyncClient: Configured HTTPX async client 

3567 """ 

3568 return httpx.AsyncClient( 

3569 verify=ssl_context if ssl_context else get_default_verify(), 

3570 follow_redirects=True, 

3571 headers=headers, 

3572 timeout=timeout if timeout else get_http_timeout(), 

3573 auth=auth, 

3574 limits=httpx.Limits( 

3575 max_connections=settings.httpx_max_connections, 

3576 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

3577 keepalive_expiry=settings.httpx_keepalive_expiry, 

3578 ), 

3579 ) 

3580 

3581 # Use isolated client for gateway health checks (each gateway may have custom CA cert) 

3582 # Use admin timeout for health checks (fail fast, don't wait 120s for slow upstreams) 

3583 # Pass ssl_context if present, otherwise let get_isolated_http_client use skip_ssl_verify setting 

3584 async with get_isolated_http_client(timeout=settings.httpx_admin_read_timeout, verify=ssl_context) as client: 

3585 logger.debug(f"Checking health of gateway: {gateway_name} ({gateway_url_sanitized})") 

3586 try: 

3587 # Handle different authentication types 

3588 headers = {} 

3589 

3590 if gateway_auth_type == "oauth" and gateway_oauth_config: 

3591 grant_type = gateway_oauth_config.get("grant_type", "client_credentials") 

3592 

3593 if grant_type == "authorization_code": 

3594 # For Authorization Code flow, try to get stored tokens 

3595 try: 

3596 # First-Party 

3597 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel 

3598 

3599 # Use fresh session for OAuth token lookup 

3600 with fresh_db_session() as token_db: 

3601 token_storage = TokenStorageService(token_db) 

3602 

3603 # Get user-specific OAuth token 

3604 if not user_email: 

3605 if span: 3605 ↛ 3608line 3605 didn't jump to line 3608 because the condition on line 3605 was always true

3606 span.set_attribute("health.status", "unhealthy") 

3607 span.set_attribute("error.message", "User email required for OAuth token") 

3608 await self._handle_gateway_failure(gateway) 

3609 return 

3610 

3611 access_token = await token_storage.get_user_token(gateway_id, user_email) 

3612 

3613 if access_token: 

3614 headers["Authorization"] = f"Bearer {access_token}" 

3615 else: 

3616 if span: 3616 ↛ 3619line 3616 didn't jump to line 3619 because the condition on line 3616 was always true

3617 span.set_attribute("health.status", "unhealthy") 

3618 span.set_attribute("error.message", "No valid OAuth token for user") 

3619 await self._handle_gateway_failure(gateway) 

3620 return 

3621 except Exception as e: 

3622 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}") 

3623 if span: 3623 ↛ 3626line 3623 didn't jump to line 3626 because the condition on line 3623 was always true

3624 span.set_attribute("health.status", "unhealthy") 

3625 span.set_attribute("error.message", "Failed to obtain stored OAuth token") 

3626 await self._handle_gateway_failure(gateway) 

3627 return 

3628 else: 

3629 # For Client Credentials flow, get token directly 

3630 try: 

3631 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config) 

3632 headers["Authorization"] = f"Bearer {access_token}" 

3633 except Exception as e: 

3634 if span: 3634 ↛ 3637line 3634 didn't jump to line 3637 because the condition on line 3634 was always true

3635 span.set_attribute("health.status", "unhealthy") 

3636 span.set_attribute("error.message", str(e)) 

3637 await self._handle_gateway_failure(gateway) 

3638 return 

3639 else: 

3640 # Handle non-OAuth authentication (existing logic) 

3641 auth_data = gateway_auth_value or {} 

3642 if isinstance(auth_data, str): 

3643 headers = decode_auth(auth_data) 

3644 elif isinstance(auth_data, dict): 

3645 headers = {str(k): str(v) for k, v in auth_data.items()} 

3646 else: 

3647 headers = {} 

3648 

3649 # Perform the GET and raise on 4xx/5xx 

3650 if (gateway_transport).lower() == "sse": 

3651 timeout = httpx.Timeout(settings.health_check_timeout) 

3652 async with client.stream("GET", gateway_url, headers=headers, timeout=timeout) as response: 

3653 # This will raise immediately if status is 4xx/5xx 

3654 response.raise_for_status() 

3655 if span: 3655 ↛ 3698line 3655 didn't jump to line 3698

3656 span.set_attribute("http.status_code", response.status_code) 

3657 elif (gateway_transport).lower() == "streamablehttp": 

3658 # Use session pool if enabled for faster health checks 

3659 use_pool = False 

3660 pool = None 

3661 if settings.mcp_session_pool_enabled: 3661 ↛ 3669line 3661 didn't jump to line 3669 because the condition on line 3661 was always true

3662 try: 

3663 pool = get_mcp_session_pool() 

3664 use_pool = True 

3665 except RuntimeError: 

3666 # Pool not initialized (e.g., in tests), fall back to per-call sessions 

3667 pass 

3668 

3669 if use_pool and pool is not None: 

3670 # Health checks are system operations, not user-driven. 

3671 # Use system identity to isolate from user sessions. 

3672 async with pool.session( 

3673 url=gateway_url, 

3674 headers=headers, 

3675 transport_type=TransportType.STREAMABLE_HTTP, 

3676 httpx_client_factory=get_httpx_client_factory, 

3677 user_identity="_system_health_check", 

3678 gateway_id=gateway_id, 

3679 ) as pooled: 

3680 # Optional explicit RPC verification (off by default for performance). 

3681 # Pool's internal staleness check handles health via _validate_session. 

3682 if settings.mcp_session_pool_explicit_health_rpc: 3682 ↛ 3698line 3682 didn't jump to line 3698

3683 await asyncio.wait_for( 

3684 pooled.session.list_tools(), 

3685 timeout=settings.health_check_timeout, 

3686 ) 

3687 else: 

3688 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout, httpx_client_factory=get_httpx_client_factory) as ( 

3689 read_stream, 

3690 write_stream, 

3691 _get_session_id, 

3692 ): 

3693 async with ClientSession(read_stream, write_stream) as session: 

3694 # Initialize the session 

3695 response = await session.initialize() 

3696 

3697 # Reactivate gateway if it was previously inactive and health check passed now 

3698 if gateway_enabled and not gateway_reachable: 

3699 logger.info(f"Reactivating gateway: {gateway_name}, as it is healthy now") 

3700 with cast(Any, SessionLocal)() as status_db: 

3701 await self.set_gateway_state(status_db, gateway_id, activate=True, reachable=True, only_update_reachable=True) 

3702 

3703 # Update last_seen with fresh session (gateway object is detached) 

3704 try: 

3705 with fresh_db_session() as update_db: 

3706 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

3707 if db_gateway: 3707 ↛ 3714line 3707 didn't jump to line 3714

3708 db_gateway.last_seen = datetime.now(timezone.utc) 

3709 update_db.commit() 

3710 except Exception as update_error: 

3711 logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}") 

3712 

3713 # Auto-refresh tools/resources/prompts if enabled 

3714 if settings.auto_refresh_servers: 

3715 try: 

3716 # Throttling: Check if refresh is needed based on last_refresh_at 

3717 refresh_needed = True 

3718 if gateway.last_refresh_at: 3718 ↛ 3736line 3718 didn't jump to line 3736 because the condition on line 3718 was always true

3719 # Default to config value if configured interval is missing 

3720 

3721 last_refresh = gateway.last_refresh_at 

3722 if last_refresh.tzinfo is None: 3722 ↛ 3723line 3722 didn't jump to line 3723 because the condition on line 3722 was never true

3723 last_refresh = last_refresh.replace(tzinfo=timezone.utc) 

3724 

3725 # Use per-gateway interval if set, otherwise fall back to global default 

3726 refresh_interval = getattr(settings, "gateway_auto_refresh_interval", 300) 

3727 if gateway.refresh_interval_seconds is not None: 3727 ↛ 3730line 3727 didn't jump to line 3730 because the condition on line 3727 was always true

3728 refresh_interval = gateway.refresh_interval_seconds 

3729 

3730 time_since_refresh = (datetime.now(timezone.utc) - last_refresh).total_seconds() 

3731 

3732 if time_since_refresh < refresh_interval: 

3733 refresh_needed = False 

3734 logger.debug(f"Skipping auto-refresh for {gateway_name}: last refreshed {int(time_since_refresh)}s ago") 

3735 

3736 if refresh_needed: 

3737 # Locking: Try to acquire lock to avoid conflict with manual refresh 

3738 lock = self._get_refresh_lock(gateway_id) 

3739 if not lock.locked(): 

3740 # Acquire lock to prevent concurrent manual refresh 

3741 async with lock: 

3742 await self._refresh_gateway_tools_resources_prompts( 

3743 gateway_id=gateway_id, 

3744 _user_email=user_email, 

3745 created_via="health_check", 

3746 pre_auth_headers=headers if headers else None, 

3747 gateway=gateway, 

3748 ) 

3749 else: 

3750 logger.debug(f"Skipping auto-refresh for {gateway_name}: lock held (likely manual refresh in progress)") 

3751 except Exception as refresh_error: 

3752 logger.warning(f"Failed to refresh tools for gateway {gateway_name}: {refresh_error}") 

3753 

3754 if span: 

3755 span.set_attribute("health.status", "healthy") 

3756 span.set_attribute("success", True) 

3757 

3758 except Exception as e: 

3759 if span: 

3760 span.set_attribute("health.status", "unhealthy") 

3761 span.set_attribute("error.message", str(e)) 

3762 

3763 # Set the logger as debug as this check happens for each interval 

3764 logger.debug(f"Health check failed for gateway {gateway_name}: {e}") 

3765 await self._handle_gateway_failure(gateway) 

3766 

3767 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]: 

3768 """ 

3769 Aggregate capabilities across all gateways. 

3770 

3771 Args: 

3772 db: Database session 

3773 

3774 Returns: 

3775 Dictionary of aggregated capabilities 

3776 

3777 Examples: 

3778 >>> from mcpgateway.services.gateway_service import GatewayService 

3779 >>> from unittest.mock import MagicMock 

3780 >>> service = GatewayService() 

3781 >>> db = MagicMock() 

3782 >>> gateway_mock = MagicMock() 

3783 >>> gateway_mock.capabilities = {"tools": {"listChanged": True}, "custom": {"feature": True}} 

3784 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_mock] 

3785 >>> import asyncio 

3786 >>> result = asyncio.run(service.aggregate_capabilities(db)) 

3787 >>> isinstance(result, dict) 

3788 True 

3789 >>> 'prompts' in result 

3790 True 

3791 >>> 'resources' in result 

3792 True 

3793 >>> 'tools' in result 

3794 True 

3795 >>> 'logging' in result 

3796 True 

3797 >>> result['prompts']['listChanged'] 

3798 True 

3799 >>> result['resources']['subscribe'] 

3800 True 

3801 >>> result['resources']['listChanged'] 

3802 True 

3803 >>> result['tools']['listChanged'] 

3804 True 

3805 >>> isinstance(result['logging'], dict) 

3806 True 

3807 

3808 >>> # Test with no gateways 

3809 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

3810 >>> empty_result = asyncio.run(service.aggregate_capabilities(db)) 

3811 >>> isinstance(empty_result, dict) 

3812 True 

3813 >>> 'tools' in empty_result 

3814 True 

3815 

3816 >>> # Test capability merging 

3817 >>> gateway1 = MagicMock() 

3818 >>> gateway1.capabilities = {"tools": {"feature1": True}} 

3819 >>> gateway2 = MagicMock() 

3820 >>> gateway2.capabilities = {"tools": {"feature2": True}} 

3821 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway1, gateway2] 

3822 >>> merged_result = asyncio.run(service.aggregate_capabilities(db)) 

3823 >>> merged_result['tools']['listChanged'] # Default capability 

3824 True 

3825 """ 

3826 capabilities = { 

3827 "prompts": {"listChanged": True}, 

3828 "resources": {"subscribe": True, "listChanged": True}, 

3829 "tools": {"listChanged": True}, 

3830 "logging": {}, 

3831 } 

3832 

3833 # Get all active gateways 

3834 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

3835 

3836 # Combine capabilities 

3837 for gateway in gateways: 

3838 if gateway.capabilities: 

3839 for key, value in gateway.capabilities.items(): 

3840 if key not in capabilities: 

3841 capabilities[key] = value 

3842 elif isinstance(value, dict): 3842 ↛ 3839line 3842 didn't jump to line 3839 because the condition on line 3842 was always true

3843 capabilities[key].update(value) 

3844 

3845 return capabilities 

3846 

3847 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

3848 """Subscribe to gateway events. 

3849 

3850 Creates a new event queue and subscribes to gateway events. Events are 

3851 yielded as they are published. The subscription is automatically cleaned 

3852 up when the generator is closed or goes out of scope. 

3853 

3854 Yields: 

3855 Dict[str, Any]: Gateway event messages with 'type', 'data', and 'timestamp' fields 

3856 

3857 Examples: 

3858 >>> service = GatewayService() 

3859 >>> import asyncio 

3860 >>> from unittest.mock import MagicMock 

3861 >>> # Create a mock async generator for the event service 

3862 >>> async def mock_event_gen(): 

3863 ... yield {"type": "test_event", "data": "payload"} 

3864 >>> 

3865 >>> # Mock the event service to return our generator 

3866 >>> service._event_service = MagicMock() 

3867 >>> service._event_service.subscribe_events.return_value = mock_event_gen() 

3868 >>> 

3869 >>> # Test the subscription 

3870 >>> async def test_sub(): 

3871 ... async for event in service.subscribe_events(): 

3872 ... return event 

3873 >>> 

3874 >>> result = asyncio.run(test_sub()) 

3875 >>> result 

3876 {'type': 'test_event', 'data': 'payload'} 

3877 """ 

3878 async for event in self._event_service.subscribe_events(): 

3879 yield event 

3880 

3881 async def _initialize_gateway( 

3882 self, 

3883 url: str, 

3884 authentication: Optional[Dict[str, str]] = None, 

3885 transport: str = "SSE", 

3886 auth_type: Optional[str] = None, 

3887 oauth_config: Optional[Dict[str, Any]] = None, 

3888 ca_certificate: Optional[bytes] = None, 

3889 pre_auth_headers: Optional[Dict[str, str]] = None, 

3890 include_resources: bool = True, 

3891 include_prompts: bool = True, 

3892 auth_query_params: Optional[Dict[str, str]] = None, 

3893 oauth_auto_fetch_tool_flag: Optional[bool] = False, 

3894 ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: 

3895 """Initialize connection to a gateway and retrieve its capabilities. 

3896 

3897 Connects to an MCP gateway using the specified transport protocol, 

3898 performs the MCP handshake, and retrieves capabilities, tools, 

3899 resources, and prompts from the gateway. 

3900 

3901 Args: 

3902 url: Gateway URL to connect to 

3903 authentication: Optional authentication headers for the connection 

3904 transport: Transport protocol - "SSE" or "StreamableHTTP" 

3905 auth_type: Authentication type - "basic", "bearer", "headers", "oauth", "query_param" or None 

3906 oauth_config: OAuth configuration if auth_type is "oauth" 

3907 ca_certificate: CA certificate for SSL verification 

3908 pre_auth_headers: Pre-authenticated headers to skip OAuth token fetch (for reuse) 

3909 include_resources: Whether to include resources in the fetch 

3910 include_prompts: Whether to include prompts in the fetch 

3911 auth_query_params: Query param names for URL sanitization in error logs (decrypted values) 

3912 oauth_auto_fetch_tool_flag: Whether to skip the early return for OAuth Authorization Code flow. 

3913 When False (default), auth_code gateways return empty lists immediately (for health checks). 

3914 When True, attempts to connect even for auth_code gateways (for activation after user authorization). 

3915 

3916 Returns: 

3917 tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]: 

3918 Capabilities dictionary, list of ToolCreate objects, list of ResourceCreate objects, and list of PromptCreate objects 

3919 

3920 Raises: 

3921 GatewayConnectionError: If connection or initialization fails 

3922 

3923 Examples: 

3924 >>> service = GatewayService() 

3925 >>> # Test parameter validation 

3926 >>> import asyncio 

3927 >>> from unittest.mock import AsyncMock 

3928 >>> # Avoid opening a real SSE connection in doctests (it can leak anyio streams on failure paths) 

3929 >>> service.connect_to_sse_server = AsyncMock(side_effect=GatewayConnectionError("boom")) 

3930 >>> async def test_params(): 

3931 ... try: 

3932 ... await service._initialize_gateway("hello//") 

3933 ... except Exception as e: 

3934 ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e) 

3935 

3936 >>> asyncio.run(test_params()) 

3937 True 

3938 

3939 >>> # Test default parameters 

3940 >>> hasattr(service, '_initialize_gateway') 

3941 True 

3942 >>> import inspect 

3943 >>> sig = inspect.signature(service._initialize_gateway) 

3944 >>> sig.parameters['transport'].default 

3945 'SSE' 

3946 >>> sig.parameters['authentication'].default is None 

3947 True 

3948 >>> 

3949 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

3950 >>> asyncio.run(service._http_client.aclose()) 

3951 """ 

3952 try: 

3953 if authentication is None: 

3954 authentication = {} 

3955 

3956 # Use pre-authenticated headers if provided (avoids duplicate OAuth token fetch) 

3957 if pre_auth_headers: 

3958 authentication = pre_auth_headers 

3959 # Handle OAuth authentication 

3960 elif auth_type == "oauth" and oauth_config: 

3961 grant_type = oauth_config.get("grant_type", "client_credentials") 

3962 

3963 if grant_type == "authorization_code": 

3964 if not oauth_auto_fetch_tool_flag: 

3965 # For Authorization Code flow during health checks, we can't initialize immediately 

3966 # because we need user consent. Just store the configuration 

3967 # and let the user complete the OAuth flow later. 

3968 logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""") 

3969 # Don't try to get access token here - it will be obtained during tool invocation 

3970 authentication = {} 

3971 

3972 # Skip MCP server connection for Authorization Code flow 

3973 # Tools will be fetched after OAuth completion 

3974 return {}, [], [], [] 

3975 # When flag is True (activation), skip token fetch but try to connect 

3976 # This allows activation to proceed - actual auth happens during tool invocation 

3977 logger.debug("OAuth Authorization Code gateway activation - skipping token fetch") 

3978 elif grant_type == "client_credentials": 3978 ↛ 3988line 3978 didn't jump to line 3988 because the condition on line 3978 was always true

3979 # For Client Credentials flow, we can get the token immediately 

3980 try: 

3981 logger.debug("Obtaining OAuth access token for Client Credentials flow") 

3982 access_token = await self.oauth_manager.get_access_token(oauth_config) 

3983 authentication = {"Authorization": f"Bearer {access_token}"} 

3984 except Exception as e: 

3985 logger.error(f"Failed to obtain OAuth access token: {e}") 

3986 raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}") 

3987 

3988 capabilities = {} 

3989 tools = [] 

3990 resources = [] 

3991 prompts = [] 

3992 if auth_type in ("basic", "bearer", "headers") and isinstance(authentication, str): 

3993 authentication = decode_auth(authentication) 

3994 if transport.lower() == "sse": 

3995 capabilities, tools, resources, prompts = await self.connect_to_sse_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params) 

3996 elif transport.lower() == "streamablehttp": 

3997 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params) 

3998 

3999 return capabilities, tools, resources, prompts 

4000 except Exception as e: 

4001 

4002 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup 

4003 root_cause = e 

4004 if isinstance(e, BaseExceptionGroup): 

4005 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions: 

4006 root_cause = root_cause.exceptions[0] 

4007 sanitized_url = sanitize_url_for_logging(url, auth_query_params) 

4008 sanitized_error = sanitize_exception_message(str(root_cause), auth_query_params) 

4009 logger.error(f"Gateway initialization failed for {sanitized_url}: {sanitized_error}", exc_info=True) 

4010 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: {sanitized_error}") 

4011 

4012 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]: 

4013 """Sync function for database operations (runs in thread). 

4014 

4015 Args: 

4016 include_inactive: Whether to include inactive gateways 

4017 

4018 Returns: 

4019 List[DbGateway]: List of active gateways 

4020 

4021 Examples: 

4022 >>> from unittest.mock import patch, MagicMock 

4023 >>> service = GatewayService() 

4024 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session: 

4025 ... mock_db = MagicMock() 

4026 ... mock_session.return_value.__enter__.return_value = mock_db 

4027 ... mock_db.execute.return_value.scalars.return_value.all.return_value = [] 

4028 ... result = service._get_gateways() 

4029 ... isinstance(result, list) 

4030 True 

4031 

4032 >>> # Test include_inactive parameter handling 

4033 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session: 

4034 ... mock_db = MagicMock() 

4035 ... mock_session.return_value.__enter__.return_value = mock_db 

4036 ... mock_db.execute.return_value.scalars.return_value.all.return_value = [] 

4037 ... result_active_only = service._get_gateways(include_inactive=False) 

4038 ... isinstance(result_active_only, list) 

4039 True 

4040 """ 

4041 with cast(Any, SessionLocal)() as db: 

4042 if include_inactive: 

4043 return db.execute(select(DbGateway)).scalars().all() 

4044 # Only return active gateways 

4045 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

4046 

4047 def get_first_gateway_by_url(self, db: Session, url: str, team_id: Optional[str] = None, include_inactive: bool = False) -> Optional[GatewayRead]: 

4048 """Return the first DbGateway matching the given URL and optional team_id. 

4049 

4050 This is a synchronous helper intended for use from request handlers where 

4051 a simple DB lookup is needed. It normalizes the provided URL similar to 

4052 how gateways are stored and matches by the `url` column. If team_id is 

4053 provided, it restricts the search to that team. 

4054 

4055 Args: 

4056 db: Database session to use for the query 

4057 url: Gateway base URL to match (will be normalized) 

4058 team_id: Optional team id to restrict search 

4059 include_inactive: Whether to include inactive gateways 

4060 

4061 Returns: 

4062 Optional[DbGateway]: First matching gateway or None 

4063 """ 

4064 query = select(DbGateway).where(DbGateway.url == url) 

4065 if not include_inactive: 4065 ↛ 4067line 4065 didn't jump to line 4067 because the condition on line 4065 was always true

4066 query = query.where(DbGateway.enabled) 

4067 if team_id: 4067 ↛ 4068line 4067 didn't jump to line 4068 because the condition on line 4067 was never true

4068 query = query.where(DbGateway.team_id == team_id) 

4069 result = db.execute(query).scalars().first() 

4070 # Wrap the DB object in the GatewayRead schema for consistency with 

4071 # other service methods. Return None if no match found. 

4072 if result is None: 

4073 return None 

4074 return GatewayRead.model_validate(self._prepare_gateway_for_read(result)).masked() 

4075 

4076 async def _run_leader_heartbeat(self) -> None: 

4077 """Run leader heartbeat loop to keep leader key alive. 

4078 

4079 This runs independently from health checks to ensure the leader key 

4080 is refreshed frequently enough (every redis_leader_heartbeat_interval seconds) 

4081 to prevent expiration during long-running health check operations. 

4082 

4083 The loop exits if this instance loses leadership. 

4084 """ 

4085 while True: 

4086 try: 

4087 await asyncio.sleep(self._leader_heartbeat_interval) 

4088 

4089 if not self._redis_client: 

4090 return 

4091 

4092 # Check if we're still the leader 

4093 current_leader = await self._redis_client.get(self._leader_key) 

4094 if current_leader != self._instance_id: 

4095 logger.info("Lost Redis leadership, stopping heartbeat") 

4096 return 

4097 

4098 # Refresh the leader key TTL 

4099 await self._redis_client.expire(self._leader_key, self._leader_ttl) 

4100 logger.debug(f"Leader heartbeat: refreshed TTL to {self._leader_ttl}s") 

4101 

4102 except Exception as e: 

4103 logger.warning(f"Leader heartbeat error: {e}") 

4104 # Continue trying - the main health check loop will handle leadership loss 

4105 

4106 async def _run_health_checks(self, user_email: str) -> None: 

4107 """Run health checks periodically, 

4108 Uses Redis or FileLock - for multiple workers. 

4109 Uses simple health check for single worker mode. 

4110 

4111 NOTE: This method intentionally does NOT take a db parameter. 

4112 Health checks use fresh_db_session() only when DB access is needed, 

4113 avoiding holding connections during HTTP calls to MCP servers. 

4114 

4115 Args: 

4116 user_email: Email of the user for OAuth token lookup 

4117 

4118 Examples: 

4119 >>> service = GatewayService() 

4120 >>> service._health_check_interval = 0.1 # Short interval for testing 

4121 >>> service._redis_client = None 

4122 >>> import asyncio 

4123 >>> # Test that method exists and is callable 

4124 >>> callable(service._run_health_checks) 

4125 True 

4126 >>> # Test setup without actual execution (would run forever) 

4127 >>> hasattr(service, '_health_check_interval') 

4128 True 

4129 >>> service._health_check_interval == 0.1 

4130 True 

4131 """ 

4132 

4133 while True: 

4134 try: 

4135 if self._redis_client and settings.cache_type == "redis": 4135 ↛ 4138line 4135 didn't jump to line 4138 because the condition on line 4135 was never true

4136 # Redis-based leader check (async, decode_responses=True returns strings) 

4137 # Note: Leader key TTL refresh is handled by _run_leader_heartbeat task 

4138 current_leader = await self._redis_client.get(self._leader_key) 

4139 if current_leader != self._instance_id: 

4140 return 

4141 

4142 # Run health checks 

4143 gateways = await asyncio.to_thread(self._get_gateways) 

4144 if gateways: 

4145 await self.check_health_of_gateways(gateways, user_email) 

4146 

4147 await asyncio.sleep(self._health_check_interval) 

4148 

4149 elif settings.cache_type == "none": 

4150 try: 

4151 # For single worker mode, run health checks directly 

4152 gateways = await asyncio.to_thread(self._get_gateways) 

4153 if gateways: 4153 ↛ 4158line 4153 didn't jump to line 4158 because the condition on line 4153 was always true

4154 await self.check_health_of_gateways(gateways, user_email) 

4155 except Exception as e: 

4156 logger.error(f"Health check run failed: {str(e)}") 

4157 

4158 await asyncio.sleep(self._health_check_interval) 

4159 

4160 else: 

4161 # FileLock-based leader fallback 

4162 try: 

4163 self._file_lock.acquire(timeout=0) 

4164 logger.info("File lock acquired. Running health checks.") 

4165 

4166 while True: 

4167 gateways = await asyncio.to_thread(self._get_gateways) 

4168 if gateways: 4168 ↛ 4169line 4168 didn't jump to line 4169 because the condition on line 4168 was never true

4169 await self.check_health_of_gateways(gateways, user_email) 

4170 await asyncio.sleep(self._health_check_interval) 

4171 

4172 except Timeout: 

4173 logger.debug("File lock already held. Retrying later.") 

4174 await asyncio.sleep(self._health_check_interval) 

4175 

4176 except Exception as e: 

4177 logger.error(f"FileLock health check failed: {str(e)}") 

4178 

4179 finally: 

4180 if self._file_lock.is_locked: 4180 ↛ 4133line 4180 didn't jump to line 4133 because the condition on line 4180 was always true

4181 try: 

4182 self._file_lock.release() 

4183 logger.info("Released file lock.") 

4184 except Exception as e: 

4185 logger.warning(f"Failed to release file lock: {str(e)}") 

4186 

4187 except Exception as e: 

4188 logger.error(f"Unexpected error in health check loop: {str(e)}") 

4189 await asyncio.sleep(self._health_check_interval) 

4190 

4191 def _get_auth_headers(self) -> Dict[str, str]: 

4192 """Get default headers for gateway requests (no authentication). 

4193 

4194 SECURITY: This method intentionally does NOT include authentication credentials. 

4195 Each gateway should have its own auth_value configured. Never send this gateway's 

4196 admin credentials to remote servers. 

4197 

4198 Returns: 

4199 dict: Default headers without authentication 

4200 

4201 Examples: 

4202 >>> service = GatewayService() 

4203 >>> headers = service._get_auth_headers() 

4204 >>> isinstance(headers, dict) 

4205 True 

4206 >>> 'Content-Type' in headers 

4207 True 

4208 >>> headers['Content-Type'] 

4209 'application/json' 

4210 >>> 'Authorization' not in headers # No credentials leaked 

4211 True 

4212 """ 

4213 return {"Content-Type": "application/json"} 

4214 

4215 async def _notify_gateway_added(self, gateway: DbGateway) -> None: 

4216 """Notify subscribers of gateway addition. 

4217 

4218 Args: 

4219 gateway: Gateway to add 

4220 """ 

4221 event = { 

4222 "type": "gateway_added", 

4223 "data": { 

4224 "id": gateway.id, 

4225 "name": gateway.name, 

4226 "url": gateway.url, 

4227 "description": gateway.description, 

4228 "enabled": gateway.enabled, 

4229 }, 

4230 "timestamp": datetime.now(timezone.utc).isoformat(), 

4231 } 

4232 await self._publish_event(event) 

4233 

4234 async def _notify_gateway_activated(self, gateway: DbGateway) -> None: 

4235 """Notify subscribers of gateway activation. 

4236 

4237 Args: 

4238 gateway: Gateway to activate 

4239 """ 

4240 event = { 

4241 "type": "gateway_activated", 

4242 "data": { 

4243 "id": gateway.id, 

4244 "name": gateway.name, 

4245 "url": gateway.url, 

4246 "enabled": gateway.enabled, 

4247 "reachable": gateway.reachable, 

4248 }, 

4249 "timestamp": datetime.now(timezone.utc).isoformat(), 

4250 } 

4251 await self._publish_event(event) 

4252 

4253 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None: 

4254 """Notify subscribers of gateway deactivation. 

4255 

4256 Args: 

4257 gateway: Gateway database object 

4258 """ 

4259 event = { 

4260 "type": "gateway_deactivated", 

4261 "data": { 

4262 "id": gateway.id, 

4263 "name": gateway.name, 

4264 "url": gateway.url, 

4265 "enabled": gateway.enabled, 

4266 "reachable": gateway.reachable, 

4267 }, 

4268 "timestamp": datetime.now(timezone.utc).isoformat(), 

4269 } 

4270 await self._publish_event(event) 

4271 

4272 async def _notify_gateway_offline(self, gateway: DbGateway) -> None: 

4273 """ 

4274 Notify subscribers that gateway is offline (Enabled but Unreachable). 

4275 

4276 Args: 

4277 gateway: Gateway database object 

4278 """ 

4279 event = { 

4280 "type": "gateway_offline", 

4281 "data": { 

4282 "id": gateway.id, 

4283 "name": gateway.name, 

4284 "url": gateway.url, 

4285 "enabled": True, 

4286 "reachable": False, 

4287 }, 

4288 "timestamp": datetime.now(timezone.utc).isoformat(), 

4289 } 

4290 await self._publish_event(event) 

4291 

4292 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None: 

4293 """Notify subscribers of gateway deletion. 

4294 

4295 Args: 

4296 gateway_info: Dict containing information about gateway to delete 

4297 """ 

4298 event = { 

4299 "type": "gateway_deleted", 

4300 "data": gateway_info, 

4301 "timestamp": datetime.now(timezone.utc).isoformat(), 

4302 } 

4303 await self._publish_event(event) 

4304 

4305 async def _notify_gateway_removed(self, gateway: DbGateway) -> None: 

4306 """Notify subscribers of gateway removal (deactivation). 

4307 

4308 Args: 

4309 gateway: Gateway to remove 

4310 """ 

4311 event = { 

4312 "type": "gateway_removed", 

4313 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled}, 

4314 "timestamp": datetime.now(timezone.utc).isoformat(), 

4315 } 

4316 await self._publish_event(event) 

4317 

4318 def convert_gateway_to_read(self, gateway: DbGateway) -> GatewayRead: 

4319 """Convert a DbGateway instance to a GatewayRead Pydantic model. 

4320 

4321 Args: 

4322 gateway: Gateway database object 

4323 

4324 Returns: 

4325 GatewayRead: Pydantic model instance 

4326 """ 

4327 gateway_dict = gateway.__dict__.copy() 

4328 gateway_dict.pop("_sa_instance_state", None) 

4329 

4330 # Ensure auth_value is properly encoded 

4331 if isinstance(gateway.auth_value, dict): 

4332 gateway_dict["auth_value"] = encode_auth(gateway.auth_value) 

4333 

4334 if gateway.tags: 

4335 # Check tags are list of strings or list of Dict[str, str] 

4336 if isinstance(gateway.tags[0], str): 

4337 # Convert tags from List[str] to List[Dict[str, str]] for GatewayRead 

4338 gateway_dict["tags"] = validate_tags_field(gateway.tags) 

4339 else: 

4340 gateway_dict["tags"] = gateway.tags 

4341 else: 

4342 gateway_dict["tags"] = [] 

4343 

4344 # Include metadata fields 

4345 gateway_dict["created_by"] = getattr(gateway, "created_by", None) 

4346 gateway_dict["modified_by"] = getattr(gateway, "modified_by", None) 

4347 gateway_dict["created_at"] = getattr(gateway, "created_at", None) 

4348 gateway_dict["updated_at"] = getattr(gateway, "updated_at", None) 

4349 gateway_dict["version"] = getattr(gateway, "version", None) 

4350 gateway_dict["team"] = getattr(gateway, "team", None) 

4351 

4352 return GatewayRead.model_validate(gateway_dict).masked() 

4353 

4354 def _prepare_gateway_for_read(self, gateway: DbGateway) -> DbGateway: 

4355 """DEPRECATED: Use convert_gateway_to_read instead. 

4356 

4357 Prepare a gateway object for GatewayRead validation. 

4358 

4359 Ensures auth_value is in the correct format (encoded string) for the schema. 

4360 Converts legacy List[str] tags to List[Dict[str, str]] format for GatewayRead schema. 

4361 

4362 Args: 

4363 gateway: Gateway database object 

4364 

4365 Returns: 

4366 Gateway object with properly formatted auth_value and tags 

4367 """ 

4368 # If auth_value is a dict, encode it to string for GatewayRead schema 

4369 if isinstance(gateway.auth_value, dict): 

4370 gateway.auth_value = encode_auth(gateway.auth_value) 

4371 

4372 # Handle legacy List[str] tags - convert to List[Dict[str, str]] for GatewayRead schema 

4373 if gateway.tags: 

4374 if isinstance(gateway.tags[0], str): 

4375 # Legacy format: convert to dict format 

4376 gateway.tags = validate_tags_field(gateway.tags) 

4377 

4378 return gateway 

4379 

4380 def _create_db_tool( 

4381 self, 

4382 tool: ToolCreate, 

4383 gateway: DbGateway, 

4384 created_by: Optional[str] = None, 

4385 created_from_ip: Optional[str] = None, 

4386 created_via: Optional[str] = None, 

4387 created_user_agent: Optional[str] = None, 

4388 ) -> DbTool: 

4389 """Create a DbTool with consistent federation metadata across all scenarios. 

4390 

4391 Args: 

4392 tool: Tool creation schema 

4393 gateway: Gateway database object 

4394 created_by: Username who created/updated this tool 

4395 created_from_ip: IP address of creator 

4396 created_via: Creation method (ui, api, federation, rediscovery) 

4397 created_user_agent: User agent of creation request 

4398 

4399 Returns: 

4400 DbTool: Consistently configured database tool object 

4401 """ 

4402 return DbTool( 

4403 original_name=tool.name, 

4404 custom_name=tool.name, 

4405 custom_name_slug=slugify(tool.name), 

4406 display_name=generate_display_name(tool.name), 

4407 url=gateway.url, 

4408 description=tool.description, 

4409 integration_type="MCP", # Gateway-discovered tools are MCP type 

4410 request_type=tool.request_type, 

4411 headers=tool.headers, 

4412 input_schema=tool.input_schema, 

4413 annotations=tool.annotations, 

4414 jsonpath_filter=tool.jsonpath_filter, 

4415 auth_type=gateway.auth_type, 

4416 auth_value=encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value, 

4417 # Federation metadata - consistent across all scenarios 

4418 created_by=created_by or "system", 

4419 created_from_ip=created_from_ip, 

4420 created_via=created_via or "federation", 

4421 created_user_agent=created_user_agent, 

4422 federation_source=gateway.name, 

4423 version=1, 

4424 # Inherit team assignment and visibility from gateway 

4425 team_id=gateway.team_id, 

4426 owner_email=gateway.owner_email, 

4427 visibility="public", # Federated tools should be public for discovery 

4428 ) 

4429 

4430 def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGateway, created_via: str) -> List[DbTool]: 

4431 """Helper to handle update-or-create logic for tools from MCP server. 

4432 

4433 Args: 

4434 db: Database session 

4435 tools: List of tools from MCP server 

4436 gateway: Gateway object 

4437 created_via: String indicating creation source ("oauth", "update", etc.) 

4438 

4439 Returns: 

4440 List of new tools to be added to the database 

4441 """ 

4442 if not tools: 

4443 return [] 

4444 

4445 tools_to_add = [] 

4446 

4447 # Batch fetch all existing tools for this gateway 

4448 tool_names = [tool.name for tool in tools if tool is not None] 

4449 if not tool_names: 4449 ↛ 4450line 4449 didn't jump to line 4450 because the condition on line 4449 was never true

4450 return [] 

4451 

4452 existing_tools_query = select(DbTool).where(DbTool.gateway_id == gateway.id, DbTool.original_name.in_(tool_names)) 

4453 existing_tools = db.execute(existing_tools_query).scalars().all() 

4454 existing_tools_map = {tool.original_name: tool for tool in existing_tools} 

4455 

4456 for tool in tools: 

4457 if tool is None: 

4458 logger.warning("Skipping None tool in tools list") 

4459 continue 

4460 

4461 try: 

4462 # Check if tool already exists for this gateway from the tools_map 

4463 existing_tool = existing_tools_map.get(tool.name) 

4464 if existing_tool: 

4465 # Update existing tool if there are changes 

4466 fields_to_update = False 

4467 

4468 # Check basic field changes 

4469 basic_fields_changed = ( 

4470 existing_tool.url != gateway.url or existing_tool.description != tool.description or existing_tool.integration_type != "MCP" or existing_tool.request_type != tool.request_type 

4471 ) 

4472 

4473 # Check schema and configuration changes 

4474 schema_fields_changed = ( 

4475 existing_tool.headers != tool.headers 

4476 or existing_tool.input_schema != tool.input_schema 

4477 or existing_tool.output_schema != tool.output_schema 

4478 or existing_tool.jsonpath_filter != tool.jsonpath_filter 

4479 ) 

4480 

4481 # Check authentication and visibility changes 

4482 auth_fields_changed = existing_tool.auth_type != gateway.auth_type or existing_tool.auth_value != gateway.auth_value or existing_tool.visibility != gateway.visibility 

4483 

4484 if basic_fields_changed or schema_fields_changed or auth_fields_changed: 4484 ↛ 4486line 4484 didn't jump to line 4486 because the condition on line 4484 was always true

4485 fields_to_update = True 

4486 if fields_to_update: 4486 ↛ 4456line 4486 didn't jump to line 4456 because the condition on line 4486 was always true

4487 existing_tool.url = gateway.url 

4488 existing_tool.description = tool.description 

4489 existing_tool.integration_type = "MCP" 

4490 existing_tool.request_type = tool.request_type 

4491 existing_tool.headers = tool.headers 

4492 existing_tool.input_schema = tool.input_schema 

4493 existing_tool.output_schema = tool.output_schema 

4494 existing_tool.jsonpath_filter = tool.jsonpath_filter 

4495 existing_tool.auth_type = gateway.auth_type 

4496 existing_tool.auth_value = gateway.auth_value 

4497 existing_tool.visibility = gateway.visibility 

4498 logger.debug(f"Updated existing tool: {tool.name}") 

4499 else: 

4500 # Create new tool if it doesn't exist 

4501 db_tool = self._create_db_tool( 

4502 tool=tool, 

4503 gateway=gateway, 

4504 created_by="system", 

4505 created_via=created_via, 

4506 ) 

4507 # Attach relationship to avoid NoneType during flush 

4508 db_tool.gateway = gateway 

4509 tools_to_add.append(db_tool) 

4510 logger.debug(f"Created new tool: {tool.name}") 

4511 except Exception as e: 

4512 logger.warning(f"Failed to process tool {getattr(tool, 'name', 'unknown')}: {e}") 

4513 continue 

4514 

4515 return tools_to_add 

4516 

4517 def _update_or_create_resources(self, db: Session, resources: List[Any], gateway: DbGateway, created_via: str) -> List[DbResource]: 

4518 """Helper to handle update-or-create logic for resources from MCP server. 

4519 

4520 Args: 

4521 db: Database session 

4522 resources: List of resources from MCP server 

4523 gateway: Gateway object 

4524 created_via: String indicating creation source ("oauth", "update", etc.) 

4525 

4526 Returns: 

4527 List of new resources to be added to the database 

4528 """ 

4529 if not resources: 

4530 return [] 

4531 

4532 resources_to_add = [] 

4533 

4534 # Batch fetch all existing resources for this gateway 

4535 resource_uris = [resource.uri for resource in resources if resource is not None] 

4536 if not resource_uris: 4536 ↛ 4537line 4536 didn't jump to line 4537 because the condition on line 4536 was never true

4537 return [] 

4538 

4539 existing_resources_query = select(DbResource).where(DbResource.gateway_id == gateway.id, DbResource.uri.in_(resource_uris)) 

4540 existing_resources = db.execute(existing_resources_query).scalars().all() 

4541 existing_resources_map = {resource.uri: resource for resource in existing_resources} 

4542 

4543 for resource in resources: 

4544 if resource is None: 

4545 logger.warning("Skipping None resource in resources list") 

4546 continue 

4547 

4548 try: 

4549 # Check if resource already exists for this gateway from the resources_map 

4550 existing_resource = existing_resources_map.get(resource.uri) 

4551 

4552 if existing_resource: 

4553 # Update existing resource if there are changes 

4554 fields_to_update = False 

4555 

4556 if ( 4556 ↛ 4565line 4556 didn't jump to line 4565 because the condition on line 4556 was always true

4557 existing_resource.name != resource.name 

4558 or existing_resource.description != resource.description 

4559 or existing_resource.mime_type != resource.mime_type 

4560 or existing_resource.uri_template != resource.uri_template 

4561 or existing_resource.visibility != gateway.visibility 

4562 ): 

4563 fields_to_update = True 

4564 

4565 if fields_to_update: 4565 ↛ 4543line 4565 didn't jump to line 4543 because the condition on line 4565 was always true

4566 existing_resource.name = resource.name 

4567 existing_resource.description = resource.description 

4568 existing_resource.mime_type = resource.mime_type 

4569 existing_resource.uri_template = resource.uri_template 

4570 existing_resource.visibility = gateway.visibility 

4571 logger.debug(f"Updated existing resource: {resource.uri}") 

4572 else: 

4573 # Create new resource if it doesn't exist 

4574 db_resource = DbResource( 

4575 uri=resource.uri, 

4576 name=resource.name, 

4577 description=resource.description, 

4578 mime_type=resource.mime_type, 

4579 uri_template=resource.uri_template, 

4580 gateway_id=gateway.id, 

4581 created_by="system", 

4582 created_via=created_via, 

4583 visibility=gateway.visibility, 

4584 ) 

4585 resources_to_add.append(db_resource) 

4586 logger.debug(f"Created new resource: {resource.uri}") 

4587 except Exception as e: 

4588 logger.warning(f"Failed to process resource {getattr(resource, 'uri', 'unknown')}: {e}") 

4589 continue 

4590 

4591 return resources_to_add 

4592 

4593 def _update_or_create_prompts(self, db: Session, prompts: List[Any], gateway: DbGateway, created_via: str) -> List[DbPrompt]: 

4594 """Helper to handle update-or-create logic for prompts from MCP server. 

4595 

4596 Args: 

4597 db: Database session 

4598 prompts: List of prompts from MCP server 

4599 gateway: Gateway object 

4600 created_via: String indicating creation source ("oauth", "update", etc.) 

4601 

4602 Returns: 

4603 List of new prompts to be added to the database 

4604 """ 

4605 if not prompts: 

4606 return [] 

4607 

4608 prompts_to_add = [] 

4609 

4610 # Batch fetch all existing prompts for this gateway 

4611 prompt_names = [prompt.name for prompt in prompts if prompt is not None] 

4612 if not prompt_names: 4612 ↛ 4613line 4612 didn't jump to line 4613 because the condition on line 4612 was never true

4613 return [] 

4614 

4615 existing_prompts_query = select(DbPrompt).where(DbPrompt.gateway_id == gateway.id, DbPrompt.original_name.in_(prompt_names)) 

4616 existing_prompts = db.execute(existing_prompts_query).scalars().all() 

4617 existing_prompts_map = {prompt.original_name: prompt for prompt in existing_prompts} 

4618 

4619 for prompt in prompts: 

4620 if prompt is None: 

4621 logger.warning("Skipping None prompt in prompts list") 

4622 continue 

4623 

4624 try: 

4625 # Check if resource already exists for this gateway from the prompts_map 

4626 existing_prompt = existing_prompts_map.get(prompt.name) 

4627 

4628 if existing_prompt: 

4629 # Update existing prompt if there are changes 

4630 fields_to_update = False 

4631 

4632 if ( 4632 ↛ 4639line 4632 didn't jump to line 4639 because the condition on line 4632 was always true

4633 existing_prompt.description != prompt.description 

4634 or existing_prompt.template != (prompt.template if hasattr(prompt, "template") else "") 

4635 or existing_prompt.visibility != gateway.visibility 

4636 ): 

4637 fields_to_update = True 

4638 

4639 if fields_to_update: 4639 ↛ 4619line 4639 didn't jump to line 4619 because the condition on line 4639 was always true

4640 existing_prompt.description = prompt.description 

4641 existing_prompt.template = prompt.template if hasattr(prompt, "template") else "" 

4642 existing_prompt.visibility = gateway.visibility 

4643 logger.debug(f"Updated existing prompt: {prompt.name}") 

4644 else: 

4645 # Create new prompt if it doesn't exist 

4646 db_prompt = DbPrompt( 

4647 name=prompt.name, 

4648 original_name=prompt.name, 

4649 custom_name=prompt.name, 

4650 display_name=prompt.name, 

4651 description=prompt.description, 

4652 template=prompt.template if hasattr(prompt, "template") else "", 

4653 argument_schema={}, # Use argument_schema instead of arguments 

4654 gateway_id=gateway.id, 

4655 created_by="system", 

4656 created_via=created_via, 

4657 visibility=gateway.visibility, 

4658 ) 

4659 db_prompt.gateway = gateway 

4660 prompts_to_add.append(db_prompt) 

4661 logger.debug(f"Created new prompt: {prompt.name}") 

4662 except Exception as e: 

4663 logger.warning(f"Failed to process prompt {getattr(prompt, 'name', 'unknown')}: {e}") 

4664 continue 

4665 

4666 return prompts_to_add 

4667 

4668 async def _refresh_gateway_tools_resources_prompts( 

4669 self, 

4670 gateway_id: str, 

4671 _user_email: Optional[str] = None, 

4672 created_via: str = "health_check", 

4673 pre_auth_headers: Optional[Dict[str, str]] = None, 

4674 gateway: Optional[DbGateway] = None, 

4675 include_resources: bool = True, 

4676 include_prompts: bool = True, 

4677 ) -> Dict[str, int]: 

4678 """Refresh tools, resources, and prompts for a gateway during health checks. 

4679 

4680 Fetches the latest tools/resources/prompts from the MCP server and syncs 

4681 with the database (add new, update changed, remove stale). Only performs 

4682 DB operations if actual changes are detected. 

4683 

4684 This method uses fresh_db_session() internally to avoid holding 

4685 connections during HTTP calls to MCP servers. 

4686 

4687 Args: 

4688 gateway_id: ID of the gateway to refresh 

4689 _user_email: Optional user email for OAuth token lookup (unused currently) 

4690 created_via: String indicating creation source (default: "health_check") 

4691 pre_auth_headers: Pre-authenticated headers from health check to avoid duplicate OAuth token fetch 

4692 gateway: Optional DbGateway object to avoid redundant DB lookup 

4693 include_resources: Whether to include resources in the refresh 

4694 include_prompts: Whether to include prompts in the refresh 

4695 

4696 Returns: 

4697 Dict with counts: {tools_added, tools_removed, resources_added, 

4698 resources_removed, prompts_added, prompts_removed} 

4699 

4700 Examples: 

4701 >>> from mcpgateway.services.gateway_service import GatewayService 

4702 >>> from unittest.mock import patch, MagicMock, AsyncMock 

4703 >>> import asyncio 

4704 

4705 >>> # Test gateway not found returns empty result 

4706 >>> service = GatewayService() 

4707 >>> mock_session = MagicMock() 

4708 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = None 

4709 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4710 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4711 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4712 >>> result['tools_added'] == 0 and result['tools_removed'] == 0 

4713 True 

4714 >>> result['resources_added'] == 0 and result['resources_removed'] == 0 

4715 True 

4716 >>> result['success'] is True and result['error'] is None 

4717 True 

4718 

4719 >>> # Test disabled gateway returns empty result 

4720 >>> mock_gw = MagicMock() 

4721 >>> mock_gw.enabled = False 

4722 >>> mock_gw.reachable = True 

4723 >>> mock_gw.name = 'test_gw' 

4724 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = mock_gw 

4725 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4726 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4727 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4728 >>> result['tools_added'] 

4729 0 

4730 

4731 >>> # Test unreachable gateway returns empty result 

4732 >>> mock_gw.enabled = True 

4733 >>> mock_gw.reachable = False 

4734 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh: 

4735 ... mock_fresh.return_value.__enter__.return_value = mock_session 

4736 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123')) 

4737 >>> result['tools_added'] 

4738 0 

4739 

4740 >>> # Test method is async and callable 

4741 >>> import inspect 

4742 >>> inspect.iscoroutinefunction(service._refresh_gateway_tools_resources_prompts) 

4743 True 

4744 >>> 

4745 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs 

4746 >>> asyncio.run(service._http_client.aclose()) 

4747 """ 

4748 result = { 

4749 "tools_added": 0, 

4750 "tools_removed": 0, 

4751 "resources_added": 0, 

4752 "resources_removed": 0, 

4753 "prompts_added": 0, 

4754 "prompts_removed": 0, 

4755 "tools_updated": 0, 

4756 "resources_updated": 0, 

4757 "prompts_updated": 0, 

4758 "success": True, 

4759 "error": None, 

4760 "validation_errors": [], 

4761 } 

4762 

4763 # Fetch gateway metadata only (no relationships needed for MCP call) 

4764 # Use provided gateway object if available to save a DB call 

4765 gateway_name = None 

4766 gateway_url = None 

4767 gateway_transport = None 

4768 gateway_auth_type = None 

4769 gateway_auth_value = None 

4770 gateway_oauth_config = None 

4771 gateway_ca_certificate = None 

4772 gateway_auth_query_params = None 

4773 

4774 if gateway: 

4775 if not gateway.enabled or not gateway.reachable: 

4776 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway.name}") 

4777 return result 

4778 

4779 gateway_name = gateway.name 

4780 gateway_url = gateway.url 

4781 gateway_transport = gateway.transport 

4782 gateway_auth_type = gateway.auth_type 

4783 gateway_auth_value = gateway.auth_value 

4784 gateway_oauth_config = gateway.oauth_config 

4785 gateway_ca_certificate = gateway.ca_certificate 

4786 gateway_auth_query_params = gateway.auth_query_params 

4787 else: 

4788 with fresh_db_session() as db: 

4789 gateway_obj = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

4790 

4791 if not gateway_obj: 

4792 logger.warning(f"Gateway {gateway_id} not found for tool refresh") 

4793 return result 

4794 

4795 if not gateway_obj.enabled or not gateway_obj.reachable: 

4796 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway_obj.name}") 

4797 return result 

4798 

4799 # Extract metadata before session closes 

4800 gateway_name = gateway_obj.name 

4801 gateway_url = gateway_obj.url 

4802 gateway_transport = gateway_obj.transport 

4803 gateway_auth_type = gateway_obj.auth_type 

4804 gateway_auth_value = gateway_obj.auth_value 

4805 gateway_oauth_config = gateway_obj.oauth_config 

4806 gateway_ca_certificate = gateway_obj.ca_certificate 

4807 gateway_auth_query_params = gateway_obj.auth_query_params 

4808 

4809 # Handle query_param auth - decrypt and apply to URL for refresh 

4810 auth_query_params_decrypted: Optional[Dict[str, str]] = None 

4811 if gateway_auth_type == "query_param" and gateway_auth_query_params: 

4812 auth_query_params_decrypted = {} 

4813 for param_key, encrypted_value in gateway_auth_query_params.items(): 

4814 if encrypted_value: 4814 ↛ 4813line 4814 didn't jump to line 4813 because the condition on line 4814 was always true

4815 try: 

4816 decrypted = decode_auth(encrypted_value) 

4817 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "") 

4818 except Exception: 

4819 logger.debug(f"Failed to decrypt query param '{param_key}' for tool refresh") 

4820 if auth_query_params_decrypted: 4820 ↛ 4824line 4820 didn't jump to line 4824 because the condition on line 4820 was always true

4821 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted) 

4822 

4823 # Fetch tools/resources/prompts from MCP server (no DB connection held) 

4824 try: 

4825 _capabilities, tools, resources, prompts = await self._initialize_gateway( 

4826 url=gateway_url, 

4827 authentication=gateway_auth_value, 

4828 transport=gateway_transport, 

4829 auth_type=gateway_auth_type, 

4830 oauth_config=gateway_oauth_config, 

4831 ca_certificate=gateway_ca_certificate.encode() if gateway_ca_certificate else None, 

4832 pre_auth_headers=pre_auth_headers, 

4833 include_resources=include_resources, 

4834 include_prompts=include_prompts, 

4835 auth_query_params=auth_query_params_decrypted, 

4836 ) 

4837 except Exception as e: 

4838 logger.warning(f"Failed to fetch tools from gateway {gateway_name}: {e}") 

4839 result["success"] = False 

4840 result["error"] = str(e) 

4841 return result 

4842 

4843 # For authorization_code OAuth gateways, empty responses may indicate incomplete auth flow 

4844 # Skip only if it's an auth_code gateway with no data (user may not have completed authorization) 

4845 is_auth_code_gateway = gateway_oauth_config and isinstance(gateway_oauth_config, dict) and gateway_oauth_config.get("grant_type") == "authorization_code" 

4846 if not tools and not resources and not prompts and is_auth_code_gateway: 

4847 logger.debug(f"No tools/resources/prompts returned from auth_code gateway {gateway_name} (user may not have authorized)") 

4848 return result 

4849 

4850 # For non-auth_code gateways, empty responses are legitimate and will clear stale items 

4851 

4852 # Update database with fresh session 

4853 with fresh_db_session() as db: 

4854 # Fetch gateway with relationships for update/comparison 

4855 gateway = db.execute( 

4856 select(DbGateway) 

4857 .options( 

4858 selectinload(DbGateway.tools), 

4859 selectinload(DbGateway.resources), 

4860 selectinload(DbGateway.prompts), 

4861 ) 

4862 .where(DbGateway.id == gateway_id) 

4863 ).scalar_one_or_none() 

4864 

4865 if not gateway: 4865 ↛ 4866line 4865 didn't jump to line 4866 because the condition on line 4865 was never true

4866 result["success"] = False 

4867 result["error"] = f"Gateway {gateway_id} not found during refresh" 

4868 return result 

4869 

4870 new_tool_names = [tool.name for tool in tools] 

4871 new_resource_uris = [resource.uri for resource in resources] if include_resources else None 

4872 new_prompt_names = [prompt.name for prompt in prompts] if include_prompts else None 

4873 

4874 # Track dirty objects before update operations to count per-type updates 

4875 pending_tools_before = {obj for obj in db.dirty if isinstance(obj, DbTool)} 

4876 pending_resources_before = {obj for obj in db.dirty if isinstance(obj, DbResource)} 

4877 pending_prompts_before = {obj for obj in db.dirty if isinstance(obj, DbPrompt)} 

4878 

4879 # Update/create tools, resources, and prompts 

4880 tools_to_add = self._update_or_create_tools(db, tools, gateway, created_via) 

4881 resources_to_add = self._update_or_create_resources(db, resources, gateway, created_via) if include_resources else [] 

4882 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, created_via) if include_prompts else [] 

4883 

4884 # Count per-type updates 

4885 result["tools_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbTool)} - pending_tools_before) 

4886 result["resources_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbResource)} - pending_resources_before) 

4887 result["prompts_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbPrompt)} - pending_prompts_before) 

4888 

4889 # Only delete MCP-discovered items (not user-created entries) 

4890 # Excludes "api", "ui", None (legacy/user-created) to preserve user entries 

4891 mcp_created_via_values = {"MCP", "federation", "health_check", "manual_refresh", "oauth", "update"} 

4892 

4893 # Find and remove stale tools (only MCP-discovered ones) 

4894 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names and tool.created_via in mcp_created_via_values] 

4895 if stale_tool_ids: 

4896 for i in range(0, len(stale_tool_ids), 500): 

4897 chunk = stale_tool_ids[i : i + 500] 

4898 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk))) 

4899 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk))) 

4900 db.execute(delete(DbTool).where(DbTool.id.in_(chunk))) 

4901 result["tools_removed"] = len(stale_tool_ids) 

4902 

4903 # Find and remove stale resources (only MCP-discovered ones, only if resources were fetched) 

4904 stale_resource_ids = [] 

4905 if new_resource_uris is not None: 4905 ↛ 4917line 4905 didn't jump to line 4917 because the condition on line 4905 was always true

4906 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris and resource.created_via in mcp_created_via_values] 

4907 if stale_resource_ids: 4907 ↛ 4908line 4907 didn't jump to line 4908 because the condition on line 4907 was never true

4908 for i in range(0, len(stale_resource_ids), 500): 

4909 chunk = stale_resource_ids[i : i + 500] 

4910 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk))) 

4911 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk))) 

4912 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk))) 

4913 db.execute(delete(DbResource).where(DbResource.id.in_(chunk))) 

4914 result["resources_removed"] = len(stale_resource_ids) 

4915 

4916 # Find and remove stale prompts (only MCP-discovered ones, only if prompts were fetched) 

4917 stale_prompt_ids = [] 

4918 if new_prompt_names is not None: 4918 ↛ 4929line 4918 didn't jump to line 4929 because the condition on line 4918 was always true

4919 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names and prompt.created_via in mcp_created_via_values] 

4920 if stale_prompt_ids: 4920 ↛ 4921line 4920 didn't jump to line 4921 because the condition on line 4920 was never true

4921 for i in range(0, len(stale_prompt_ids), 500): 

4922 chunk = stale_prompt_ids[i : i + 500] 

4923 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk))) 

4924 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk))) 

4925 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk))) 

4926 result["prompts_removed"] = len(stale_prompt_ids) 

4927 

4928 # Expire gateway if stale items were deleted 

4929 if stale_tool_ids or stale_resource_ids or stale_prompt_ids: 

4930 db.expire(gateway) 

4931 

4932 # Add new items in chunks 

4933 chunk_size = 50 

4934 if tools_to_add: 

4935 for i in range(0, len(tools_to_add), chunk_size): 

4936 chunk = tools_to_add[i : i + chunk_size] 

4937 db.add_all(chunk) 

4938 db.flush() 

4939 result["tools_added"] = len(tools_to_add) 

4940 

4941 if resources_to_add: 

4942 for i in range(0, len(resources_to_add), chunk_size): 

4943 chunk = resources_to_add[i : i + chunk_size] 

4944 db.add_all(chunk) 

4945 db.flush() 

4946 result["resources_added"] = len(resources_to_add) 

4947 

4948 if prompts_to_add: 

4949 for i in range(0, len(prompts_to_add), chunk_size): 

4950 chunk = prompts_to_add[i : i + chunk_size] 

4951 db.add_all(chunk) 

4952 db.flush() 

4953 result["prompts_added"] = len(prompts_to_add) 

4954 

4955 gateway.last_refresh_at = datetime.now(timezone.utc) 

4956 

4957 total_changes = ( 

4958 result["tools_added"] 

4959 + result["tools_removed"] 

4960 + result["tools_updated"] 

4961 + result["resources_added"] 

4962 + result["resources_removed"] 

4963 + result["resources_updated"] 

4964 + result["prompts_added"] 

4965 + result["prompts_removed"] 

4966 + result["prompts_updated"] 

4967 ) 

4968 

4969 has_changes = total_changes > 0 

4970 

4971 if has_changes: 

4972 db.commit() 

4973 logger.info( 

4974 f"Refreshed gateway {gateway_name}: " 

4975 f"tools(+{result['tools_added']}/-{result['tools_removed']}/~{result['tools_updated']}), " 

4976 f"resources(+{result['resources_added']}/-{result['resources_removed']}/~{result['resources_updated']}), " 

4977 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}/~{result['prompts_updated']})" 

4978 ) 

4979 

4980 # Invalidate caches per-type based on actual changes 

4981 cache = _get_registry_cache() 

4982 if result["tools_added"] > 0 or result["tools_removed"] > 0 or result["tools_updated"] > 0: 4982 ↛ 4984line 4982 didn't jump to line 4984 because the condition on line 4982 was always true

4983 await cache.invalidate_tools() 

4984 if result["resources_added"] > 0 or result["resources_removed"] > 0 or result["resources_updated"] > 0: 

4985 await cache.invalidate_resources() 

4986 if result["prompts_added"] > 0 or result["prompts_removed"] > 0 or result["prompts_updated"] > 0: 

4987 await cache.invalidate_prompts() 

4988 

4989 # Invalidate tool lookup cache for this gateway 

4990 tool_lookup_cache = _get_tool_lookup_cache() 

4991 await tool_lookup_cache.invalidate_gateway(str(gateway_id)) 

4992 else: 

4993 db.commit() 

4994 logger.debug(f"No changes detected during refresh of gateway {gateway_name}") 

4995 

4996 return result 

4997 

4998 def _get_refresh_lock(self, gateway_id: str) -> asyncio.Lock: 

4999 """Get or create a per-gateway refresh lock. 

5000 

5001 This ensures only one refresh operation can run for a given gateway at a time. 

5002 

5003 Args: 

5004 gateway_id: ID of the gateway to get the lock for 

5005 

5006 Returns: 

5007 asyncio.Lock: The lock for the specified gateway 

5008 

5009 Examples: 

5010 >>> from mcpgateway.services.gateway_service import GatewayService 

5011 >>> service = GatewayService() 

5012 >>> lock1 = service._get_refresh_lock('gw-123') 

5013 >>> lock2 = service._get_refresh_lock('gw-123') 

5014 >>> lock1 is lock2 

5015 True 

5016 >>> lock3 = service._get_refresh_lock('gw-456') 

5017 >>> lock1 is lock3 

5018 False 

5019 """ 

5020 if gateway_id not in self._refresh_locks: 

5021 self._refresh_locks[gateway_id] = asyncio.Lock() 

5022 return self._refresh_locks[gateway_id] 

5023 

5024 async def refresh_gateway_manually( 

5025 self, 

5026 gateway_id: str, 

5027 include_resources: bool = True, 

5028 include_prompts: bool = True, 

5029 user_email: Optional[str] = None, 

5030 request_headers: Optional[Dict[str, str]] = None, 

5031 ) -> Dict[str, Any]: 

5032 """Manually trigger a refresh of tools/resources/prompts for a gateway. 

5033 

5034 This method provides a public API for triggering an immediate refresh 

5035 of a gateway's tools, resources, and prompts from its MCP server. 

5036 It includes concurrency control via per-gateway locking. 

5037 

5038 Args: 

5039 gateway_id: Gateway ID to refresh 

5040 include_resources: Whether to include resources in the refresh 

5041 include_prompts: Whether to include prompts in the refresh 

5042 user_email: Email of the user triggering the refresh 

5043 request_headers: Optional request headers for passthrough authentication 

5044 

5045 Returns: 

5046 Dict with counts: {tools_added, tools_updated, tools_removed, 

5047 resources_added, resources_updated, resources_removed, 

5048 prompts_added, prompts_updated, prompts_removed, 

5049 validation_errors, duration_ms, refreshed_at} 

5050 

5051 Raises: 

5052 GatewayNotFoundError: If the gateway does not exist 

5053 GatewayError: If another refresh is already in progress for this gateway 

5054 

5055 Examples: 

5056 >>> from mcpgateway.services.gateway_service import GatewayService 

5057 >>> from unittest.mock import patch, MagicMock, AsyncMock 

5058 >>> import asyncio 

5059 

5060 >>> # Test method is async 

5061 >>> service = GatewayService() 

5062 >>> import inspect 

5063 >>> inspect.iscoroutinefunction(service.refresh_gateway_manually) 

5064 True 

5065 """ 

5066 start_time = time.monotonic() 

5067 

5068 pre_auth_headers = {} 

5069 

5070 # Check if gateway exists before acquiring lock 

5071 with fresh_db_session() as db: 

5072 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() 

5073 if not gateway: 

5074 raise GatewayNotFoundError(f"Gateway with ID '{gateway_id}' not found") 

5075 gateway_name = gateway.name 

5076 

5077 # Get passthrough headers if request headers provided 

5078 if request_headers: 

5079 pre_auth_headers = get_passthrough_headers(request_headers, {}, db, gateway) 

5080 

5081 lock = self._get_refresh_lock(gateway_id) 

5082 

5083 # Check if lock is already held (concurrent refresh in progress) 

5084 if lock.locked(): 

5085 raise GatewayError(f"Refresh already in progress for gateway {gateway_name}") 

5086 

5087 async with lock: 

5088 logger.info(f"Starting manual refresh for gateway {gateway_name} (ID: {gateway_id})") 

5089 

5090 result = await self._refresh_gateway_tools_resources_prompts( 

5091 gateway_id=gateway_id, 

5092 _user_email=user_email, 

5093 created_via="manual_refresh", 

5094 pre_auth_headers=pre_auth_headers, 

5095 gateway=gateway, 

5096 include_resources=include_resources, 

5097 include_prompts=include_prompts, 

5098 ) 

5099 # Note: last_refresh_at is updated inside _refresh_gateway_tools_resources_prompts on success 

5100 

5101 result["duration_ms"] = (time.monotonic() - start_time) * 1000 

5102 result["refreshed_at"] = datetime.now(timezone.utc) 

5103 

5104 log_level = logging.INFO if result.get("success", True) else logging.WARNING 

5105 status_msg = "succeeded" if result.get("success", True) else f"failed: {result.get('error')}" 

5106 

5107 logger.log( 

5108 log_level, 

5109 f"Manual refresh for gateway {gateway_id} {status_msg}. Stats: " 

5110 f"tools(+{result['tools_added']}/-{result['tools_removed']}), " 

5111 f"resources(+{result['resources_added']}/-{result['resources_removed']}), " 

5112 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}) " 

5113 f"in {result['duration_ms']:.2f}ms", 

5114 ) 

5115 

5116 return result 

5117 

5118 async def _publish_event(self, event: Dict[str, Any]) -> None: 

5119 """Publish event to all subscribers. 

5120 

5121 Args: 

5122 event: event dictionary 

5123 

5124 Examples: 

5125 >>> import asyncio 

5126 >>> from unittest.mock import AsyncMock 

5127 >>> service = GatewayService() 

5128 >>> # Mock the underlying event service 

5129 >>> service._event_service = AsyncMock() 

5130 >>> test_event = {"type": "test", "data": {}} 

5131 >>> 

5132 >>> asyncio.run(service._publish_event(test_event)) 

5133 >>> 

5134 >>> # Verify the event was passed to the event service 

5135 >>> service._event_service.publish_event.assert_awaited_with(test_event) 

5136 """ 

5137 await self._event_service.publish_event(event) 

5138 

5139 def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> tuple[list[ToolCreate], list[str]]: 

5140 """Validate tools individually with richer logging and error aggregation. 

5141 

5142 Args: 

5143 tools: list of tool dicts 

5144 context: caller context, e.g. "oauth" to tailor errors/messages 

5145 

5146 Returns: 

5147 tuple[list[ToolCreate], list[str]]: Tuple of (valid tools, validation errors) 

5148 

5149 Raises: 

5150 OAuthToolValidationError: If all tools fail validation in OAuth context 

5151 GatewayConnectionError: If all tools fail validation in default context 

5152 """ 

5153 valid_tools: list[ToolCreate] = [] 

5154 validation_errors: list[str] = [] 

5155 

5156 for i, tool_dict in enumerate(tools): 

5157 tool_name = tool_dict.get("name", f"unknown_tool_{i}") 

5158 try: 

5159 logger.debug(f"Validating tool: {tool_name}") 

5160 validated_tool = ToolCreate.model_validate(tool_dict) 

5161 valid_tools.append(validated_tool) 

5162 logger.debug(f"Tool '{tool_name}' validated successfully") 

5163 except ValidationError as e: 

5164 error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}" 

5165 logger.error(error_msg) 

5166 logger.debug(f"Failed tool schema: {tool_dict}") 

5167 validation_errors.append(error_msg) 

5168 except ValueError as e: 

5169 if "JSON structure exceeds maximum depth" in str(e): 

5170 error_msg = f"Tool '{tool_name}' schema too deeply nested. " f"Current depth limit: {settings.validation_max_json_depth}" 

5171 logger.error(error_msg) 

5172 logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable") 

5173 else: 

5174 error_msg = f"ValueError for tool '{tool_name}': {str(e)}" 

5175 logger.error(error_msg) 

5176 validation_errors.append(error_msg) 

5177 except Exception as e: # pragma: no cover - defensive 

5178 error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}" 

5179 logger.error(error_msg, exc_info=True) 

5180 validation_errors.append(error_msg) 

5181 

5182 if validation_errors: 

5183 logger.warning(f"Tool validation completed with {len(validation_errors)} error(s). " f"Successfully validated {len(valid_tools)} tool(s).") 

5184 for err in validation_errors[:3]: 

5185 logger.debug(f"Validation error: {err}") 

5186 

5187 if not valid_tools and validation_errors: 

5188 if context == "oauth": 

5189 raise OAuthToolValidationError(f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}") 

5190 raise GatewayConnectionError(f"Failed to fetch tools: All {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}") 

5191 

5192 return valid_tools, validation_errors 

5193 

5194 async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None): 

5195 """Connect to an MCP server running with SSE transport, skipping URL validation. 

5196 

5197 This is used for OAuth-protected servers where we've already validated the token works. 

5198 

5199 Args: 

5200 server_url: The URL of the SSE MCP server to connect to. 

5201 authentication: Optional dictionary containing authentication headers. 

5202 

5203 Returns: 

5204 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5205 """ 

5206 if authentication is None: 

5207 authentication = {} 

5208 

5209 # Skip validation for OAuth servers - we already validated via OAuth flow 

5210 # Use async with for both sse_client and ClientSession 

5211 try: 

5212 async with sse_client(url=server_url, headers=authentication) as streams: 

5213 async with ClientSession(*streams) as session: 

5214 # Initialize the session 

5215 response = await session.initialize() 

5216 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5217 logger.debug(f"Server capabilities: {capabilities}") 

5218 

5219 response = await session.list_tools() 

5220 tools = response.tools 

5221 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5222 

5223 tools, _ = self._validate_tools(tools, context="oauth") 

5224 if tools: 

5225 logger.info(f"Fetched {len(tools)} tools from gateway") 

5226 # Fetch resources if supported 

5227 

5228 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5229 resources = [] 

5230 if capabilities.get("resources"): 5230 ↛ 5282line 5230 didn't jump to line 5282 because the condition on line 5230 was always true

5231 try: 

5232 response = await session.list_resources() 

5233 raw_resources = response.resources 

5234 for resource in raw_resources: 

5235 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5236 # Convert AnyUrl to string if present 

5237 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5238 resource_data["uri"] = str(resource_data["uri"]) 

5239 # Add default content if not present (will be fetched on demand) 

5240 if "content" not in resource_data: 5240 ↛ 5242line 5240 didn't jump to line 5242 because the condition on line 5240 was always true

5241 resource_data["content"] = "" 

5242 try: 

5243 resources.append(ResourceCreate.model_validate(resource_data)) 

5244 except Exception: 

5245 # If validation fails, create minimal resource 

5246 resources.append( 

5247 ResourceCreate( 

5248 uri=str(resource_data.get("uri", "")), 

5249 name=resource_data.get("name", ""), 

5250 description=resource_data.get("description"), 

5251 mime_type=resource_data.get("mimeType"), 

5252 uri_template=resource_data.get("uriTemplate") or None, 

5253 content="", 

5254 ) 

5255 ) 

5256 logger.info(f"Fetched {len(resources)} resources from gateway") 

5257 except Exception as e: 

5258 logger.warning(f"Failed to fetch resources: {e}") 

5259 

5260 # resource template URI 

5261 try: 

5262 response_templates = await session.list_resource_templates() 

5263 raw_resources_templates = response_templates.resourceTemplates 

5264 resource_templates = [] 

5265 for resource_template in raw_resources_templates: 

5266 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5267 

5268 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 5268 ↛ 5272line 5268 didn't jump to line 5272 because the condition on line 5268 was always true

5269 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5270 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5271 

5272 if "content" not in resource_template_data: 5272 ↛ 5275line 5272 didn't jump to line 5275 because the condition on line 5272 was always true

5273 resource_template_data["content"] = "" 

5274 

5275 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5276 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5277 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") 

5278 except Exception as e: 

5279 logger.warning(f"Failed to fetch resource templates: {e}") 

5280 

5281 # Fetch prompts if supported 

5282 prompts = [] 

5283 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5284 if capabilities.get("prompts"): 5284 ↛ 5308line 5284 didn't jump to line 5308 because the condition on line 5284 was always true

5285 try: 

5286 response = await session.list_prompts() 

5287 raw_prompts = response.prompts 

5288 for prompt in raw_prompts: 

5289 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5290 # Add default template if not present 

5291 if "template" not in prompt_data: 5291 ↛ 5293line 5291 didn't jump to line 5293 because the condition on line 5291 was always true

5292 prompt_data["template"] = "" 

5293 try: 

5294 prompts.append(PromptCreate.model_validate(prompt_data)) 

5295 except Exception: 

5296 # If validation fails, create minimal prompt 

5297 prompts.append( 

5298 PromptCreate( 

5299 name=prompt_data.get("name", ""), 

5300 description=prompt_data.get("description"), 

5301 template=prompt_data.get("template", ""), 

5302 ) 

5303 ) 

5304 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5305 except Exception as e: 

5306 logger.warning(f"Failed to fetch prompts: {e}") 

5307 

5308 return capabilities, tools, resources, prompts 

5309 except Exception as e: 

5310 # Note: This function is for OAuth servers only, which don't use query param auth 

5311 # Still sanitize in case exception contains URL with static sensitive params 

5312 sanitized_url = sanitize_url_for_logging(server_url) 

5313 sanitized_error = sanitize_exception_message(str(e)) 

5314 logger.error(f"SSE connection error details: {type(e).__name__}: {sanitized_error}", exc_info=True) 

5315 raise GatewayConnectionError(f"Failed to connect to SSE server at {sanitized_url}: {sanitized_error}") 

5316 

5317 async def connect_to_sse_server( 

5318 self, 

5319 server_url: str, 

5320 authentication: Optional[Dict[str, str]] = None, 

5321 ca_certificate: Optional[bytes] = None, 

5322 include_prompts: bool = True, 

5323 include_resources: bool = True, 

5324 auth_query_params: Optional[Dict[str, str]] = None, 

5325 ): 

5326 """Connect to an MCP server running with SSE transport. 

5327 

5328 Args: 

5329 server_url: The URL of the SSE MCP server to connect to. 

5330 authentication: Optional dictionary containing authentication headers. 

5331 ca_certificate: Optional CA certificate for SSL verification. 

5332 include_prompts: Whether to fetch prompts from the server. 

5333 include_resources: Whether to fetch resources from the server. 

5334 auth_query_params: Query param names for URL sanitization in error logs. 

5335 

5336 Returns: 

5337 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5338 """ 

5339 if authentication is None: 

5340 authentication = {} 

5341 

5342 def get_httpx_client_factory( 

5343 headers: dict[str, str] | None = None, 

5344 timeout: httpx.Timeout | None = None, 

5345 auth: httpx.Auth | None = None, 

5346 ) -> httpx.AsyncClient: 

5347 """Factory function to create httpx.AsyncClient with optional CA certificate. 

5348 

5349 Args: 

5350 headers: Optional headers for the client 

5351 timeout: Optional timeout for the client 

5352 auth: Optional auth for the client 

5353 

5354 Returns: 

5355 httpx.AsyncClient: Configured HTTPX async client 

5356 """ 

5357 if ca_certificate: 

5358 ctx = self.create_ssl_context(ca_certificate) 

5359 else: 

5360 ctx = None 

5361 return httpx.AsyncClient( 

5362 verify=ctx if ctx else get_default_verify(), 

5363 follow_redirects=True, 

5364 headers=headers, 

5365 timeout=timeout if timeout else get_http_timeout(), 

5366 auth=auth, 

5367 limits=httpx.Limits( 

5368 max_connections=settings.httpx_max_connections, 

5369 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

5370 keepalive_expiry=settings.httpx_keepalive_expiry, 

5371 ), 

5372 ) 

5373 

5374 # Use async with for both sse_client and ClientSession 

5375 async with sse_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as streams: 

5376 async with ClientSession(*streams) as session: 

5377 # Initialize the session 

5378 response = await session.initialize() 

5379 

5380 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5381 logger.debug(f"Server capabilities: {capabilities}") 

5382 

5383 response = await session.list_tools() 

5384 tools = response.tools 

5385 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5386 

5387 tools, _ = self._validate_tools(tools) 

5388 if tools: 

5389 logger.info(f"Fetched {len(tools)} tools from gateway") 

5390 # Fetch resources if supported 

5391 resources = [] 

5392 if include_resources: 5392 ↛ 5446line 5392 didn't jump to line 5446 because the condition on line 5392 was always true

5393 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5394 if capabilities.get("resources"): 

5395 try: 

5396 response = await session.list_resources() 

5397 raw_resources = response.resources 

5398 for resource in raw_resources: 

5399 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5400 # Convert AnyUrl to string if present 

5401 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 

5402 resource_data["uri"] = str(resource_data["uri"]) 

5403 # Add default content if not present (will be fetched on demand) 

5404 if "content" not in resource_data: 5404 ↛ 5406line 5404 didn't jump to line 5406 because the condition on line 5404 was always true

5405 resource_data["content"] = "" 

5406 try: 

5407 resources.append(ResourceCreate.model_validate(resource_data)) 

5408 except Exception: 

5409 # If validation fails, create minimal resource 

5410 resources.append( 

5411 ResourceCreate( 

5412 uri=str(resource_data.get("uri", "")), 

5413 name=resource_data.get("name", ""), 

5414 description=resource_data.get("description"), 

5415 mime_type=resource_data.get("mimeType"), 

5416 uri_template=resource_data.get("uriTemplate") or None, 

5417 content="", 

5418 ) 

5419 ) 

5420 logger.info(f"Fetched {len(resources)} resources from gateway") 

5421 except Exception as e: 

5422 logger.warning(f"Failed to fetch resources: {e}") 

5423 

5424 # resource template URI 

5425 try: 

5426 response_templates = await session.list_resource_templates() 

5427 raw_resources_templates = response_templates.resourceTemplates 

5428 resource_templates = [] 

5429 for resource_template in raw_resources_templates: 

5430 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5431 

5432 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 5432 ↛ 5436line 5432 didn't jump to line 5436 because the condition on line 5432 was always true

5433 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5434 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5435 

5436 if "content" not in resource_template_data: 5436 ↛ 5439line 5436 didn't jump to line 5439 because the condition on line 5436 was always true

5437 resource_template_data["content"] = "" 

5438 

5439 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5440 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5441 logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway") 

5442 except Exception as ei: 

5443 logger.warning(f"Failed to fetch resource templates: {ei}") 

5444 

5445 # Fetch prompts if supported 

5446 prompts = [] 

5447 if include_prompts: 5447 ↛ 5473line 5447 didn't jump to line 5473 because the condition on line 5447 was always true

5448 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5449 if capabilities.get("prompts"): 

5450 try: 

5451 response = await session.list_prompts() 

5452 raw_prompts = response.prompts 

5453 for prompt in raw_prompts: 

5454 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5455 # Add default template if not present 

5456 if "template" not in prompt_data: 

5457 prompt_data["template"] = "" 

5458 try: 

5459 prompts.append(PromptCreate.model_validate(prompt_data)) 

5460 except Exception: 

5461 # If validation fails, create minimal prompt 

5462 prompts.append( 

5463 PromptCreate( 

5464 name=prompt_data.get("name", ""), 

5465 description=prompt_data.get("description"), 

5466 template=prompt_data.get("template", ""), 

5467 ) 

5468 ) 

5469 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5470 except Exception as e: 

5471 logger.warning(f"Failed to fetch prompts: {e}") 

5472 

5473 return capabilities, tools, resources, prompts 

5474 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params) 

5475 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established") 

5476 

5477 async def connect_to_streamablehttp_server( 

5478 self, 

5479 server_url: str, 

5480 authentication: Optional[Dict[str, str]] = None, 

5481 ca_certificate: Optional[bytes] = None, 

5482 include_prompts: bool = True, 

5483 include_resources: bool = True, 

5484 auth_query_params: Optional[Dict[str, str]] = None, 

5485 ): 

5486 """Connect to an MCP server running with Streamable HTTP transport. 

5487 

5488 Args: 

5489 server_url: The URL of the Streamable HTTP MCP server to connect to. 

5490 authentication: Optional dictionary containing authentication headers. 

5491 ca_certificate: Optional CA certificate for SSL verification. 

5492 include_prompts: Whether to fetch prompts from the server. 

5493 include_resources: Whether to fetch resources from the server. 

5494 auth_query_params: Query param names for URL sanitization in error logs. 

5495 

5496 Returns: 

5497 Tuple containing (capabilities, tools, resources, prompts) from the MCP server. 

5498 """ 

5499 if authentication is None: 

5500 authentication = {} 

5501 

5502 # Use authentication directly instead 

5503 def get_httpx_client_factory( 

5504 headers: dict[str, str] | None = None, 

5505 timeout: httpx.Timeout | None = None, 

5506 auth: httpx.Auth | None = None, 

5507 ) -> httpx.AsyncClient: 

5508 """Factory function to create httpx.AsyncClient with optional CA certificate. 

5509 

5510 Args: 

5511 headers: Optional headers for the client 

5512 timeout: Optional timeout for the client 

5513 auth: Optional auth for the client 

5514 

5515 Returns: 

5516 httpx.AsyncClient: Configured HTTPX async client 

5517 """ 

5518 if ca_certificate: 5518 ↛ 5521line 5518 didn't jump to line 5521 because the condition on line 5518 was always true

5519 ctx = self.create_ssl_context(ca_certificate) 

5520 else: 

5521 ctx = None 

5522 return httpx.AsyncClient( 

5523 verify=ctx if ctx else get_default_verify(), 

5524 follow_redirects=True, 

5525 headers=headers, 

5526 timeout=timeout if timeout else get_http_timeout(), 

5527 auth=auth, 

5528 limits=httpx.Limits( 

5529 max_connections=settings.httpx_max_connections, 

5530 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

5531 keepalive_expiry=settings.httpx_keepalive_expiry, 

5532 ), 

5533 ) 

5534 

5535 async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): 

5536 async with ClientSession(read_stream, write_stream) as session: 

5537 # Initialize the session 

5538 response = await session.initialize() 

5539 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

5540 logger.debug(f"Server capabilities: {capabilities}") 

5541 

5542 response = await session.list_tools() 

5543 tools = response.tools 

5544 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools] 

5545 

5546 tools, _ = self._validate_tools(tools) 

5547 for tool in tools: 

5548 tool.request_type = "STREAMABLEHTTP" 

5549 if tools: 5549 ↛ 5553line 5549 didn't jump to line 5553 because the condition on line 5549 was always true

5550 logger.info(f"Fetched {len(tools)} tools from gateway") 

5551 

5552 # Fetch resources if supported 

5553 resources = [] 

5554 if include_resources: 5554 ↛ 5608line 5554 didn't jump to line 5608 because the condition on line 5554 was always true

5555 logger.debug(f"Checking for resources support: {capabilities.get('resources')}") 

5556 if capabilities.get("resources"): 

5557 try: 

5558 response = await session.list_resources() 

5559 raw_resources = response.resources 

5560 for resource in raw_resources: 

5561 resource_data = resource.model_dump(by_alias=True, exclude_none=True) 

5562 # Convert AnyUrl to string if present 

5563 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 5563 ↛ 5566line 5563 didn't jump to line 5566 because the condition on line 5563 was always true

5564 resource_data["uri"] = str(resource_data["uri"]) 

5565 # Add default content if not present 

5566 if "content" not in resource_data: 5566 ↛ 5568line 5566 didn't jump to line 5568 because the condition on line 5566 was always true

5567 resource_data["content"] = "" 

5568 try: 

5569 resources.append(ResourceCreate.model_validate(resource_data)) 

5570 except Exception: 

5571 # If validation fails, create minimal resource 

5572 resources.append( 

5573 ResourceCreate( 

5574 uri=str(resource_data.get("uri", "")), 

5575 name=resource_data.get("name", ""), 

5576 description=resource_data.get("description"), 

5577 mime_type=resource_data.get("mimeType"), 

5578 uri_template=resource_data.get("uriTemplate") or None, 

5579 content="", 

5580 ) 

5581 ) 

5582 logger.info(f"Fetched {len(resources)} resources from gateway") 

5583 except Exception as e: 

5584 logger.warning(f"Failed to fetch resources: {e}") 

5585 

5586 # resource template URI 

5587 try: 

5588 response_templates = await session.list_resource_templates() 

5589 raw_resources_templates = response_templates.resourceTemplates 

5590 resource_templates = [] 

5591 for resource_template in raw_resources_templates: 

5592 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) 

5593 

5594 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 5594 ↛ 5598line 5594 didn't jump to line 5598 because the condition on line 5594 was always true

5595 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) 

5596 resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) 

5597 

5598 if "content" not in resource_template_data: 5598 ↛ 5601line 5598 didn't jump to line 5601 because the condition on line 5598 was always true

5599 resource_template_data["content"] = "" 

5600 

5601 resources.append(ResourceCreate.model_validate(resource_template_data)) 

5602 resource_templates.append(ResourceCreate.model_validate(resource_template_data)) 

5603 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") 

5604 except Exception as e: 

5605 logger.warning(f"Failed to fetch resource templates: {e}") 

5606 

5607 # Fetch prompts if supported 

5608 prompts = [] 

5609 if include_prompts: 5609 ↛ 5625line 5609 didn't jump to line 5625 because the condition on line 5609 was always true

5610 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") 

5611 if capabilities.get("prompts"): 

5612 try: 

5613 response = await session.list_prompts() 

5614 raw_prompts = response.prompts 

5615 for prompt in raw_prompts: 

5616 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) 

5617 # Add default template if not present 

5618 if "template" not in prompt_data: 5618 ↛ 5619line 5618 didn't jump to line 5619 because the condition on line 5618 was never true

5619 prompt_data["template"] = "" 

5620 prompts.append(PromptCreate.model_validate(prompt_data)) 

5621 logger.info(f"Fetched {len(prompts)} prompts from gateway") 

5622 except Exception as e: 

5623 logger.warning(f"Failed to fetch prompts: {e}") 

5624 

5625 return capabilities, tools, resources, prompts 

5626 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params) 

5627 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established") 

5628 

5629 

5630# Lazy singleton - created on first access, not at module import time. 

5631# This avoids instantiation when only exception classes are imported. 

5632_gateway_service_instance = None # pylint: disable=invalid-name 

5633 

5634 

5635def __getattr__(name: str): 

5636 """Module-level __getattr__ for lazy singleton creation. 

5637 

5638 Args: 

5639 name: The attribute name being accessed. 

5640 

5641 Returns: 

5642 The gateway_service singleton instance if name is "gateway_service". 

5643 

5644 Raises: 

5645 AttributeError: If the attribute name is not "gateway_service". 

5646 """ 

5647 global _gateway_service_instance # pylint: disable=global-statement 

5648 if name == "gateway_service": 

5649 if _gateway_service_instance is None: 

5650 _gateway_service_instance = GatewayService() 

5651 return _gateway_service_instance 

5652 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")