Coverage for mcpgateway / services / gateway_service.py: 93%
2074 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2# pylint: disable=import-outside-toplevel,no-name-in-module
3"""Location: ./mcpgateway/services/gateway_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Gateway Service Implementation.
9This module implements gateway federation according to the MCP specification.
10It handles:
11- Gateway discovery and registration
12- Capability aggregation
13- Health monitoring
14- Active/inactive gateway management
16Examples:
17 >>> from mcpgateway.services.gateway_service import GatewayService, GatewayError
18 >>> service = GatewayService()
19 >>> isinstance(service, GatewayService)
20 True
21 >>> hasattr(service, '_active_gateways')
22 True
23 >>> isinstance(service._active_gateways, set)
24 True
26 Test error classes:
27 >>> error = GatewayError("Test error")
28 >>> str(error)
29 'Test error'
30 >>> isinstance(error, Exception)
31 True
33 >>> conflict_error = GatewayNameConflictError("test_gw")
34 >>> "test_gw" in str(conflict_error)
35 True
36 >>> conflict_error.enabled
37 True
38 >>>
39 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
40 >>> import asyncio
41 >>> asyncio.run(service._http_client.aclose())
42"""
44# Standard
45import asyncio
46import binascii
47from datetime import datetime, timezone
48import logging
49import mimetypes
50import os
51import ssl
52import tempfile
53import time
54from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union
55from urllib.parse import urlparse, urlunparse
56import uuid
58# Third-Party
59from filelock import FileLock, Timeout
60import httpx
61from mcp import ClientSession
62from mcp.client.sse import sse_client
63from mcp.client.streamable_http import streamablehttp_client
64from pydantic import ValidationError
65from sqlalchemy import and_, delete, desc, or_, select, update
66from sqlalchemy.exc import IntegrityError
67from sqlalchemy.orm import joinedload, selectinload, Session
69try:
70 # Third-Party - check if redis is available
71 # Third-Party
72 import redis.asyncio as _aioredis # noqa: F401 # pylint: disable=unused-import
74 REDIS_AVAILABLE = True
75 del _aioredis # Only needed for availability check
76except ImportError:
77 REDIS_AVAILABLE = False
78 logging.info("Redis is not utilized in this environment.")
80# First-Party
81from mcpgateway.config import settings
82from mcpgateway.db import fresh_db_session
83from mcpgateway.db import Gateway as DbGateway
84from mcpgateway.db import get_for_update
85from mcpgateway.db import Prompt as DbPrompt
86from mcpgateway.db import PromptMetric
87from mcpgateway.db import Resource as DbResource
88from mcpgateway.db import ResourceMetric, ResourceSubscription, server_prompt_association, server_resource_association, server_tool_association, SessionLocal
89from mcpgateway.db import Tool as DbTool
90from mcpgateway.db import ToolMetric
91from mcpgateway.observability import create_span
92from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate
94# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks
95from mcpgateway.services.audit_trail_service import get_audit_trail_service
96from mcpgateway.services.base_service import BaseService
97from mcpgateway.services.encryption_service import protect_oauth_config_for_storage
98from mcpgateway.services.event_service import EventService
99from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout, get_isolated_http_client
100from mcpgateway.services.logging_service import LoggingService
101from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, register_gateway_capabilities_for_notifications, TransportType
102from mcpgateway.services.oauth_manager import OAuthManager
103from mcpgateway.services.structured_logger import get_structured_logger
104from mcpgateway.services.team_management_service import TeamManagementService
105from mcpgateway.utils.create_slug import slugify
106from mcpgateway.utils.display_name import generate_display_name
107from mcpgateway.utils.pagination import unified_paginate
108from mcpgateway.utils.passthrough_headers import get_passthrough_headers
109from mcpgateway.utils.redis_client import get_redis_client
110from mcpgateway.utils.retry_manager import ResilientHttpClient
111from mcpgateway.utils.services_auth import decode_auth, encode_auth
112from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
113from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context
114from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging
115from mcpgateway.utils.validate_signature import validate_signature
116from mcpgateway.validation.tags import validate_tags_field
118# Cache import (lazy to avoid circular dependencies)
119_REGISTRY_CACHE = None
120_TOOL_LOOKUP_CACHE = None
123def _get_registry_cache():
124 """Get registry cache singleton lazily.
126 Returns:
127 RegistryCache instance.
128 """
129 global _REGISTRY_CACHE # pylint: disable=global-statement
130 if _REGISTRY_CACHE is None:
131 # First-Party
132 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
134 _REGISTRY_CACHE = registry_cache
135 return _REGISTRY_CACHE
138def _get_tool_lookup_cache():
139 """Get tool lookup cache singleton lazily.
141 Returns:
142 ToolLookupCache instance.
143 """
144 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement
145 if _TOOL_LOOKUP_CACHE is None:
146 # First-Party
147 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
149 _TOOL_LOOKUP_CACHE = tool_lookup_cache
150 return _TOOL_LOOKUP_CACHE
153# Initialize logging service first
154logging_service = LoggingService()
155logger = logging_service.get_logger(__name__)
157# Initialize structured logger and audit trail for gateway operations
158structured_logger = get_structured_logger("gateway_service")
159audit_trail = get_audit_trail_service()
162GW_FAILURE_THRESHOLD = settings.unhealthy_threshold
163GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval
166class GatewayError(Exception):
167 """Base class for gateway-related errors.
169 Examples:
170 >>> error = GatewayError("Test error")
171 >>> str(error)
172 'Test error'
173 >>> isinstance(error, Exception)
174 True
175 """
178class GatewayNotFoundError(GatewayError):
179 """Raised when a requested gateway is not found.
181 Examples:
182 >>> error = GatewayNotFoundError("Gateway not found")
183 >>> str(error)
184 'Gateway not found'
185 >>> isinstance(error, GatewayError)
186 True
187 """
190class GatewayNameConflictError(GatewayError):
191 """Raised when a gateway name conflicts with existing (active or inactive) gateway.
193 Args:
194 name: The conflicting gateway name
195 enabled: Whether the existing gateway is enabled
196 gateway_id: ID of the existing gateway if available
197 visibility: The visibility of the gateway ("public" or "team").
199 Examples:
200 >>> error = GatewayNameConflictError("test_gateway")
201 >>> str(error)
202 'Public Gateway already exists with name: test_gateway'
203 >>> error.name
204 'test_gateway'
205 >>> error.enabled
206 True
207 >>> error.gateway_id is None
208 True
210 >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123)
211 >>> str(error_inactive)
212 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)'
213 >>> error_inactive.enabled
214 False
215 >>> error_inactive.gateway_id
216 123
217 """
219 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"):
220 """Initialize the error with gateway information.
222 Args:
223 name: The conflicting gateway name
224 enabled: Whether the existing gateway is enabled
225 gateway_id: ID of the existing gateway if available
226 visibility: The visibility of the gateway ("public" or "team").
227 """
228 self.name = name
229 self.enabled = enabled
230 self.gateway_id = gateway_id
231 if visibility == "team":
232 vis_label = "Team-level"
233 else:
234 vis_label = "Public"
235 message = f"{vis_label} Gateway already exists with name: {name}"
236 if not enabled:
237 message += f" (currently inactive, ID: {gateway_id})"
238 super().__init__(message)
241class GatewayDuplicateConflictError(GatewayError):
242 """Raised when a gateway conflicts with an existing gateway (same URL + credentials).
244 This error is raised when attempting to register a gateway with a URL and
245 authentication credentials that already exist within the same scope:
246 - Public: Global uniqueness required across all public gateways.
247 - Team: Uniqueness required within the same team.
248 - Private: Uniqueness required for the same user, a user cannot have two private gateways with the same URL and credentials.
250 Args:
251 duplicate_gateway: The existing conflicting gateway (DbGateway instance).
253 Examples:
254 >>> # Public gateway conflict with the same URL and basic auth
255 >>> existing_gw = DbGateway(url="https://api.example.com", id="abc-123", enabled=True, visibility="public", team_id=None, name="API Gateway", owner_email="alice@example.com")
256 >>> error = GatewayDuplicateConflictError(
257 ... duplicate_gateway=existing_gw
258 ... )
259 >>> str(error)
260 'The Server already exists in Public scope (Name: API Gateway, Status: active)'
262 >>> # Team gateway conflict with the same URL and OAuth credentials
263 >>> team_gw = DbGateway(url="https://api.example.com", id="def-456", enabled=False, visibility="team", team_id="engineering-team", name="API Gateway", owner_email="bob@example.com")
264 >>> error = GatewayDuplicateConflictError(
265 ... duplicate_gateway=team_gw
266 ... )
267 >>> str(error)
268 'The Server already exists in your Team (Name: API Gateway, Status: inactive). You may want to re-enable the existing gateway instead.'
270 >>> # Private gateway conflict (same user cannot have two gateways with the same URL)
271 >>> private_gw = DbGateway(url="https://api.example.com", id="ghi-789", enabled=True, visibility="private", team_id="none", name="API Gateway", owner_email="charlie@example.com")
272 >>> error = GatewayDuplicateConflictError(
273 ... duplicate_gateway=private_gw
274 ... )
275 >>> str(error)
276 'The Server already exists in "private" scope (Name: API Gateway, Status: active)'
277 """
279 def __init__(
280 self,
281 duplicate_gateway: "DbGateway",
282 ):
283 """Initialize the error with gateway information.
285 Args:
286 duplicate_gateway: The existing conflicting gateway (DbGateway instance)
287 """
288 self.duplicate_gateway = duplicate_gateway
289 self.url = duplicate_gateway.url
290 self.gateway_id = duplicate_gateway.id
291 self.enabled = duplicate_gateway.enabled
292 self.visibility = duplicate_gateway.visibility
293 self.team_id = duplicate_gateway.team_id
294 self.name = duplicate_gateway.name
296 # Build scope description
297 if self.visibility == "public":
298 scope_desc = "Public scope"
299 elif self.visibility == "team" and self.team_id:
300 scope_desc = "your Team"
301 else:
302 scope_desc = f'"{self.visibility}" scope'
304 # Build status description
305 status = "active" if self.enabled else "inactive"
307 # Construct error message
308 message = f"The Server already exists in {scope_desc} " f"(Name: {self.name}, Status: {status})"
310 # Add helpful hint for inactive gateways
311 if not self.enabled:
312 message += ". You may want to re-enable the existing gateway instead."
314 super().__init__(message)
317class GatewayConnectionError(GatewayError):
318 """Raised when gateway connection fails.
320 Examples:
321 >>> error = GatewayConnectionError("Connection failed")
322 >>> str(error)
323 'Connection failed'
324 >>> isinstance(error, GatewayError)
325 True
326 """
329class OAuthToolValidationError(GatewayConnectionError):
330 """Raised when tool validation fails during OAuth-driven fetch."""
333class GatewayService(BaseService): # pylint: disable=too-many-instance-attributes
334 """Service for managing federated gateways.
336 Handles:
337 - Gateway registration and health checks
338 - Capability negotiation
339 - Federation events
340 - Active/inactive status management
341 """
343 _visibility_model_cls = DbGateway
345 def __init__(self) -> None:
346 """Initialize the gateway service.
348 Examples:
349 >>> from mcpgateway.services.gateway_service import GatewayService
350 >>> from mcpgateway.services.event_service import EventService
351 >>> from mcpgateway.utils.retry_manager import ResilientHttpClient
352 >>> from mcpgateway.services.tool_service import ToolService
353 >>> service = GatewayService()
354 >>> isinstance(service._event_service, EventService)
355 True
356 >>> isinstance(service._http_client, ResilientHttpClient)
357 True
358 >>> service._health_check_interval == GW_HEALTH_CHECK_INTERVAL
359 True
360 >>> service._health_check_task is None
361 True
362 >>> isinstance(service._active_gateways, set)
363 True
364 >>> len(service._active_gateways)
365 0
366 >>> service._stream_response is None
367 True
368 >>> isinstance(service._pending_responses, dict)
369 True
370 >>> len(service._pending_responses)
371 0
372 >>> isinstance(service.tool_service, ToolService)
373 True
374 >>> isinstance(service._gateway_failure_counts, dict)
375 True
376 >>> len(service._gateway_failure_counts)
377 0
378 >>> hasattr(service, 'redis_url')
379 True
380 >>>
381 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
382 >>> import asyncio
383 >>> asyncio.run(service._http_client.aclose())
384 """
385 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
386 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL
387 self._health_check_task: Optional[asyncio.Task] = None
388 self._active_gateways: Set[str] = set() # Track active gateway URLs
389 self._stream_response = None
390 self._pending_responses = {}
391 # Prefer using the globally-initialized singletons from the service modules
392 # so events propagate via their initialized EventService/Redis clients.
393 # Import lazily and fall back to creating local instances when the module-level
394 # __getattr__ singletons are not yet available (e.g. circular import during
395 # Gunicorn --preload).
396 # First-Party
397 try:
398 # First-Party
399 from mcpgateway.services.prompt_service import prompt_service
400 except ImportError:
401 # First-Party
402 from mcpgateway.services.prompt_service import PromptService
404 prompt_service = PromptService()
405 try:
406 # First-Party
407 from mcpgateway.services.resource_service import resource_service
408 except ImportError:
409 # First-Party
410 from mcpgateway.services.resource_service import ResourceService
412 resource_service = ResourceService()
413 try:
414 # First-Party
415 from mcpgateway.services.tool_service import tool_service
416 except ImportError:
417 # First-Party
418 from mcpgateway.services.tool_service import ToolService
420 tool_service = ToolService()
422 self.tool_service = tool_service
423 self.prompt_service = prompt_service
424 self.resource_service = resource_service
425 self._gateway_failure_counts: dict[str, int] = {}
426 self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3")))
427 self._event_service = EventService(channel_name="mcpgateway:gateway_events")
429 # Per-gateway refresh locks to prevent concurrent refreshes for the same gateway
430 self._refresh_locks: Dict[str, asyncio.Lock] = {}
432 # For health checks, we determine the leader instance.
433 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None
435 # Initialize optional Redis client holder (set in initialize())
436 self._redis_client: Optional[Any] = None
438 # Leader election settings from config
439 if self.redis_url and REDIS_AVAILABLE:
440 self._instance_id = str(uuid.uuid4()) # Unique ID for this process
441 self._leader_key = settings.redis_leader_key
442 self._leader_ttl = settings.redis_leader_ttl
443 self._leader_heartbeat_interval = settings.redis_leader_heartbeat_interval
444 self._leader_heartbeat_task: Optional[asyncio.Task] = None
446 # Always initialize file lock as fallback (used if Redis connection fails at runtime)
447 if settings.cache_type != "none":
448 temp_dir = tempfile.gettempdir()
449 user_path = os.path.normpath(settings.filelock_name)
450 if os.path.isabs(user_path):
451 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep)
452 full_path = os.path.join(temp_dir, user_path)
453 self._lock_path = full_path.replace("\\", "/")
454 self._file_lock = FileLock(self._lock_path)
456 @staticmethod
457 def normalize_url(url: str) -> str:
458 """
459 Normalize a URL by ensuring it's properly formatted.
461 Special handling for localhost to prevent duplicates:
462 - Converts 127.0.0.1 to localhost for consistency
463 - Preserves all other domain names as-is for CDN/load balancer support
465 Args:
466 url (str): The URL to normalize.
468 Returns:
469 str: The normalized URL.
471 Examples:
472 >>> GatewayService.normalize_url('http://localhost:8080/path')
473 'http://localhost:8080/path'
474 >>> GatewayService.normalize_url('http://127.0.0.1:8080/path')
475 'http://localhost:8080/path'
476 >>> GatewayService.normalize_url('https://example.com/api')
477 'https://example.com/api'
478 """
479 parsed = urlparse(url)
480 hostname = parsed.hostname
482 # Special case: normalize 127.0.0.1 to localhost to prevent duplicates
483 # but preserve all other domains as-is for CDN/load balancer support
484 if hostname == "127.0.0.1":
485 netloc = "localhost"
486 if parsed.port:
487 netloc += f":{parsed.port}"
488 normalized = parsed._replace(netloc=netloc)
489 return str(urlunparse(normalized))
491 # For all other URLs, preserve the domain name
492 return url
494 def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext:
495 """Create an SSL context with the provided CA certificate.
497 Uses caching to avoid repeated SSL context creation for the same certificate.
499 Args:
500 ca_certificate: CA certificate in PEM format
502 Returns:
503 ssl.SSLContext: Configured SSL context
504 """
505 return get_cached_ssl_context(ca_certificate)
507 async def initialize(self) -> None:
508 """Initialize the service and start health check if this instance is the leader.
510 Raises:
511 ConnectionError: When redis ping fails
512 """
513 logger.info("Initializing gateway service")
515 # Initialize event service with shared Redis client
516 await self._event_service.initialize()
518 # NOTE: We intentionally do NOT create a long-lived DB session here.
519 # Health checks use fresh_db_session() only when DB access is actually needed,
520 # avoiding holding connections during HTTP calls to MCP servers.
522 user_email = settings.platform_admin_email
524 # Get shared Redis client from factory
525 if self.redis_url and REDIS_AVAILABLE:
526 self._redis_client = await get_redis_client()
528 if self._redis_client:
529 # Check if Redis is available (ping already done by factory, but verify)
530 try:
531 await self._redis_client.ping()
532 except Exception as e:
533 raise ConnectionError(f"Redis ping failed: {e}") from e
535 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
536 if is_leader:
537 logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.")
538 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
539 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat())
540 else:
541 # Always create the health check task in filelock mode; leader check is handled inside.
542 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
544 async def shutdown(self) -> None:
545 """Shutdown the service.
547 Examples:
548 >>> service = GatewayService()
549 >>> # Mock internal components
550 >>> from unittest.mock import AsyncMock
551 >>> service._event_service = AsyncMock()
552 >>> service._active_gateways = {'test_gw'}
553 >>> import asyncio
554 >>> asyncio.run(service.shutdown())
555 >>> # Verify event service shutdown was called
556 >>> service._event_service.shutdown.assert_awaited_once()
557 >>> len(service._active_gateways)
558 0
559 """
560 if self._health_check_task:
561 self._health_check_task.cancel()
562 try:
563 await self._health_check_task
564 except asyncio.CancelledError:
565 pass
567 # Cancel leader heartbeat task if running
568 if getattr(self, "_leader_heartbeat_task", None):
569 self._leader_heartbeat_task.cancel()
570 try:
571 await self._leader_heartbeat_task
572 except asyncio.CancelledError:
573 pass
575 # Release Redis leadership atomically if we hold it
576 if self._redis_client:
577 try:
578 # Lua script for atomic check-and-delete (only delete if we own the key)
579 release_script = """
580 if redis.call("get", KEYS[1]) == ARGV[1] then
581 return redis.call("del", KEYS[1])
582 else
583 return 0
584 end
585 """
586 result = await self._redis_client.eval(release_script, 1, self._leader_key, self._instance_id)
587 if result:
588 logger.info("Released Redis leadership on shutdown")
589 except Exception as e:
590 logger.warning(f"Failed to release Redis leader key on shutdown: {e}")
592 await self._http_client.aclose()
593 await self._event_service.shutdown()
594 self._active_gateways.clear()
595 logger.info("Gateway service shutdown complete")
597 def _check_gateway_uniqueness(
598 self,
599 db: Session,
600 url: str,
601 auth_value: Optional[Dict[str, str]],
602 oauth_config: Optional[Dict[str, Any]],
603 team_id: Optional[str],
604 owner_email: str,
605 visibility: str,
606 gateway_id: Optional[str] = None,
607 ) -> Optional[DbGateway]:
608 """
609 Check if a gateway with the same URL and credentials already exists.
611 Args:
612 db: Database session
613 url: Gateway URL (normalized)
614 auth_value: Decoded auth_value dict (not encrypted)
615 oauth_config: OAuth configuration dict
616 team_id: Team ID for team-scoped gateways
617 owner_email: Email of the gateway owner
618 visibility: Gateway visibility (public/team/private)
619 gateway_id: Optional gateway ID to exclude from check (for updates)
621 Returns:
622 DbGateway if duplicate found, None otherwise
623 """
624 # Build base query based on visibility
625 if visibility == "public":
626 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "public")
627 elif visibility == "team" and team_id:
628 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "team", DbGateway.team_id == team_id)
629 elif visibility == "private":
630 # Check for duplicates within the same user's private gateways
631 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "private", DbGateway.owner_email == owner_email) # Scoped to same user
632 else:
633 return None
635 # Exclude current gateway if updating
636 if gateway_id:
637 query = query.filter(DbGateway.id != gateway_id)
639 existing_gateways = query.all()
641 # Check each existing gateway
642 for existing in existing_gateways:
643 # Case 1: Both have OAuth config
644 if oauth_config and existing.oauth_config:
645 # Compare OAuth configs (exclude dynamic fields like tokens)
646 existing_oauth = existing.oauth_config or {}
647 new_oauth = oauth_config or {}
649 # Compare key OAuth fields
650 oauth_keys = ["grant_type", "client_id", "authorization_url", "token_url", "scope"]
651 if all(existing_oauth.get(k) == new_oauth.get(k) for k in oauth_keys):
652 return existing # Duplicate OAuth config found
654 # Case 2: Both have auth_value (need to decrypt and compare)
655 elif auth_value and existing.auth_value:
657 try:
658 # Decrypt existing auth_value
659 if isinstance(existing.auth_value, str):
660 existing_decoded = decode_auth(existing.auth_value)
662 elif isinstance(existing.auth_value, dict):
663 existing_decoded = existing.auth_value
665 else:
666 continue
668 # Compare decoded auth values
669 if auth_value == existing_decoded:
670 return existing # Duplicate credentials found
671 except Exception as e:
672 logger.warning(f"Failed to decode auth_value for comparison: {e}")
673 continue
675 # Case 3: Both have no auth (URL only, not allowed)
676 elif not auth_value and not oauth_config and not existing.auth_value and not existing.oauth_config:
677 return existing # Duplicate URL without credentials
679 return None # No duplicate found
681 async def register_gateway(
682 self,
683 db: Session,
684 gateway: GatewayCreate,
685 created_by: Optional[str] = None,
686 created_from_ip: Optional[str] = None,
687 created_via: Optional[str] = None,
688 created_user_agent: Optional[str] = None,
689 team_id: Optional[str] = None,
690 owner_email: Optional[str] = None,
691 visibility: Optional[str] = None,
692 initialize_timeout: Optional[float] = None,
693 ) -> GatewayRead:
694 """Register a new gateway.
696 Args:
697 db: Database session
698 gateway: Gateway creation schema
699 created_by: Username who created this gateway
700 created_from_ip: IP address of creator
701 created_via: Creation method (ui, api, federation)
702 created_user_agent: User agent of creation request
703 team_id (Optional[str]): Team ID to assign the gateway to.
704 owner_email (Optional[str]): Email of the user who owns this gateway.
705 visibility (Optional[str]): Gateway visibility level (private, team, public).
706 initialize_timeout (Optional[float]): Timeout in seconds for gateway initialization.
708 Returns:
709 Created gateway information
711 Raises:
712 GatewayNameConflictError: If gateway name already exists
713 GatewayConnectionError: If there was an error connecting to the gateway
714 ValueError: If required values are missing
715 RuntimeError: If there is an error during processing that is not covered by other exceptions
716 IntegrityError: If there is a database integrity error
717 BaseException: If an unexpected error occurs
719 Examples:
720 >>> from mcpgateway.services.gateway_service import GatewayService
721 >>> from unittest.mock import MagicMock
722 >>> service = GatewayService()
723 >>> db = MagicMock()
724 >>> gateway = MagicMock()
725 >>> db.execute.return_value.scalar_one_or_none.return_value = None
726 >>> db.add = MagicMock()
727 >>> db.commit = MagicMock()
728 >>> db.refresh = MagicMock()
729 >>> service._notify_gateway_added = MagicMock()
730 >>> import asyncio
731 >>> try:
732 ... asyncio.run(service.register_gateway(db, gateway))
733 ... except Exception:
734 ... pass
735 >>>
736 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
737 >>> asyncio.run(service._http_client.aclose())
738 """
739 visibility = "public" if visibility not in ("private", "team", "public") else visibility
740 try:
741 # # Check for name conflicts (both active and inactive)
742 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none()
744 # if existing_gateway:
745 # raise GatewayNameConflictError(
746 # gateway.name,
747 # enabled=existing_gateway.enabled,
748 # gateway_id=existing_gateway.id,
749 # )
750 # Check for existing gateway with the same slug and visibility
751 slug_name = slugify(gateway.name)
752 if visibility.lower() == "public":
753 # Check for existing public gateway with the same slug (row-locked)
754 existing_gateway = get_for_update(
755 db,
756 DbGateway,
757 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "public"),
758 )
759 if existing_gateway:
760 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility)
761 elif visibility.lower() == "team" and team_id:
762 # Check for existing team gateway with the same slug (row-locked)
763 existing_gateway = get_for_update(
764 db,
765 DbGateway,
766 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "team", DbGateway.team_id == team_id),
767 )
768 if existing_gateway:
769 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility)
771 # Normalize the gateway URL
772 normalized_url = self.normalize_url(str(gateway.url))
774 decoded_auth_value = None
775 if gateway.auth_value:
776 if isinstance(gateway.auth_value, str):
777 try:
778 decoded_auth_value = decode_auth(gateway.auth_value)
779 except Exception as e:
780 logger.warning(f"Failed to decode provided auth_value: {e}")
781 decoded_auth_value = None
782 elif isinstance(gateway.auth_value, dict):
783 decoded_auth_value = gateway.auth_value
785 # Check for duplicate gateway
786 if not gateway.one_time_auth:
787 duplicate_gateway = self._check_gateway_uniqueness(
788 db=db, url=normalized_url, auth_value=decoded_auth_value, oauth_config=gateway.oauth_config, team_id=team_id, owner_email=owner_email, visibility=visibility
789 )
791 if duplicate_gateway:
792 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway)
794 # Prevent URL-only gateways (no auth at all)
795 # if not decoded_auth_value and not gateway.oauth_config:
796 # raise ValueError(
797 # f"Gateway with URL '{normalized_url}' must have either auth_value or oauth_config. "
798 # "URL-only gateways are not allowed."
799 # )
801 auth_type = getattr(gateway, "auth_type", None)
802 # Support multiple custom headers
803 auth_value = getattr(gateway, "auth_value", {})
804 authentication_headers: Optional[Dict[str, str]] = None
806 # Handle query_param auth - encrypt and prepare for storage
807 auth_query_params_encrypted: Optional[Dict[str, str]] = None
808 auth_query_params_decrypted: Optional[Dict[str, str]] = None
809 init_url = normalized_url # URL to use for initialization
811 if auth_type == "query_param":
812 # Extract and encrypt query param auth
813 param_key = getattr(gateway, "auth_query_param_key", None)
814 param_value = getattr(gateway, "auth_query_param_value", None)
815 if param_key and param_value:
816 # Get the actual secret value
817 if hasattr(param_value, "get_secret_value"):
818 raw_value = param_value.get_secret_value()
819 else:
820 raw_value = str(param_value)
821 # Encrypt for storage
822 encrypted_value = encode_auth({param_key: raw_value})
823 auth_query_params_encrypted = {param_key: encrypted_value}
824 auth_query_params_decrypted = {param_key: raw_value}
825 # Append query params to URL for initialization
826 init_url = apply_query_param_auth(normalized_url, auth_query_params_decrypted)
827 # Query param auth doesn't use auth_value
828 auth_value = None
829 authentication_headers = None
831 elif hasattr(gateway, "auth_headers") and gateway.auth_headers:
832 # Convert list of {key, value} to dict
833 header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
834 auth_value = header_dict # store plain dict, consistent with update path and DB column type
835 authentication_headers = {str(k): str(v) for k, v in header_dict.items()}
837 elif isinstance(auth_value, str) and auth_value:
838 # Decode persisted auth for initialization
839 decoded = decode_auth(auth_value)
840 authentication_headers = {str(k): str(v) for k, v in decoded.items()}
841 else:
842 authentication_headers = None
844 oauth_config = await protect_oauth_config_for_storage(getattr(gateway, "oauth_config", None))
845 ca_certificate = getattr(gateway, "ca_certificate", None)
847 # Check if gateway is in direct_proxy mode
848 gateway_mode = getattr(gateway, "gateway_mode", "cache")
850 if gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled:
851 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.")
853 if initialize_timeout is not None:
854 try:
855 capabilities, tools, resources, prompts = await asyncio.wait_for(
856 self._initialize_gateway(
857 init_url, # URL with query params if applicable
858 authentication_headers,
859 gateway.transport,
860 auth_type,
861 oauth_config,
862 ca_certificate,
863 auth_query_params=auth_query_params_decrypted,
864 ),
865 timeout=initialize_timeout,
866 )
867 except asyncio.TimeoutError as exc:
868 sanitized = sanitize_url_for_logging(init_url, auth_query_params_decrypted)
869 raise GatewayConnectionError(f"Gateway initialization timed out after {initialize_timeout}s for {sanitized}") from exc
870 else:
871 capabilities, tools, resources, prompts = await self._initialize_gateway(
872 init_url, # URL with query params if applicable
873 authentication_headers,
874 gateway.transport,
875 auth_type,
876 oauth_config,
877 ca_certificate,
878 auth_query_params=auth_query_params_decrypted,
879 )
881 if gateway.one_time_auth:
882 # For one-time auth, clear auth_type and auth_value after initialization
883 auth_type = "one_time_auth"
884 auth_value = None
885 oauth_config = None
887 # DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before
888 # storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and
889 # receives the plain dict directly (see assignment above).
890 tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value
892 tools = [
893 DbTool(
894 original_name=tool.name,
895 custom_name=tool.name,
896 custom_name_slug=slugify(tool.name),
897 display_name=generate_display_name(tool.name),
898 url=normalized_url,
899 original_description=tool.description,
900 description=tool.description,
901 integration_type="MCP", # Gateway-discovered tools are MCP type
902 request_type=tool.request_type,
903 headers=tool.headers,
904 input_schema=tool.input_schema,
905 output_schema=tool.output_schema,
906 annotations=tool.annotations,
907 jsonpath_filter=tool.jsonpath_filter,
908 auth_type=auth_type,
909 auth_value=tool_auth_value,
910 # Federation metadata
911 created_by=created_by or "system",
912 created_from_ip=created_from_ip,
913 created_via="federation", # These are federated tools
914 created_user_agent=created_user_agent,
915 federation_source=gateway.name,
916 version=1,
917 # Inherit team assignment from gateway
918 team_id=team_id,
919 owner_email=owner_email,
920 visibility=visibility,
921 )
922 for tool in tools
923 ]
925 # Create resource DB models with upsert logic for ORPHANED resources only
926 # Query for existing ORPHANED resources (gateway_id IS NULL or points to non-existent gateway)
927 # with same (team_id, owner_email, uri) to handle resources left behind from incomplete
928 # gateway deletions (e.g., issue #2341 crash scenarios).
929 # We only update orphaned resources - resources belonging to active gateways are not touched.
930 resource_uris = [r.uri for r in resources]
931 effective_owner = owner_email or created_by
933 # Build lookup map: (team_id, owner_email, uri) -> orphaned DbResource
934 # We query all resources matching our URIs, then filter to orphaned ones in Python
935 # to handle per-resource team/owner overrides correctly
936 orphaned_resources_map: Dict[tuple, DbResource] = {}
937 if resource_uris:
938 try:
939 # Get valid gateway IDs to identify orphaned resources
940 valid_gateway_ids = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all())
941 candidate_resources = db.execute(select(DbResource).where(DbResource.uri.in_(resource_uris))).scalars().all()
942 for res in candidate_resources:
943 # Only consider orphaned resources (no gateway or gateway doesn't exist)
944 is_orphaned = res.gateway_id is None or res.gateway_id not in valid_gateway_ids
945 if is_orphaned:
946 key = (res.team_id, res.owner_email, res.uri)
947 orphaned_resources_map[key] = res
948 if orphaned_resources_map:
949 logger.info(f"Found {len(orphaned_resources_map)} orphaned resources to reassign for gateway {gateway.name}")
950 except Exception as e:
951 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new resources
952 # This is conservative - we won't accidentally reassign resources from active gateways
953 logger.debug(f"Orphan resource detection skipped: {e}")
955 db_resources = []
956 for r in resources:
957 mime_type = mimetypes.guess_type(r.uri)[0] or ("text/plain" if isinstance(r.content, str) else "application/octet-stream")
958 r_team_id = getattr(r, "team_id", None) or team_id
959 r_owner_email = getattr(r, "owner_email", None) or effective_owner
960 r_visibility = getattr(r, "visibility", None) or visibility
962 # Check if there's an orphaned resource with matching unique key
963 lookup_key = (r_team_id, r_owner_email, r.uri)
964 if lookup_key in orphaned_resources_map:
965 # Update orphaned resource - reassign to new gateway
966 existing = orphaned_resources_map[lookup_key]
967 existing.name = r.name
968 existing.description = r.description
969 existing.mime_type = mime_type
970 existing.uri_template = r.uri_template or None
971 existing.text_content = r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None
972 existing.binary_content = (
973 r.content.encode() if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else r.content if isinstance(r.content, bytes) else None
974 )
975 existing.size = len(r.content) if r.content else 0
976 existing.tags = getattr(r, "tags", []) or []
977 existing.federation_source = gateway.name
978 existing.modified_by = created_by
979 existing.modified_from_ip = created_from_ip
980 existing.modified_via = "federation"
981 existing.modified_user_agent = created_user_agent
982 existing.updated_at = datetime.now(timezone.utc)
983 existing.visibility = r_visibility
984 # Note: gateway_id will be set when gateway is created (relationship)
985 db_resources.append(existing)
986 else:
987 # Create new resource
988 db_resources.append(
989 DbResource(
990 uri=r.uri,
991 name=r.name,
992 description=r.description,
993 mime_type=mime_type,
994 uri_template=r.uri_template or None,
995 text_content=r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None,
996 binary_content=(
997 r.content.encode()
998 if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str)
999 else r.content if isinstance(r.content, bytes) else None
1000 ),
1001 size=len(r.content) if r.content else 0,
1002 tags=getattr(r, "tags", []) or [],
1003 created_by=created_by or "system",
1004 created_from_ip=created_from_ip,
1005 created_via="federation",
1006 created_user_agent=created_user_agent,
1007 import_batch_id=None,
1008 federation_source=gateway.name,
1009 version=1,
1010 team_id=r_team_id,
1011 owner_email=r_owner_email,
1012 visibility=r_visibility,
1013 )
1014 )
1016 # Create prompt DB models with upsert logic for ORPHANED prompts only
1017 # Query for existing ORPHANED prompts (gateway_id IS NULL or points to non-existent gateway)
1018 # with same (team_id, owner_email, name) to handle prompts left behind from incomplete
1019 # gateway deletions. We only update orphaned prompts - prompts belonging to active gateways are not touched.
1020 prompt_names = [p.name for p in prompts]
1022 # Build lookup map: (team_id, owner_email, name) -> orphaned DbPrompt
1023 orphaned_prompts_map: Dict[tuple, DbPrompt] = {}
1024 if prompt_names:
1025 try:
1026 # Get valid gateway IDs to identify orphaned prompts
1027 valid_gateway_ids_for_prompts = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all())
1028 candidate_prompts = db.execute(select(DbPrompt).where(DbPrompt.name.in_(prompt_names))).scalars().all()
1029 for pmt in candidate_prompts:
1030 # Only consider orphaned prompts (no gateway or gateway doesn't exist)
1031 is_orphaned = pmt.gateway_id is None or pmt.gateway_id not in valid_gateway_ids_for_prompts
1032 if is_orphaned:
1033 key = (pmt.team_id, pmt.owner_email, pmt.name)
1034 orphaned_prompts_map[key] = pmt
1035 if orphaned_prompts_map:
1036 logger.info(f"Found {len(orphaned_prompts_map)} orphaned prompts to reassign for gateway {gateway.name}")
1037 except Exception as e:
1038 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new prompts
1039 logger.debug(f"Orphan prompt detection skipped: {e}")
1041 db_prompts = []
1042 for prompt in prompts:
1043 # Prompts inherit team/owner from gateway (no per-prompt overrides)
1044 p_team_id = team_id
1045 p_owner_email = owner_email or effective_owner
1047 # Check if there's an orphaned prompt with matching unique key
1048 lookup_key = (p_team_id, p_owner_email, prompt.name)
1049 if lookup_key in orphaned_prompts_map:
1050 # Update orphaned prompt - reassign to new gateway
1051 existing = orphaned_prompts_map[lookup_key]
1052 existing.original_name = prompt.name
1053 existing.custom_name = prompt.name
1054 existing.display_name = prompt.name
1055 existing.description = prompt.description
1056 existing.template = prompt.template if hasattr(prompt, "template") else ""
1057 existing.federation_source = gateway.name
1058 existing.modified_by = created_by
1059 existing.modified_from_ip = created_from_ip
1060 existing.modified_via = "federation"
1061 existing.modified_user_agent = created_user_agent
1062 existing.updated_at = datetime.now(timezone.utc)
1063 existing.visibility = visibility
1064 # Note: gateway_id will be set when gateway is created (relationship)
1065 db_prompts.append(existing)
1066 else:
1067 # Create new prompt
1068 db_prompts.append(
1069 DbPrompt(
1070 name=prompt.name,
1071 original_name=prompt.name,
1072 custom_name=prompt.name,
1073 display_name=prompt.name,
1074 description=prompt.description,
1075 template=prompt.template if hasattr(prompt, "template") else "",
1076 argument_schema={}, # Use argument_schema instead of arguments
1077 # Federation metadata
1078 created_by=created_by or "system",
1079 created_from_ip=created_from_ip,
1080 created_via="federation", # These are federated prompts
1081 created_user_agent=created_user_agent,
1082 federation_source=gateway.name,
1083 version=1,
1084 # Inherit team assignment from gateway
1085 team_id=team_id,
1086 owner_email=owner_email,
1087 visibility=visibility,
1088 )
1089 )
1091 # Create DB model
1092 db_gateway = DbGateway(
1093 name=gateway.name,
1094 slug=slug_name,
1095 url=normalized_url,
1096 description=gateway.description,
1097 tags=gateway.tags or [],
1098 transport=gateway.transport,
1099 capabilities=capabilities,
1100 last_seen=datetime.now(timezone.utc),
1101 auth_type=auth_type,
1102 auth_value=auth_value,
1103 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth
1104 oauth_config=oauth_config,
1105 passthrough_headers=gateway.passthrough_headers,
1106 tools=tools,
1107 resources=db_resources,
1108 prompts=db_prompts,
1109 # Gateway metadata
1110 created_by=created_by,
1111 created_from_ip=created_from_ip,
1112 created_via=created_via or "api",
1113 created_user_agent=created_user_agent,
1114 version=1,
1115 # Team scoping fields
1116 team_id=team_id,
1117 owner_email=owner_email,
1118 visibility=visibility,
1119 ca_certificate=gateway.ca_certificate,
1120 ca_certificate_sig=gateway.ca_certificate_sig,
1121 signing_algorithm=gateway.signing_algorithm,
1122 # Gateway mode configuration
1123 gateway_mode=gateway_mode,
1124 )
1126 # Add to DB
1127 db.add(db_gateway)
1128 db.flush() # Flush to get the ID without committing
1129 db.refresh(db_gateway)
1131 # Update tracking
1132 self._active_gateways.add(db_gateway.url)
1134 # Notify subscribers
1135 await self._notify_gateway_added(db_gateway)
1137 logger.info(f"Registered gateway: {gateway.name}")
1139 # Structured logging: Audit trail for gateway creation
1140 audit_trail.log_action(
1141 user_id=created_by or "system",
1142 action="create_gateway",
1143 resource_type="gateway",
1144 resource_id=str(db_gateway.id),
1145 resource_name=db_gateway.name,
1146 user_email=owner_email,
1147 team_id=team_id,
1148 client_ip=created_from_ip,
1149 user_agent=created_user_agent,
1150 new_values={
1151 "name": db_gateway.name,
1152 "url": db_gateway.url,
1153 "visibility": visibility,
1154 "transport": db_gateway.transport,
1155 "tools_count": len(tools),
1156 "resources_count": len(db_resources),
1157 "prompts_count": len(db_prompts),
1158 },
1159 context={
1160 "created_via": created_via,
1161 },
1162 db=db,
1163 )
1165 # Structured logging: Log successful gateway creation
1166 structured_logger.log(
1167 level="INFO",
1168 message="Gateway created successfully",
1169 event_type="gateway_created",
1170 component="gateway_service",
1171 user_id=created_by,
1172 user_email=owner_email,
1173 team_id=team_id,
1174 resource_type="gateway",
1175 resource_id=str(db_gateway.id),
1176 custom_fields={
1177 "gateway_name": db_gateway.name,
1178 "gateway_url": normalized_url,
1179 "visibility": visibility,
1180 "transport": db_gateway.transport,
1181 },
1182 )
1184 return GatewayRead.model_validate(self._prepare_gateway_for_read(db_gateway)).masked()
1185 except* GatewayConnectionError as ge: # pragma: no mutate
1186 if TYPE_CHECKING:
1187 ge: ExceptionGroup[GatewayConnectionError]
1188 logger.error(f"GatewayConnectionError in group: {ge.exceptions}")
1189 db.rollback()
1191 structured_logger.log(
1192 level="ERROR",
1193 message="Gateway creation failed due to connection error",
1194 event_type="gateway_creation_failed",
1195 component="gateway_service",
1196 user_id=created_by,
1197 user_email=owner_email,
1198 error=ge.exceptions[0],
1199 custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)},
1200 )
1201 raise ge.exceptions[0]
1202 except* GatewayNameConflictError as gnce: # pragma: no mutate
1203 if TYPE_CHECKING:
1204 gnce: ExceptionGroup[GatewayNameConflictError]
1205 logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}")
1206 db.rollback()
1208 structured_logger.log(
1209 level="WARNING",
1210 message="Gateway creation failed due to name conflict",
1211 event_type="gateway_name_conflict",
1212 component="gateway_service",
1213 user_id=created_by,
1214 user_email=owner_email,
1215 custom_fields={"gateway_name": gateway.name, "visibility": visibility},
1216 )
1217 raise gnce.exceptions[0]
1218 except* GatewayDuplicateConflictError as guce: # pragma: no mutate
1219 if TYPE_CHECKING:
1220 guce: ExceptionGroup[GatewayDuplicateConflictError]
1221 logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}")
1222 db.rollback()
1224 structured_logger.log(
1225 level="WARNING",
1226 message="Gateway creation failed due to duplicate",
1227 event_type="gateway_duplicate_conflict",
1228 component="gateway_service",
1229 user_id=created_by,
1230 user_email=owner_email,
1231 custom_fields={"gateway_name": gateway.name},
1232 )
1233 raise guce.exceptions[0]
1234 except* ValueError as ve: # pragma: no mutate
1235 if TYPE_CHECKING:
1236 ve: ExceptionGroup[ValueError]
1237 logger.error(f"ValueErrors in group: {ve.exceptions}")
1238 db.rollback()
1240 structured_logger.log(
1241 level="ERROR",
1242 message="Gateway creation failed due to validation error",
1243 event_type="gateway_creation_failed",
1244 component="gateway_service",
1245 user_id=created_by,
1246 user_email=owner_email,
1247 error=ve.exceptions[0],
1248 custom_fields={"gateway_name": gateway.name},
1249 )
1250 raise ve.exceptions[0]
1251 except* RuntimeError as re: # pragma: no mutate
1252 if TYPE_CHECKING:
1253 re: ExceptionGroup[RuntimeError]
1254 logger.error(f"RuntimeErrors in group: {re.exceptions}")
1255 db.rollback()
1257 structured_logger.log(
1258 level="ERROR",
1259 message="Gateway creation failed due to runtime error",
1260 event_type="gateway_creation_failed",
1261 component="gateway_service",
1262 user_id=created_by,
1263 user_email=owner_email,
1264 error=re.exceptions[0],
1265 custom_fields={"gateway_name": gateway.name},
1266 )
1267 raise re.exceptions[0]
1268 except* IntegrityError as ie: # pragma: no mutate
1269 if TYPE_CHECKING:
1270 ie: ExceptionGroup[IntegrityError]
1271 logger.error(f"IntegrityErrors in group: {ie.exceptions}")
1272 db.rollback()
1274 structured_logger.log(
1275 level="ERROR",
1276 message="Gateway creation failed due to database integrity error",
1277 event_type="gateway_creation_failed",
1278 component="gateway_service",
1279 user_id=created_by,
1280 user_email=owner_email,
1281 error=ie.exceptions[0],
1282 custom_fields={"gateway_name": gateway.name},
1283 )
1284 raise ie.exceptions[0]
1285 except* BaseException as other: # catches every other sub-exception # pragma: no mutate
1286 if TYPE_CHECKING:
1287 other: ExceptionGroup[Exception]
1288 logger.error(f"Other grouped errors: {other.exceptions}")
1289 db.rollback()
1290 raise other.exceptions[0]
1292 async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_email: str) -> Dict[str, Any]:
1293 """Fetch tools from MCP server after OAuth completion for Authorization Code flow.
1295 Args:
1296 db: Database session
1297 gateway_id: ID of the gateway to fetch tools for
1298 app_user_email: ContextForge user email for token retrieval
1300 Returns:
1301 Dict containing capabilities, tools, resources, and prompts
1303 Raises:
1304 GatewayConnectionError: If connection or OAuth fails
1305 """
1306 try:
1307 # Get the gateway with eager loading for sync operations to avoid N+1 queries
1308 gateway = db.execute(
1309 select(DbGateway)
1310 .options(
1311 selectinload(DbGateway.tools),
1312 selectinload(DbGateway.resources),
1313 selectinload(DbGateway.prompts),
1314 joinedload(DbGateway.email_team),
1315 )
1316 .where(DbGateway.id == gateway_id)
1317 ).scalar_one_or_none()
1319 if not gateway:
1320 raise ValueError(f"Gateway {gateway_id} not found")
1322 if not gateway.oauth_config:
1323 raise ValueError(f"Gateway {gateway_id} has no OAuth configuration")
1325 grant_type = gateway.oauth_config.get("grant_type")
1326 if grant_type != "authorization_code":
1327 raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow")
1329 # Get OAuth tokens for this gateway
1330 # First-Party
1331 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
1333 token_storage = TokenStorageService(db)
1335 # Get user-specific OAuth token
1336 if not app_user_email:
1337 raise GatewayConnectionError(f"User authentication required for OAuth gateway {gateway.name}")
1339 access_token = await token_storage.get_user_token(gateway.id, app_user_email)
1341 if not access_token:
1342 raise GatewayConnectionError(
1343 f"No OAuth tokens found for user {app_user_email} on gateway {gateway.name}. Please complete the OAuth authorization flow first at /oauth/authorize/{gateway.id}"
1344 )
1346 # Debug: Check if token was decrypted
1347 if access_token.startswith("Z0FBQUFBQm"): # Encrypted tokens start with this
1348 logger.error("OAuth token decryption may have failed before gateway initialization")
1349 else:
1350 logger.info("Using decrypted OAuth token for gateway %s", gateway.name)
1352 # Now connect to MCP server with the access token
1353 authentication = {"Authorization": f"Bearer {access_token}"}
1355 # Use the existing connection logic
1356 # Note: For OAuth servers, skip validation since we already validated via OAuth flow
1357 if gateway.transport.upper() == "SSE":
1358 capabilities, tools, resources, prompts = await self._connect_to_sse_server_without_validation(gateway.url, authentication)
1359 elif gateway.transport.upper() == "STREAMABLEHTTP":
1360 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication)
1361 else:
1362 raise ValueError(f"Unsupported transport type: {gateway.transport}")
1364 # Handle tools, resources, and prompts using helper methods
1365 tools_to_add = self._update_or_create_tools(db, tools, gateway, "oauth")
1366 resources_to_add = self._update_or_create_resources(db, resources, gateway, "oauth")
1367 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "oauth")
1369 # Clean up items that are no longer available from the gateway
1370 new_tool_names = [tool.name for tool in tools]
1371 new_resource_uris = [resource.uri for resource in resources]
1372 new_prompt_names = [prompt.name for prompt in prompts]
1374 # Count items before cleanup for logging
1376 # Bulk delete tools that are no longer available from the gateway
1377 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
1378 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
1379 if stale_tool_ids:
1380 # Delete child records first to avoid FK constraint violations
1381 for i in range(0, len(stale_tool_ids), 500):
1382 chunk = stale_tool_ids[i : i + 500]
1383 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
1384 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
1385 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
1387 # Bulk delete resources that are no longer available from the gateway
1388 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
1389 if stale_resource_ids:
1390 # Delete child records first to avoid FK constraint violations
1391 for i in range(0, len(stale_resource_ids), 500):
1392 chunk = stale_resource_ids[i : i + 500]
1393 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
1394 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
1395 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
1396 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
1398 # Bulk delete prompts that are no longer available from the gateway
1399 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
1400 if stale_prompt_ids:
1401 # Delete child records first to avoid FK constraint violations
1402 for i in range(0, len(stale_prompt_ids), 500):
1403 chunk = stale_prompt_ids[i : i + 500]
1404 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
1405 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
1406 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
1408 # Expire gateway to clear cached relationships after bulk deletes
1409 # This prevents SQLAlchemy from trying to re-delete already-deleted items
1410 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
1411 db.expire(gateway)
1413 # Update gateway relationships to reflect deletions
1414 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names]
1415 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris]
1416 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names]
1418 # Log cleanup results
1419 tools_removed = len(stale_tool_ids)
1420 resources_removed = len(stale_resource_ids)
1421 prompts_removed = len(stale_prompt_ids)
1423 if tools_removed > 0:
1424 logger.info(f"Removed {tools_removed} tools no longer available from gateway")
1425 if resources_removed > 0:
1426 logger.info(f"Removed {resources_removed} resources no longer available from gateway")
1427 if prompts_removed > 0:
1428 logger.info(f"Removed {prompts_removed} prompts no longer available from gateway")
1430 # Update gateway capabilities and last_seen
1431 gateway.capabilities = capabilities
1432 gateway.last_seen = datetime.now(timezone.utc)
1434 # Register capabilities for notification-driven actions
1435 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
1437 # Add new items to DB in chunks to prevent lock escalation
1438 items_added = 0
1439 chunk_size = 50
1441 if tools_to_add:
1442 for i in range(0, len(tools_to_add), chunk_size):
1443 chunk = tools_to_add[i : i + chunk_size]
1444 db.add_all(chunk)
1445 db.flush() # Flush each chunk to avoid excessive memory usage
1446 items_added += len(tools_to_add)
1447 logger.info(f"Added {len(tools_to_add)} new tools to database")
1449 if resources_to_add:
1450 for i in range(0, len(resources_to_add), chunk_size):
1451 chunk = resources_to_add[i : i + chunk_size]
1452 db.add_all(chunk)
1453 db.flush()
1454 items_added += len(resources_to_add)
1455 logger.info(f"Added {len(resources_to_add)} new resources to database")
1457 if prompts_to_add:
1458 for i in range(0, len(prompts_to_add), chunk_size):
1459 chunk = prompts_to_add[i : i + chunk_size]
1460 db.add_all(chunk)
1461 db.flush()
1462 items_added += len(prompts_to_add)
1463 logger.info(f"Added {len(prompts_to_add)} new prompts to database")
1465 if items_added > 0:
1466 db.commit()
1467 logger.info(f"Total {items_added} new items added to database")
1468 else:
1469 logger.info("No new items to add to database")
1470 # Still commit to save any updates to existing items
1471 db.commit()
1473 cache = _get_registry_cache()
1474 await cache.invalidate_tools()
1475 await cache.invalidate_resources()
1476 await cache.invalidate_prompts()
1477 tool_lookup_cache = _get_tool_lookup_cache()
1478 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
1479 # Also invalidate tags cache since tool/resource tags may have changed
1480 # First-Party
1481 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1483 await admin_stats_cache.invalidate_tags()
1485 return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts}
1487 except GatewayConnectionError as gce:
1488 db.rollback()
1489 # Surface validation or depth-related failures directly to the user
1490 logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}")
1491 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}")
1492 except Exception as e:
1493 db.rollback()
1494 logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}")
1495 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}")
1497 async def list_gateways(
1498 self,
1499 db: Session,
1500 include_inactive: bool = False,
1501 tags: Optional[List[str]] = None,
1502 cursor: Optional[str] = None,
1503 limit: Optional[int] = None,
1504 page: Optional[int] = None,
1505 per_page: Optional[int] = None,
1506 user_email: Optional[str] = None,
1507 team_id: Optional[str] = None,
1508 visibility: Optional[str] = None,
1509 token_teams: Optional[List[str]] = None,
1510 ) -> Union[tuple[List[GatewayRead], Optional[str]], Dict[str, Any]]:
1511 """List all registered gateways with cursor pagination and optional team filtering.
1513 Args:
1514 db: Database session
1515 include_inactive: Whether to include inactive gateways
1516 tags (Optional[List[str]]): Filter resources by tags. If provided, only resources with at least one matching tag will be returned.
1517 cursor: Cursor for pagination (encoded last created_at and id).
1518 limit: Maximum number of gateways to return. None for default, 0 for unlimited.
1519 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1520 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1521 user_email: Email of user for team-based access control. None for no access control.
1522 team_id: Optional team ID to filter by specific team (requires user_email).
1523 visibility: Optional visibility filter (private, team, public) (requires user_email).
1524 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only).
1526 Returns:
1527 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}}
1528 If cursor is provided or neither: tuple of (list of GatewayRead objects, next_cursor).
1530 Examples:
1531 >>> from mcpgateway.services.gateway_service import GatewayService
1532 >>> from unittest.mock import MagicMock, AsyncMock, patch
1533 >>> from mcpgateway.schemas import GatewayRead
1534 >>> import asyncio
1535 >>> service = GatewayService()
1536 >>> db = MagicMock()
1537 >>> gateway_obj = MagicMock()
1538 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj]
1539 >>> gateway_read_obj = MagicMock(spec=GatewayRead)
1540 >>> service.convert_gateway_to_read = MagicMock(return_value=gateway_read_obj)
1541 >>> # Mock the cache to bypass caching logic
1542 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory:
1543 ... mock_cache = MagicMock()
1544 ... mock_cache.get = AsyncMock(return_value=None)
1545 ... mock_cache.set = AsyncMock(return_value=None)
1546 ... mock_cache.hash_filters = MagicMock(return_value="hash")
1547 ... mock_cache_factory.return_value = mock_cache
1548 ... gateways, cursor = asyncio.run(service.list_gateways(db))
1549 ... gateways == [gateway_read_obj] and cursor is None
1550 True
1552 >>> # Test empty result
1553 >>> db.execute.return_value.scalars.return_value.all.return_value = []
1554 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory:
1555 ... mock_cache = MagicMock()
1556 ... mock_cache.get = AsyncMock(return_value=None)
1557 ... mock_cache.set = AsyncMock(return_value=None)
1558 ... mock_cache.hash_filters = MagicMock(return_value="hash")
1559 ... mock_cache_factory.return_value = mock_cache
1560 ... empty_result, cursor = asyncio.run(service.list_gateways(db))
1561 ... empty_result == [] and cursor is None
1562 True
1563 >>>
1564 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
1565 >>> asyncio.run(service._http_client.aclose())
1566 """
1567 # Check cache for first page only - only for public-only queries (no user/team filtering)
1568 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped
1569 cache = _get_registry_cache()
1570 is_public_only = token_teams is not None and len(token_teams) == 0
1571 use_cache = cursor is None and user_email is None and page is None and is_public_only
1572 if use_cache:
1573 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None)
1574 cached = await cache.get("gateways", filters_hash)
1575 if cached is not None:
1576 # Reconstruct GatewayRead objects from cached dicts
1577 # SECURITY: Always apply .masked() to ensure stale cache entries don't leak credentials
1578 cached_gateways = [GatewayRead.model_validate(g).masked() for g in cached["gateways"]]
1579 return (cached_gateways, cached.get("next_cursor"))
1581 # Build base query with ordering
1582 query = select(DbGateway).options(joinedload(DbGateway.email_team)).order_by(desc(DbGateway.created_at), desc(DbGateway.id))
1584 # Apply active/inactive filter
1585 if not include_inactive:
1586 query = query.where(DbGateway.enabled)
1588 query = await self._apply_access_control(query, db, user_email, token_teams, team_id)
1590 if visibility:
1591 query = query.where(DbGateway.visibility == visibility)
1593 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
1594 if tags:
1595 query = query.where(json_contains_tag_expr(db, DbGateway.tags, tags, match_any=True))
1596 # Use unified pagination helper - handles both page and cursor pagination
1597 pag_result = await unified_paginate(
1598 db=db,
1599 query=query,
1600 page=page,
1601 per_page=per_page,
1602 cursor=cursor,
1603 limit=limit,
1604 base_url="/admin/gateways", # Used for page-based links
1605 query_params={"include_inactive": include_inactive} if include_inactive else {},
1606 )
1608 next_cursor = None
1609 # Extract gateways based on pagination type
1610 if page is not None:
1611 # Page-based: pag_result is a dict
1612 gateways_db = pag_result["data"]
1613 else:
1614 # Cursor-based: pag_result is a tuple
1615 gateways_db, next_cursor = pag_result
1617 db.commit() # Release transaction to avoid idle-in-transaction
1619 # Convert to GatewayRead (common for both pagination types)
1620 result = []
1621 for s in gateways_db:
1622 try:
1623 result.append(self.convert_gateway_to_read(s))
1624 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1625 logger.exception(f"Failed to convert gateway {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
1626 # Continue with remaining gateways instead of failing completely
1628 # Return appropriate format based on pagination type
1629 if page is not None:
1630 # Page-based format
1631 return {
1632 "data": result,
1633 "pagination": pag_result["pagination"],
1634 "links": pag_result["links"],
1635 }
1637 # Cursor-based format
1639 # Cache first page results - only for public-only queries (no user/team filtering)
1640 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped
1641 if cursor is None and user_email is None and is_public_only:
1642 try:
1643 cache_data = {"gateways": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
1644 await cache.set("gateways", cache_data, filters_hash)
1645 except AttributeError:
1646 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
1648 return (result, next_cursor)
1650 async def list_gateways_for_user(
1651 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100
1652 ) -> List[GatewayRead]:
1653 """
1654 DEPRECATED: Use list_gateways() with user_email parameter instead.
1656 This method is maintained for backward compatibility but is no longer used.
1657 New code should call list_gateways() with user_email, team_id, and visibility parameters.
1659 List gateways user has access to with team filtering.
1661 Args:
1662 db: Database session
1663 user_email: Email of the user requesting gateways
1664 team_id: Optional team ID to filter by specific team
1665 visibility: Optional visibility filter (private, team, public)
1666 include_inactive: Whether to include inactive gateways
1667 skip: Number of gateways to skip for pagination
1668 limit: Maximum number of gateways to return
1670 Returns:
1671 List[GatewayRead]: Gateways the user has access to
1672 """
1673 # Build query following existing patterns from list_gateways()
1674 team_service = TeamManagementService(db)
1675 user_teams = await team_service.get_user_teams(user_email)
1676 team_ids = [team.id for team in user_teams]
1678 # Use joinedload to eager load email_team relationship (avoids N+1 queries)
1679 query = select(DbGateway).options(joinedload(DbGateway.email_team))
1681 # Apply active/inactive filter
1682 if not include_inactive:
1683 query = query.where(DbGateway.enabled.is_(True))
1685 if team_id:
1686 if team_id not in team_ids:
1687 return [] # No access to team
1689 access_conditions = []
1690 # Filter by specific team
1692 # Team-owned gateways (team-scoped gateways)
1693 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"])))
1695 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email))
1697 # Also include global public gateways (no team_id) so public gateways are visible regardless of selected team
1698 access_conditions.append(DbGateway.visibility == "public")
1700 query = query.where(or_(*access_conditions))
1701 else:
1702 # Get user's accessible teams
1703 # Build access conditions following existing patterns
1704 access_conditions = []
1705 # 1. User's personal resources (owner_email matches)
1706 access_conditions.append(DbGateway.owner_email == user_email)
1707 # 2. Team resources where user is member
1708 if team_ids:
1709 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"])))
1710 # 3. Public resources (if visibility allows)
1711 access_conditions.append(DbGateway.visibility == "public")
1713 query = query.where(or_(*access_conditions))
1715 # Apply visibility filter if specified
1716 if visibility:
1717 query = query.where(DbGateway.visibility == visibility)
1719 # Apply pagination following existing patterns
1720 query = query.offset(skip).limit(limit)
1722 gateways = db.execute(query).scalars().all()
1724 db.commit() # Release transaction to avoid idle-in-transaction
1726 # Team names are loaded via joinedload(DbGateway.email_team)
1727 result = []
1728 for g in gateways:
1729 logger.info(f"Gateway: {g.team_id}, Team: {g.team}")
1730 result.append(GatewayRead.model_validate(self._prepare_gateway_for_read(g)).masked())
1731 return result
1733 async def update_gateway(
1734 self,
1735 db: Session,
1736 gateway_id: str,
1737 gateway_update: GatewayUpdate,
1738 modified_by: Optional[str] = None,
1739 modified_from_ip: Optional[str] = None,
1740 modified_via: Optional[str] = None,
1741 modified_user_agent: Optional[str] = None,
1742 include_inactive: bool = True,
1743 user_email: Optional[str] = None,
1744 ) -> Optional[GatewayRead]:
1745 """Update a gateway.
1747 Args:
1748 db: Database session
1749 gateway_id: Gateway ID to update
1750 gateway_update: Updated gateway data
1751 modified_by: Username of the person modifying the gateway
1752 modified_from_ip: IP address where the modification request originated
1753 modified_via: Source of modification (ui/api/import)
1754 modified_user_agent: User agent string from the modification request
1755 include_inactive: Whether to include inactive gateways
1756 user_email: Email of user performing update (for ownership check)
1758 Returns:
1759 Updated gateway information
1761 Raises:
1762 GatewayNotFoundError: If gateway not found
1763 PermissionError: If user doesn't own the gateway
1764 GatewayError: For other update errors
1765 GatewayNameConflictError: If gateway name conflict occurs
1766 IntegrityError: If there is a database integrity error
1767 ValidationError: If validation fails
1768 """
1769 try: # pylint: disable=too-many-nested-blocks
1770 # Acquire row lock and eager-load relationships while locked so
1771 # concurrent updates are serialized on Postgres.
1772 gateway = get_for_update(
1773 db,
1774 DbGateway,
1775 gateway_id,
1776 options=[
1777 selectinload(DbGateway.tools),
1778 selectinload(DbGateway.resources),
1779 selectinload(DbGateway.prompts),
1780 selectinload(DbGateway.email_team), # Use selectinload to avoid locking email_teams
1781 ],
1782 )
1783 if not gateway:
1784 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
1786 # Check ownership if user_email provided
1787 if user_email:
1788 # First-Party
1789 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1791 permission_service = PermissionService(db)
1792 if not await permission_service.check_resource_ownership(user_email, gateway):
1793 raise PermissionError("Only the owner can update this gateway")
1795 if gateway.enabled or include_inactive:
1796 # Check for name conflicts if name is being changed
1797 if gateway_update.name is not None and gateway_update.name != gateway.name:
1798 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none()
1800 # if existing_gateway:
1801 # raise GatewayNameConflictError(
1802 # gateway_update.name,
1803 # enabled=existing_gateway.enabled,
1804 # gateway_id=existing_gateway.id,
1805 # )
1806 # Check for existing gateway with the same slug and visibility
1807 new_slug = slugify(gateway_update.name)
1808 if gateway_update.visibility is not None:
1809 vis = gateway_update.visibility
1810 else:
1811 vis = gateway.visibility
1812 if vis == "public":
1813 # Check for existing public gateway with the same slug (row-locked)
1814 existing_gateway = get_for_update(
1815 db,
1816 DbGateway,
1817 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "public", DbGateway.id != gateway_id),
1818 )
1819 if existing_gateway:
1820 raise GatewayNameConflictError(
1821 new_slug,
1822 enabled=existing_gateway.enabled,
1823 gateway_id=existing_gateway.id,
1824 visibility=existing_gateway.visibility,
1825 )
1826 elif vis == "team" and gateway.team_id:
1827 # Check for existing team gateway with the same slug (row-locked)
1828 existing_gateway = get_for_update(
1829 db,
1830 DbGateway,
1831 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "team", DbGateway.team_id == gateway.team_id, DbGateway.id != gateway_id),
1832 )
1833 if existing_gateway:
1834 raise GatewayNameConflictError(
1835 new_slug,
1836 enabled=existing_gateway.enabled,
1837 gateway_id=existing_gateway.id,
1838 visibility=existing_gateway.visibility,
1839 )
1840 # Check for existing gateway with the same URL and visibility
1841 normalized_url = ""
1842 if gateway_update.url is not None:
1843 normalized_url = self.normalize_url(str(gateway_update.url))
1844 else:
1845 normalized_url = None
1847 # Prepare decoded auth_value for uniqueness check
1848 decoded_auth_value = None
1849 if gateway_update.auth_value:
1850 if isinstance(gateway_update.auth_value, str):
1851 try:
1852 decoded_auth_value = decode_auth(gateway_update.auth_value)
1853 except Exception as e:
1854 logger.warning(f"Failed to decode provided auth_value: {e}")
1855 elif isinstance(gateway_update.auth_value, dict):
1856 decoded_auth_value = gateway_update.auth_value
1858 # Determine final values for uniqueness check
1859 final_auth_value = decoded_auth_value if gateway_update.auth_value is not None else (decode_auth(gateway.auth_value) if isinstance(gateway.auth_value, str) else gateway.auth_value)
1860 final_oauth_config = gateway_update.oauth_config if gateway_update.oauth_config is not None else gateway.oauth_config
1861 final_visibility = gateway_update.visibility if gateway_update.visibility is not None else gateway.visibility
1863 # Check for duplicates with updated credentials
1864 if not gateway_update.one_time_auth:
1865 duplicate_gateway = self._check_gateway_uniqueness(
1866 db=db,
1867 url=normalized_url,
1868 auth_value=final_auth_value,
1869 oauth_config=final_oauth_config,
1870 team_id=gateway.team_id,
1871 visibility=final_visibility,
1872 gateway_id=gateway_id, # Exclude current gateway from check
1873 owner_email=user_email,
1874 )
1876 if duplicate_gateway:
1877 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway)
1879 # FIX for Issue #1025: Determine if URL actually changed before we update it
1880 # We need this early because we update gateway.url below, and need to know
1881 # if it actually changed to decide whether to re-fetch tools
1882 # tools/resoures/prompts are need to be re-fetched not only if URL changed , in case any update like authentication and visibility changed
1883 # url_changed = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != gateway.url
1885 # Save original values BEFORE updating for change detection checks later
1886 original_url = gateway.url
1887 original_auth_type = gateway.auth_type
1889 # Update fields if provided
1890 if gateway_update.name is not None:
1891 gateway.name = gateway_update.name
1892 gateway.slug = slugify(gateway_update.name)
1893 if gateway_update.url is not None:
1894 # Normalize the updated URL
1895 gateway.url = self.normalize_url(str(gateway_update.url))
1896 if gateway_update.description is not None:
1897 gateway.description = gateway_update.description
1898 if gateway_update.transport is not None:
1899 gateway.transport = gateway_update.transport
1900 if gateway_update.tags is not None:
1901 gateway.tags = gateway_update.tags
1902 if gateway_update.visibility is not None:
1903 gateway.visibility = gateway_update.visibility
1904 # Propagate visibility to all linked items immediately so it
1905 # takes effect even when the upstream server is unreachable
1906 # and _initialize_gateway fails.
1907 for tool in gateway.tools:
1908 tool.visibility = gateway.visibility
1909 for resource in gateway.resources:
1910 resource.visibility = gateway.visibility
1911 for prompt in gateway.prompts:
1912 prompt.visibility = gateway.visibility
1913 if gateway_update.passthrough_headers is not None:
1914 if isinstance(gateway_update.passthrough_headers, list):
1915 gateway.passthrough_headers = gateway_update.passthrough_headers
1916 else:
1917 if isinstance(gateway_update.passthrough_headers, str):
1918 parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()]
1919 gateway.passthrough_headers = parsed
1920 else:
1921 raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string")
1923 logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}")
1925 # Only update auth_type if explicitly provided in the update
1926 if gateway_update.auth_type is not None:
1927 gateway.auth_type = gateway_update.auth_type
1929 # If auth_type is empty, update the auth_value too
1930 if gateway_update.auth_type == "":
1931 gateway.auth_value = cast(Any, "")
1933 # Clear auth_query_params when switching away from query_param auth
1934 if original_auth_type == "query_param" and gateway_update.auth_type != "query_param":
1935 gateway.auth_query_params = None
1936 logger.debug(f"Cleared auth_query_params for gateway {gateway.id} (switched from query_param to {gateway_update.auth_type})")
1938 # if auth_type is not None and only then check auth_value
1939 # Handle OAuth configuration updates
1940 if gateway_update.oauth_config is not None:
1941 gateway.oauth_config = await protect_oauth_config_for_storage(gateway_update.oauth_config, existing_oauth_config=gateway.oauth_config)
1943 # Handle auth_value updates (both existing and new auth values)
1944 token = gateway_update.auth_token
1945 password = gateway_update.auth_password
1946 header_value = gateway_update.auth_header_value
1948 # Support multiple custom headers on update
1949 if hasattr(gateway_update, "auth_headers") and gateway_update.auth_headers:
1950 existing_auth_raw = getattr(gateway, "auth_value", {}) or {}
1951 if isinstance(existing_auth_raw, str):
1952 try:
1953 existing_auth = decode_auth(existing_auth_raw)
1954 except Exception:
1955 existing_auth = {}
1956 elif isinstance(existing_auth_raw, dict):
1957 existing_auth = existing_auth_raw
1958 else:
1959 existing_auth = {}
1961 header_dict: Dict[str, str] = {}
1962 for header in gateway_update.auth_headers:
1963 key = header.get("key")
1964 if not key:
1965 continue
1966 value = header.get("value", "")
1967 if value == settings.masked_auth_value and key in existing_auth:
1968 header_dict[key] = existing_auth[key]
1969 else:
1970 header_dict[key] = value
1971 gateway.auth_value = header_dict # Store as dict for DB JSON field
1972 elif settings.masked_auth_value not in (token, password, header_value):
1973 # Check if values differ from existing ones or if setting for first time
1974 decoded_auth = decode_auth(gateway_update.auth_value) if gateway_update.auth_value else {}
1975 current_auth = getattr(gateway, "auth_value", {}) or {}
1976 if current_auth != decoded_auth:
1977 gateway.auth_value = decoded_auth
1979 # Handle query_param auth updates with service-layer enforcement
1980 auth_query_params_decrypted: Optional[Dict[str, str]] = None
1981 init_url = gateway.url
1983 # Check if updating to query_param auth or updating existing query_param credentials
1984 # Use original_auth_type since gateway.auth_type may have been updated already
1985 is_switching_to_queryparam = gateway_update.auth_type == "query_param" and original_auth_type != "query_param"
1986 is_updating_queryparam_creds = original_auth_type == "query_param" and (gateway_update.auth_query_param_key is not None or gateway_update.auth_query_param_value is not None)
1987 is_url_changing = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != original_url
1989 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"):
1990 # Service-layer enforcement: Check feature flag
1991 if not settings.insecure_allow_queryparam_auth:
1992 # Grandfather clause: Allow updates to existing query_param gateways
1993 # unless they're trying to change credentials
1994 if is_switching_to_queryparam or is_updating_queryparam_creds:
1995 raise ValueError("Query parameter authentication is disabled. " + "Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
1997 # Service-layer enforcement: Check host allowlist
1998 if settings.insecure_queryparam_auth_allowed_hosts:
1999 check_url = str(gateway_update.url) if gateway_update.url else gateway.url
2000 parsed = urlparse(check_url)
2001 hostname = (parsed.hostname or "").lower()
2002 if hostname not in settings.insecure_queryparam_auth_allowed_hosts:
2003 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
2004 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. Allowed: {allowed}")
2006 # Process query_param auth credentials
2007 param_key = getattr(gateway_update, "auth_query_param_key", None) or (next(iter(gateway.auth_query_params.keys()), None) if gateway.auth_query_params else None)
2008 param_value = getattr(gateway_update, "auth_query_param_value", None)
2010 # Get raw value from SecretStr if applicable
2011 raw_value: Optional[str] = None
2012 if param_value:
2013 if hasattr(param_value, "get_secret_value"):
2014 raw_value = param_value.get_secret_value()
2015 else:
2016 raw_value = str(param_value)
2018 # Check if the value is the masked placeholder - if so, keep existing value
2019 is_masked_placeholder = raw_value == settings.masked_auth_value
2021 if param_key:
2022 if raw_value and not is_masked_placeholder:
2023 # New value provided - encrypt for storage
2024 encrypted_value = encode_auth({param_key: raw_value})
2025 gateway.auth_query_params = {param_key: encrypted_value}
2026 auth_query_params_decrypted = {param_key: raw_value}
2027 elif gateway.auth_query_params:
2028 # Use existing encrypted value
2029 existing_encrypted = gateway.auth_query_params.get(param_key, "")
2030 if existing_encrypted:
2031 decrypted = decode_auth(existing_encrypted)
2032 auth_query_params_decrypted = {param_key: decrypted.get(param_key, "")}
2034 # Append query params to URL for initialization
2035 if auth_query_params_decrypted:
2036 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2038 # Update auth_type if switching
2039 if is_switching_to_queryparam:
2040 gateway.auth_type = "query_param"
2041 gateway.auth_value = None # Query param auth doesn't use auth_value
2043 elif gateway.auth_type == "query_param" and gateway.auth_query_params:
2044 # Existing query_param gateway without credential changes - decrypt for init
2045 first_key = next(iter(gateway.auth_query_params.keys()), None)
2046 if first_key:
2047 encrypted_value = gateway.auth_query_params.get(first_key, "")
2048 if encrypted_value:
2049 decrypted = decode_auth(encrypted_value)
2050 auth_query_params_decrypted = {first_key: decrypted.get(first_key, "")}
2051 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2053 # Try to reinitialize connection if URL actually changed
2054 # if url_changed:
2055 # Initialize empty lists in case initialization fails
2056 tools_to_add = []
2057 resources_to_add = []
2058 prompts_to_add = []
2060 try:
2061 ca_certificate = getattr(gateway, "ca_certificate", None)
2062 capabilities, tools, resources, prompts = await self._initialize_gateway(
2063 init_url,
2064 gateway.auth_value,
2065 gateway.transport,
2066 gateway.auth_type,
2067 gateway.oauth_config,
2068 ca_certificate,
2069 auth_query_params=auth_query_params_decrypted,
2070 )
2071 new_tool_names = [tool.name for tool in tools]
2072 new_resource_uris = [resource.uri for resource in resources]
2073 new_prompt_names = [prompt.name for prompt in prompts]
2075 if gateway_update.one_time_auth:
2076 # For one-time auth, clear auth_type and auth_value after initialization
2077 gateway.auth_type = "one_time_auth"
2078 gateway.auth_value = None
2079 gateway.oauth_config = None
2081 # Update tools using helper method
2082 tools_to_add = self._update_or_create_tools(db, tools, gateway, "update")
2084 # Update resources using helper method
2085 resources_to_add = self._update_or_create_resources(db, resources, gateway, "update")
2087 # Update prompts using helper method
2088 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "update")
2090 # Log newly added items
2091 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add)
2092 if items_added > 0:
2093 if tools_to_add:
2094 logger.info(f"Added {len(tools_to_add)} new tools during gateway update")
2095 if resources_to_add:
2096 logger.info(f"Added {len(resources_to_add)} new resources during gateway update")
2097 if prompts_to_add:
2098 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway update")
2099 logger.info(f"Total {items_added} new items added during gateway update")
2101 # Count items before cleanup for logging
2103 # Bulk delete tools that are no longer available from the gateway
2104 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2105 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
2106 if stale_tool_ids:
2107 # Delete child records first to avoid FK constraint violations
2108 for i in range(0, len(stale_tool_ids), 500):
2109 chunk = stale_tool_ids[i : i + 500]
2110 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2111 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2112 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2114 # Bulk delete resources that are no longer available from the gateway
2115 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
2116 if stale_resource_ids:
2117 # Delete child records first to avoid FK constraint violations
2118 for i in range(0, len(stale_resource_ids), 500):
2119 chunk = stale_resource_ids[i : i + 500]
2120 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2121 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2122 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2123 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2125 # Bulk delete prompts that are no longer available from the gateway
2126 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
2127 if stale_prompt_ids:
2128 # Delete child records first to avoid FK constraint violations
2129 for i in range(0, len(stale_prompt_ids), 500):
2130 chunk = stale_prompt_ids[i : i + 500]
2131 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2132 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2133 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2135 # Expire gateway to clear cached relationships after bulk deletes
2136 # This prevents SQLAlchemy from trying to re-delete already-deleted items
2137 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
2138 db.expire(gateway)
2140 gateway.capabilities = capabilities
2142 # Register capabilities for notification-driven actions
2143 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
2145 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
2146 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows
2147 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows
2149 # Log cleanup results
2150 tools_removed = len(stale_tool_ids)
2151 resources_removed = len(stale_resource_ids)
2152 prompts_removed = len(stale_prompt_ids)
2154 if tools_removed > 0:
2155 logger.info(f"Removed {tools_removed} tools no longer available during gateway update")
2156 if resources_removed > 0:
2157 logger.info(f"Removed {resources_removed} resources no longer available during gateway update")
2158 if prompts_removed > 0:
2159 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway update")
2161 gateway.last_seen = datetime.now(timezone.utc)
2163 # Add new items to database session in chunks to prevent lock escalation
2164 chunk_size = 50
2166 if tools_to_add:
2167 for i in range(0, len(tools_to_add), chunk_size):
2168 chunk = tools_to_add[i : i + chunk_size]
2169 db.add_all(chunk)
2170 db.flush()
2171 if resources_to_add:
2172 for i in range(0, len(resources_to_add), chunk_size):
2173 chunk = resources_to_add[i : i + chunk_size]
2174 db.add_all(chunk)
2175 db.flush()
2176 if prompts_to_add:
2177 for i in range(0, len(prompts_to_add), chunk_size):
2178 chunk = prompts_to_add[i : i + chunk_size]
2179 db.add_all(chunk)
2180 db.flush()
2182 # Update tracking with new URL
2183 self._active_gateways.discard(gateway.url)
2184 self._active_gateways.add(gateway.url)
2185 except Exception as e:
2186 logger.warning(f"Failed to initialize updated gateway: {e}")
2188 # Update tags if provided
2189 if gateway_update.tags is not None:
2190 gateway.tags = gateway_update.tags
2192 # Update gateway_mode if provided
2193 if hasattr(gateway_update, "gateway_mode") and gateway_update.gateway_mode is not None:
2194 if gateway_update.gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled:
2195 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.")
2196 gateway.gateway_mode = gateway_update.gateway_mode
2198 # Update metadata fields
2199 gateway.updated_at = datetime.now(timezone.utc)
2200 if modified_by:
2201 gateway.modified_by = modified_by
2202 if modified_from_ip:
2203 gateway.modified_from_ip = modified_from_ip
2204 if modified_via:
2205 gateway.modified_via = modified_via
2206 if modified_user_agent:
2207 gateway.modified_user_agent = modified_user_agent
2208 if hasattr(gateway, "version") and gateway.version is not None:
2209 gateway.version = gateway.version + 1
2210 else:
2211 gateway.version = 1
2213 db.commit()
2214 db.refresh(gateway)
2216 # Invalidate cache after successful update
2217 cache = _get_registry_cache()
2218 await cache.invalidate_gateways()
2219 tool_lookup_cache = _get_tool_lookup_cache()
2220 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
2221 # Also invalidate tags cache since gateway tags may have changed
2222 # First-Party
2223 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2225 await admin_stats_cache.invalidate_tags()
2227 # Notify subscribers
2228 await self._notify_gateway_updated(gateway)
2230 logger.info(f"Updated gateway: {gateway.name}")
2232 # Structured logging: Audit trail for gateway update
2233 audit_trail.log_action(
2234 user_id=user_email or modified_by or "system",
2235 action="update_gateway",
2236 resource_type="gateway",
2237 resource_id=str(gateway.id),
2238 resource_name=gateway.name,
2239 user_email=user_email,
2240 team_id=gateway.team_id,
2241 client_ip=modified_from_ip,
2242 user_agent=modified_user_agent,
2243 new_values={
2244 "name": gateway.name,
2245 "url": gateway.url,
2246 "version": gateway.version,
2247 },
2248 context={
2249 "modified_via": modified_via,
2250 },
2251 db=db,
2252 )
2254 # Structured logging: Log successful gateway update
2255 structured_logger.log(
2256 level="INFO",
2257 message="Gateway updated successfully",
2258 event_type="gateway_updated",
2259 component="gateway_service",
2260 user_id=modified_by,
2261 user_email=user_email,
2262 team_id=gateway.team_id,
2263 resource_type="gateway",
2264 resource_id=str(gateway.id),
2265 custom_fields={
2266 "gateway_name": gateway.name,
2267 "version": gateway.version,
2268 },
2269 )
2271 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked()
2272 # Gateway is inactive and include_inactive is False → skip update, return None
2273 return None
2274 except GatewayNameConflictError as ge:
2275 logger.error(f"GatewayNameConflictError in group: {ge}")
2276 db.rollback()
2278 structured_logger.log(
2279 level="WARNING",
2280 message="Gateway update failed due to name conflict",
2281 event_type="gateway_name_conflict",
2282 component="gateway_service",
2283 user_email=user_email,
2284 resource_type="gateway",
2285 resource_id=gateway_id,
2286 error=ge,
2287 )
2288 raise ge
2289 except GatewayNotFoundError as gnfe:
2290 logger.error(f"GatewayNotFoundError: {gnfe}")
2291 db.rollback()
2293 structured_logger.log(
2294 level="ERROR",
2295 message="Gateway update failed - gateway not found",
2296 event_type="gateway_not_found",
2297 component="gateway_service",
2298 user_email=user_email,
2299 resource_type="gateway",
2300 resource_id=gateway_id,
2301 error=gnfe,
2302 )
2303 raise gnfe
2304 except IntegrityError as ie:
2305 logger.error(f"IntegrityErrors in group: {ie}")
2306 db.rollback()
2308 structured_logger.log(
2309 level="ERROR",
2310 message="Gateway update failed due to database integrity error",
2311 event_type="gateway_update_failed",
2312 component="gateway_service",
2313 user_email=user_email,
2314 resource_type="gateway",
2315 resource_id=gateway_id,
2316 error=ie,
2317 )
2318 raise ie
2319 except PermissionError as pe:
2320 db.rollback()
2322 structured_logger.log(
2323 level="WARNING",
2324 message="Gateway update failed due to permission error",
2325 event_type="gateway_update_permission_denied",
2326 component="gateway_service",
2327 user_email=user_email,
2328 resource_type="gateway",
2329 resource_id=gateway_id,
2330 error=pe,
2331 )
2332 raise
2333 except Exception as e:
2334 db.rollback()
2336 structured_logger.log(
2337 level="ERROR",
2338 message="Gateway update failed",
2339 event_type="gateway_update_failed",
2340 component="gateway_service",
2341 user_email=user_email,
2342 resource_type="gateway",
2343 resource_id=gateway_id,
2344 error=e,
2345 )
2346 raise GatewayError(f"Failed to update gateway: {str(e)}")
2348 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead:
2349 """Get a gateway by its ID.
2351 Args:
2352 db: Database session
2353 gateway_id: Gateway ID
2354 include_inactive: Whether to include inactive gateways
2356 Returns:
2357 GatewayRead object
2359 Raises:
2360 GatewayNotFoundError: If the gateway is not found
2362 Examples:
2363 >>> from unittest.mock import MagicMock
2364 >>> from mcpgateway.schemas import GatewayRead
2365 >>> service = GatewayService()
2366 >>> db = MagicMock()
2367 >>> gateway_mock = MagicMock()
2368 >>> gateway_mock.enabled = True
2369 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock
2370 >>> mocked_gateway_read = MagicMock()
2371 >>> mocked_gateway_read.masked.return_value = 'gateway_read'
2372 >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read)
2373 >>> import asyncio
2374 >>> result = asyncio.run(service.get_gateway(db, 'gateway_id'))
2375 >>> result == 'gateway_read'
2376 True
2378 >>> # Test with inactive gateway but include_inactive=True
2379 >>> gateway_mock.enabled = False
2380 >>> result_inactive = asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=True))
2381 >>> result_inactive == 'gateway_read'
2382 True
2384 >>> # Test gateway not found
2385 >>> db.execute.return_value.scalar_one_or_none.return_value = None
2386 >>> try:
2387 ... asyncio.run(service.get_gateway(db, 'missing_id'))
2388 ... except GatewayNotFoundError as e:
2389 ... 'Gateway not found: missing_id' in str(e)
2390 True
2392 >>> # Test inactive gateway with include_inactive=False
2393 >>> gateway_mock.enabled = False
2394 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock
2395 >>> try:
2396 ... asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=False))
2397 ... except GatewayNotFoundError as e:
2398 ... 'Gateway not found: gateway_id' in str(e)
2399 True
2400 >>>
2401 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
2402 >>> asyncio.run(service._http_client.aclose())
2403 """
2404 # Use eager loading to avoid N+1 queries for relationships and team name
2405 gateway = db.execute(
2406 select(DbGateway)
2407 .options(
2408 selectinload(DbGateway.tools),
2409 selectinload(DbGateway.resources),
2410 selectinload(DbGateway.prompts),
2411 joinedload(DbGateway.email_team),
2412 )
2413 .where(DbGateway.id == gateway_id)
2414 ).scalar_one_or_none()
2416 if not gateway:
2417 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2419 if gateway.enabled or include_inactive:
2420 # Structured logging: Log gateway view
2421 structured_logger.log(
2422 level="INFO",
2423 message="Gateway retrieved successfully",
2424 event_type="gateway_viewed",
2425 component="gateway_service",
2426 team_id=getattr(gateway, "team_id", None),
2427 resource_type="gateway",
2428 resource_id=str(gateway.id),
2429 custom_fields={
2430 "gateway_name": gateway.name,
2431 "gateway_url": gateway.url,
2432 "include_inactive": include_inactive,
2433 },
2434 )
2436 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked()
2438 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2440 async def set_gateway_state(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False, user_email: Optional[str] = None) -> GatewayRead:
2441 """
2442 Set the activation status of a gateway.
2444 Args:
2445 db: Database session
2446 gateway_id: Gateway ID
2447 activate: True to activate, False to deactivate
2448 reachable: Whether the gateway is reachable
2449 only_update_reachable: Only update reachable status
2450 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2452 Returns:
2453 The updated GatewayRead object
2455 Raises:
2456 GatewayNotFoundError: If the gateway is not found
2457 GatewayError: For other errors
2458 PermissionError: If user doesn't own the agent.
2459 """
2460 try:
2461 # Eager-load collections for the gateway. Note: we don't use FOR UPDATE
2462 # here because _initialize_gateway does network I/O, and holding a row
2463 # lock during network calls would block other operations and risk timeouts.
2464 gateway = db.execute(
2465 select(DbGateway)
2466 .options(
2467 selectinload(DbGateway.tools),
2468 selectinload(DbGateway.resources),
2469 selectinload(DbGateway.prompts),
2470 joinedload(DbGateway.email_team),
2471 )
2472 .where(DbGateway.id == gateway_id)
2473 ).scalar_one_or_none()
2474 if not gateway:
2475 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2477 if user_email:
2478 # First-Party
2479 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2481 permission_service = PermissionService(db)
2482 if not await permission_service.check_resource_ownership(user_email, gateway):
2483 raise PermissionError("Only the owner can activate the gateway" if activate else "Only the owner can deactivate the gateway")
2485 # Update status if it's different
2486 if (gateway.enabled != activate) or (gateway.reachable != reachable):
2487 gateway.enabled = activate
2488 gateway.reachable = reachable
2489 gateway.updated_at = datetime.now(timezone.utc)
2490 # Update tracking
2491 if activate and reachable:
2492 self._active_gateways.add(gateway.url)
2494 # Initialize empty lists in case initialization fails
2495 tools_to_add = []
2496 resources_to_add = []
2497 prompts_to_add = []
2499 # Try to initialize if activating
2500 try:
2501 # Handle query_param auth - decrypt and apply to URL
2502 init_url = gateway.url
2503 auth_query_params_decrypted: Optional[Dict[str, str]] = None
2504 if gateway.auth_type == "query_param" and gateway.auth_query_params:
2505 auth_query_params_decrypted = {}
2506 for param_key, encrypted_value in gateway.auth_query_params.items():
2507 if encrypted_value:
2508 try:
2509 decrypted = decode_auth(encrypted_value)
2510 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
2511 except Exception:
2512 logger.debug(f"Failed to decrypt query param '{param_key}' for gateway activation")
2513 if auth_query_params_decrypted:
2514 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2516 capabilities, tools, resources, prompts = await self._initialize_gateway(
2517 init_url, gateway.auth_value, gateway.transport, gateway.auth_type, gateway.oauth_config, auth_query_params=auth_query_params_decrypted, oauth_auto_fetch_tool_flag=True
2518 )
2519 new_tool_names = [tool.name for tool in tools]
2520 new_resource_uris = [resource.uri for resource in resources]
2521 new_prompt_names = [prompt.name for prompt in prompts]
2523 # Update tools, resources, and prompts using helper methods
2524 tools_to_add = self._update_or_create_tools(db, tools, gateway, "rediscovery")
2525 resources_to_add = self._update_or_create_resources(db, resources, gateway, "rediscovery")
2526 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "rediscovery")
2528 # Log newly added items
2529 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add)
2530 if items_added > 0:
2531 if tools_to_add:
2532 logger.info(f"Added {len(tools_to_add)} new tools during gateway reactivation")
2533 if resources_to_add:
2534 logger.info(f"Added {len(resources_to_add)} new resources during gateway reactivation")
2535 if prompts_to_add:
2536 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway reactivation")
2537 logger.info(f"Total {items_added} new items added during gateway reactivation")
2539 # Count items before cleanup for logging
2541 # Bulk delete tools that are no longer available from the gateway
2542 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2543 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
2544 if stale_tool_ids:
2545 # Delete child records first to avoid FK constraint violations
2546 for i in range(0, len(stale_tool_ids), 500):
2547 chunk = stale_tool_ids[i : i + 500]
2548 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2549 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2550 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2552 # Bulk delete resources that are no longer available from the gateway
2553 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
2554 if stale_resource_ids:
2555 # Delete child records first to avoid FK constraint violations
2556 for i in range(0, len(stale_resource_ids), 500):
2557 chunk = stale_resource_ids[i : i + 500]
2558 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2559 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2560 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2561 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2563 # Bulk delete prompts that are no longer available from the gateway
2564 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
2565 if stale_prompt_ids:
2566 # Delete child records first to avoid FK constraint violations
2567 for i in range(0, len(stale_prompt_ids), 500):
2568 chunk = stale_prompt_ids[i : i + 500]
2569 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2570 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2571 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2573 # Expire gateway to clear cached relationships after bulk deletes
2574 # This prevents SQLAlchemy from trying to re-delete already-deleted items
2575 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
2576 db.expire(gateway)
2578 gateway.capabilities = capabilities
2580 # Register capabilities for notification-driven actions
2581 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
2583 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
2584 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows
2585 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows
2587 # Log cleanup results
2588 tools_removed = len(stale_tool_ids)
2589 resources_removed = len(stale_resource_ids)
2590 prompts_removed = len(stale_prompt_ids)
2592 if tools_removed > 0:
2593 logger.info(f"Removed {tools_removed} tools no longer available during gateway reactivation")
2594 if resources_removed > 0:
2595 logger.info(f"Removed {resources_removed} resources no longer available during gateway reactivation")
2596 if prompts_removed > 0:
2597 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway reactivation")
2599 gateway.last_seen = datetime.now(timezone.utc)
2601 # Add new items to database session in chunks to prevent lock escalation
2602 chunk_size = 50
2604 if tools_to_add:
2605 for i in range(0, len(tools_to_add), chunk_size):
2606 chunk = tools_to_add[i : i + chunk_size]
2607 db.add_all(chunk)
2608 db.flush()
2609 if resources_to_add:
2610 for i in range(0, len(resources_to_add), chunk_size):
2611 chunk = resources_to_add[i : i + chunk_size]
2612 db.add_all(chunk)
2613 db.flush()
2614 if prompts_to_add:
2615 for i in range(0, len(prompts_to_add), chunk_size):
2616 chunk = prompts_to_add[i : i + chunk_size]
2617 db.add_all(chunk)
2618 db.flush()
2619 except Exception as e:
2620 logger.warning(f"Failed to initialize reactivated gateway: {e}")
2621 else:
2622 self._active_gateways.discard(gateway.url)
2624 db.commit()
2625 db.refresh(gateway)
2627 # Invalidate cache after status change
2628 cache = _get_registry_cache()
2629 await cache.invalidate_gateways()
2631 # Notify Subscribers
2632 if not gateway.enabled:
2633 # Inactive
2634 await self._notify_gateway_deactivated(gateway)
2635 elif gateway.enabled and not gateway.reachable:
2636 # Offline (Enabled but Unreachable)
2637 await self._notify_gateway_offline(gateway)
2638 else:
2639 # Active (Enabled and Reachable)
2640 await self._notify_gateway_activated(gateway)
2642 # Bulk update tools - single UPDATE statement instead of N FOR UPDATE locks
2643 # This prevents lock contention under high concurrent load
2644 now = datetime.now(timezone.utc)
2645 if only_update_reachable:
2646 # Only update reachable status, keep enabled as-is
2647 tools_result = db.execute(update(DbTool).where(DbTool.gateway_id == gateway_id).where(DbTool.reachable != reachable).values(reachable=reachable, updated_at=now))
2648 else:
2649 # Update both enabled and reachable
2650 tools_result = db.execute(
2651 update(DbTool)
2652 .where(DbTool.gateway_id == gateway_id)
2653 .where(or_(DbTool.enabled != activate, DbTool.reachable != reachable))
2654 .values(enabled=activate, reachable=reachable, updated_at=now)
2655 )
2656 tools_updated = tools_result.rowcount
2658 # Commit tool updates
2659 if tools_updated > 0:
2660 db.commit()
2662 # Invalidate tools cache once after bulk update
2663 if tools_updated > 0:
2664 await cache.invalidate_tools()
2665 tool_lookup_cache = _get_tool_lookup_cache()
2666 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
2668 # Bulk update prompts when gateway is deactivated/activated (skip for reachability-only updates)
2669 prompts_updated = 0
2670 if not only_update_reachable:
2671 prompts_result = db.execute(update(DbPrompt).where(DbPrompt.gateway_id == gateway_id).where(DbPrompt.enabled != activate).values(enabled=activate, updated_at=now))
2672 prompts_updated = prompts_result.rowcount
2673 if prompts_updated > 0:
2674 db.commit()
2675 await cache.invalidate_prompts()
2677 # Bulk update resources when gateway is deactivated/activated (skip for reachability-only updates)
2678 resources_updated = 0
2679 if not only_update_reachable:
2680 resources_result = db.execute(update(DbResource).where(DbResource.gateway_id == gateway_id).where(DbResource.enabled != activate).values(enabled=activate, updated_at=now))
2681 resources_updated = resources_result.rowcount
2682 if resources_updated > 0:
2683 db.commit()
2684 await cache.invalidate_resources()
2686 logger.debug(f"Gateway {gateway.name} bulk state update: {tools_updated} tools, {prompts_updated} prompts, {resources_updated} resources")
2688 logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}")
2690 # Structured logging: Audit trail for gateway state change
2691 audit_trail.log_action(
2692 user_id=user_email or "system",
2693 action="set_gateway_state",
2694 resource_type="gateway",
2695 resource_id=str(gateway.id),
2696 resource_name=gateway.name,
2697 user_email=user_email,
2698 team_id=gateway.team_id,
2699 new_values={
2700 "enabled": gateway.enabled,
2701 "reachable": gateway.reachable,
2702 },
2703 context={
2704 "action": "activate" if activate else "deactivate",
2705 "only_update_reachable": only_update_reachable,
2706 },
2707 db=db,
2708 )
2710 # Structured logging: Log successful gateway state change
2711 structured_logger.log(
2712 level="INFO",
2713 message=f"Gateway {'activated' if activate else 'deactivated'} successfully",
2714 event_type="gateway_state_changed",
2715 component="gateway_service",
2716 user_email=user_email,
2717 team_id=gateway.team_id,
2718 resource_type="gateway",
2719 resource_id=str(gateway.id),
2720 custom_fields={
2721 "gateway_name": gateway.name,
2722 "enabled": gateway.enabled,
2723 "reachable": gateway.reachable,
2724 },
2725 )
2727 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked()
2729 except PermissionError as e:
2730 db.rollback()
2732 # Structured logging: Log permission error
2733 structured_logger.log(
2734 level="WARNING",
2735 message="Gateway state change failed due to permission error",
2736 event_type="gateway_state_change_permission_denied",
2737 component="gateway_service",
2738 user_email=user_email,
2739 resource_type="gateway",
2740 resource_id=gateway_id,
2741 error=e,
2742 )
2743 raise e
2744 except Exception as e:
2745 db.rollback()
2747 # Structured logging: Log generic gateway state change failure
2748 structured_logger.log(
2749 level="ERROR",
2750 message="Gateway state change failed",
2751 event_type="gateway_state_change_failed",
2752 component="gateway_service",
2753 user_email=user_email,
2754 resource_type="gateway",
2755 resource_id=gateway_id,
2756 error=e,
2757 )
2758 raise GatewayError(f"Failed to set gateway state: {str(e)}")
2760 async def _notify_gateway_updated(self, gateway: DbGateway) -> None:
2761 """
2762 Notify subscribers of gateway update.
2764 Args:
2765 gateway: Gateway to update
2766 """
2767 event = {
2768 "type": "gateway_updated",
2769 "data": {
2770 "id": gateway.id,
2771 "name": gateway.name,
2772 "url": gateway.url,
2773 "description": gateway.description,
2774 "enabled": gateway.enabled,
2775 },
2776 "timestamp": datetime.now(timezone.utc).isoformat(),
2777 }
2778 await self._publish_event(event)
2780 async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optional[str] = None) -> None:
2781 """
2782 Delete a gateway by its ID.
2784 Args:
2785 db: Database session
2786 gateway_id: Gateway ID
2787 user_email: Email of user performing deletion (for ownership check)
2789 Raises:
2790 GatewayNotFoundError: If the gateway is not found
2791 PermissionError: If user doesn't own the gateway
2792 GatewayError: For other deletion errors
2794 Examples:
2795 >>> from mcpgateway.services.gateway_service import GatewayService
2796 >>> from unittest.mock import MagicMock
2797 >>> service = GatewayService()
2798 >>> db = MagicMock()
2799 >>> gateway = MagicMock()
2800 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway
2801 >>> db.delete = MagicMock()
2802 >>> db.commit = MagicMock()
2803 >>> service._notify_gateway_deleted = MagicMock()
2804 >>> import asyncio
2805 >>> try:
2806 ... asyncio.run(service.delete_gateway(db, 'gateway_id', 'user@example.com'))
2807 ... except Exception:
2808 ... pass
2809 >>>
2810 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
2811 >>> asyncio.run(service._http_client.aclose())
2812 """
2813 try:
2814 # Find gateway with eager loading for deletion to avoid N+1 queries
2815 gateway = db.execute(
2816 select(DbGateway)
2817 .options(
2818 selectinload(DbGateway.tools),
2819 selectinload(DbGateway.resources),
2820 selectinload(DbGateway.prompts),
2821 joinedload(DbGateway.email_team),
2822 )
2823 .where(DbGateway.id == gateway_id)
2824 ).scalar_one_or_none()
2826 if not gateway:
2827 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2829 # Check ownership if user_email provided
2830 if user_email:
2831 # First-Party
2832 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2834 permission_service = PermissionService(db)
2835 if not await permission_service.check_resource_ownership(user_email, gateway):
2836 raise PermissionError("Only the owner can delete this gateway")
2838 # Store gateway info for notification before deletion
2839 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url}
2840 gateway_name = gateway.name
2841 gateway_team_id = gateway.team_id
2842 gateway_url = gateway.url # Store URL before expiring the object
2844 # Manually delete children first to avoid FK constraint violations
2845 # (passive_deletes=True means ORM won't auto-cascade, we must do it explicitly)
2846 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2847 tool_ids = [t.id for t in gateway.tools]
2848 resource_ids = [r.id for r in gateway.resources]
2849 prompt_ids = [p.id for p in gateway.prompts]
2851 # Delete tool children and tools
2852 if tool_ids:
2853 for i in range(0, len(tool_ids), 500):
2854 chunk = tool_ids[i : i + 500]
2855 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2856 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2857 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2859 # Delete resource children and resources
2860 if resource_ids:
2861 for i in range(0, len(resource_ids), 500):
2862 chunk = resource_ids[i : i + 500]
2863 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2864 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2865 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2866 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2868 # Delete prompt children and prompts
2869 if prompt_ids:
2870 for i in range(0, len(prompt_ids), 500):
2871 chunk = prompt_ids[i : i + 500]
2872 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2873 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2874 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2876 # Expire gateway to clear cached relationships after bulk deletes
2877 db.expire(gateway)
2879 # Use DELETE with rowcount check for database-agnostic atomic delete
2880 # (RETURNING is not supported on MySQL/MariaDB)
2881 stmt = delete(DbGateway).where(DbGateway.id == gateway_id)
2882 result = db.execute(stmt)
2883 if result.rowcount == 0:
2884 # Gateway was already deleted by another concurrent request
2885 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2887 db.commit()
2889 # Invalidate cache after successful deletion
2890 cache = _get_registry_cache()
2891 await cache.invalidate_gateways()
2892 tool_lookup_cache = _get_tool_lookup_cache()
2893 await tool_lookup_cache.invalidate_gateway(str(gateway_id))
2894 # Also invalidate tags cache since gateway tags may have changed
2895 # First-Party
2896 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2898 await admin_stats_cache.invalidate_tags()
2900 # Update tracking
2901 self._active_gateways.discard(gateway_url)
2903 # Notify subscribers
2904 await self._notify_gateway_deleted(gateway_info)
2906 logger.info(f"Permanently deleted gateway: {gateway_name}")
2908 # Structured logging: Audit trail for gateway deletion
2909 audit_trail.log_action(
2910 user_id=user_email or "system",
2911 action="delete_gateway",
2912 resource_type="gateway",
2913 resource_id=str(gateway_info["id"]),
2914 resource_name=gateway_name,
2915 user_email=user_email,
2916 team_id=gateway_team_id,
2917 old_values={
2918 "name": gateway_name,
2919 "url": gateway_info["url"],
2920 },
2921 db=db,
2922 )
2924 # Structured logging: Log successful gateway deletion
2925 structured_logger.log(
2926 level="INFO",
2927 message="Gateway deleted successfully",
2928 event_type="gateway_deleted",
2929 component="gateway_service",
2930 user_email=user_email,
2931 team_id=gateway_team_id,
2932 resource_type="gateway",
2933 resource_id=str(gateway_info["id"]),
2934 custom_fields={
2935 "gateway_name": gateway_name,
2936 "gateway_url": gateway_info["url"],
2937 },
2938 )
2940 except PermissionError as pe:
2941 db.rollback()
2943 # Structured logging: Log permission error
2944 structured_logger.log(
2945 level="WARNING",
2946 message="Gateway deletion failed due to permission error",
2947 event_type="gateway_delete_permission_denied",
2948 component="gateway_service",
2949 user_email=user_email,
2950 resource_type="gateway",
2951 resource_id=gateway_id,
2952 error=pe,
2953 )
2954 raise
2955 except Exception as e:
2956 db.rollback()
2958 # Structured logging: Log generic gateway deletion failure
2959 structured_logger.log(
2960 level="ERROR",
2961 message="Gateway deletion failed",
2962 event_type="gateway_deletion_failed",
2963 component="gateway_service",
2964 user_email=user_email,
2965 resource_type="gateway",
2966 resource_id=gateway_id,
2967 error=e,
2968 )
2969 raise GatewayError(f"Failed to delete gateway: {str(e)}")
2971 async def _handle_gateway_failure(self, gateway: DbGateway) -> None:
2972 """Tracks and handles gateway failures during health checks.
2973 If the failure count exceeds the threshold, the gateway is deactivated.
2975 Args:
2976 gateway: The gateway object that failed its health check.
2978 Returns:
2979 None
2981 Examples:
2982 >>> from mcpgateway.services.gateway_service import GatewayService
2983 >>> service = GatewayService()
2984 >>> gateway = type('Gateway', (), {
2985 ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True
2986 ... })()
2987 >>> service._gateway_failure_counts = {}
2988 >>> import asyncio
2989 >>> # Test failure counting
2990 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS
2991 >>> service._gateway_failure_counts['gw1'] >= 1
2992 True
2994 >>> # Test disabled gateway (no action)
2995 >>> gateway.enabled = False
2996 >>> old_count = service._gateway_failure_counts.get('gw1', 0)
2997 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS
2998 >>> service._gateway_failure_counts.get('gw1', 0) == old_count
2999 True
3000 """
3001 if GW_FAILURE_THRESHOLD == -1:
3002 return # Gateway failure action disabled
3004 if not gateway.enabled:
3005 return # No action needed for inactive gateways
3007 if not gateway.reachable:
3008 return # No action needed for unreachable gateways
3010 count = self._gateway_failure_counts.get(gateway.id, 0) + 1
3011 self._gateway_failure_counts[gateway.id] = count
3013 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).")
3015 if count >= GW_FAILURE_THRESHOLD:
3016 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...")
3017 with cast(Any, SessionLocal)() as db:
3018 await self.set_gateway_state(db, gateway.id, activate=True, reachable=False, only_update_reachable=True)
3019 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation
3021 async def check_health_of_gateways(self, gateways: List[DbGateway], user_email: Optional[str] = None) -> bool:
3022 """Check health of a batch of gateways.
3024 Performs an asynchronous health-check for each gateway in `gateways` using
3025 an Async HTTP client. The function handles different authentication
3026 modes (OAuth client_credentials and authorization_code, and non-OAuth
3027 auth headers). When a gateway uses the authorization_code flow, the
3028 optional `user_email` is used to look up stored user tokens with
3029 fresh_db_session(). On individual failures the service will record the
3030 failure and call internal failure handling which may mark a gateway
3031 unreachable or deactivate it after repeated failures. If a previously
3032 unreachable gateway becomes healthy again the service will attempt to
3033 update its reachable status.
3035 NOTE: This method intentionally does NOT take a db parameter.
3036 DB access uses fresh_db_session() only when needed, avoiding holding
3037 connections during HTTP calls to MCP servers.
3039 Args:
3040 gateways: List of DbGateway objects to check.
3041 user_email: Optional MCP gateway user email used to retrieve
3042 stored OAuth tokens for gateways using the
3043 "authorization_code" grant type. If not provided, authorization
3044 code flows that require a user token will be treated as failed.
3046 Returns:
3047 bool: True when the health-check batch completes. This return
3048 value indicates completion of the checks, not that every gateway
3049 was healthy. Individual gateway failures are handled internally
3050 (via _handle_gateway_failure and status updates).
3052 Examples:
3053 >>> from mcpgateway.services.gateway_service import GatewayService
3054 >>> from unittest.mock import MagicMock
3055 >>> service = GatewayService()
3056 >>> gateways = [MagicMock()]
3057 >>> gateways[0].ca_certificate = None
3058 >>> import asyncio
3059 >>> result = asyncio.run(service.check_health_of_gateways(gateways))
3060 >>> isinstance(result, bool)
3061 True
3063 >>> # Test empty gateway list
3064 >>> empty_result = asyncio.run(service.check_health_of_gateways([]))
3065 >>> empty_result
3066 True
3068 >>> # Test multiple gateways (basic smoke)
3069 >>> multiple_gateways = [MagicMock(), MagicMock(), MagicMock()]
3070 >>> for i, gw in enumerate(multiple_gateways):
3071 ... gw.name = f"gateway_{i}"
3072 ... gw.url = f"http://gateway{i}.example.com"
3073 ... gw.transport = "SSE"
3074 ... gw.enabled = True
3075 ... gw.reachable = True
3076 ... gw.auth_value = {}
3077 ... gw.ca_certificate = None
3078 >>> multi_result = asyncio.run(service.check_health_of_gateways(multiple_gateways))
3079 >>> isinstance(multi_result, bool)
3080 True
3081 """
3082 start_time = time.monotonic()
3083 concurrency_limit = min(settings.max_concurrent_health_checks, max(10, os.cpu_count() * 5)) # adaptive concurrency
3084 semaphore = asyncio.Semaphore(concurrency_limit)
3086 async def limited_check(gateway: DbGateway):
3087 """
3088 Checks the health of a single gateway while respecting a concurrency limit.
3090 This function checks the health of the given database gateway, ensuring that
3091 the number of concurrent checks does not exceed a predefined limit. The check
3092 is performed asynchronously and uses a semaphore to manage concurrency.
3094 Args:
3095 gateway (DbGateway): The database gateway whose health is to be checked.
3097 Raises:
3098 Any exceptions raised during the health check will be propagated to the caller.
3099 """
3100 async with semaphore:
3101 try:
3102 await asyncio.wait_for(
3103 self._check_single_gateway_health(gateway, user_email),
3104 timeout=settings.gateway_health_check_timeout,
3105 )
3106 except asyncio.TimeoutError:
3107 logger.warning(f"Gateway {getattr(gateway, 'name', 'unknown')} health check timed out after {settings.gateway_health_check_timeout}s")
3108 # Treat timeout as a failed health check
3109 await self._handle_gateway_failure(gateway)
3111 # Create trace span for health check batch
3112 with create_span("gateway.health_check_batch", {"gateway.count": len(gateways), "check.type": "health"}) as batch_span:
3113 # Chunk processing to avoid overload
3114 if not gateways:
3115 return True
3116 chunk_size = concurrency_limit
3117 for i in range(0, len(gateways), chunk_size):
3118 # batch will be a sublist of gateways from index i to i + chunk_size
3119 batch = gateways[i : i + chunk_size]
3121 # Each task is a health check for a gateway in the batch, excluding those with auth_type == "one_time_auth"
3122 tasks = [limited_check(gw) for gw in batch if gw.auth_type != "one_time_auth"]
3124 # Execute all health checks concurrently
3125 await asyncio.gather(*tasks, return_exceptions=True)
3126 await asyncio.sleep(0.05) # small pause prevents network saturation
3128 elapsed = time.monotonic() - start_time
3130 if batch_span:
3131 batch_span.set_attribute("check.duration_ms", int(elapsed * 1000))
3132 batch_span.set_attribute("check.completed", True)
3134 logger.debug(f"Health check batch completed for {len(gateways)} gateways in {elapsed:.2f}s")
3136 return True
3138 async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Optional[str] = None) -> None:
3139 """Check health of a single gateway.
3141 NOTE: This method intentionally does NOT take a db parameter.
3142 DB access uses fresh_db_session() only when needed, avoiding holding
3143 connections during HTTP calls to MCP servers.
3145 Args:
3146 gateway: Gateway to check (may be detached from session)
3147 user_email: Optional user email for OAuth token lookup
3148 """
3149 # Extract gateway data upfront (gateway may be detached from session)
3150 gateway_id = gateway.id
3151 gateway_name = gateway.name
3152 gateway_url = gateway.url
3153 gateway_transport = gateway.transport
3154 gateway_enabled = gateway.enabled
3155 gateway_reachable = gateway.reachable
3156 gateway_ca_certificate = gateway.ca_certificate
3157 gateway_ca_certificate_sig = gateway.ca_certificate_sig
3158 gateway_auth_type = gateway.auth_type
3159 gateway_oauth_config = gateway.oauth_config
3160 gateway_auth_value = gateway.auth_value
3161 gateway_auth_query_params = gateway.auth_query_params
3163 # Handle query_param auth - decrypt and apply to URL for health check
3164 auth_query_params_decrypted: Optional[Dict[str, str]] = None
3165 if gateway_auth_type == "query_param" and gateway_auth_query_params:
3166 auth_query_params_decrypted = {}
3167 for param_key, encrypted_value in gateway_auth_query_params.items():
3168 if encrypted_value:
3169 try:
3170 decrypted = decode_auth(encrypted_value)
3171 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3172 except Exception:
3173 logger.debug(f"Failed to decrypt query param '{param_key}' for health check")
3174 if auth_query_params_decrypted:
3175 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
3177 # Sanitize URL for logging/telemetry (redacts sensitive query params)
3178 gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted)
3180 # Create span for individual gateway health check
3181 with create_span(
3182 "gateway.health_check",
3183 {
3184 "gateway.name": gateway_name,
3185 "gateway.id": str(gateway_id),
3186 "gateway.url": gateway_url_sanitized,
3187 "gateway.transport": gateway_transport,
3188 "gateway.enabled": gateway_enabled,
3189 "http.method": "GET",
3190 "http.url": gateway_url_sanitized,
3191 },
3192 ) as span:
3193 valid = False
3194 if gateway_ca_certificate:
3195 if settings.enable_ed25519_signing:
3196 public_key_pem = settings.ed25519_public_key
3197 valid = validate_signature(gateway_ca_certificate.encode(), gateway_ca_certificate_sig, public_key_pem)
3198 else:
3199 valid = True
3200 if valid:
3201 ssl_context = self.create_ssl_context(gateway_ca_certificate)
3202 else:
3203 ssl_context = None
3205 def get_httpx_client_factory(
3206 headers: dict[str, str] | None = None,
3207 timeout: httpx.Timeout | None = None,
3208 auth: httpx.Auth | None = None,
3209 ) -> httpx.AsyncClient:
3210 """Factory function to create httpx.AsyncClient with optional CA certificate.
3212 Args:
3213 headers: Optional headers for the client
3214 timeout: Optional timeout for the client
3215 auth: Optional auth for the client
3217 Returns:
3218 httpx.AsyncClient: Configured HTTPX async client
3219 """
3220 return httpx.AsyncClient(
3221 verify=ssl_context if ssl_context else get_default_verify(),
3222 follow_redirects=True,
3223 headers=headers,
3224 timeout=timeout if timeout else get_http_timeout(),
3225 auth=auth,
3226 limits=httpx.Limits(
3227 max_connections=settings.httpx_max_connections,
3228 max_keepalive_connections=settings.httpx_max_keepalive_connections,
3229 keepalive_expiry=settings.httpx_keepalive_expiry,
3230 ),
3231 )
3233 # Use isolated client for gateway health checks (each gateway may have custom CA cert)
3234 # Use admin timeout for health checks (fail fast, don't wait 120s for slow upstreams)
3235 # Pass ssl_context if present, otherwise let get_isolated_http_client use skip_ssl_verify setting
3236 async with get_isolated_http_client(timeout=settings.httpx_admin_read_timeout, verify=ssl_context) as client:
3237 logger.debug(f"Checking health of gateway: {gateway_name} ({gateway_url_sanitized})")
3238 try:
3239 # Handle different authentication types
3240 headers = {}
3242 if gateway_auth_type == "oauth" and gateway_oauth_config:
3243 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
3245 if grant_type == "authorization_code":
3246 # For Authorization Code flow, try to get stored tokens
3247 try:
3248 # First-Party
3249 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3251 # Use fresh session for OAuth token lookup
3252 with fresh_db_session() as token_db:
3253 token_storage = TokenStorageService(token_db)
3255 # Get user-specific OAuth token
3256 if not user_email:
3257 if span:
3258 span.set_attribute("health.status", "unhealthy")
3259 span.set_attribute("error.message", "User email required for OAuth token")
3260 await self._handle_gateway_failure(gateway)
3261 return
3263 access_token = await token_storage.get_user_token(gateway_id, user_email)
3265 if access_token:
3266 headers["Authorization"] = f"Bearer {access_token}"
3267 else:
3268 if span:
3269 span.set_attribute("health.status", "unhealthy")
3270 span.set_attribute("error.message", "No valid OAuth token for user")
3271 await self._handle_gateway_failure(gateway)
3272 return
3273 except Exception as e:
3274 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
3275 if span:
3276 span.set_attribute("health.status", "unhealthy")
3277 span.set_attribute("error.message", "Failed to obtain stored OAuth token")
3278 await self._handle_gateway_failure(gateway)
3279 return
3280 else:
3281 # For Client Credentials flow, get token directly
3282 try:
3283 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
3284 headers["Authorization"] = f"Bearer {access_token}"
3285 except Exception as e:
3286 if span:
3287 span.set_attribute("health.status", "unhealthy")
3288 span.set_attribute("error.message", str(e))
3289 await self._handle_gateway_failure(gateway)
3290 return
3291 else:
3292 # Handle non-OAuth authentication (existing logic)
3293 auth_data = gateway_auth_value or {}
3294 if isinstance(auth_data, str):
3295 headers = decode_auth(auth_data)
3296 elif isinstance(auth_data, dict):
3297 headers = {str(k): str(v) for k, v in auth_data.items()}
3298 else:
3299 headers = {}
3301 # Perform the GET and raise on 4xx/5xx
3302 if (gateway_transport).lower() == "sse":
3303 timeout = httpx.Timeout(settings.health_check_timeout)
3304 async with client.stream("GET", gateway_url, headers=headers, timeout=timeout) as response:
3305 # This will raise immediately if status is 4xx/5xx
3306 response.raise_for_status()
3307 if span:
3308 span.set_attribute("http.status_code", response.status_code)
3309 elif (gateway_transport).lower() == "streamablehttp":
3310 # Use session pool if enabled for faster health checks
3311 use_pool = False
3312 pool = None
3313 if settings.mcp_session_pool_enabled:
3314 try:
3315 pool = get_mcp_session_pool()
3316 use_pool = True
3317 except RuntimeError:
3318 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3319 pass
3321 if use_pool and pool is not None:
3322 # Health checks are system operations, not user-driven.
3323 # Use system identity to isolate from user sessions.
3324 async with pool.session(
3325 url=gateway_url,
3326 headers=headers,
3327 transport_type=TransportType.STREAMABLE_HTTP,
3328 httpx_client_factory=get_httpx_client_factory,
3329 user_identity="_system_health_check",
3330 gateway_id=gateway_id,
3331 ) as pooled:
3332 # Optional explicit RPC verification (off by default for performance).
3333 # Pool's internal staleness check handles health via _validate_session.
3334 if settings.mcp_session_pool_explicit_health_rpc:
3335 await asyncio.wait_for(
3336 pooled.session.list_tools(),
3337 timeout=settings.health_check_timeout,
3338 )
3339 else:
3340 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout, httpx_client_factory=get_httpx_client_factory) as (
3341 read_stream,
3342 write_stream,
3343 _get_session_id,
3344 ):
3345 async with ClientSession(read_stream, write_stream) as session:
3346 # Initialize the session
3347 response = await session.initialize()
3349 # Reactivate gateway if it was previously inactive and health check passed now
3350 if gateway_enabled and not gateway_reachable:
3351 logger.info(f"Reactivating gateway: {gateway_name}, as it is healthy now")
3352 with cast(Any, SessionLocal)() as status_db:
3353 await self.set_gateway_state(status_db, gateway_id, activate=True, reachable=True, only_update_reachable=True)
3355 # Update last_seen with fresh session (gateway object is detached)
3356 try:
3357 with fresh_db_session() as update_db:
3358 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
3359 if db_gateway:
3360 db_gateway.last_seen = datetime.now(timezone.utc)
3361 update_db.commit()
3362 except Exception as update_error:
3363 logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}")
3365 # Auto-refresh tools/resources/prompts if enabled
3366 if settings.auto_refresh_servers:
3367 try:
3368 # Throttling: Check if refresh is needed based on last_refresh_at
3369 refresh_needed = True
3370 if gateway.last_refresh_at:
3371 # Default to config value if configured interval is missing
3373 last_refresh = gateway.last_refresh_at
3374 if last_refresh.tzinfo is None:
3375 last_refresh = last_refresh.replace(tzinfo=timezone.utc)
3377 # Use per-gateway interval if set, otherwise fall back to global default
3378 refresh_interval = getattr(settings, "gateway_auto_refresh_interval", 300)
3379 if gateway.refresh_interval_seconds is not None:
3380 refresh_interval = gateway.refresh_interval_seconds
3382 time_since_refresh = (datetime.now(timezone.utc) - last_refresh).total_seconds()
3384 if time_since_refresh < refresh_interval:
3385 refresh_needed = False
3386 logger.debug(f"Skipping auto-refresh for {gateway_name}: last refreshed {int(time_since_refresh)}s ago")
3388 if refresh_needed:
3389 # Locking: Try to acquire lock to avoid conflict with manual refresh
3390 lock = self._get_refresh_lock(gateway_id)
3391 if not lock.locked():
3392 # Acquire lock to prevent concurrent manual refresh
3393 async with lock:
3394 await self._refresh_gateway_tools_resources_prompts(
3395 gateway_id=gateway_id,
3396 _user_email=user_email,
3397 created_via="health_check",
3398 pre_auth_headers=headers if headers else None,
3399 gateway=gateway,
3400 )
3401 else:
3402 logger.debug(f"Skipping auto-refresh for {gateway_name}: lock held (likely manual refresh in progress)")
3403 except Exception as refresh_error:
3404 logger.warning(f"Failed to refresh tools for gateway {gateway_name}: {refresh_error}")
3406 if span:
3407 span.set_attribute("health.status", "healthy")
3408 span.set_attribute("success", True)
3410 except Exception as e:
3411 if span:
3412 span.set_attribute("health.status", "unhealthy")
3413 span.set_attribute("error.message", str(e))
3415 # Set the logger as debug as this check happens for each interval
3416 logger.debug(f"Health check failed for gateway {gateway_name}: {e}")
3417 await self._handle_gateway_failure(gateway)
3419 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
3420 """
3421 Aggregate capabilities across all gateways.
3423 Args:
3424 db: Database session
3426 Returns:
3427 Dictionary of aggregated capabilities
3429 Examples:
3430 >>> from mcpgateway.services.gateway_service import GatewayService
3431 >>> from unittest.mock import MagicMock
3432 >>> service = GatewayService()
3433 >>> db = MagicMock()
3434 >>> gateway_mock = MagicMock()
3435 >>> gateway_mock.capabilities = {"tools": {"listChanged": True}, "custom": {"feature": True}}
3436 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_mock]
3437 >>> import asyncio
3438 >>> result = asyncio.run(service.aggregate_capabilities(db))
3439 >>> isinstance(result, dict)
3440 True
3441 >>> 'prompts' in result
3442 True
3443 >>> 'resources' in result
3444 True
3445 >>> 'tools' in result
3446 True
3447 >>> 'logging' in result
3448 True
3449 >>> result['prompts']['listChanged']
3450 True
3451 >>> result['resources']['subscribe']
3452 True
3453 >>> result['resources']['listChanged']
3454 True
3455 >>> result['tools']['listChanged']
3456 True
3457 >>> isinstance(result['logging'], dict)
3458 True
3460 >>> # Test with no gateways
3461 >>> db.execute.return_value.scalars.return_value.all.return_value = []
3462 >>> empty_result = asyncio.run(service.aggregate_capabilities(db))
3463 >>> isinstance(empty_result, dict)
3464 True
3465 >>> 'tools' in empty_result
3466 True
3468 >>> # Test capability merging
3469 >>> gateway1 = MagicMock()
3470 >>> gateway1.capabilities = {"tools": {"feature1": True}}
3471 >>> gateway2 = MagicMock()
3472 >>> gateway2.capabilities = {"tools": {"feature2": True}}
3473 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway1, gateway2]
3474 >>> merged_result = asyncio.run(service.aggregate_capabilities(db))
3475 >>> merged_result['tools']['listChanged'] # Default capability
3476 True
3477 """
3478 capabilities = {
3479 "prompts": {"listChanged": True},
3480 "resources": {"subscribe": True, "listChanged": True},
3481 "tools": {"listChanged": True},
3482 "logging": {},
3483 }
3485 # Get all active gateways
3486 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
3488 # Combine capabilities
3489 for gateway in gateways:
3490 if gateway.capabilities:
3491 for key, value in gateway.capabilities.items():
3492 if key not in capabilities:
3493 capabilities[key] = value
3494 elif isinstance(value, dict):
3495 capabilities[key].update(value)
3497 return capabilities
3499 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
3500 """Subscribe to gateway events.
3502 Creates a new event queue and subscribes to gateway events. Events are
3503 yielded as they are published. The subscription is automatically cleaned
3504 up when the generator is closed or goes out of scope.
3506 Yields:
3507 Dict[str, Any]: Gateway event messages with 'type', 'data', and 'timestamp' fields
3509 Examples:
3510 >>> service = GatewayService()
3511 >>> import asyncio
3512 >>> from unittest.mock import MagicMock
3513 >>> # Create a mock async generator for the event service
3514 >>> async def mock_event_gen():
3515 ... yield {"type": "test_event", "data": "payload"}
3516 >>>
3517 >>> # Mock the event service to return our generator
3518 >>> service._event_service = MagicMock()
3519 >>> service._event_service.subscribe_events.return_value = mock_event_gen()
3520 >>>
3521 >>> # Test the subscription
3522 >>> async def test_sub():
3523 ... async for event in service.subscribe_events():
3524 ... return event
3525 >>>
3526 >>> result = asyncio.run(test_sub())
3527 >>> result
3528 {'type': 'test_event', 'data': 'payload'}
3529 """
3530 async for event in self._event_service.subscribe_events():
3531 yield event
3533 async def _initialize_gateway(
3534 self,
3535 url: str,
3536 authentication: Optional[Dict[str, str]] = None,
3537 transport: str = "SSE",
3538 auth_type: Optional[str] = None,
3539 oauth_config: Optional[Dict[str, Any]] = None,
3540 ca_certificate: Optional[bytes] = None,
3541 pre_auth_headers: Optional[Dict[str, str]] = None,
3542 include_resources: bool = True,
3543 include_prompts: bool = True,
3544 auth_query_params: Optional[Dict[str, str]] = None,
3545 oauth_auto_fetch_tool_flag: Optional[bool] = False,
3546 ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
3547 """Initialize connection to a gateway and retrieve its capabilities.
3549 Connects to an MCP gateway using the specified transport protocol,
3550 performs the MCP handshake, and retrieves capabilities, tools,
3551 resources, and prompts from the gateway.
3553 Args:
3554 url: Gateway URL to connect to
3555 authentication: Optional authentication headers for the connection
3556 transport: Transport protocol - "SSE" or "StreamableHTTP"
3557 auth_type: Authentication type - "basic", "bearer", "authheaders", "oauth", "query_param" or None
3558 oauth_config: OAuth configuration if auth_type is "oauth"
3559 ca_certificate: CA certificate for SSL verification
3560 pre_auth_headers: Pre-authenticated headers to skip OAuth token fetch (for reuse)
3561 include_resources: Whether to include resources in the fetch
3562 include_prompts: Whether to include prompts in the fetch
3563 auth_query_params: Query param names for URL sanitization in error logs (decrypted values)
3564 oauth_auto_fetch_tool_flag: Whether to skip the early return for OAuth Authorization Code flow.
3565 When False (default), auth_code gateways return empty lists immediately (for health checks).
3566 When True, attempts to connect even for auth_code gateways (for activation after user authorization).
3568 Returns:
3569 tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
3570 Capabilities dictionary, list of ToolCreate objects, list of ResourceCreate objects, and list of PromptCreate objects
3572 Raises:
3573 GatewayConnectionError: If connection or initialization fails
3575 Examples:
3576 >>> service = GatewayService()
3577 >>> # Test parameter validation
3578 >>> import asyncio
3579 >>> from unittest.mock import AsyncMock
3580 >>> # Avoid opening a real SSE connection in doctests (it can leak anyio streams on failure paths)
3581 >>> service.connect_to_sse_server = AsyncMock(side_effect=GatewayConnectionError("boom"))
3582 >>> async def test_params():
3583 ... try:
3584 ... await service._initialize_gateway("hello//")
3585 ... except Exception as e:
3586 ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e)
3588 >>> asyncio.run(test_params())
3589 True
3591 >>> # Test default parameters
3592 >>> hasattr(service, '_initialize_gateway')
3593 True
3594 >>> import inspect
3595 >>> sig = inspect.signature(service._initialize_gateway)
3596 >>> sig.parameters['transport'].default
3597 'SSE'
3598 >>> sig.parameters['authentication'].default is None
3599 True
3600 >>>
3601 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
3602 >>> asyncio.run(service._http_client.aclose())
3603 """
3604 try:
3605 if authentication is None:
3606 authentication = {}
3608 # Use pre-authenticated headers if provided (avoids duplicate OAuth token fetch)
3609 if pre_auth_headers:
3610 authentication = pre_auth_headers
3611 # Handle OAuth authentication
3612 elif auth_type == "oauth" and oauth_config:
3613 grant_type = oauth_config.get("grant_type", "client_credentials")
3615 if grant_type == "authorization_code":
3616 if not oauth_auto_fetch_tool_flag:
3617 # For Authorization Code flow during health checks, we can't initialize immediately
3618 # because we need user consent. Just store the configuration
3619 # and let the user complete the OAuth flow later.
3620 logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""")
3621 # Don't try to get access token here - it will be obtained during tool invocation
3622 authentication = {}
3624 # Skip MCP server connection for Authorization Code flow
3625 # Tools will be fetched after OAuth completion
3626 return {}, [], [], []
3627 # When flag is True (activation), skip token fetch but try to connect
3628 # This allows activation to proceed - actual auth happens during tool invocation
3629 logger.debug("OAuth Authorization Code gateway activation - skipping token fetch")
3630 elif grant_type == "client_credentials":
3631 # For Client Credentials flow, we can get the token immediately
3632 try:
3633 logger.debug("Obtaining OAuth access token for Client Credentials flow")
3634 access_token = await self.oauth_manager.get_access_token(oauth_config)
3635 authentication = {"Authorization": f"Bearer {access_token}"}
3636 except Exception as e:
3637 logger.error(f"Failed to obtain OAuth access token: {e}")
3638 raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}")
3640 capabilities = {}
3641 tools = []
3642 resources = []
3643 prompts = []
3644 if auth_type in ("basic", "bearer", "authheaders") and isinstance(authentication, str):
3645 authentication = decode_auth(authentication)
3646 if transport.lower() == "sse":
3647 capabilities, tools, resources, prompts = await self.connect_to_sse_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params)
3648 elif transport.lower() == "streamablehttp":
3649 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params)
3651 return capabilities, tools, resources, prompts
3652 except Exception as e:
3654 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3655 root_cause = e
3656 if isinstance(e, BaseExceptionGroup):
3657 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3658 root_cause = root_cause.exceptions[0]
3659 sanitized_url = sanitize_url_for_logging(url, auth_query_params)
3660 sanitized_error = sanitize_exception_message(str(root_cause), auth_query_params)
3661 logger.error(f"Gateway initialization failed for {sanitized_url}: {sanitized_error}", exc_info=True)
3662 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: {sanitized_error}")
3664 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:
3665 """Sync function for database operations (runs in thread).
3667 Args:
3668 include_inactive: Whether to include inactive gateways
3670 Returns:
3671 List[DbGateway]: List of active gateways
3673 Examples:
3674 >>> from unittest.mock import patch, MagicMock
3675 >>> service = GatewayService()
3676 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session:
3677 ... mock_db = MagicMock()
3678 ... mock_session.return_value.__enter__.return_value = mock_db
3679 ... mock_db.execute.return_value.scalars.return_value.all.return_value = []
3680 ... result = service._get_gateways()
3681 ... isinstance(result, list)
3682 True
3684 >>> # Test include_inactive parameter handling
3685 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session:
3686 ... mock_db = MagicMock()
3687 ... mock_session.return_value.__enter__.return_value = mock_db
3688 ... mock_db.execute.return_value.scalars.return_value.all.return_value = []
3689 ... result_active_only = service._get_gateways(include_inactive=False)
3690 ... isinstance(result_active_only, list)
3691 True
3692 """
3693 with cast(Any, SessionLocal)() as db:
3694 if include_inactive:
3695 return db.execute(select(DbGateway)).scalars().all()
3696 # Only return active gateways
3697 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
3699 def get_first_gateway_by_url(self, db: Session, url: str, team_id: Optional[str] = None, include_inactive: bool = False) -> Optional[GatewayRead]:
3700 """Return the first DbGateway matching the given URL and optional team_id.
3702 This is a synchronous helper intended for use from request handlers where
3703 a simple DB lookup is needed. It normalizes the provided URL similar to
3704 how gateways are stored and matches by the `url` column. If team_id is
3705 provided, it restricts the search to that team.
3707 Args:
3708 db: Database session to use for the query
3709 url: Gateway base URL to match (will be normalized)
3710 team_id: Optional team id to restrict search
3711 include_inactive: Whether to include inactive gateways
3713 Returns:
3714 Optional[DbGateway]: First matching gateway or None
3715 """
3716 query = select(DbGateway).where(DbGateway.url == url)
3717 if not include_inactive:
3718 query = query.where(DbGateway.enabled)
3719 if team_id:
3720 query = query.where(DbGateway.team_id == team_id)
3721 result = db.execute(query).scalars().first()
3722 # Wrap the DB object in the GatewayRead schema for consistency with
3723 # other service methods. Return None if no match found.
3724 if result is None:
3725 return None
3726 return GatewayRead.model_validate(self._prepare_gateway_for_read(result)).masked()
3728 async def _run_leader_heartbeat(self) -> None:
3729 """Run leader heartbeat loop to keep leader key alive.
3731 This runs independently from health checks to ensure the leader key
3732 is refreshed frequently enough (every redis_leader_heartbeat_interval seconds)
3733 to prevent expiration during long-running health check operations.
3735 The loop exits if this instance loses leadership.
3736 """
3737 while True:
3738 try:
3739 await asyncio.sleep(self._leader_heartbeat_interval)
3741 if not self._redis_client:
3742 return
3744 # Check if we're still the leader
3745 current_leader = await self._redis_client.get(self._leader_key)
3746 if current_leader != self._instance_id:
3747 logger.info("Lost Redis leadership, stopping heartbeat")
3748 return
3750 # Refresh the leader key TTL
3751 await self._redis_client.expire(self._leader_key, self._leader_ttl)
3752 logger.debug(f"Leader heartbeat: refreshed TTL to {self._leader_ttl}s")
3754 except Exception as e:
3755 logger.warning(f"Leader heartbeat error: {e}")
3756 # Continue trying - the main health check loop will handle leadership loss
3758 async def _run_health_checks(self, user_email: str) -> None:
3759 """Run health checks periodically,
3760 Uses Redis or FileLock - for multiple workers.
3761 Uses simple health check for single worker mode.
3763 NOTE: This method intentionally does NOT take a db parameter.
3764 Health checks use fresh_db_session() only when DB access is needed,
3765 avoiding holding connections during HTTP calls to MCP servers.
3767 Args:
3768 user_email: Email of the user for OAuth token lookup
3770 Examples:
3771 >>> service = GatewayService()
3772 >>> service._health_check_interval = 0.1 # Short interval for testing
3773 >>> service._redis_client = None
3774 >>> import asyncio
3775 >>> # Test that method exists and is callable
3776 >>> callable(service._run_health_checks)
3777 True
3778 >>> # Test setup without actual execution (would run forever)
3779 >>> hasattr(service, '_health_check_interval')
3780 True
3781 >>> service._health_check_interval == 0.1
3782 True
3783 """
3785 while True:
3786 try:
3787 if self._redis_client and settings.cache_type == "redis":
3788 # Redis-based leader check (async, decode_responses=True returns strings)
3789 # Note: Leader key TTL refresh is handled by _run_leader_heartbeat task
3790 current_leader = await self._redis_client.get(self._leader_key)
3791 if current_leader != self._instance_id:
3792 return
3794 # Run health checks
3795 gateways = await asyncio.to_thread(self._get_gateways)
3796 if gateways:
3797 await self.check_health_of_gateways(gateways, user_email)
3799 await asyncio.sleep(self._health_check_interval)
3801 elif settings.cache_type == "none":
3802 try:
3803 # For single worker mode, run health checks directly
3804 gateways = await asyncio.to_thread(self._get_gateways)
3805 if gateways:
3806 await self.check_health_of_gateways(gateways, user_email)
3807 except Exception as e:
3808 logger.error(f"Health check run failed: {str(e)}")
3810 await asyncio.sleep(self._health_check_interval)
3812 else:
3813 # FileLock-based leader fallback
3814 try:
3815 self._file_lock.acquire(timeout=0)
3816 logger.info("File lock acquired. Running health checks.")
3818 while True:
3819 gateways = await asyncio.to_thread(self._get_gateways)
3820 if gateways:
3821 await self.check_health_of_gateways(gateways, user_email)
3822 await asyncio.sleep(self._health_check_interval)
3824 except Timeout:
3825 logger.debug("File lock already held. Retrying later.")
3826 await asyncio.sleep(self._health_check_interval)
3828 except Exception as e:
3829 logger.error(f"FileLock health check failed: {str(e)}")
3831 finally:
3832 if self._file_lock.is_locked:
3833 try:
3834 self._file_lock.release()
3835 logger.info("Released file lock.")
3836 except Exception as e:
3837 logger.warning(f"Failed to release file lock: {str(e)}")
3839 except Exception as e:
3840 logger.error(f"Unexpected error in health check loop: {str(e)}")
3841 await asyncio.sleep(self._health_check_interval)
3843 def _get_auth_headers(self) -> Dict[str, str]:
3844 """Get default headers for gateway requests (no authentication).
3846 SECURITY: This method intentionally does NOT include authentication credentials.
3847 Each gateway should have its own auth_value configured. Never send this gateway's
3848 admin credentials to remote servers.
3850 Returns:
3851 dict: Default headers without authentication
3853 Examples:
3854 >>> service = GatewayService()
3855 >>> headers = service._get_auth_headers()
3856 >>> isinstance(headers, dict)
3857 True
3858 >>> 'Content-Type' in headers
3859 True
3860 >>> headers['Content-Type']
3861 'application/json'
3862 >>> 'Authorization' not in headers # No credentials leaked
3863 True
3864 """
3865 return {"Content-Type": "application/json"}
3867 async def _notify_gateway_added(self, gateway: DbGateway) -> None:
3868 """Notify subscribers of gateway addition.
3870 Args:
3871 gateway: Gateway to add
3872 """
3873 event = {
3874 "type": "gateway_added",
3875 "data": {
3876 "id": gateway.id,
3877 "name": gateway.name,
3878 "url": gateway.url,
3879 "description": gateway.description,
3880 "enabled": gateway.enabled,
3881 },
3882 "timestamp": datetime.now(timezone.utc).isoformat(),
3883 }
3884 await self._publish_event(event)
3886 async def _notify_gateway_activated(self, gateway: DbGateway) -> None:
3887 """Notify subscribers of gateway activation.
3889 Args:
3890 gateway: Gateway to activate
3891 """
3892 event = {
3893 "type": "gateway_activated",
3894 "data": {
3895 "id": gateway.id,
3896 "name": gateway.name,
3897 "url": gateway.url,
3898 "enabled": gateway.enabled,
3899 "reachable": gateway.reachable,
3900 },
3901 "timestamp": datetime.now(timezone.utc).isoformat(),
3902 }
3903 await self._publish_event(event)
3905 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None:
3906 """Notify subscribers of gateway deactivation.
3908 Args:
3909 gateway: Gateway database object
3910 """
3911 event = {
3912 "type": "gateway_deactivated",
3913 "data": {
3914 "id": gateway.id,
3915 "name": gateway.name,
3916 "url": gateway.url,
3917 "enabled": gateway.enabled,
3918 "reachable": gateway.reachable,
3919 },
3920 "timestamp": datetime.now(timezone.utc).isoformat(),
3921 }
3922 await self._publish_event(event)
3924 async def _notify_gateway_offline(self, gateway: DbGateway) -> None:
3925 """
3926 Notify subscribers that gateway is offline (Enabled but Unreachable).
3928 Args:
3929 gateway: Gateway database object
3930 """
3931 event = {
3932 "type": "gateway_offline",
3933 "data": {
3934 "id": gateway.id,
3935 "name": gateway.name,
3936 "url": gateway.url,
3937 "enabled": True,
3938 "reachable": False,
3939 },
3940 "timestamp": datetime.now(timezone.utc).isoformat(),
3941 }
3942 await self._publish_event(event)
3944 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None:
3945 """Notify subscribers of gateway deletion.
3947 Args:
3948 gateway_info: Dict containing information about gateway to delete
3949 """
3950 event = {
3951 "type": "gateway_deleted",
3952 "data": gateway_info,
3953 "timestamp": datetime.now(timezone.utc).isoformat(),
3954 }
3955 await self._publish_event(event)
3957 async def _notify_gateway_removed(self, gateway: DbGateway) -> None:
3958 """Notify subscribers of gateway removal (deactivation).
3960 Args:
3961 gateway: Gateway to remove
3962 """
3963 event = {
3964 "type": "gateway_removed",
3965 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled},
3966 "timestamp": datetime.now(timezone.utc).isoformat(),
3967 }
3968 await self._publish_event(event)
3970 def convert_gateway_to_read(self, gateway: DbGateway) -> GatewayRead:
3971 """Convert a DbGateway instance to a GatewayRead Pydantic model.
3973 Args:
3974 gateway: Gateway database object
3976 Returns:
3977 GatewayRead: Pydantic model instance
3978 """
3979 gateway_dict = gateway.__dict__.copy()
3980 gateway_dict.pop("_sa_instance_state", None)
3982 # Ensure auth_value is properly encoded
3983 if isinstance(gateway.auth_value, dict):
3984 gateway_dict["auth_value"] = encode_auth(gateway.auth_value)
3986 if gateway.tags:
3987 # Check tags are list of strings or list of Dict[str, str]
3988 if isinstance(gateway.tags[0], str):
3989 # Convert tags from List[str] to List[Dict[str, str]] for GatewayRead
3990 gateway_dict["tags"] = validate_tags_field(gateway.tags)
3991 else:
3992 gateway_dict["tags"] = gateway.tags
3993 else:
3994 gateway_dict["tags"] = []
3996 # Include metadata fields
3997 gateway_dict["created_by"] = getattr(gateway, "created_by", None)
3998 gateway_dict["modified_by"] = getattr(gateway, "modified_by", None)
3999 gateway_dict["created_at"] = getattr(gateway, "created_at", None)
4000 gateway_dict["updated_at"] = getattr(gateway, "updated_at", None)
4001 gateway_dict["version"] = getattr(gateway, "version", None)
4002 gateway_dict["team"] = getattr(gateway, "team", None)
4004 return GatewayRead.model_validate(gateway_dict).masked()
4006 def _prepare_gateway_for_read(self, gateway: DbGateway) -> DbGateway:
4007 """DEPRECATED: Use convert_gateway_to_read instead.
4009 Prepare a gateway object for GatewayRead validation.
4011 Ensures auth_value is in the correct format (encoded string) for the schema.
4012 Converts legacy List[str] tags to List[Dict[str, str]] format for GatewayRead schema.
4014 Args:
4015 gateway: Gateway database object
4017 Returns:
4018 Gateway object with properly formatted auth_value and tags
4019 """
4020 # If auth_value is a dict, encode it to string for GatewayRead schema
4021 if isinstance(gateway.auth_value, dict):
4022 gateway.auth_value = encode_auth(gateway.auth_value)
4024 # Handle legacy List[str] tags - convert to List[Dict[str, str]] for GatewayRead schema
4025 if gateway.tags:
4026 if isinstance(gateway.tags[0], str):
4027 # Legacy format: convert to dict format
4028 gateway.tags = validate_tags_field(gateway.tags)
4030 return gateway
4032 def _create_db_tool(
4033 self,
4034 tool: ToolCreate,
4035 gateway: DbGateway,
4036 created_by: Optional[str] = None,
4037 created_from_ip: Optional[str] = None,
4038 created_via: Optional[str] = None,
4039 created_user_agent: Optional[str] = None,
4040 ) -> DbTool:
4041 """Create a DbTool with consistent federation metadata across all scenarios.
4043 Args:
4044 tool: Tool creation schema
4045 gateway: Gateway database object
4046 created_by: Username who created/updated this tool
4047 created_from_ip: IP address of creator
4048 created_via: Creation method (ui, api, federation, rediscovery)
4049 created_user_agent: User agent of creation request
4051 Returns:
4052 DbTool: Consistently configured database tool object
4053 """
4054 return DbTool(
4055 original_name=tool.name,
4056 custom_name=tool.name,
4057 custom_name_slug=slugify(tool.name),
4058 display_name=generate_display_name(tool.name),
4059 url=gateway.url,
4060 original_description=tool.description,
4061 description=tool.description,
4062 integration_type="MCP", # Gateway-discovered tools are MCP type
4063 request_type=tool.request_type,
4064 headers=tool.headers,
4065 input_schema=tool.input_schema,
4066 annotations=tool.annotations,
4067 jsonpath_filter=tool.jsonpath_filter,
4068 auth_type=gateway.auth_type,
4069 auth_value=encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
4070 # Federation metadata - consistent across all scenarios
4071 created_by=created_by or "system",
4072 created_from_ip=created_from_ip,
4073 created_via=created_via or "federation",
4074 created_user_agent=created_user_agent,
4075 federation_source=gateway.name,
4076 version=1,
4077 # Inherit team assignment and visibility from gateway
4078 team_id=gateway.team_id,
4079 owner_email=gateway.owner_email,
4080 visibility="public", # Federated tools should be public for discovery
4081 )
4083 def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGateway, created_via: str) -> List[DbTool]:
4084 """Helper to handle update-or-create logic for tools from MCP server.
4086 Args:
4087 db: Database session
4088 tools: List of tools from MCP server
4089 gateway: Gateway object
4090 created_via: String indicating creation source ("oauth", "update", etc.)
4092 Returns:
4093 List of new tools to be added to the database
4094 """
4095 if not tools:
4096 return []
4098 tools_to_add = []
4100 # Batch fetch all existing tools for this gateway
4101 tool_names = [tool.name for tool in tools if tool is not None]
4102 if not tool_names:
4103 return []
4105 existing_tools_query = select(DbTool).where(DbTool.gateway_id == gateway.id, DbTool.original_name.in_(tool_names))
4106 existing_tools = db.execute(existing_tools_query).scalars().all()
4107 existing_tools_map = {tool.original_name: tool for tool in existing_tools}
4109 for tool in tools:
4110 if tool is None:
4111 logger.warning("Skipping None tool in tools list")
4112 continue
4114 try:
4115 # Check if tool already exists for this gateway from the tools_map
4116 existing_tool = existing_tools_map.get(tool.name)
4117 if existing_tool:
4118 # Update existing tool if there are changes
4119 fields_to_update = False
4121 # Check basic field changes
4122 # Compare against original_description (upstream value) rather than description
4123 # (which may have been customized by the user)
4124 basic_fields_changed = (
4125 existing_tool.url != gateway.url
4126 or existing_tool.original_description != tool.description
4127 or existing_tool.integration_type != "MCP"
4128 or existing_tool.request_type != tool.request_type
4129 )
4131 # Check schema and configuration changes
4132 schema_fields_changed = (
4133 existing_tool.headers != tool.headers
4134 or existing_tool.input_schema != tool.input_schema
4135 or existing_tool.output_schema != tool.output_schema
4136 or existing_tool.jsonpath_filter != tool.jsonpath_filter
4137 )
4139 # Check authentication and visibility changes.
4140 # DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict).
4141 # encode_auth() uses a random nonce, so comparing ciphertext would always
4142 # differ even when the plaintext hasn't changed. Compare on decoded
4143 # (plaintext) values instead, and only encode on the write path.
4144 # If decoding fails (legacy/corrupt data), fall back to direct comparison.
4145 try:
4146 gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {})
4147 existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {}
4148 auth_value_changed = existing_tool_auth_plain != gateway_auth_plain
4149 except Exception:
4150 gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
4151 auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value
4152 auth_fields_changed = existing_tool.auth_type != gateway.auth_type or auth_value_changed or existing_tool.visibility != gateway.visibility
4154 if basic_fields_changed or schema_fields_changed or auth_fields_changed:
4155 fields_to_update = True
4156 if fields_to_update:
4157 existing_tool.url = gateway.url
4158 # Only overwrite user-facing description if it hasn't been customized
4159 # (mirrors original_name/custom_name pattern)
4160 if existing_tool.description == existing_tool.original_description:
4161 existing_tool.description = tool.description
4162 existing_tool.original_description = tool.description
4163 existing_tool.integration_type = "MCP"
4164 existing_tool.request_type = tool.request_type
4165 existing_tool.headers = tool.headers
4166 existing_tool.input_schema = tool.input_schema
4167 existing_tool.output_schema = tool.output_schema
4168 existing_tool.jsonpath_filter = tool.jsonpath_filter
4169 existing_tool.auth_type = gateway.auth_type
4170 existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
4171 existing_tool.visibility = gateway.visibility
4172 logger.debug(f"Updated existing tool: {tool.name}")
4173 else:
4174 # Create new tool if it doesn't exist
4175 db_tool = self._create_db_tool(
4176 tool=tool,
4177 gateway=gateway,
4178 created_by="system",
4179 created_via=created_via,
4180 )
4181 # Attach relationship to avoid NoneType during flush
4182 db_tool.gateway = gateway
4183 tools_to_add.append(db_tool)
4184 logger.debug(f"Created new tool: {tool.name}")
4185 except Exception as e:
4186 logger.warning(f"Failed to process tool {getattr(tool, 'name', 'unknown')}: {e}")
4187 continue
4189 return tools_to_add
4191 def _update_or_create_resources(self, db: Session, resources: List[Any], gateway: DbGateway, created_via: str) -> List[DbResource]:
4192 """Helper to handle update-or-create logic for resources from MCP server.
4194 Args:
4195 db: Database session
4196 resources: List of resources from MCP server
4197 gateway: Gateway object
4198 created_via: String indicating creation source ("oauth", "update", etc.)
4200 Returns:
4201 List of new resources to be added to the database
4202 """
4203 if not resources:
4204 return []
4206 resources_to_add = []
4208 # Batch fetch all existing resources for this gateway
4209 resource_uris = [resource.uri for resource in resources if resource is not None]
4210 if not resource_uris:
4211 return []
4213 existing_resources_query = select(DbResource).where(DbResource.gateway_id == gateway.id, DbResource.uri.in_(resource_uris))
4214 existing_resources = db.execute(existing_resources_query).scalars().all()
4215 existing_resources_map = {resource.uri: resource for resource in existing_resources}
4217 for resource in resources:
4218 if resource is None:
4219 logger.warning("Skipping None resource in resources list")
4220 continue
4222 try:
4223 # Check if resource already exists for this gateway from the resources_map
4224 existing_resource = existing_resources_map.get(resource.uri)
4226 if existing_resource:
4227 # Update existing resource if there are changes
4228 fields_to_update = False
4230 if (
4231 existing_resource.name != resource.name
4232 or existing_resource.description != resource.description
4233 or existing_resource.mime_type != resource.mime_type
4234 or existing_resource.uri_template != resource.uri_template
4235 or existing_resource.visibility != gateway.visibility
4236 ):
4237 fields_to_update = True
4239 if fields_to_update:
4240 existing_resource.name = resource.name
4241 existing_resource.description = resource.description
4242 existing_resource.mime_type = resource.mime_type
4243 existing_resource.uri_template = resource.uri_template
4244 existing_resource.visibility = gateway.visibility
4245 logger.debug(f"Updated existing resource: {resource.uri}")
4246 else:
4247 # Create new resource if it doesn't exist
4248 db_resource = DbResource(
4249 uri=resource.uri,
4250 name=resource.name,
4251 description=resource.description,
4252 mime_type=resource.mime_type,
4253 uri_template=resource.uri_template,
4254 gateway_id=gateway.id,
4255 created_by="system",
4256 created_via=created_via,
4257 visibility=gateway.visibility,
4258 )
4259 resources_to_add.append(db_resource)
4260 logger.debug(f"Created new resource: {resource.uri}")
4261 except Exception as e:
4262 logger.warning(f"Failed to process resource {getattr(resource, 'uri', 'unknown')}: {e}")
4263 continue
4265 return resources_to_add
4267 def _update_or_create_prompts(self, db: Session, prompts: List[Any], gateway: DbGateway, created_via: str) -> List[DbPrompt]:
4268 """Helper to handle update-or-create logic for prompts from MCP server.
4270 Args:
4271 db: Database session
4272 prompts: List of prompts from MCP server
4273 gateway: Gateway object
4274 created_via: String indicating creation source ("oauth", "update", etc.)
4276 Returns:
4277 List of new prompts to be added to the database
4278 """
4279 if not prompts:
4280 return []
4282 prompts_to_add = []
4284 # Batch fetch all existing prompts for this gateway
4285 prompt_names = [prompt.name for prompt in prompts if prompt is not None]
4286 if not prompt_names:
4287 return []
4289 existing_prompts_query = select(DbPrompt).where(DbPrompt.gateway_id == gateway.id, DbPrompt.original_name.in_(prompt_names))
4290 existing_prompts = db.execute(existing_prompts_query).scalars().all()
4291 existing_prompts_map = {prompt.original_name: prompt for prompt in existing_prompts}
4293 for prompt in prompts:
4294 if prompt is None:
4295 logger.warning("Skipping None prompt in prompts list")
4296 continue
4298 try:
4299 # Check if resource already exists for this gateway from the prompts_map
4300 existing_prompt = existing_prompts_map.get(prompt.name)
4302 if existing_prompt:
4303 # Update existing prompt if there are changes
4304 fields_to_update = False
4306 if (
4307 existing_prompt.description != prompt.description
4308 or existing_prompt.template != (prompt.template if hasattr(prompt, "template") else "")
4309 or existing_prompt.visibility != gateway.visibility
4310 ):
4311 fields_to_update = True
4313 if fields_to_update:
4314 existing_prompt.description = prompt.description
4315 existing_prompt.template = prompt.template if hasattr(prompt, "template") else ""
4316 existing_prompt.visibility = gateway.visibility
4317 logger.debug(f"Updated existing prompt: {prompt.name}")
4318 else:
4319 # Create new prompt if it doesn't exist
4320 db_prompt = DbPrompt(
4321 name=prompt.name,
4322 original_name=prompt.name,
4323 custom_name=prompt.name,
4324 display_name=prompt.name,
4325 description=prompt.description,
4326 template=prompt.template if hasattr(prompt, "template") else "",
4327 argument_schema={}, # Use argument_schema instead of arguments
4328 gateway_id=gateway.id,
4329 created_by="system",
4330 created_via=created_via,
4331 visibility=gateway.visibility,
4332 )
4333 db_prompt.gateway = gateway
4334 prompts_to_add.append(db_prompt)
4335 logger.debug(f"Created new prompt: {prompt.name}")
4336 except Exception as e:
4337 logger.warning(f"Failed to process prompt {getattr(prompt, 'name', 'unknown')}: {e}")
4338 continue
4340 return prompts_to_add
4342 async def _refresh_gateway_tools_resources_prompts(
4343 self,
4344 gateway_id: str,
4345 _user_email: Optional[str] = None,
4346 created_via: str = "health_check",
4347 pre_auth_headers: Optional[Dict[str, str]] = None,
4348 gateway: Optional[DbGateway] = None,
4349 include_resources: bool = True,
4350 include_prompts: bool = True,
4351 ) -> Dict[str, int]:
4352 """Refresh tools, resources, and prompts for a gateway during health checks.
4354 Fetches the latest tools/resources/prompts from the MCP server and syncs
4355 with the database (add new, update changed, remove stale). Only performs
4356 DB operations if actual changes are detected.
4358 This method uses fresh_db_session() internally to avoid holding
4359 connections during HTTP calls to MCP servers.
4361 Args:
4362 gateway_id: ID of the gateway to refresh
4363 _user_email: Optional user email for OAuth token lookup (unused currently)
4364 created_via: String indicating creation source (default: "health_check")
4365 pre_auth_headers: Pre-authenticated headers from health check to avoid duplicate OAuth token fetch
4366 gateway: Optional DbGateway object to avoid redundant DB lookup
4367 include_resources: Whether to include resources in the refresh
4368 include_prompts: Whether to include prompts in the refresh
4370 Returns:
4371 Dict with counts: {tools_added, tools_removed, resources_added,
4372 resources_removed, prompts_added, prompts_removed}
4374 Examples:
4375 >>> from mcpgateway.services.gateway_service import GatewayService
4376 >>> from unittest.mock import patch, MagicMock, AsyncMock
4377 >>> import asyncio
4379 >>> # Test gateway not found returns empty result
4380 >>> service = GatewayService()
4381 >>> mock_session = MagicMock()
4382 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = None
4383 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4384 ... mock_fresh.return_value.__enter__.return_value = mock_session
4385 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4386 >>> result['tools_added'] == 0 and result['tools_removed'] == 0
4387 True
4388 >>> result['resources_added'] == 0 and result['resources_removed'] == 0
4389 True
4390 >>> result['success'] is True and result['error'] is None
4391 True
4393 >>> # Test disabled gateway returns empty result
4394 >>> mock_gw = MagicMock()
4395 >>> mock_gw.enabled = False
4396 >>> mock_gw.reachable = True
4397 >>> mock_gw.name = 'test_gw'
4398 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = mock_gw
4399 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4400 ... mock_fresh.return_value.__enter__.return_value = mock_session
4401 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4402 >>> result['tools_added']
4403 0
4405 >>> # Test unreachable gateway returns empty result
4406 >>> mock_gw.enabled = True
4407 >>> mock_gw.reachable = False
4408 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4409 ... mock_fresh.return_value.__enter__.return_value = mock_session
4410 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4411 >>> result['tools_added']
4412 0
4414 >>> # Test method is async and callable
4415 >>> import inspect
4416 >>> inspect.iscoroutinefunction(service._refresh_gateway_tools_resources_prompts)
4417 True
4418 >>>
4419 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
4420 >>> asyncio.run(service._http_client.aclose())
4421 """
4422 result = {
4423 "tools_added": 0,
4424 "tools_removed": 0,
4425 "resources_added": 0,
4426 "resources_removed": 0,
4427 "prompts_added": 0,
4428 "prompts_removed": 0,
4429 "tools_updated": 0,
4430 "resources_updated": 0,
4431 "prompts_updated": 0,
4432 "success": True,
4433 "error": None,
4434 "validation_errors": [],
4435 }
4437 # Fetch gateway metadata only (no relationships needed for MCP call)
4438 # Use provided gateway object if available to save a DB call
4439 gateway_name = None
4440 gateway_url = None
4441 gateway_transport = None
4442 gateway_auth_type = None
4443 gateway_auth_value = None
4444 gateway_oauth_config = None
4445 gateway_ca_certificate = None
4446 gateway_auth_query_params = None
4448 if gateway:
4449 if not gateway.enabled or not gateway.reachable:
4450 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway.name}")
4451 return result
4453 gateway_name = gateway.name
4454 gateway_url = gateway.url
4455 gateway_transport = gateway.transport
4456 gateway_auth_type = gateway.auth_type
4457 gateway_auth_value = gateway.auth_value
4458 gateway_oauth_config = gateway.oauth_config
4459 gateway_ca_certificate = gateway.ca_certificate
4460 gateway_auth_query_params = gateway.auth_query_params
4461 else:
4462 with fresh_db_session() as db:
4463 gateway_obj = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
4465 if not gateway_obj:
4466 logger.warning(f"Gateway {gateway_id} not found for tool refresh")
4467 return result
4469 if not gateway_obj.enabled or not gateway_obj.reachable:
4470 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway_obj.name}")
4471 return result
4473 # Extract metadata before session closes
4474 gateway_name = gateway_obj.name
4475 gateway_url = gateway_obj.url
4476 gateway_transport = gateway_obj.transport
4477 gateway_auth_type = gateway_obj.auth_type
4478 gateway_auth_value = gateway_obj.auth_value
4479 gateway_oauth_config = gateway_obj.oauth_config
4480 gateway_ca_certificate = gateway_obj.ca_certificate
4481 gateway_auth_query_params = gateway_obj.auth_query_params
4483 # Handle query_param auth - decrypt and apply to URL for refresh
4484 auth_query_params_decrypted: Optional[Dict[str, str]] = None
4485 if gateway_auth_type == "query_param" and gateway_auth_query_params:
4486 auth_query_params_decrypted = {}
4487 for param_key, encrypted_value in gateway_auth_query_params.items():
4488 if encrypted_value:
4489 try:
4490 decrypted = decode_auth(encrypted_value)
4491 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
4492 except Exception:
4493 logger.debug(f"Failed to decrypt query param '{param_key}' for tool refresh")
4494 if auth_query_params_decrypted:
4495 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
4497 # Fetch tools/resources/prompts from MCP server (no DB connection held)
4498 try:
4499 _capabilities, tools, resources, prompts = await self._initialize_gateway(
4500 url=gateway_url,
4501 authentication=gateway_auth_value,
4502 transport=gateway_transport,
4503 auth_type=gateway_auth_type,
4504 oauth_config=gateway_oauth_config,
4505 ca_certificate=gateway_ca_certificate.encode() if gateway_ca_certificate else None,
4506 pre_auth_headers=pre_auth_headers,
4507 include_resources=include_resources,
4508 include_prompts=include_prompts,
4509 auth_query_params=auth_query_params_decrypted,
4510 )
4511 except Exception as e:
4512 logger.warning(f"Failed to fetch tools from gateway {gateway_name}: {e}")
4513 result["success"] = False
4514 result["error"] = str(e)
4515 return result
4517 # For authorization_code OAuth gateways, empty responses may indicate incomplete auth flow
4518 # Skip only if it's an auth_code gateway with no data (user may not have completed authorization)
4519 is_auth_code_gateway = gateway_oauth_config and isinstance(gateway_oauth_config, dict) and gateway_oauth_config.get("grant_type") == "authorization_code"
4520 if not tools and not resources and not prompts and is_auth_code_gateway:
4521 logger.debug(f"No tools/resources/prompts returned from auth_code gateway {gateway_name} (user may not have authorized)")
4522 return result
4524 # For non-auth_code gateways, empty responses are legitimate and will clear stale items
4526 # Update database with fresh session
4527 with fresh_db_session() as db:
4528 # Fetch gateway with relationships for update/comparison
4529 gateway = db.execute(
4530 select(DbGateway)
4531 .options(
4532 selectinload(DbGateway.tools),
4533 selectinload(DbGateway.resources),
4534 selectinload(DbGateway.prompts),
4535 )
4536 .where(DbGateway.id == gateway_id)
4537 ).scalar_one_or_none()
4539 if not gateway:
4540 result["success"] = False
4541 result["error"] = f"Gateway {gateway_id} not found during refresh"
4542 return result
4544 new_tool_names = [tool.name for tool in tools]
4545 new_resource_uris = [resource.uri for resource in resources] if include_resources else None
4546 new_prompt_names = [prompt.name for prompt in prompts] if include_prompts else None
4548 # Track dirty objects before update operations to count per-type updates
4549 pending_tools_before = {obj for obj in db.dirty if isinstance(obj, DbTool)}
4550 pending_resources_before = {obj for obj in db.dirty if isinstance(obj, DbResource)}
4551 pending_prompts_before = {obj for obj in db.dirty if isinstance(obj, DbPrompt)}
4553 # Update/create tools, resources, and prompts
4554 tools_to_add = self._update_or_create_tools(db, tools, gateway, created_via)
4555 resources_to_add = self._update_or_create_resources(db, resources, gateway, created_via) if include_resources else []
4556 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, created_via) if include_prompts else []
4558 # Count per-type updates
4559 result["tools_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbTool)} - pending_tools_before)
4560 result["resources_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbResource)} - pending_resources_before)
4561 result["prompts_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbPrompt)} - pending_prompts_before)
4563 # Only delete MCP-discovered items (not user-created entries)
4564 # Excludes "api", "ui", None (legacy/user-created) to preserve user entries
4565 mcp_created_via_values = {"MCP", "federation", "health_check", "manual_refresh", "oauth", "update"}
4567 # Find and remove stale tools (only MCP-discovered ones)
4568 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names and tool.created_via in mcp_created_via_values]
4569 if stale_tool_ids:
4570 for i in range(0, len(stale_tool_ids), 500):
4571 chunk = stale_tool_ids[i : i + 500]
4572 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
4573 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
4574 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
4575 result["tools_removed"] = len(stale_tool_ids)
4577 # Find and remove stale resources (only MCP-discovered ones, only if resources were fetched)
4578 stale_resource_ids = []
4579 if new_resource_uris is not None:
4580 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris and resource.created_via in mcp_created_via_values]
4581 if stale_resource_ids:
4582 for i in range(0, len(stale_resource_ids), 500):
4583 chunk = stale_resource_ids[i : i + 500]
4584 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
4585 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
4586 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
4587 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
4588 result["resources_removed"] = len(stale_resource_ids)
4590 # Find and remove stale prompts (only MCP-discovered ones, only if prompts were fetched)
4591 stale_prompt_ids = []
4592 if new_prompt_names is not None:
4593 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names and prompt.created_via in mcp_created_via_values]
4594 if stale_prompt_ids:
4595 for i in range(0, len(stale_prompt_ids), 500):
4596 chunk = stale_prompt_ids[i : i + 500]
4597 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
4598 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
4599 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
4600 result["prompts_removed"] = len(stale_prompt_ids)
4602 # Expire gateway if stale items were deleted
4603 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
4604 db.expire(gateway)
4606 # Add new items in chunks
4607 chunk_size = 50
4608 if tools_to_add:
4609 for i in range(0, len(tools_to_add), chunk_size):
4610 chunk = tools_to_add[i : i + chunk_size]
4611 db.add_all(chunk)
4612 db.flush()
4613 result["tools_added"] = len(tools_to_add)
4615 if resources_to_add:
4616 for i in range(0, len(resources_to_add), chunk_size):
4617 chunk = resources_to_add[i : i + chunk_size]
4618 db.add_all(chunk)
4619 db.flush()
4620 result["resources_added"] = len(resources_to_add)
4622 if prompts_to_add:
4623 for i in range(0, len(prompts_to_add), chunk_size):
4624 chunk = prompts_to_add[i : i + chunk_size]
4625 db.add_all(chunk)
4626 db.flush()
4627 result["prompts_added"] = len(prompts_to_add)
4629 gateway.last_refresh_at = datetime.now(timezone.utc)
4631 total_changes = (
4632 result["tools_added"]
4633 + result["tools_removed"]
4634 + result["tools_updated"]
4635 + result["resources_added"]
4636 + result["resources_removed"]
4637 + result["resources_updated"]
4638 + result["prompts_added"]
4639 + result["prompts_removed"]
4640 + result["prompts_updated"]
4641 )
4643 has_changes = total_changes > 0
4645 if has_changes:
4646 db.commit()
4647 logger.info(
4648 f"Refreshed gateway {gateway_name}: "
4649 f"tools(+{result['tools_added']}/-{result['tools_removed']}/~{result['tools_updated']}), "
4650 f"resources(+{result['resources_added']}/-{result['resources_removed']}/~{result['resources_updated']}), "
4651 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}/~{result['prompts_updated']})"
4652 )
4654 # Invalidate caches per-type based on actual changes
4655 cache = _get_registry_cache()
4656 if result["tools_added"] > 0 or result["tools_removed"] > 0 or result["tools_updated"] > 0:
4657 await cache.invalidate_tools()
4658 if result["resources_added"] > 0 or result["resources_removed"] > 0 or result["resources_updated"] > 0:
4659 await cache.invalidate_resources()
4660 if result["prompts_added"] > 0 or result["prompts_removed"] > 0 or result["prompts_updated"] > 0:
4661 await cache.invalidate_prompts()
4663 # Invalidate tool lookup cache for this gateway
4664 tool_lookup_cache = _get_tool_lookup_cache()
4665 await tool_lookup_cache.invalidate_gateway(str(gateway_id))
4666 else:
4667 db.commit()
4668 logger.debug(f"No changes detected during refresh of gateway {gateway_name}")
4670 return result
4672 def _get_refresh_lock(self, gateway_id: str) -> asyncio.Lock:
4673 """Get or create a per-gateway refresh lock.
4675 This ensures only one refresh operation can run for a given gateway at a time.
4677 Args:
4678 gateway_id: ID of the gateway to get the lock for
4680 Returns:
4681 asyncio.Lock: The lock for the specified gateway
4683 Examples:
4684 >>> from mcpgateway.services.gateway_service import GatewayService
4685 >>> service = GatewayService()
4686 >>> lock1 = service._get_refresh_lock('gw-123')
4687 >>> lock2 = service._get_refresh_lock('gw-123')
4688 >>> lock1 is lock2
4689 True
4690 >>> lock3 = service._get_refresh_lock('gw-456')
4691 >>> lock1 is lock3
4692 False
4693 """
4694 if gateway_id not in self._refresh_locks:
4695 self._refresh_locks[gateway_id] = asyncio.Lock()
4696 return self._refresh_locks[gateway_id]
4698 async def refresh_gateway_manually(
4699 self,
4700 gateway_id: str,
4701 include_resources: bool = True,
4702 include_prompts: bool = True,
4703 user_email: Optional[str] = None,
4704 request_headers: Optional[Dict[str, str]] = None,
4705 ) -> Dict[str, Any]:
4706 """Manually trigger a refresh of tools/resources/prompts for a gateway.
4708 This method provides a public API for triggering an immediate refresh
4709 of a gateway's tools, resources, and prompts from its MCP server.
4710 It includes concurrency control via per-gateway locking.
4712 Args:
4713 gateway_id: Gateway ID to refresh
4714 include_resources: Whether to include resources in the refresh
4715 include_prompts: Whether to include prompts in the refresh
4716 user_email: Email of the user triggering the refresh
4717 request_headers: Optional request headers for passthrough authentication
4719 Returns:
4720 Dict with counts: {tools_added, tools_updated, tools_removed,
4721 resources_added, resources_updated, resources_removed,
4722 prompts_added, prompts_updated, prompts_removed,
4723 validation_errors, duration_ms, refreshed_at}
4725 Raises:
4726 GatewayNotFoundError: If the gateway does not exist
4727 GatewayError: If another refresh is already in progress for this gateway
4729 Examples:
4730 >>> from mcpgateway.services.gateway_service import GatewayService
4731 >>> from unittest.mock import patch, MagicMock, AsyncMock
4732 >>> import asyncio
4734 >>> # Test method is async
4735 >>> service = GatewayService()
4736 >>> import inspect
4737 >>> inspect.iscoroutinefunction(service.refresh_gateway_manually)
4738 True
4739 """
4740 start_time = time.monotonic()
4742 pre_auth_headers = {}
4744 # Check if gateway exists before acquiring lock
4745 with fresh_db_session() as db:
4746 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
4747 if not gateway:
4748 raise GatewayNotFoundError(f"Gateway with ID '{gateway_id}' not found")
4749 gateway_name = gateway.name
4751 # Get passthrough headers if request headers provided
4752 if request_headers:
4753 pre_auth_headers = get_passthrough_headers(request_headers, {}, db, gateway)
4755 lock = self._get_refresh_lock(gateway_id)
4757 # Check if lock is already held (concurrent refresh in progress)
4758 if lock.locked():
4759 raise GatewayError(f"Refresh already in progress for gateway {gateway_name}")
4761 async with lock:
4762 logger.info(f"Starting manual refresh for gateway {gateway_name} (ID: {gateway_id})")
4764 result = await self._refresh_gateway_tools_resources_prompts(
4765 gateway_id=gateway_id,
4766 _user_email=user_email,
4767 created_via="manual_refresh",
4768 pre_auth_headers=pre_auth_headers,
4769 gateway=gateway,
4770 include_resources=include_resources,
4771 include_prompts=include_prompts,
4772 )
4773 # Note: last_refresh_at is updated inside _refresh_gateway_tools_resources_prompts on success
4775 result["duration_ms"] = (time.monotonic() - start_time) * 1000
4776 result["refreshed_at"] = datetime.now(timezone.utc)
4778 log_level = logging.INFO if result.get("success", True) else logging.WARNING
4779 status_msg = "succeeded" if result.get("success", True) else f"failed: {result.get('error')}"
4781 logger.log(
4782 log_level,
4783 f"Manual refresh for gateway {gateway_id} {status_msg}. Stats: "
4784 f"tools(+{result['tools_added']}/-{result['tools_removed']}), "
4785 f"resources(+{result['resources_added']}/-{result['resources_removed']}), "
4786 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}) "
4787 f"in {result['duration_ms']:.2f}ms",
4788 )
4790 return result
4792 async def _publish_event(self, event: Dict[str, Any]) -> None:
4793 """Publish event to all subscribers.
4795 Args:
4796 event: event dictionary
4798 Examples:
4799 >>> import asyncio
4800 >>> from unittest.mock import AsyncMock
4801 >>> service = GatewayService()
4802 >>> # Mock the underlying event service
4803 >>> service._event_service = AsyncMock()
4804 >>> test_event = {"type": "test", "data": {}}
4805 >>>
4806 >>> asyncio.run(service._publish_event(test_event))
4807 >>>
4808 >>> # Verify the event was passed to the event service
4809 >>> service._event_service.publish_event.assert_awaited_with(test_event)
4810 """
4811 await self._event_service.publish_event(event)
4813 def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> tuple[list[ToolCreate], list[str]]:
4814 """Validate tools individually with richer logging and error aggregation.
4816 Args:
4817 tools: list of tool dicts
4818 context: caller context, e.g. "oauth" to tailor errors/messages
4820 Returns:
4821 tuple[list[ToolCreate], list[str]]: Tuple of (valid tools, validation errors)
4823 Raises:
4824 OAuthToolValidationError: If all tools fail validation in OAuth context
4825 GatewayConnectionError: If all tools fail validation in default context
4826 """
4827 valid_tools: list[ToolCreate] = []
4828 validation_errors: list[str] = []
4830 for i, tool_dict in enumerate(tools):
4831 tool_name = tool_dict.get("name", f"unknown_tool_{i}")
4832 try:
4833 logger.debug(f"Validating tool: {tool_name}")
4834 validated_tool = ToolCreate.model_validate(tool_dict)
4835 valid_tools.append(validated_tool)
4836 logger.debug(f"Tool '{tool_name}' validated successfully")
4837 except ValidationError as e:
4838 error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}"
4839 logger.error(error_msg)
4840 logger.debug(f"Failed tool schema: {tool_dict}")
4841 validation_errors.append(error_msg)
4842 except ValueError as e:
4843 if "JSON structure exceeds maximum depth" in str(e):
4844 error_msg = f"Tool '{tool_name}' schema too deeply nested. " f"Current depth limit: {settings.validation_max_json_depth}"
4845 logger.error(error_msg)
4846 logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable")
4847 else:
4848 error_msg = f"ValueError for tool '{tool_name}': {str(e)}"
4849 logger.error(error_msg)
4850 validation_errors.append(error_msg)
4851 except Exception as e: # pragma: no cover - defensive
4852 error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}"
4853 logger.error(error_msg, exc_info=True)
4854 validation_errors.append(error_msg)
4856 if validation_errors:
4857 logger.warning(f"Tool validation completed with {len(validation_errors)} error(s). " f"Successfully validated {len(valid_tools)} tool(s).")
4858 for err in validation_errors[:3]:
4859 logger.debug(f"Validation error: {err}")
4861 if not valid_tools and validation_errors:
4862 if context == "oauth":
4863 raise OAuthToolValidationError(f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}")
4864 raise GatewayConnectionError(f"Failed to fetch tools: All {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}")
4866 return valid_tools, validation_errors
4868 async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None):
4869 """Connect to an MCP server running with SSE transport, skipping URL validation.
4871 This is used for OAuth-protected servers where we've already validated the token works.
4873 Args:
4874 server_url: The URL of the SSE MCP server to connect to.
4875 authentication: Optional dictionary containing authentication headers.
4877 Returns:
4878 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
4879 """
4880 if authentication is None:
4881 authentication = {}
4883 # Skip validation for OAuth servers - we already validated via OAuth flow
4884 # Use async with for both sse_client and ClientSession
4885 try:
4886 async with sse_client(url=server_url, headers=authentication) as streams:
4887 async with ClientSession(*streams) as session:
4888 # Initialize the session
4889 response = await session.initialize()
4890 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
4891 logger.debug(f"Server capabilities: {capabilities}")
4893 response = await session.list_tools()
4894 tools = response.tools
4895 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
4897 tools, _ = self._validate_tools(tools, context="oauth")
4898 if tools:
4899 logger.info(f"Fetched {len(tools)} tools from gateway")
4900 # Fetch resources if supported
4902 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
4903 resources = []
4904 if capabilities.get("resources"):
4905 try:
4906 response = await session.list_resources()
4907 raw_resources = response.resources
4908 for resource in raw_resources:
4909 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
4910 # Convert AnyUrl to string if present
4911 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
4912 resource_data["uri"] = str(resource_data["uri"])
4913 # Add default content if not present (will be fetched on demand)
4914 if "content" not in resource_data:
4915 resource_data["content"] = ""
4916 try:
4917 resources.append(ResourceCreate.model_validate(resource_data))
4918 except Exception:
4919 # If validation fails, create minimal resource
4920 resources.append(
4921 ResourceCreate(
4922 uri=str(resource_data.get("uri", "")),
4923 name=resource_data.get("name", ""),
4924 description=resource_data.get("description"),
4925 mime_type=resource_data.get("mimeType"),
4926 uri_template=resource_data.get("uriTemplate") or None,
4927 content="",
4928 )
4929 )
4930 logger.info(f"Fetched {len(resources)} resources from gateway")
4931 except Exception as e:
4932 logger.warning(f"Failed to fetch resources: {e}")
4934 # resource template URI
4935 try:
4936 response_templates = await session.list_resource_templates()
4937 raw_resources_templates = response_templates.resourceTemplates
4938 resource_templates = []
4939 for resource_template in raw_resources_templates:
4940 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
4942 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"):
4943 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
4944 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
4946 if "content" not in resource_template_data:
4947 resource_template_data["content"] = ""
4949 resources.append(ResourceCreate.model_validate(resource_template_data))
4950 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
4951 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway")
4952 except Exception as e:
4953 logger.warning(f"Failed to fetch resource templates: {e}")
4955 # Fetch prompts if supported
4956 prompts = []
4957 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
4958 if capabilities.get("prompts"):
4959 try:
4960 response = await session.list_prompts()
4961 raw_prompts = response.prompts
4962 for prompt in raw_prompts:
4963 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
4964 # Add default template if not present
4965 if "template" not in prompt_data:
4966 prompt_data["template"] = ""
4967 try:
4968 prompts.append(PromptCreate.model_validate(prompt_data))
4969 except Exception:
4970 # If validation fails, create minimal prompt
4971 prompts.append(
4972 PromptCreate(
4973 name=prompt_data.get("name", ""),
4974 description=prompt_data.get("description"),
4975 template=prompt_data.get("template", ""),
4976 )
4977 )
4978 logger.info(f"Fetched {len(prompts)} prompts from gateway")
4979 except Exception as e:
4980 logger.warning(f"Failed to fetch prompts: {e}")
4982 return capabilities, tools, resources, prompts
4983 except Exception as e:
4984 # Note: This function is for OAuth servers only, which don't use query param auth
4985 # Still sanitize in case exception contains URL with static sensitive params
4986 sanitized_url = sanitize_url_for_logging(server_url)
4987 sanitized_error = sanitize_exception_message(str(e))
4988 logger.error(f"SSE connection error details: {type(e).__name__}: {sanitized_error}", exc_info=True)
4989 raise GatewayConnectionError(f"Failed to connect to SSE server at {sanitized_url}: {sanitized_error}")
4991 async def connect_to_sse_server(
4992 self,
4993 server_url: str,
4994 authentication: Optional[Dict[str, str]] = None,
4995 ca_certificate: Optional[bytes] = None,
4996 include_prompts: bool = True,
4997 include_resources: bool = True,
4998 auth_query_params: Optional[Dict[str, str]] = None,
4999 ):
5000 """Connect to an MCP server running with SSE transport.
5002 Args:
5003 server_url: The URL of the SSE MCP server to connect to.
5004 authentication: Optional dictionary containing authentication headers.
5005 ca_certificate: Optional CA certificate for SSL verification.
5006 include_prompts: Whether to fetch prompts from the server.
5007 include_resources: Whether to fetch resources from the server.
5008 auth_query_params: Query param names for URL sanitization in error logs.
5010 Returns:
5011 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5012 """
5013 if authentication is None:
5014 authentication = {}
5016 def get_httpx_client_factory(
5017 headers: dict[str, str] | None = None,
5018 timeout: httpx.Timeout | None = None,
5019 auth: httpx.Auth | None = None,
5020 ) -> httpx.AsyncClient:
5021 """Factory function to create httpx.AsyncClient with optional CA certificate.
5023 Args:
5024 headers: Optional headers for the client
5025 timeout: Optional timeout for the client
5026 auth: Optional auth for the client
5028 Returns:
5029 httpx.AsyncClient: Configured HTTPX async client
5030 """
5031 if ca_certificate:
5032 ctx = self.create_ssl_context(ca_certificate)
5033 else:
5034 ctx = None
5035 return httpx.AsyncClient(
5036 verify=ctx if ctx else get_default_verify(),
5037 follow_redirects=True,
5038 headers=headers,
5039 timeout=timeout if timeout else get_http_timeout(),
5040 auth=auth,
5041 limits=httpx.Limits(
5042 max_connections=settings.httpx_max_connections,
5043 max_keepalive_connections=settings.httpx_max_keepalive_connections,
5044 keepalive_expiry=settings.httpx_keepalive_expiry,
5045 ),
5046 )
5048 # Use async with for both sse_client and ClientSession
5049 async with sse_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as streams:
5050 async with ClientSession(*streams) as session:
5051 # Initialize the session
5052 response = await session.initialize()
5054 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5055 logger.debug(f"Server capabilities: {capabilities}")
5057 response = await session.list_tools()
5058 tools = response.tools
5059 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5061 tools, _ = self._validate_tools(tools)
5062 if tools:
5063 logger.info(f"Fetched {len(tools)} tools from gateway")
5064 # Fetch resources if supported
5065 resources = []
5066 if include_resources:
5067 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5068 if capabilities.get("resources"):
5069 try:
5070 response = await session.list_resources()
5071 raw_resources = response.resources
5072 for resource in raw_resources:
5073 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5074 # Convert AnyUrl to string if present
5075 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5076 resource_data["uri"] = str(resource_data["uri"])
5077 # Add default content if not present (will be fetched on demand)
5078 if "content" not in resource_data:
5079 resource_data["content"] = ""
5080 try:
5081 resources.append(ResourceCreate.model_validate(resource_data))
5082 except Exception:
5083 # If validation fails, create minimal resource
5084 resources.append(
5085 ResourceCreate(
5086 uri=str(resource_data.get("uri", "")),
5087 name=resource_data.get("name", ""),
5088 description=resource_data.get("description"),
5089 mime_type=resource_data.get("mimeType"),
5090 uri_template=resource_data.get("uriTemplate") or None,
5091 content="",
5092 )
5093 )
5094 logger.info(f"Fetched {len(resources)} resources from gateway")
5095 except Exception as e:
5096 logger.warning(f"Failed to fetch resources: {e}")
5098 # resource template URI
5099 try:
5100 response_templates = await session.list_resource_templates()
5101 raw_resources_templates = response_templates.resourceTemplates
5102 resource_templates = []
5103 for resource_template in raw_resources_templates:
5104 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5106 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"):
5107 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5108 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5110 if "content" not in resource_template_data:
5111 resource_template_data["content"] = ""
5113 resources.append(ResourceCreate.model_validate(resource_template_data))
5114 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5115 logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway")
5116 except Exception as ei:
5117 logger.warning(f"Failed to fetch resource templates: {ei}")
5119 # Fetch prompts if supported
5120 prompts = []
5121 if include_prompts:
5122 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5123 if capabilities.get("prompts"):
5124 try:
5125 response = await session.list_prompts()
5126 raw_prompts = response.prompts
5127 for prompt in raw_prompts:
5128 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5129 # Add default template if not present
5130 if "template" not in prompt_data:
5131 prompt_data["template"] = ""
5132 try:
5133 prompts.append(PromptCreate.model_validate(prompt_data))
5134 except Exception:
5135 # If validation fails, create minimal prompt
5136 prompts.append(
5137 PromptCreate(
5138 name=prompt_data.get("name", ""),
5139 description=prompt_data.get("description"),
5140 template=prompt_data.get("template", ""),
5141 )
5142 )
5143 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5144 except Exception as e:
5145 logger.warning(f"Failed to fetch prompts: {e}")
5147 return capabilities, tools, resources, prompts
5148 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params)
5149 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established")
5151 async def connect_to_streamablehttp_server(
5152 self,
5153 server_url: str,
5154 authentication: Optional[Dict[str, str]] = None,
5155 ca_certificate: Optional[bytes] = None,
5156 include_prompts: bool = True,
5157 include_resources: bool = True,
5158 auth_query_params: Optional[Dict[str, str]] = None,
5159 ):
5160 """Connect to an MCP server running with Streamable HTTP transport.
5162 Args:
5163 server_url: The URL of the Streamable HTTP MCP server to connect to.
5164 authentication: Optional dictionary containing authentication headers.
5165 ca_certificate: Optional CA certificate for SSL verification.
5166 include_prompts: Whether to fetch prompts from the server.
5167 include_resources: Whether to fetch resources from the server.
5168 auth_query_params: Query param names for URL sanitization in error logs.
5170 Returns:
5171 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5172 """
5173 if authentication is None:
5174 authentication = {}
5176 # Use authentication directly instead
5177 def get_httpx_client_factory(
5178 headers: dict[str, str] | None = None,
5179 timeout: httpx.Timeout | None = None,
5180 auth: httpx.Auth | None = None,
5181 ) -> httpx.AsyncClient:
5182 """Factory function to create httpx.AsyncClient with optional CA certificate.
5184 Args:
5185 headers: Optional headers for the client
5186 timeout: Optional timeout for the client
5187 auth: Optional auth for the client
5189 Returns:
5190 httpx.AsyncClient: Configured HTTPX async client
5191 """
5192 if ca_certificate:
5193 ctx = self.create_ssl_context(ca_certificate)
5194 else:
5195 ctx = None
5196 return httpx.AsyncClient(
5197 verify=ctx if ctx else get_default_verify(),
5198 follow_redirects=True,
5199 headers=headers,
5200 timeout=timeout if timeout else get_http_timeout(),
5201 auth=auth,
5202 limits=httpx.Limits(
5203 max_connections=settings.httpx_max_connections,
5204 max_keepalive_connections=settings.httpx_max_keepalive_connections,
5205 keepalive_expiry=settings.httpx_keepalive_expiry,
5206 ),
5207 )
5209 async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id):
5210 async with ClientSession(read_stream, write_stream) as session:
5211 # Initialize the session
5212 response = await session.initialize()
5213 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5214 logger.debug(f"Server capabilities: {capabilities}")
5216 response = await session.list_tools()
5217 tools = response.tools
5218 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5220 tools, _ = self._validate_tools(tools)
5221 for tool in tools:
5222 tool.request_type = "STREAMABLEHTTP"
5223 if tools:
5224 logger.info(f"Fetched {len(tools)} tools from gateway")
5226 # Fetch resources if supported
5227 resources = []
5228 if include_resources:
5229 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5230 if capabilities.get("resources"):
5231 try:
5232 response = await session.list_resources()
5233 raw_resources = response.resources
5234 for resource in raw_resources:
5235 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5236 # Convert AnyUrl to string if present
5237 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5238 resource_data["uri"] = str(resource_data["uri"])
5239 # Add default content if not present
5240 if "content" not in resource_data:
5241 resource_data["content"] = ""
5242 try:
5243 resources.append(ResourceCreate.model_validate(resource_data))
5244 except Exception:
5245 # If validation fails, create minimal resource
5246 resources.append(
5247 ResourceCreate(
5248 uri=str(resource_data.get("uri", "")),
5249 name=resource_data.get("name", ""),
5250 description=resource_data.get("description"),
5251 mime_type=resource_data.get("mimeType"),
5252 uri_template=resource_data.get("uriTemplate") or None,
5253 content="",
5254 )
5255 )
5256 logger.info(f"Fetched {len(resources)} resources from gateway")
5257 except Exception as e:
5258 logger.warning(f"Failed to fetch resources: {e}")
5260 # resource template URI
5261 try:
5262 response_templates = await session.list_resource_templates()
5263 raw_resources_templates = response_templates.resourceTemplates
5264 resource_templates = []
5265 for resource_template in raw_resources_templates:
5266 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5268 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"):
5269 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5270 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5272 if "content" not in resource_template_data:
5273 resource_template_data["content"] = ""
5275 resources.append(ResourceCreate.model_validate(resource_template_data))
5276 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5277 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway")
5278 except Exception as e:
5279 logger.warning(f"Failed to fetch resource templates: {e}")
5281 # Fetch prompts if supported
5282 prompts = []
5283 if include_prompts:
5284 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5285 if capabilities.get("prompts"):
5286 try:
5287 response = await session.list_prompts()
5288 raw_prompts = response.prompts
5289 for prompt in raw_prompts:
5290 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5291 # Add default template if not present
5292 if "template" not in prompt_data:
5293 prompt_data["template"] = ""
5294 prompts.append(PromptCreate.model_validate(prompt_data))
5295 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5296 except Exception as e:
5297 logger.warning(f"Failed to fetch prompts: {e}")
5299 return capabilities, tools, resources, prompts
5300 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params)
5301 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established")
5304# Lazy singleton - created on first access, not at module import time.
5305# This avoids instantiation when only exception classes are imported.
5306_gateway_service_instance = None # pylint: disable=invalid-name
5309def __getattr__(name: str):
5310 """Module-level __getattr__ for lazy singleton creation.
5312 Args:
5313 name: The attribute name being accessed.
5315 Returns:
5316 The gateway_service singleton instance if name is "gateway_service".
5318 Raises:
5319 AttributeError: If the attribute name is not "gateway_service".
5320 """
5321 global _gateway_service_instance # pylint: disable=global-statement
5322 if name == "gateway_service":
5323 if _gateway_service_instance is None:
5324 _gateway_service_instance = GatewayService()
5325 return _gateway_service_instance
5326 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")