Coverage for mcpgateway / services / gateway_service.py: 90%
2210 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2# pylint: disable=import-outside-toplevel,no-name-in-module
3"""Location: ./mcpgateway/services/gateway_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Gateway Service Implementation.
9This module implements gateway federation according to the MCP specification.
10It handles:
11- Gateway discovery and registration
12- Request forwarding
13- Capability aggregation
14- Health monitoring
15- Active/inactive gateway management
17Examples:
18 >>> from mcpgateway.services.gateway_service import GatewayService, GatewayError
19 >>> service = GatewayService()
20 >>> isinstance(service, GatewayService)
21 True
22 >>> hasattr(service, '_active_gateways')
23 True
24 >>> isinstance(service._active_gateways, set)
25 True
27 Test error classes:
28 >>> error = GatewayError("Test error")
29 >>> str(error)
30 'Test error'
31 >>> isinstance(error, Exception)
32 True
34 >>> conflict_error = GatewayNameConflictError("test_gw")
35 >>> "test_gw" in str(conflict_error)
36 True
37 >>> conflict_error.enabled
38 True
39 >>>
40 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
41 >>> import asyncio
42 >>> asyncio.run(service._http_client.aclose())
43"""
45# Standard
46import asyncio
47import binascii
48from datetime import datetime, timezone
49import logging
50import mimetypes
51import os
52import ssl
53import tempfile
54import time
55from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union
56from urllib.parse import urljoin, urlparse, urlunparse
57import uuid
59# Third-Party
60from filelock import FileLock, Timeout
61import httpx
62from mcp import ClientSession
63from mcp.client.sse import sse_client
64from mcp.client.streamable_http import streamablehttp_client
65from pydantic import ValidationError
66from sqlalchemy import and_, delete, desc, or_, select, update
67from sqlalchemy.exc import IntegrityError
68from sqlalchemy.orm import joinedload, selectinload, Session
70try:
71 # Third-Party - check if redis is available
72 # Third-Party
73 import redis.asyncio as _aioredis # noqa: F401 # pylint: disable=unused-import
75 REDIS_AVAILABLE = True
76 del _aioredis # Only needed for availability check
77except ImportError:
78 REDIS_AVAILABLE = False
79 logging.info("Redis is not utilized in this environment.")
81# First-Party
82from mcpgateway.config import settings
83from mcpgateway.db import fresh_db_session
84from mcpgateway.db import Gateway as DbGateway
85from mcpgateway.db import get_db, get_for_update
86from mcpgateway.db import Prompt as DbPrompt
87from mcpgateway.db import PromptMetric
88from mcpgateway.db import Resource as DbResource
89from mcpgateway.db import ResourceMetric, ResourceSubscription, server_prompt_association, server_resource_association, server_tool_association, SessionLocal
90from mcpgateway.db import Tool as DbTool
91from mcpgateway.db import ToolMetric
92from mcpgateway.observability import create_span
93from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate
95# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks
96from mcpgateway.services.audit_trail_service import get_audit_trail_service
97from mcpgateway.services.event_service import EventService
98from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout, get_isolated_http_client
99from mcpgateway.services.logging_service import LoggingService
100from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, register_gateway_capabilities_for_notifications, TransportType
101from mcpgateway.services.oauth_manager import OAuthManager
102from mcpgateway.services.structured_logger import get_structured_logger
103from mcpgateway.services.team_management_service import TeamManagementService
104from mcpgateway.utils.create_slug import slugify
105from mcpgateway.utils.display_name import generate_display_name
106from mcpgateway.utils.pagination import unified_paginate
107from mcpgateway.utils.passthrough_headers import get_passthrough_headers
108from mcpgateway.utils.redis_client import get_redis_client
109from mcpgateway.utils.retry_manager import ResilientHttpClient
110from mcpgateway.utils.services_auth import decode_auth, encode_auth
111from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
112from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context
113from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging
114from mcpgateway.utils.validate_signature import validate_signature
115from mcpgateway.validation.tags import validate_tags_field
117# Cache import (lazy to avoid circular dependencies)
118_REGISTRY_CACHE = None
119_TOOL_LOOKUP_CACHE = None
122def _get_registry_cache():
123 """Get registry cache singleton lazily.
125 Returns:
126 RegistryCache instance.
127 """
128 global _REGISTRY_CACHE # pylint: disable=global-statement
129 if _REGISTRY_CACHE is None:
130 # First-Party
131 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
133 _REGISTRY_CACHE = registry_cache
134 return _REGISTRY_CACHE
137def _get_tool_lookup_cache():
138 """Get tool lookup cache singleton lazily.
140 Returns:
141 ToolLookupCache instance.
142 """
143 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement
144 if _TOOL_LOOKUP_CACHE is None:
145 # First-Party
146 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
148 _TOOL_LOOKUP_CACHE = tool_lookup_cache
149 return _TOOL_LOOKUP_CACHE
152# Initialize logging service first
153logging_service = LoggingService()
154logger = logging_service.get_logger(__name__)
156# Initialize structured logger and audit trail for gateway operations
157structured_logger = get_structured_logger("gateway_service")
158audit_trail = get_audit_trail_service()
161GW_FAILURE_THRESHOLD = settings.unhealthy_threshold
162GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval
165class GatewayError(Exception):
166 """Base class for gateway-related errors.
168 Examples:
169 >>> error = GatewayError("Test error")
170 >>> str(error)
171 'Test error'
172 >>> isinstance(error, Exception)
173 True
174 """
177class GatewayNotFoundError(GatewayError):
178 """Raised when a requested gateway is not found.
180 Examples:
181 >>> error = GatewayNotFoundError("Gateway not found")
182 >>> str(error)
183 'Gateway not found'
184 >>> isinstance(error, GatewayError)
185 True
186 """
189class GatewayNameConflictError(GatewayError):
190 """Raised when a gateway name conflicts with existing (active or inactive) gateway.
192 Args:
193 name: The conflicting gateway name
194 enabled: Whether the existing gateway is enabled
195 gateway_id: ID of the existing gateway if available
196 visibility: The visibility of the gateway ("public" or "team").
198 Examples:
199 >>> error = GatewayNameConflictError("test_gateway")
200 >>> str(error)
201 'Public Gateway already exists with name: test_gateway'
202 >>> error.name
203 'test_gateway'
204 >>> error.enabled
205 True
206 >>> error.gateway_id is None
207 True
209 >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123)
210 >>> str(error_inactive)
211 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)'
212 >>> error_inactive.enabled
213 False
214 >>> error_inactive.gateway_id
215 123
216 """
218 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"):
219 """Initialize the error with gateway information.
221 Args:
222 name: The conflicting gateway name
223 enabled: Whether the existing gateway is enabled
224 gateway_id: ID of the existing gateway if available
225 visibility: The visibility of the gateway ("public" or "team").
226 """
227 self.name = name
228 self.enabled = enabled
229 self.gateway_id = gateway_id
230 if visibility == "team":
231 vis_label = "Team-level"
232 else:
233 vis_label = "Public"
234 message = f"{vis_label} Gateway already exists with name: {name}"
235 if not enabled:
236 message += f" (currently inactive, ID: {gateway_id})"
237 super().__init__(message)
240class GatewayDuplicateConflictError(GatewayError):
241 """Raised when a gateway conflicts with an existing gateway (same URL + credentials).
243 This error is raised when attempting to register a gateway with a URL and
244 authentication credentials that already exist within the same scope:
245 - Public: Global uniqueness required across all public gateways.
246 - Team: Uniqueness required within the same team.
247 - Private: Uniqueness required for the same user, a user cannot have two private gateways with the same URL and credentials.
249 Args:
250 duplicate_gateway: The existing conflicting gateway (DbGateway instance).
252 Examples:
253 >>> # Public gateway conflict with the same URL and basic auth
254 >>> existing_gw = DbGateway(url="https://api.example.com", id="abc-123", enabled=True, visibility="public", team_id=None, name="API Gateway", owner_email="alice@example.com")
255 >>> error = GatewayDuplicateConflictError(
256 ... duplicate_gateway=existing_gw
257 ... )
258 >>> str(error)
259 'The Server already exists in Public scope (Name: API Gateway, Status: active)'
261 >>> # Team gateway conflict with the same URL and OAuth credentials
262 >>> team_gw = DbGateway(url="https://api.example.com", id="def-456", enabled=False, visibility="team", team_id="engineering-team", name="API Gateway", owner_email="bob@example.com")
263 >>> error = GatewayDuplicateConflictError(
264 ... duplicate_gateway=team_gw
265 ... )
266 >>> str(error)
267 'The Server already exists in your Team (Name: API Gateway, Status: inactive). You may want to re-enable the existing gateway instead.'
269 >>> # Private gateway conflict (same user cannot have two gateways with the same URL)
270 >>> private_gw = DbGateway(url="https://api.example.com", id="ghi-789", enabled=True, visibility="private", team_id="none", name="API Gateway", owner_email="charlie@example.com")
271 >>> error = GatewayDuplicateConflictError(
272 ... duplicate_gateway=private_gw
273 ... )
274 >>> str(error)
275 'The Server already exists in "private" scope (Name: API Gateway, Status: active)'
276 """
278 def __init__(
279 self,
280 duplicate_gateway: "DbGateway",
281 ):
282 """Initialize the error with gateway information.
284 Args:
285 duplicate_gateway: The existing conflicting gateway (DbGateway instance)
286 """
287 self.duplicate_gateway = duplicate_gateway
288 self.url = duplicate_gateway.url
289 self.gateway_id = duplicate_gateway.id
290 self.enabled = duplicate_gateway.enabled
291 self.visibility = duplicate_gateway.visibility
292 self.team_id = duplicate_gateway.team_id
293 self.name = duplicate_gateway.name
295 # Build scope description
296 if self.visibility == "public":
297 scope_desc = "Public scope"
298 elif self.visibility == "team" and self.team_id:
299 scope_desc = "your Team"
300 else:
301 scope_desc = f'"{self.visibility}" scope'
303 # Build status description
304 status = "active" if self.enabled else "inactive"
306 # Construct error message
307 message = f"The Server already exists in {scope_desc} " f"(Name: {self.name}, Status: {status})"
309 # Add helpful hint for inactive gateways
310 if not self.enabled:
311 message += ". You may want to re-enable the existing gateway instead."
313 super().__init__(message)
316class GatewayConnectionError(GatewayError):
317 """Raised when gateway connection fails.
319 Examples:
320 >>> error = GatewayConnectionError("Connection failed")
321 >>> str(error)
322 'Connection failed'
323 >>> isinstance(error, GatewayError)
324 True
325 """
328class OAuthToolValidationError(GatewayConnectionError):
329 """Raised when tool validation fails during OAuth-driven fetch."""
332class GatewayService: # pylint: disable=too-many-instance-attributes
333 """Service for managing federated gateways.
335 Handles:
336 - Gateway registration and health checks
337 - Request forwarding
338 - Capability negotiation
339 - Federation events
340 - Active/inactive status management
341 """
343 def __init__(self) -> None:
344 """Initialize the gateway service.
346 Examples:
347 >>> from mcpgateway.services.gateway_service import GatewayService
348 >>> from mcpgateway.services.event_service import EventService
349 >>> from mcpgateway.utils.retry_manager import ResilientHttpClient
350 >>> from mcpgateway.services.tool_service import ToolService
351 >>> service = GatewayService()
352 >>> isinstance(service._event_service, EventService)
353 True
354 >>> isinstance(service._http_client, ResilientHttpClient)
355 True
356 >>> service._health_check_interval == GW_HEALTH_CHECK_INTERVAL
357 True
358 >>> service._health_check_task is None
359 True
360 >>> isinstance(service._active_gateways, set)
361 True
362 >>> len(service._active_gateways)
363 0
364 >>> service._stream_response is None
365 True
366 >>> isinstance(service._pending_responses, dict)
367 True
368 >>> len(service._pending_responses)
369 0
370 >>> isinstance(service.tool_service, ToolService)
371 True
372 >>> isinstance(service._gateway_failure_counts, dict)
373 True
374 >>> len(service._gateway_failure_counts)
375 0
376 >>> hasattr(service, 'redis_url')
377 True
378 >>>
379 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
380 >>> import asyncio
381 >>> asyncio.run(service._http_client.aclose())
382 """
383 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
384 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL
385 self._health_check_task: Optional[asyncio.Task] = None
386 self._active_gateways: Set[str] = set() # Track active gateway URLs
387 self._stream_response = None
388 self._pending_responses = {}
389 # Prefer using the globally-initialized singletons from mcpgateway.main
390 # (created at application startup). Import lazily to avoid circular
391 # import issues during module import time. Fall back to creating
392 # local instances if the singletons are not available.
393 # Use the globally-exported singletons from the service modules so
394 # events propagate via their initialized EventService/Redis clients.
395 # First-Party
396 from mcpgateway.services.prompt_service import prompt_service
397 from mcpgateway.services.resource_service import resource_service
398 from mcpgateway.services.tool_service import tool_service
400 self.tool_service = tool_service
401 self.prompt_service = prompt_service
402 self.resource_service = resource_service
403 self._gateway_failure_counts: dict[str, int] = {}
404 self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3")))
405 self._event_service = EventService(channel_name="mcpgateway:gateway_events")
407 # Per-gateway refresh locks to prevent concurrent refreshes for the same gateway
408 self._refresh_locks: Dict[str, asyncio.Lock] = {}
410 # For health checks, we determine the leader instance.
411 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None
413 # Initialize optional Redis client holder (set in initialize())
414 self._redis_client: Optional[Any] = None
416 # Leader election settings from config
417 if self.redis_url and REDIS_AVAILABLE:
418 self._instance_id = str(uuid.uuid4()) # Unique ID for this process
419 self._leader_key = settings.redis_leader_key
420 self._leader_ttl = settings.redis_leader_ttl
421 self._leader_heartbeat_interval = settings.redis_leader_heartbeat_interval
422 self._leader_heartbeat_task: Optional[asyncio.Task] = None
424 # Always initialize file lock as fallback (used if Redis connection fails at runtime)
425 if settings.cache_type != "none":
426 temp_dir = tempfile.gettempdir()
427 user_path = os.path.normpath(settings.filelock_name)
428 if os.path.isabs(user_path):
429 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep)
430 full_path = os.path.join(temp_dir, user_path)
431 self._lock_path = full_path.replace("\\", "/")
432 self._file_lock = FileLock(self._lock_path)
434 @staticmethod
435 def normalize_url(url: str) -> str:
436 """
437 Normalize a URL by ensuring it's properly formatted.
439 Special handling for localhost to prevent duplicates:
440 - Converts 127.0.0.1 to localhost for consistency
441 - Preserves all other domain names as-is for CDN/load balancer support
443 Args:
444 url (str): The URL to normalize.
446 Returns:
447 str: The normalized URL.
449 Examples:
450 >>> GatewayService.normalize_url('http://localhost:8080/path')
451 'http://localhost:8080/path'
452 >>> GatewayService.normalize_url('http://127.0.0.1:8080/path')
453 'http://localhost:8080/path'
454 >>> GatewayService.normalize_url('https://example.com/api')
455 'https://example.com/api'
456 """
457 parsed = urlparse(url)
458 hostname = parsed.hostname
460 # Special case: normalize 127.0.0.1 to localhost to prevent duplicates
461 # but preserve all other domains as-is for CDN/load balancer support
462 if hostname == "127.0.0.1":
463 netloc = "localhost"
464 if parsed.port:
465 netloc += f":{parsed.port}"
466 normalized = parsed._replace(netloc=netloc)
467 return str(urlunparse(normalized))
469 # For all other URLs, preserve the domain name
470 return url
472 def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext:
473 """Create an SSL context with the provided CA certificate.
475 Uses caching to avoid repeated SSL context creation for the same certificate.
477 Args:
478 ca_certificate: CA certificate in PEM format
480 Returns:
481 ssl.SSLContext: Configured SSL context
482 """
483 return get_cached_ssl_context(ca_certificate)
485 async def initialize(self) -> None:
486 """Initialize the service and start health check if this instance is the leader.
488 Raises:
489 ConnectionError: When redis ping fails
490 """
491 logger.info("Initializing gateway service")
493 # Initialize event service with shared Redis client
494 await self._event_service.initialize()
496 # NOTE: We intentionally do NOT create a long-lived DB session here.
497 # Health checks use fresh_db_session() only when DB access is actually needed,
498 # avoiding holding connections during HTTP calls to MCP servers.
500 user_email = settings.platform_admin_email
502 # Get shared Redis client from factory
503 if self.redis_url and REDIS_AVAILABLE:
504 self._redis_client = await get_redis_client()
506 if self._redis_client:
507 # Check if Redis is available (ping already done by factory, but verify)
508 try:
509 await self._redis_client.ping()
510 except Exception as e:
511 raise ConnectionError(f"Redis ping failed: {e}") from e
513 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
514 if is_leader: 514 ↛ exitline 514 didn't return from function 'initialize' because the condition on line 514 was always true
515 logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.")
516 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
517 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat())
518 else:
519 # Always create the health check task in filelock mode; leader check is handled inside.
520 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
522 async def shutdown(self) -> None:
523 """Shutdown the service.
525 Examples:
526 >>> service = GatewayService()
527 >>> # Mock internal components
528 >>> from unittest.mock import AsyncMock
529 >>> service._event_service = AsyncMock()
530 >>> service._active_gateways = {'test_gw'}
531 >>> import asyncio
532 >>> asyncio.run(service.shutdown())
533 >>> # Verify event service shutdown was called
534 >>> service._event_service.shutdown.assert_awaited_once()
535 >>> len(service._active_gateways)
536 0
537 """
538 if self._health_check_task:
539 self._health_check_task.cancel()
540 try:
541 await self._health_check_task
542 except asyncio.CancelledError:
543 pass
545 # Cancel leader heartbeat task if running
546 if getattr(self, "_leader_heartbeat_task", None):
547 self._leader_heartbeat_task.cancel()
548 try:
549 await self._leader_heartbeat_task
550 except asyncio.CancelledError:
551 pass
553 # Release Redis leadership atomically if we hold it
554 if self._redis_client:
555 try:
556 # Lua script for atomic check-and-delete (only delete if we own the key)
557 release_script = """
558 if redis.call("get", KEYS[1]) == ARGV[1] then
559 return redis.call("del", KEYS[1])
560 else
561 return 0
562 end
563 """
564 result = await self._redis_client.eval(release_script, 1, self._leader_key, self._instance_id)
565 if result: 565 ↛ 570line 565 didn't jump to line 570 because the condition on line 565 was always true
566 logger.info("Released Redis leadership on shutdown")
567 except Exception as e:
568 logger.warning(f"Failed to release Redis leader key on shutdown: {e}")
570 await self._http_client.aclose()
571 await self._event_service.shutdown()
572 self._active_gateways.clear()
573 logger.info("Gateway service shutdown complete")
575 def _check_gateway_uniqueness(
576 self,
577 db: Session,
578 url: str,
579 auth_value: Optional[Dict[str, str]],
580 oauth_config: Optional[Dict[str, Any]],
581 team_id: Optional[str],
582 owner_email: str,
583 visibility: str,
584 gateway_id: Optional[str] = None,
585 ) -> Optional[DbGateway]:
586 """
587 Check if a gateway with the same URL and credentials already exists.
589 Args:
590 db: Database session
591 url: Gateway URL (normalized)
592 auth_value: Decoded auth_value dict (not encrypted)
593 oauth_config: OAuth configuration dict
594 team_id: Team ID for team-scoped gateways
595 owner_email: Email of the gateway owner
596 visibility: Gateway visibility (public/team/private)
597 gateway_id: Optional gateway ID to exclude from check (for updates)
599 Returns:
600 DbGateway if duplicate found, None otherwise
601 """
602 # Build base query based on visibility
603 if visibility == "public":
604 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "public")
605 elif visibility == "team" and team_id:
606 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "team", DbGateway.team_id == team_id)
607 elif visibility == "private":
608 # Check for duplicates within the same user's private gateways
609 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "private", DbGateway.owner_email == owner_email) # Scoped to same user
610 else:
611 return None
613 # Exclude current gateway if updating
614 if gateway_id:
615 query = query.filter(DbGateway.id != gateway_id)
617 existing_gateways = query.all()
619 # Check each existing gateway
620 for existing in existing_gateways:
621 # Case 1: Both have OAuth config
622 if oauth_config and existing.oauth_config:
623 # Compare OAuth configs (exclude dynamic fields like tokens)
624 existing_oauth = existing.oauth_config or {}
625 new_oauth = oauth_config or {}
627 # Compare key OAuth fields
628 oauth_keys = ["grant_type", "client_id", "authorization_url", "token_url", "scope"]
629 if all(existing_oauth.get(k) == new_oauth.get(k) for k in oauth_keys): 629 ↛ 620line 629 didn't jump to line 620 because the condition on line 629 was always true
630 return existing # Duplicate OAuth config found
632 # Case 2: Both have auth_value (need to decrypt and compare)
633 elif auth_value and existing.auth_value:
635 try:
636 # Decrypt existing auth_value
637 if isinstance(existing.auth_value, str):
638 existing_decoded = decode_auth(existing.auth_value)
640 elif isinstance(existing.auth_value, dict):
641 existing_decoded = existing.auth_value
643 else:
644 continue
646 # Compare decoded auth values
647 if auth_value == existing_decoded: 647 ↛ 620line 647 didn't jump to line 620 because the condition on line 647 was always true
648 return existing # Duplicate credentials found
649 except Exception as e:
650 logger.warning(f"Failed to decode auth_value for comparison: {e}")
651 continue
653 # Case 3: Both have no auth (URL only, not allowed)
654 elif not auth_value and not oauth_config and not existing.auth_value and not existing.oauth_config: 654 ↛ 620line 654 didn't jump to line 620 because the condition on line 654 was always true
655 return existing # Duplicate URL without credentials
657 return None # No duplicate found
659 async def register_gateway(
660 self,
661 db: Session,
662 gateway: GatewayCreate,
663 created_by: Optional[str] = None,
664 created_from_ip: Optional[str] = None,
665 created_via: Optional[str] = None,
666 created_user_agent: Optional[str] = None,
667 team_id: Optional[str] = None,
668 owner_email: Optional[str] = None,
669 visibility: Optional[str] = None,
670 initialize_timeout: Optional[float] = None,
671 ) -> GatewayRead:
672 """Register a new gateway.
674 Args:
675 db: Database session
676 gateway: Gateway creation schema
677 created_by: Username who created this gateway
678 created_from_ip: IP address of creator
679 created_via: Creation method (ui, api, federation)
680 created_user_agent: User agent of creation request
681 team_id (Optional[str]): Team ID to assign the gateway to.
682 owner_email (Optional[str]): Email of the user who owns this gateway.
683 visibility (Optional[str]): Gateway visibility level (private, team, public).
684 initialize_timeout (Optional[float]): Timeout in seconds for gateway initialization.
686 Returns:
687 Created gateway information
689 Raises:
690 GatewayNameConflictError: If gateway name already exists
691 GatewayConnectionError: If there was an error connecting to the gateway
692 ValueError: If required values are missing
693 RuntimeError: If there is an error during processing that is not covered by other exceptions
694 IntegrityError: If there is a database integrity error
695 BaseException: If an unexpected error occurs
697 Examples:
698 >>> from mcpgateway.services.gateway_service import GatewayService
699 >>> from unittest.mock import MagicMock
700 >>> service = GatewayService()
701 >>> db = MagicMock()
702 >>> gateway = MagicMock()
703 >>> db.execute.return_value.scalar_one_or_none.return_value = None
704 >>> db.add = MagicMock()
705 >>> db.commit = MagicMock()
706 >>> db.refresh = MagicMock()
707 >>> service._notify_gateway_added = MagicMock()
708 >>> import asyncio
709 >>> try:
710 ... asyncio.run(service.register_gateway(db, gateway))
711 ... except Exception:
712 ... pass
713 >>>
714 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
715 >>> asyncio.run(service._http_client.aclose())
716 """
717 visibility = "public" if visibility not in ("private", "team", "public") else visibility
718 try:
719 # # Check for name conflicts (both active and inactive)
720 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none()
722 # if existing_gateway:
723 # raise GatewayNameConflictError(
724 # gateway.name,
725 # enabled=existing_gateway.enabled,
726 # gateway_id=existing_gateway.id,
727 # )
728 # Check for existing gateway with the same slug and visibility
729 slug_name = slugify(gateway.name)
730 if visibility.lower() == "public":
731 # Check for existing public gateway with the same slug (row-locked)
732 existing_gateway = get_for_update(
733 db,
734 DbGateway,
735 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "public"),
736 )
737 if existing_gateway:
738 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility)
739 elif visibility.lower() == "team" and team_id: 739 ↛ 750line 739 didn't jump to line 750 because the condition on line 739 was always true
740 # Check for existing team gateway with the same slug (row-locked)
741 existing_gateway = get_for_update(
742 db,
743 DbGateway,
744 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "team", DbGateway.team_id == team_id),
745 )
746 if existing_gateway: 746 ↛ 750line 746 didn't jump to line 750 because the condition on line 746 was always true
747 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility)
749 # Normalize the gateway URL
750 normalized_url = self.normalize_url(str(gateway.url))
752 decoded_auth_value = None
753 if gateway.auth_value:
754 if isinstance(gateway.auth_value, str):
755 try:
756 decoded_auth_value = decode_auth(gateway.auth_value)
757 except Exception as e:
758 logger.warning(f"Failed to decode provided auth_value: {e}")
759 decoded_auth_value = None
760 elif isinstance(gateway.auth_value, dict): 760 ↛ 764line 760 didn't jump to line 764 because the condition on line 760 was always true
761 decoded_auth_value = gateway.auth_value
763 # Check for duplicate gateway
764 if not gateway.one_time_auth:
765 duplicate_gateway = self._check_gateway_uniqueness(
766 db=db, url=normalized_url, auth_value=decoded_auth_value, oauth_config=gateway.oauth_config, team_id=team_id, owner_email=owner_email, visibility=visibility
767 )
769 if duplicate_gateway:
770 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway)
772 # Prevent URL-only gateways (no auth at all)
773 # if not decoded_auth_value and not gateway.oauth_config:
774 # raise ValueError(
775 # f"Gateway with URL '{normalized_url}' must have either auth_value or oauth_config. "
776 # "URL-only gateways are not allowed."
777 # )
779 auth_type = getattr(gateway, "auth_type", None)
780 # Support multiple custom headers
781 auth_value = getattr(gateway, "auth_value", {})
782 authentication_headers: Optional[Dict[str, str]] = None
784 # Handle query_param auth - encrypt and prepare for storage
785 auth_query_params_encrypted: Optional[Dict[str, str]] = None
786 auth_query_params_decrypted: Optional[Dict[str, str]] = None
787 init_url = normalized_url # URL to use for initialization
789 if auth_type == "query_param":
790 # Extract and encrypt query param auth
791 param_key = getattr(gateway, "auth_query_param_key", None)
792 param_value = getattr(gateway, "auth_query_param_value", None)
793 if param_key and param_value: 793 ↛ 823line 793 didn't jump to line 823 because the condition on line 793 was always true
794 # Get the actual secret value
795 if hasattr(param_value, "get_secret_value"):
796 raw_value = param_value.get_secret_value()
797 else:
798 raw_value = str(param_value)
799 # Encrypt for storage
800 encrypted_value = encode_auth({param_key: raw_value})
801 auth_query_params_encrypted = {param_key: encrypted_value}
802 auth_query_params_decrypted = {param_key: raw_value}
803 # Append query params to URL for initialization
804 init_url = apply_query_param_auth(normalized_url, auth_query_params_decrypted)
805 # Query param auth doesn't use auth_value
806 auth_value = None
807 authentication_headers = None
809 elif hasattr(gateway, "auth_headers") and gateway.auth_headers:
810 # Convert list of {key, value} to dict
811 header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
812 # Keep encoded form for persistence, but pass raw headers for initialization
813 auth_value = encode_auth(header_dict) # Encode the dict for consistency
814 authentication_headers = {str(k): str(v) for k, v in header_dict.items()}
816 elif isinstance(auth_value, str) and auth_value:
817 # Decode persisted auth for initialization
818 decoded = decode_auth(auth_value)
819 authentication_headers = {str(k): str(v) for k, v in decoded.items()}
820 else:
821 authentication_headers = None
823 oauth_config = getattr(gateway, "oauth_config", None)
824 ca_certificate = getattr(gateway, "ca_certificate", None)
825 if initialize_timeout is not None:
826 try:
827 capabilities, tools, resources, prompts = await asyncio.wait_for(
828 self._initialize_gateway(
829 init_url, # URL with query params if applicable
830 authentication_headers,
831 gateway.transport,
832 auth_type,
833 oauth_config,
834 ca_certificate,
835 auth_query_params=auth_query_params_decrypted,
836 ),
837 timeout=initialize_timeout,
838 )
839 except asyncio.TimeoutError as exc:
840 sanitized = sanitize_url_for_logging(init_url, auth_query_params_decrypted)
841 raise GatewayConnectionError(f"Gateway initialization timed out after {initialize_timeout}s for {sanitized}") from exc
842 else:
843 capabilities, tools, resources, prompts = await self._initialize_gateway(
844 init_url, # URL with query params if applicable
845 authentication_headers,
846 gateway.transport,
847 auth_type,
848 oauth_config,
849 ca_certificate,
850 auth_query_params=auth_query_params_decrypted,
851 )
853 if gateway.one_time_auth:
854 # For one-time auth, clear auth_type and auth_value after initialization
855 auth_type = "one_time_auth"
856 auth_value = None
857 oauth_config = None
859 tools = [
860 DbTool(
861 original_name=tool.name,
862 custom_name=tool.name,
863 custom_name_slug=slugify(tool.name),
864 display_name=generate_display_name(tool.name),
865 url=normalized_url,
866 description=tool.description,
867 integration_type="MCP", # Gateway-discovered tools are MCP type
868 request_type=tool.request_type,
869 headers=tool.headers,
870 input_schema=tool.input_schema,
871 output_schema=tool.output_schema,
872 annotations=tool.annotations,
873 jsonpath_filter=tool.jsonpath_filter,
874 auth_type=auth_type,
875 auth_value=auth_value,
876 # Federation metadata
877 created_by=created_by or "system",
878 created_from_ip=created_from_ip,
879 created_via="federation", # These are federated tools
880 created_user_agent=created_user_agent,
881 federation_source=gateway.name,
882 version=1,
883 # Inherit team assignment from gateway
884 team_id=team_id,
885 owner_email=owner_email,
886 visibility=visibility,
887 )
888 for tool in tools
889 ]
891 # Create resource DB models with upsert logic for ORPHANED resources only
892 # Query for existing ORPHANED resources (gateway_id IS NULL or points to non-existent gateway)
893 # with same (team_id, owner_email, uri) to handle resources left behind from incomplete
894 # gateway deletions (e.g., issue #2341 crash scenarios).
895 # We only update orphaned resources - resources belonging to active gateways are not touched.
896 resource_uris = [r.uri for r in resources]
897 effective_owner = owner_email or created_by
899 # Build lookup map: (team_id, owner_email, uri) -> orphaned DbResource
900 # We query all resources matching our URIs, then filter to orphaned ones in Python
901 # to handle per-resource team/owner overrides correctly
902 orphaned_resources_map: Dict[tuple, DbResource] = {}
903 if resource_uris:
904 try:
905 # Get valid gateway IDs to identify orphaned resources
906 valid_gateway_ids = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all())
907 candidate_resources = db.execute(select(DbResource).where(DbResource.uri.in_(resource_uris))).scalars().all()
908 for res in candidate_resources:
909 # Only consider orphaned resources (no gateway or gateway doesn't exist)
910 is_orphaned = res.gateway_id is None or res.gateway_id not in valid_gateway_ids
911 if is_orphaned: 911 ↛ 908line 911 didn't jump to line 908 because the condition on line 911 was always true
912 key = (res.team_id, res.owner_email, res.uri)
913 orphaned_resources_map[key] = res
914 if orphaned_resources_map:
915 logger.info(f"Found {len(orphaned_resources_map)} orphaned resources to reassign for gateway {gateway.name}")
916 except Exception as e:
917 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new resources
918 # This is conservative - we won't accidentally reassign resources from active gateways
919 logger.debug(f"Orphan resource detection skipped: {e}")
921 db_resources = []
922 for r in resources:
923 mime_type = mimetypes.guess_type(r.uri)[0] or ("text/plain" if isinstance(r.content, str) else "application/octet-stream")
924 r_team_id = getattr(r, "team_id", None) or team_id
925 r_owner_email = getattr(r, "owner_email", None) or effective_owner
926 r_visibility = getattr(r, "visibility", None) or visibility
928 # Check if there's an orphaned resource with matching unique key
929 lookup_key = (r_team_id, r_owner_email, r.uri)
930 if lookup_key in orphaned_resources_map:
931 # Update orphaned resource - reassign to new gateway
932 existing = orphaned_resources_map[lookup_key]
933 existing.name = r.name
934 existing.description = r.description
935 existing.mime_type = mime_type
936 existing.uri_template = r.uri_template or None
937 existing.text_content = r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None
938 existing.binary_content = (
939 r.content.encode() if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else r.content if isinstance(r.content, bytes) else None
940 )
941 existing.size = len(r.content) if r.content else 0
942 existing.tags = getattr(r, "tags", []) or []
943 existing.federation_source = gateway.name
944 existing.modified_by = created_by
945 existing.modified_from_ip = created_from_ip
946 existing.modified_via = "federation"
947 existing.modified_user_agent = created_user_agent
948 existing.updated_at = datetime.now(timezone.utc)
949 existing.visibility = r_visibility
950 # Note: gateway_id will be set when gateway is created (relationship)
951 db_resources.append(existing)
952 else:
953 # Create new resource
954 db_resources.append(
955 DbResource(
956 uri=r.uri,
957 name=r.name,
958 description=r.description,
959 mime_type=mime_type,
960 uri_template=r.uri_template or None,
961 text_content=r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None,
962 binary_content=(
963 r.content.encode()
964 if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str)
965 else r.content if isinstance(r.content, bytes) else None
966 ),
967 size=len(r.content) if r.content else 0,
968 tags=getattr(r, "tags", []) or [],
969 created_by=created_by or "system",
970 created_from_ip=created_from_ip,
971 created_via="federation",
972 created_user_agent=created_user_agent,
973 import_batch_id=None,
974 federation_source=gateway.name,
975 version=1,
976 team_id=r_team_id,
977 owner_email=r_owner_email,
978 visibility=r_visibility,
979 )
980 )
982 # Create prompt DB models with upsert logic for ORPHANED prompts only
983 # Query for existing ORPHANED prompts (gateway_id IS NULL or points to non-existent gateway)
984 # with same (team_id, owner_email, name) to handle prompts left behind from incomplete
985 # gateway deletions. We only update orphaned prompts - prompts belonging to active gateways are not touched.
986 prompt_names = [p.name for p in prompts]
988 # Build lookup map: (team_id, owner_email, name) -> orphaned DbPrompt
989 orphaned_prompts_map: Dict[tuple, DbPrompt] = {}
990 if prompt_names:
991 try:
992 # Get valid gateway IDs to identify orphaned prompts
993 valid_gateway_ids_for_prompts = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all())
994 candidate_prompts = db.execute(select(DbPrompt).where(DbPrompt.name.in_(prompt_names))).scalars().all()
995 for pmt in candidate_prompts:
996 # Only consider orphaned prompts (no gateway or gateway doesn't exist)
997 is_orphaned = pmt.gateway_id is None or pmt.gateway_id not in valid_gateway_ids_for_prompts
998 if is_orphaned: 998 ↛ 995line 998 didn't jump to line 995 because the condition on line 998 was always true
999 key = (pmt.team_id, pmt.owner_email, pmt.name)
1000 orphaned_prompts_map[key] = pmt
1001 if orphaned_prompts_map:
1002 logger.info(f"Found {len(orphaned_prompts_map)} orphaned prompts to reassign for gateway {gateway.name}")
1003 except Exception as e:
1004 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new prompts
1005 logger.debug(f"Orphan prompt detection skipped: {e}")
1007 db_prompts = []
1008 for prompt in prompts:
1009 # Prompts inherit team/owner from gateway (no per-prompt overrides)
1010 p_team_id = team_id
1011 p_owner_email = owner_email or effective_owner
1013 # Check if there's an orphaned prompt with matching unique key
1014 lookup_key = (p_team_id, p_owner_email, prompt.name)
1015 if lookup_key in orphaned_prompts_map:
1016 # Update orphaned prompt - reassign to new gateway
1017 existing = orphaned_prompts_map[lookup_key]
1018 existing.original_name = prompt.name
1019 existing.custom_name = prompt.name
1020 existing.display_name = prompt.name
1021 existing.description = prompt.description
1022 existing.template = prompt.template if hasattr(prompt, "template") else ""
1023 existing.federation_source = gateway.name
1024 existing.modified_by = created_by
1025 existing.modified_from_ip = created_from_ip
1026 existing.modified_via = "federation"
1027 existing.modified_user_agent = created_user_agent
1028 existing.updated_at = datetime.now(timezone.utc)
1029 existing.visibility = visibility
1030 # Note: gateway_id will be set when gateway is created (relationship)
1031 db_prompts.append(existing)
1032 else:
1033 # Create new prompt
1034 db_prompts.append(
1035 DbPrompt(
1036 name=prompt.name,
1037 original_name=prompt.name,
1038 custom_name=prompt.name,
1039 display_name=prompt.name,
1040 description=prompt.description,
1041 template=prompt.template if hasattr(prompt, "template") else "",
1042 argument_schema={}, # Use argument_schema instead of arguments
1043 # Federation metadata
1044 created_by=created_by or "system",
1045 created_from_ip=created_from_ip,
1046 created_via="federation", # These are federated prompts
1047 created_user_agent=created_user_agent,
1048 federation_source=gateway.name,
1049 version=1,
1050 # Inherit team assignment from gateway
1051 team_id=team_id,
1052 owner_email=owner_email,
1053 visibility=visibility,
1054 )
1055 )
1057 # Create DB model
1058 db_gateway = DbGateway(
1059 name=gateway.name,
1060 slug=slug_name,
1061 url=normalized_url,
1062 description=gateway.description,
1063 tags=gateway.tags or [],
1064 transport=gateway.transport,
1065 capabilities=capabilities,
1066 last_seen=datetime.now(timezone.utc),
1067 auth_type=auth_type,
1068 auth_value=auth_value,
1069 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth
1070 oauth_config=oauth_config,
1071 passthrough_headers=gateway.passthrough_headers,
1072 tools=tools,
1073 resources=db_resources,
1074 prompts=db_prompts,
1075 # Gateway metadata
1076 created_by=created_by,
1077 created_from_ip=created_from_ip,
1078 created_via=created_via or "api",
1079 created_user_agent=created_user_agent,
1080 version=1,
1081 # Team scoping fields
1082 team_id=team_id,
1083 owner_email=owner_email,
1084 visibility=visibility,
1085 ca_certificate=gateway.ca_certificate,
1086 ca_certificate_sig=gateway.ca_certificate_sig,
1087 signing_algorithm=gateway.signing_algorithm,
1088 )
1090 # Add to DB
1091 db.add(db_gateway)
1092 db.flush() # Flush to get the ID without committing
1093 db.refresh(db_gateway)
1095 # Update tracking
1096 self._active_gateways.add(db_gateway.url)
1098 # Notify subscribers
1099 await self._notify_gateway_added(db_gateway)
1101 logger.info(f"Registered gateway: {gateway.name}")
1103 # Structured logging: Audit trail for gateway creation
1104 audit_trail.log_action(
1105 user_id=created_by or "system",
1106 action="create_gateway",
1107 resource_type="gateway",
1108 resource_id=str(db_gateway.id),
1109 resource_name=db_gateway.name,
1110 user_email=owner_email,
1111 team_id=team_id,
1112 client_ip=created_from_ip,
1113 user_agent=created_user_agent,
1114 new_values={
1115 "name": db_gateway.name,
1116 "url": db_gateway.url,
1117 "visibility": visibility,
1118 "transport": db_gateway.transport,
1119 "tools_count": len(tools),
1120 "resources_count": len(db_resources),
1121 "prompts_count": len(db_prompts),
1122 },
1123 context={
1124 "created_via": created_via,
1125 },
1126 db=db,
1127 )
1129 # Structured logging: Log successful gateway creation
1130 structured_logger.log(
1131 level="INFO",
1132 message="Gateway created successfully",
1133 event_type="gateway_created",
1134 component="gateway_service",
1135 user_id=created_by,
1136 user_email=owner_email,
1137 team_id=team_id,
1138 resource_type="gateway",
1139 resource_id=str(db_gateway.id),
1140 custom_fields={
1141 "gateway_name": db_gateway.name,
1142 "gateway_url": normalized_url,
1143 "visibility": visibility,
1144 "transport": db_gateway.transport,
1145 },
1146 db=db,
1147 )
1149 return GatewayRead.model_validate(self._prepare_gateway_for_read(db_gateway)).masked()
1150 except* GatewayConnectionError as ge: # pragma: no mutate
1151 if TYPE_CHECKING:
1152 ge: ExceptionGroup[GatewayConnectionError]
1153 logger.error(f"GatewayConnectionError in group: {ge.exceptions}")
1155 structured_logger.log(
1156 level="ERROR",
1157 message="Gateway creation failed due to connection error",
1158 event_type="gateway_creation_failed",
1159 component="gateway_service",
1160 user_id=created_by,
1161 user_email=owner_email,
1162 error=ge.exceptions[0],
1163 custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)},
1164 db=db,
1165 )
1166 raise ge.exceptions[0]
1167 except* GatewayNameConflictError as gnce: # pragma: no mutate
1168 if TYPE_CHECKING:
1169 gnce: ExceptionGroup[GatewayNameConflictError]
1170 logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}")
1172 structured_logger.log(
1173 level="WARNING",
1174 message="Gateway creation failed due to name conflict",
1175 event_type="gateway_name_conflict",
1176 component="gateway_service",
1177 user_id=created_by,
1178 user_email=owner_email,
1179 custom_fields={"gateway_name": gateway.name, "visibility": visibility},
1180 db=db,
1181 )
1182 raise gnce.exceptions[0]
1183 except* GatewayDuplicateConflictError as guce: # pragma: no mutate
1184 if TYPE_CHECKING:
1185 guce: ExceptionGroup[GatewayDuplicateConflictError]
1186 logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}")
1188 structured_logger.log(
1189 level="WARNING",
1190 message="Gateway creation failed due to duplicate",
1191 event_type="gateway_duplicate_conflict",
1192 component="gateway_service",
1193 user_id=created_by,
1194 user_email=owner_email,
1195 custom_fields={"gateway_name": gateway.name},
1196 db=db,
1197 )
1198 raise guce.exceptions[0]
1199 except* ValueError as ve: # pragma: no mutate
1200 if TYPE_CHECKING:
1201 ve: ExceptionGroup[ValueError]
1202 logger.error(f"ValueErrors in group: {ve.exceptions}")
1204 structured_logger.log(
1205 level="ERROR",
1206 message="Gateway creation failed due to validation error",
1207 event_type="gateway_creation_failed",
1208 component="gateway_service",
1209 user_id=created_by,
1210 user_email=owner_email,
1211 error=ve.exceptions[0],
1212 custom_fields={"gateway_name": gateway.name},
1213 db=db,
1214 )
1215 raise ve.exceptions[0]
1216 except* RuntimeError as re: # pragma: no mutate
1217 if TYPE_CHECKING:
1218 re: ExceptionGroup[RuntimeError]
1219 logger.error(f"RuntimeErrors in group: {re.exceptions}")
1221 structured_logger.log(
1222 level="ERROR",
1223 message="Gateway creation failed due to runtime error",
1224 event_type="gateway_creation_failed",
1225 component="gateway_service",
1226 user_id=created_by,
1227 user_email=owner_email,
1228 error=re.exceptions[0],
1229 custom_fields={"gateway_name": gateway.name},
1230 db=db,
1231 )
1232 raise re.exceptions[0]
1233 except* IntegrityError as ie: # pragma: no mutate
1234 if TYPE_CHECKING:
1235 ie: ExceptionGroup[IntegrityError]
1236 logger.error(f"IntegrityErrors in group: {ie.exceptions}")
1238 structured_logger.log(
1239 level="ERROR",
1240 message="Gateway creation failed due to database integrity error",
1241 event_type="gateway_creation_failed",
1242 component="gateway_service",
1243 user_id=created_by,
1244 user_email=owner_email,
1245 error=ie.exceptions[0],
1246 custom_fields={"gateway_name": gateway.name},
1247 db=db,
1248 )
1249 raise ie.exceptions[0]
1250 except* BaseException as other: # catches every other sub-exception # pragma: no mutate
1251 if TYPE_CHECKING:
1252 other: ExceptionGroup[Exception]
1253 logger.error(f"Other grouped errors: {other.exceptions}")
1254 raise other.exceptions[0]
1256 async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_email: str) -> Dict[str, Any]:
1257 """Fetch tools from MCP server after OAuth completion for Authorization Code flow.
1259 Args:
1260 db: Database session
1261 gateway_id: ID of the gateway to fetch tools for
1262 app_user_email: MCP Gateway user email for token retrieval
1264 Returns:
1265 Dict containing capabilities, tools, resources, and prompts
1267 Raises:
1268 GatewayConnectionError: If connection or OAuth fails
1269 """
1270 try:
1271 # Get the gateway with eager loading for sync operations to avoid N+1 queries
1272 gateway = db.execute(
1273 select(DbGateway)
1274 .options(
1275 selectinload(DbGateway.tools),
1276 selectinload(DbGateway.resources),
1277 selectinload(DbGateway.prompts),
1278 joinedload(DbGateway.email_team),
1279 )
1280 .where(DbGateway.id == gateway_id)
1281 ).scalar_one_or_none()
1283 if not gateway:
1284 raise ValueError(f"Gateway {gateway_id} not found")
1286 if not gateway.oauth_config:
1287 raise ValueError(f"Gateway {gateway_id} has no OAuth configuration")
1289 grant_type = gateway.oauth_config.get("grant_type")
1290 if grant_type != "authorization_code":
1291 raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow")
1293 # Get OAuth tokens for this gateway
1294 # First-Party
1295 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
1297 token_storage = TokenStorageService(db)
1299 # Get user-specific OAuth token
1300 if not app_user_email:
1301 raise GatewayConnectionError(f"User authentication required for OAuth gateway {gateway.name}")
1303 access_token = await token_storage.get_user_token(gateway.id, app_user_email)
1305 if not access_token:
1306 raise GatewayConnectionError(
1307 f"No OAuth tokens found for user {app_user_email} on gateway {gateway.name}. Please complete the OAuth authorization flow first at /oauth/authorize/{gateway.id}"
1308 )
1310 # Debug: Check if token was decrypted
1311 if access_token.startswith("Z0FBQUFBQm"): # Encrypted tokens start with this
1312 logger.error(f"Token appears to be encrypted! Encryption service may have failed. Token length: {len(access_token)}")
1313 else:
1314 logger.info(f"Using decrypted OAuth token for {gateway.name} (length: {len(access_token)})")
1316 # Now connect to MCP server with the access token
1317 authentication = {"Authorization": f"Bearer {access_token}"}
1319 # Use the existing connection logic
1320 # Note: For OAuth servers, skip validation since we already validated via OAuth flow
1321 if gateway.transport.upper() == "SSE":
1322 capabilities, tools, resources, prompts = await self._connect_to_sse_server_without_validation(gateway.url, authentication)
1323 elif gateway.transport.upper() == "STREAMABLEHTTP":
1324 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication)
1325 else:
1326 raise ValueError(f"Unsupported transport type: {gateway.transport}")
1328 # Handle tools, resources, and prompts using helper methods
1329 tools_to_add = self._update_or_create_tools(db, tools, gateway, "oauth")
1330 resources_to_add = self._update_or_create_resources(db, resources, gateway, "oauth")
1331 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "oauth")
1333 # Clean up items that are no longer available from the gateway
1334 new_tool_names = [tool.name for tool in tools]
1335 new_resource_uris = [resource.uri for resource in resources]
1336 new_prompt_names = [prompt.name for prompt in prompts]
1338 # Count items before cleanup for logging
1340 # Bulk delete tools that are no longer available from the gateway
1341 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
1342 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
1343 if stale_tool_ids:
1344 # Delete child records first to avoid FK constraint violations
1345 for i in range(0, len(stale_tool_ids), 500):
1346 chunk = stale_tool_ids[i : i + 500]
1347 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
1348 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
1349 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
1351 # Bulk delete resources that are no longer available from the gateway
1352 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
1353 if stale_resource_ids:
1354 # Delete child records first to avoid FK constraint violations
1355 for i in range(0, len(stale_resource_ids), 500):
1356 chunk = stale_resource_ids[i : i + 500]
1357 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
1358 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
1359 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
1360 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
1362 # Bulk delete prompts that are no longer available from the gateway
1363 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
1364 if stale_prompt_ids:
1365 # Delete child records first to avoid FK constraint violations
1366 for i in range(0, len(stale_prompt_ids), 500):
1367 chunk = stale_prompt_ids[i : i + 500]
1368 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
1369 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
1370 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
1372 # Expire gateway to clear cached relationships after bulk deletes
1373 # This prevents SQLAlchemy from trying to re-delete already-deleted items
1374 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
1375 db.expire(gateway)
1377 # Update gateway relationships to reflect deletions
1378 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names]
1379 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris]
1380 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names]
1382 # Log cleanup results
1383 tools_removed = len(stale_tool_ids)
1384 resources_removed = len(stale_resource_ids)
1385 prompts_removed = len(stale_prompt_ids)
1387 if tools_removed > 0:
1388 logger.info(f"Removed {tools_removed} tools no longer available from gateway")
1389 if resources_removed > 0:
1390 logger.info(f"Removed {resources_removed} resources no longer available from gateway")
1391 if prompts_removed > 0:
1392 logger.info(f"Removed {prompts_removed} prompts no longer available from gateway")
1394 # Update gateway capabilities and last_seen
1395 gateway.capabilities = capabilities
1396 gateway.last_seen = datetime.now(timezone.utc)
1398 # Register capabilities for notification-driven actions
1399 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
1401 # Add new items to DB in chunks to prevent lock escalation
1402 items_added = 0
1403 chunk_size = 50
1405 if tools_to_add:
1406 for i in range(0, len(tools_to_add), chunk_size):
1407 chunk = tools_to_add[i : i + chunk_size]
1408 db.add_all(chunk)
1409 db.flush() # Flush each chunk to avoid excessive memory usage
1410 items_added += len(tools_to_add)
1411 logger.info(f"Added {len(tools_to_add)} new tools to database")
1413 if resources_to_add:
1414 for i in range(0, len(resources_to_add), chunk_size):
1415 chunk = resources_to_add[i : i + chunk_size]
1416 db.add_all(chunk)
1417 db.flush()
1418 items_added += len(resources_to_add)
1419 logger.info(f"Added {len(resources_to_add)} new resources to database")
1421 if prompts_to_add:
1422 for i in range(0, len(prompts_to_add), chunk_size):
1423 chunk = prompts_to_add[i : i + chunk_size]
1424 db.add_all(chunk)
1425 db.flush()
1426 items_added += len(prompts_to_add)
1427 logger.info(f"Added {len(prompts_to_add)} new prompts to database")
1429 if items_added > 0:
1430 db.commit()
1431 logger.info(f"Total {items_added} new items added to database")
1432 else:
1433 logger.info("No new items to add to database")
1434 # Still commit to save any updates to existing items
1435 db.commit()
1437 cache = _get_registry_cache()
1438 await cache.invalidate_tools()
1439 await cache.invalidate_resources()
1440 await cache.invalidate_prompts()
1441 tool_lookup_cache = _get_tool_lookup_cache()
1442 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
1443 # Also invalidate tags cache since tool/resource tags may have changed
1444 # First-Party
1445 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1447 await admin_stats_cache.invalidate_tags()
1449 return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts}
1451 except GatewayConnectionError as gce:
1452 # Surface validation or depth-related failures directly to the user
1453 logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}")
1454 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}")
1455 except Exception as e:
1456 logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}")
1457 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}")
1459 async def list_gateways(
1460 self,
1461 db: Session,
1462 include_inactive: bool = False,
1463 tags: Optional[List[str]] = None,
1464 cursor: Optional[str] = None,
1465 limit: Optional[int] = None,
1466 page: Optional[int] = None,
1467 per_page: Optional[int] = None,
1468 user_email: Optional[str] = None,
1469 team_id: Optional[str] = None,
1470 visibility: Optional[str] = None,
1471 token_teams: Optional[List[str]] = None,
1472 ) -> Union[tuple[List[GatewayRead], Optional[str]], Dict[str, Any]]:
1473 """List all registered gateways with cursor pagination and optional team filtering.
1475 Args:
1476 db: Database session
1477 include_inactive: Whether to include inactive gateways
1478 tags (Optional[List[str]]): Filter resources by tags. If provided, only resources with at least one matching tag will be returned.
1479 cursor: Cursor for pagination (encoded last created_at and id).
1480 limit: Maximum number of gateways to return. None for default, 0 for unlimited.
1481 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1482 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1483 user_email: Email of user for team-based access control. None for no access control.
1484 team_id: Optional team ID to filter by specific team (requires user_email).
1485 visibility: Optional visibility filter (private, team, public) (requires user_email).
1486 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only).
1488 Returns:
1489 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}}
1490 If cursor is provided or neither: tuple of (list of GatewayRead objects, next_cursor).
1492 Examples:
1493 >>> from mcpgateway.services.gateway_service import GatewayService
1494 >>> from unittest.mock import MagicMock, AsyncMock, patch
1495 >>> from mcpgateway.schemas import GatewayRead
1496 >>> import asyncio
1497 >>> service = GatewayService()
1498 >>> db = MagicMock()
1499 >>> gateway_obj = MagicMock()
1500 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj]
1501 >>> gateway_read_obj = MagicMock(spec=GatewayRead)
1502 >>> service.convert_gateway_to_read = MagicMock(return_value=gateway_read_obj)
1503 >>> # Mock the cache to bypass caching logic
1504 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory:
1505 ... mock_cache = MagicMock()
1506 ... mock_cache.get = AsyncMock(return_value=None)
1507 ... mock_cache.set = AsyncMock(return_value=None)
1508 ... mock_cache.hash_filters = MagicMock(return_value="hash")
1509 ... mock_cache_factory.return_value = mock_cache
1510 ... gateways, cursor = asyncio.run(service.list_gateways(db))
1511 ... gateways == [gateway_read_obj] and cursor is None
1512 True
1514 >>> # Test empty result
1515 >>> db.execute.return_value.scalars.return_value.all.return_value = []
1516 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory:
1517 ... mock_cache = MagicMock()
1518 ... mock_cache.get = AsyncMock(return_value=None)
1519 ... mock_cache.set = AsyncMock(return_value=None)
1520 ... mock_cache.hash_filters = MagicMock(return_value="hash")
1521 ... mock_cache_factory.return_value = mock_cache
1522 ... empty_result, cursor = asyncio.run(service.list_gateways(db))
1523 ... empty_result == [] and cursor is None
1524 True
1525 >>>
1526 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
1527 >>> asyncio.run(service._http_client.aclose())
1528 """
1529 # Check cache for first page only - only for public-only queries (no user/team filtering)
1530 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped
1531 cache = _get_registry_cache()
1532 is_public_only = token_teams is not None and len(token_teams) == 0
1533 use_cache = cursor is None and user_email is None and page is None and is_public_only
1534 if use_cache:
1535 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None)
1536 cached = await cache.get("gateways", filters_hash)
1537 if cached is not None:
1538 # Reconstruct GatewayRead objects from cached dicts
1539 # SECURITY: Always apply .masked() to ensure stale cache entries don't leak credentials
1540 cached_gateways = [GatewayRead.model_validate(g).masked() for g in cached["gateways"]]
1541 return (cached_gateways, cached.get("next_cursor"))
1543 # Build base query with ordering
1544 query = select(DbGateway).options(joinedload(DbGateway.email_team)).order_by(desc(DbGateway.created_at), desc(DbGateway.id))
1546 # Apply active/inactive filter
1547 if not include_inactive:
1548 query = query.where(DbGateway.enabled)
1550 # SECURITY: Apply token-based access control based on normalized token_teams
1551 # - token_teams is None: admin bypass (is_admin=true with explicit null teams) - sees all
1552 # - token_teams is []: public-only access (missing teams or explicit empty)
1553 # - token_teams is [...]: access to specified teams + public + user's own
1554 if token_teams is not None:
1555 if len(token_teams) == 0:
1556 # Public-only token: only access public gateways
1557 query = query.where(DbGateway.visibility == "public")
1558 else:
1559 # Team-scoped token: public gateways + gateways in allowed teams + user's own
1560 access_conditions = [
1561 DbGateway.visibility == "public",
1562 and_(DbGateway.team_id.in_(token_teams), DbGateway.visibility.in_(["team", "public"])),
1563 ]
1564 if user_email:
1565 access_conditions.append(and_(DbGateway.owner_email == user_email, DbGateway.visibility == "private"))
1566 query = query.where(or_(*access_conditions))
1568 if visibility:
1569 query = query.where(DbGateway.visibility == visibility)
1571 # Apply team-based access control if user_email is provided (and no token_teams filtering)
1572 elif user_email:
1573 team_service = TeamManagementService(db)
1574 user_teams = await team_service.get_user_teams(user_email)
1575 team_ids = [team.id for team in user_teams]
1577 if team_id:
1578 # User requesting specific team - verify access
1579 if team_id not in team_ids:
1580 return ([], None)
1581 access_conditions = [
1582 and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"])),
1583 and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email),
1584 ]
1585 query = query.where(or_(*access_conditions))
1586 else:
1587 # General access: user's gateways + public gateways + team gateways
1588 access_conditions = [
1589 DbGateway.owner_email == user_email,
1590 DbGateway.visibility == "public",
1591 ]
1592 if team_ids:
1593 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"])))
1594 query = query.where(or_(*access_conditions))
1596 if visibility:
1597 query = query.where(DbGateway.visibility == visibility)
1599 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
1600 if tags:
1601 query = query.where(json_contains_tag_expr(db, DbGateway.tags, tags, match_any=True))
1602 # Use unified pagination helper - handles both page and cursor pagination
1603 pag_result = await unified_paginate(
1604 db=db,
1605 query=query,
1606 page=page,
1607 per_page=per_page,
1608 cursor=cursor,
1609 limit=limit,
1610 base_url="/admin/gateways", # Used for page-based links
1611 query_params={"include_inactive": include_inactive} if include_inactive else {},
1612 )
1614 next_cursor = None
1615 # Extract gateways based on pagination type
1616 if page is not None:
1617 # Page-based: pag_result is a dict
1618 gateways_db = pag_result["data"]
1619 else:
1620 # Cursor-based: pag_result is a tuple
1621 gateways_db, next_cursor = pag_result
1623 db.commit() # Release transaction to avoid idle-in-transaction
1625 # Convert to GatewayRead (common for both pagination types)
1626 result = []
1627 for s in gateways_db:
1628 try:
1629 result.append(self.convert_gateway_to_read(s))
1630 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1631 logger.exception(f"Failed to convert gateway {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
1632 # Continue with remaining gateways instead of failing completely
1634 # Return appropriate format based on pagination type
1635 if page is not None:
1636 # Page-based format
1637 return {
1638 "data": result,
1639 "pagination": pag_result["pagination"],
1640 "links": pag_result["links"],
1641 }
1643 # Cursor-based format
1645 # Cache first page results - only for public-only queries (no user/team filtering)
1646 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped
1647 if cursor is None and user_email is None and is_public_only:
1648 try:
1649 cache_data = {"gateways": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
1650 await cache.set("gateways", cache_data, filters_hash)
1651 except AttributeError:
1652 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
1654 return (result, next_cursor)
1656 async def list_gateways_for_user(
1657 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100
1658 ) -> List[GatewayRead]:
1659 """
1660 DEPRECATED: Use list_gateways() with user_email parameter instead.
1662 This method is maintained for backward compatibility but is no longer used.
1663 New code should call list_gateways() with user_email, team_id, and visibility parameters.
1665 List gateways user has access to with team filtering.
1667 Args:
1668 db: Database session
1669 user_email: Email of the user requesting gateways
1670 team_id: Optional team ID to filter by specific team
1671 visibility: Optional visibility filter (private, team, public)
1672 include_inactive: Whether to include inactive gateways
1673 skip: Number of gateways to skip for pagination
1674 limit: Maximum number of gateways to return
1676 Returns:
1677 List[GatewayRead]: Gateways the user has access to
1678 """
1679 # Build query following existing patterns from list_gateways()
1680 team_service = TeamManagementService(db)
1681 user_teams = await team_service.get_user_teams(user_email)
1682 team_ids = [team.id for team in user_teams]
1684 # Use joinedload to eager load email_team relationship (avoids N+1 queries)
1685 query = select(DbGateway).options(joinedload(DbGateway.email_team))
1687 # Apply active/inactive filter
1688 if not include_inactive:
1689 query = query.where(DbGateway.enabled.is_(True))
1691 if team_id:
1692 if team_id not in team_ids:
1693 return [] # No access to team
1695 access_conditions = []
1696 # Filter by specific team
1698 # Team-owned gateways (team-scoped gateways)
1699 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"])))
1701 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email))
1703 # Also include global public gateways (no team_id) so public gateways are visible regardless of selected team
1704 access_conditions.append(DbGateway.visibility == "public")
1706 query = query.where(or_(*access_conditions))
1707 else:
1708 # Get user's accessible teams
1709 # Build access conditions following existing patterns
1710 access_conditions = []
1711 # 1. User's personal resources (owner_email matches)
1712 access_conditions.append(DbGateway.owner_email == user_email)
1713 # 2. Team resources where user is member
1714 if team_ids: 1714 ↛ 1715line 1714 didn't jump to line 1715 because the condition on line 1714 was never true
1715 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"])))
1716 # 3. Public resources (if visibility allows)
1717 access_conditions.append(DbGateway.visibility == "public")
1719 query = query.where(or_(*access_conditions))
1721 # Apply visibility filter if specified
1722 if visibility:
1723 query = query.where(DbGateway.visibility == visibility)
1725 # Apply pagination following existing patterns
1726 query = query.offset(skip).limit(limit)
1728 gateways = db.execute(query).scalars().all()
1730 db.commit() # Release transaction to avoid idle-in-transaction
1732 # Team names are loaded via joinedload(DbGateway.email_team)
1733 result = []
1734 for g in gateways:
1735 logger.info(f"Gateway: {g.team_id}, Team: {g.team}")
1736 result.append(GatewayRead.model_validate(self._prepare_gateway_for_read(g)).masked())
1737 return result
1739 async def update_gateway(
1740 self,
1741 db: Session,
1742 gateway_id: str,
1743 gateway_update: GatewayUpdate,
1744 modified_by: Optional[str] = None,
1745 modified_from_ip: Optional[str] = None,
1746 modified_via: Optional[str] = None,
1747 modified_user_agent: Optional[str] = None,
1748 include_inactive: bool = True,
1749 user_email: Optional[str] = None,
1750 ) -> GatewayRead:
1751 """Update a gateway.
1753 Args:
1754 db: Database session
1755 gateway_id: Gateway ID to update
1756 gateway_update: Updated gateway data
1757 modified_by: Username of the person modifying the gateway
1758 modified_from_ip: IP address where the modification request originated
1759 modified_via: Source of modification (ui/api/import)
1760 modified_user_agent: User agent string from the modification request
1761 include_inactive: Whether to include inactive gateways
1762 user_email: Email of user performing update (for ownership check)
1764 Returns:
1765 Updated gateway information
1767 Raises:
1768 GatewayNotFoundError: If gateway not found
1769 PermissionError: If user doesn't own the gateway
1770 GatewayError: For other update errors
1771 GatewayNameConflictError: If gateway name conflict occurs
1772 IntegrityError: If there is a database integrity error
1773 ValidationError: If validation fails
1774 """
1775 try: # pylint: disable=too-many-nested-blocks
1776 # Acquire row lock and eager-load relationships while locked so
1777 # concurrent updates are serialized on Postgres.
1778 gateway = get_for_update(
1779 db,
1780 DbGateway,
1781 gateway_id,
1782 options=[
1783 selectinload(DbGateway.tools),
1784 selectinload(DbGateway.resources),
1785 selectinload(DbGateway.prompts),
1786 selectinload(DbGateway.email_team), # Use selectinload to avoid locking email_teams
1787 ],
1788 )
1789 if not gateway:
1790 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
1792 # Check ownership if user_email provided
1793 if user_email:
1794 # First-Party
1795 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1797 permission_service = PermissionService(db)
1798 if not await permission_service.check_resource_ownership(user_email, gateway): 1798 ↛ 1801line 1798 didn't jump to line 1801 because the condition on line 1798 was always true
1799 raise PermissionError("Only the owner can update this gateway")
1801 if gateway.enabled or include_inactive:
1802 # Check for name conflicts if name is being changed
1803 if gateway_update.name is not None and gateway_update.name != gateway.name:
1804 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none()
1806 # if existing_gateway:
1807 # raise GatewayNameConflictError(
1808 # gateway_update.name,
1809 # enabled=existing_gateway.enabled,
1810 # gateway_id=existing_gateway.id,
1811 # )
1812 # Check for existing gateway with the same slug and visibility
1813 new_slug = slugify(gateway_update.name)
1814 if gateway_update.visibility is not None:
1815 vis = gateway_update.visibility
1816 else:
1817 vis = gateway.visibility
1818 if vis == "public":
1819 # Check for existing public gateway with the same slug (row-locked)
1820 existing_gateway = get_for_update(
1821 db,
1822 DbGateway,
1823 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "public", DbGateway.id != gateway_id),
1824 )
1825 if existing_gateway:
1826 raise GatewayNameConflictError(
1827 new_slug,
1828 enabled=existing_gateway.enabled,
1829 gateway_id=existing_gateway.id,
1830 visibility=existing_gateway.visibility,
1831 )
1832 elif vis == "team" and gateway.team_id: 1832 ↛ 1834line 1832 didn't jump to line 1834 because the condition on line 1832 was never true
1833 # Check for existing team gateway with the same slug (row-locked)
1834 existing_gateway = get_for_update(
1835 db,
1836 DbGateway,
1837 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "team", DbGateway.team_id == gateway.team_id, DbGateway.id != gateway_id),
1838 )
1839 if existing_gateway:
1840 raise GatewayNameConflictError(
1841 new_slug,
1842 enabled=existing_gateway.enabled,
1843 gateway_id=existing_gateway.id,
1844 visibility=existing_gateway.visibility,
1845 )
1846 # Check for existing gateway with the same URL and visibility
1847 normalized_url = ""
1848 if gateway_update.url is not None:
1849 normalized_url = self.normalize_url(str(gateway_update.url))
1850 else:
1851 normalized_url = None
1853 # Prepare decoded auth_value for uniqueness check
1854 decoded_auth_value = None
1855 if gateway_update.auth_value:
1856 if isinstance(gateway_update.auth_value, str): 1856 ↛ 1861line 1856 didn't jump to line 1861 because the condition on line 1856 was always true
1857 try:
1858 decoded_auth_value = decode_auth(gateway_update.auth_value)
1859 except Exception as e:
1860 logger.warning(f"Failed to decode provided auth_value: {e}")
1861 elif isinstance(gateway_update.auth_value, dict):
1862 decoded_auth_value = gateway_update.auth_value
1864 # Determine final values for uniqueness check
1865 final_auth_value = decoded_auth_value if gateway_update.auth_value is not None else (decode_auth(gateway.auth_value) if isinstance(gateway.auth_value, str) else gateway.auth_value)
1866 final_oauth_config = gateway_update.oauth_config if gateway_update.oauth_config is not None else gateway.oauth_config
1867 final_visibility = gateway_update.visibility if gateway_update.visibility is not None else gateway.visibility
1869 # Check for duplicates with updated credentials
1870 if not gateway_update.one_time_auth: 1870 ↛ 1892line 1870 didn't jump to line 1892 because the condition on line 1870 was always true
1871 duplicate_gateway = self._check_gateway_uniqueness(
1872 db=db,
1873 url=normalized_url,
1874 auth_value=final_auth_value,
1875 oauth_config=final_oauth_config,
1876 team_id=gateway.team_id,
1877 visibility=final_visibility,
1878 gateway_id=gateway_id, # Exclude current gateway from check
1879 owner_email=user_email,
1880 )
1882 if duplicate_gateway: 1882 ↛ 1883line 1882 didn't jump to line 1883 because the condition on line 1882 was never true
1883 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway)
1885 # FIX for Issue #1025: Determine if URL actually changed before we update it
1886 # We need this early because we update gateway.url below, and need to know
1887 # if it actually changed to decide whether to re-fetch tools
1888 # tools/resoures/prompts are need to be re-fetched not only if URL changed , in case any update like authentication and visibility changed
1889 # url_changed = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != gateway.url
1891 # Save original values BEFORE updating for change detection checks later
1892 original_url = gateway.url
1893 original_auth_type = gateway.auth_type
1895 # Update fields if provided
1896 if gateway_update.name is not None:
1897 gateway.name = gateway_update.name
1898 gateway.slug = slugify(gateway_update.name)
1899 if gateway_update.url is not None:
1900 # Normalize the updated URL
1901 gateway.url = self.normalize_url(str(gateway_update.url))
1902 if gateway_update.description is not None:
1903 gateway.description = gateway_update.description
1904 if gateway_update.transport is not None:
1905 gateway.transport = gateway_update.transport
1906 if gateway_update.tags is not None:
1907 gateway.tags = gateway_update.tags
1908 if gateway_update.visibility is not None:
1909 gateway.visibility = gateway_update.visibility
1910 if gateway_update.visibility is not None:
1911 gateway.visibility = gateway_update.visibility
1912 if gateway_update.passthrough_headers is not None:
1913 if isinstance(gateway_update.passthrough_headers, list):
1914 gateway.passthrough_headers = gateway_update.passthrough_headers
1915 else:
1916 if isinstance(gateway_update.passthrough_headers, str):
1917 parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()]
1918 gateway.passthrough_headers = parsed
1919 else:
1920 raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string")
1922 logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}")
1924 # Only update auth_type if explicitly provided in the update
1925 if gateway_update.auth_type is not None:
1926 gateway.auth_type = gateway_update.auth_type
1928 # If auth_type is empty, update the auth_value too
1929 if gateway_update.auth_type == "":
1930 gateway.auth_value = cast(Any, "")
1932 # Clear auth_query_params when switching away from query_param auth
1933 if original_auth_type == "query_param" and gateway_update.auth_type != "query_param":
1934 gateway.auth_query_params = None
1935 logger.debug(f"Cleared auth_query_params for gateway {gateway.id} (switched from query_param to {gateway_update.auth_type})")
1937 # if auth_type is not None and only then check auth_value
1938 # Handle OAuth configuration updates
1939 if gateway_update.oauth_config is not None:
1940 gateway.oauth_config = gateway_update.oauth_config
1942 # Handle auth_value updates (both existing and new auth values)
1943 token = gateway_update.auth_token
1944 password = gateway_update.auth_password
1945 header_value = gateway_update.auth_header_value
1947 # Support multiple custom headers on update
1948 if hasattr(gateway_update, "auth_headers") and gateway_update.auth_headers:
1949 existing_auth_raw = getattr(gateway, "auth_value", {}) or {}
1950 if isinstance(existing_auth_raw, str):
1951 try:
1952 existing_auth = decode_auth(existing_auth_raw)
1953 except Exception:
1954 existing_auth = {}
1955 elif isinstance(existing_auth_raw, dict): 1955 ↛ 1958line 1955 didn't jump to line 1958 because the condition on line 1955 was always true
1956 existing_auth = existing_auth_raw
1957 else:
1958 existing_auth = {}
1960 header_dict: Dict[str, str] = {}
1961 for header in gateway_update.auth_headers:
1962 key = header.get("key")
1963 if not key: 1963 ↛ 1964line 1963 didn't jump to line 1964 because the condition on line 1963 was never true
1964 continue
1965 value = header.get("value", "")
1966 if value == settings.masked_auth_value and key in existing_auth:
1967 header_dict[key] = existing_auth[key]
1968 else:
1969 header_dict[key] = value
1970 gateway.auth_value = header_dict # Store as dict for DB JSON field
1971 elif settings.masked_auth_value not in (token, password, header_value):
1972 # Check if values differ from existing ones or if setting for first time
1973 decoded_auth = decode_auth(gateway_update.auth_value) if gateway_update.auth_value else {}
1974 current_auth = getattr(gateway, "auth_value", {}) or {}
1975 if current_auth != decoded_auth:
1976 gateway.auth_value = decoded_auth
1978 # Handle query_param auth updates with service-layer enforcement
1979 auth_query_params_decrypted: Optional[Dict[str, str]] = None
1980 init_url = gateway.url
1982 # Check if updating to query_param auth or updating existing query_param credentials
1983 # Use original_auth_type since gateway.auth_type may have been updated already
1984 is_switching_to_queryparam = gateway_update.auth_type == "query_param" and original_auth_type != "query_param"
1985 is_updating_queryparam_creds = original_auth_type == "query_param" and (gateway_update.auth_query_param_key is not None or gateway_update.auth_query_param_value is not None)
1986 is_url_changing = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != original_url
1988 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"):
1989 # Service-layer enforcement: Check feature flag
1990 if not settings.insecure_allow_queryparam_auth:
1991 # Grandfather clause: Allow updates to existing query_param gateways
1992 # unless they're trying to change credentials
1993 if is_switching_to_queryparam or is_updating_queryparam_creds:
1994 raise ValueError("Query parameter authentication is disabled. " + "Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
1996 # Service-layer enforcement: Check host allowlist
1997 if settings.insecure_queryparam_auth_allowed_hosts:
1998 check_url = str(gateway_update.url) if gateway_update.url else gateway.url
1999 parsed = urlparse(check_url)
2000 hostname = (parsed.hostname or "").lower()
2001 if hostname not in settings.insecure_queryparam_auth_allowed_hosts:
2002 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
2003 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. Allowed: {allowed}")
2005 # Process query_param auth credentials
2006 param_key = getattr(gateway_update, "auth_query_param_key", None) or (next(iter(gateway.auth_query_params.keys()), None) if gateway.auth_query_params else None)
2007 param_value = getattr(gateway_update, "auth_query_param_value", None)
2009 # Get raw value from SecretStr if applicable
2010 raw_value: Optional[str] = None
2011 if param_value:
2012 if hasattr(param_value, "get_secret_value"):
2013 raw_value = param_value.get_secret_value()
2014 else:
2015 raw_value = str(param_value)
2017 # Check if the value is the masked placeholder - if so, keep existing value
2018 is_masked_placeholder = raw_value == settings.masked_auth_value
2020 if param_key: 2020 ↛ 2038line 2020 didn't jump to line 2038 because the condition on line 2020 was always true
2021 if raw_value and not is_masked_placeholder:
2022 # New value provided - encrypt for storage
2023 encrypted_value = encode_auth({param_key: raw_value})
2024 gateway.auth_query_params = {param_key: encrypted_value}
2025 auth_query_params_decrypted = {param_key: raw_value}
2026 elif gateway.auth_query_params: 2026 ↛ 2034line 2026 didn't jump to line 2034 because the condition on line 2026 was always true
2027 # Use existing encrypted value
2028 existing_encrypted = gateway.auth_query_params.get(param_key, "")
2029 if existing_encrypted: 2029 ↛ 2034line 2029 didn't jump to line 2034 because the condition on line 2029 was always true
2030 decrypted = decode_auth(existing_encrypted)
2031 auth_query_params_decrypted = {param_key: decrypted.get(param_key, "")}
2033 # Append query params to URL for initialization
2034 if auth_query_params_decrypted: 2034 ↛ 2038line 2034 didn't jump to line 2038 because the condition on line 2034 was always true
2035 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2037 # Update auth_type if switching
2038 if is_switching_to_queryparam:
2039 gateway.auth_type = "query_param"
2040 gateway.auth_value = None # Query param auth doesn't use auth_value
2042 elif gateway.auth_type == "query_param" and gateway.auth_query_params: 2042 ↛ 2044line 2042 didn't jump to line 2044 because the condition on line 2042 was never true
2043 # Existing query_param gateway without credential changes - decrypt for init
2044 first_key = next(iter(gateway.auth_query_params.keys()), None)
2045 if first_key:
2046 encrypted_value = gateway.auth_query_params.get(first_key, "")
2047 if encrypted_value:
2048 decrypted = decode_auth(encrypted_value)
2049 auth_query_params_decrypted = {first_key: decrypted.get(first_key, "")}
2050 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2052 # Try to reinitialize connection if URL actually changed
2053 # if url_changed:
2054 # Initialize empty lists in case initialization fails
2055 tools_to_add = []
2056 resources_to_add = []
2057 prompts_to_add = []
2059 try:
2060 ca_certificate = getattr(gateway, "ca_certificate", None)
2061 capabilities, tools, resources, prompts = await self._initialize_gateway(
2062 init_url,
2063 gateway.auth_value,
2064 gateway.transport,
2065 gateway.auth_type,
2066 gateway.oauth_config,
2067 ca_certificate,
2068 auth_query_params=auth_query_params_decrypted,
2069 )
2070 new_tool_names = [tool.name for tool in tools]
2071 new_resource_uris = [resource.uri for resource in resources]
2072 new_prompt_names = [prompt.name for prompt in prompts]
2074 if gateway_update.one_time_auth: 2074 ↛ 2076line 2074 didn't jump to line 2076 because the condition on line 2074 was never true
2075 # For one-time auth, clear auth_type and auth_value after initialization
2076 gateway.auth_type = "one_time_auth"
2077 gateway.auth_value = None
2078 gateway.oauth_config = None
2080 # Update tools using helper method
2081 tools_to_add = self._update_or_create_tools(db, tools, gateway, "update")
2083 # Update resources using helper method
2084 resources_to_add = self._update_or_create_resources(db, resources, gateway, "update")
2086 # Update prompts using helper method
2087 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "update")
2089 # Log newly added items
2090 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add)
2091 if items_added > 0:
2092 if tools_to_add: 2092 ↛ 2094line 2092 didn't jump to line 2094 because the condition on line 2092 was always true
2093 logger.info(f"Added {len(tools_to_add)} new tools during gateway update")
2094 if resources_to_add: 2094 ↛ 2095line 2094 didn't jump to line 2095 because the condition on line 2094 was never true
2095 logger.info(f"Added {len(resources_to_add)} new resources during gateway update")
2096 if prompts_to_add: 2096 ↛ 2097line 2096 didn't jump to line 2097 because the condition on line 2096 was never true
2097 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway update")
2098 logger.info(f"Total {items_added} new items added during gateway update")
2100 # Count items before cleanup for logging
2102 # Bulk delete tools that are no longer available from the gateway
2103 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2104 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
2105 if stale_tool_ids:
2106 # Delete child records first to avoid FK constraint violations
2107 for i in range(0, len(stale_tool_ids), 500):
2108 chunk = stale_tool_ids[i : i + 500]
2109 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2110 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2111 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2113 # Bulk delete resources that are no longer available from the gateway
2114 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
2115 if stale_resource_ids: 2115 ↛ 2117line 2115 didn't jump to line 2117 because the condition on line 2115 was never true
2116 # Delete child records first to avoid FK constraint violations
2117 for i in range(0, len(stale_resource_ids), 500):
2118 chunk = stale_resource_ids[i : i + 500]
2119 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2120 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2121 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2122 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2124 # Bulk delete prompts that are no longer available from the gateway
2125 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
2126 if stale_prompt_ids: 2126 ↛ 2128line 2126 didn't jump to line 2128 because the condition on line 2126 was never true
2127 # Delete child records first to avoid FK constraint violations
2128 for i in range(0, len(stale_prompt_ids), 500):
2129 chunk = stale_prompt_ids[i : i + 500]
2130 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2131 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2132 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2134 # Expire gateway to clear cached relationships after bulk deletes
2135 # This prevents SQLAlchemy from trying to re-delete already-deleted items
2136 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
2137 db.expire(gateway)
2139 gateway.capabilities = capabilities
2141 # Register capabilities for notification-driven actions
2142 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
2144 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
2145 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows
2146 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows
2148 # Log cleanup results
2149 tools_removed = len(stale_tool_ids)
2150 resources_removed = len(stale_resource_ids)
2151 prompts_removed = len(stale_prompt_ids)
2153 if tools_removed > 0:
2154 logger.info(f"Removed {tools_removed} tools no longer available during gateway update")
2155 if resources_removed > 0: 2155 ↛ 2156line 2155 didn't jump to line 2156 because the condition on line 2155 was never true
2156 logger.info(f"Removed {resources_removed} resources no longer available during gateway update")
2157 if prompts_removed > 0: 2157 ↛ 2158line 2157 didn't jump to line 2158 because the condition on line 2157 was never true
2158 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway update")
2160 gateway.last_seen = datetime.now(timezone.utc)
2162 # Add new items to database session in chunks to prevent lock escalation
2163 chunk_size = 50
2165 if tools_to_add:
2166 for i in range(0, len(tools_to_add), chunk_size):
2167 chunk = tools_to_add[i : i + chunk_size]
2168 db.add_all(chunk)
2169 db.flush()
2170 if resources_to_add: 2170 ↛ 2171line 2170 didn't jump to line 2171 because the condition on line 2170 was never true
2171 for i in range(0, len(resources_to_add), chunk_size):
2172 chunk = resources_to_add[i : i + chunk_size]
2173 db.add_all(chunk)
2174 db.flush()
2175 if prompts_to_add: 2175 ↛ 2176line 2175 didn't jump to line 2176 because the condition on line 2175 was never true
2176 for i in range(0, len(prompts_to_add), chunk_size):
2177 chunk = prompts_to_add[i : i + chunk_size]
2178 db.add_all(chunk)
2179 db.flush()
2181 # Update tracking with new URL
2182 self._active_gateways.discard(gateway.url)
2183 self._active_gateways.add(gateway.url)
2184 except Exception as e:
2185 logger.warning(f"Failed to initialize updated gateway: {e}")
2187 # Update tags if provided
2188 if gateway_update.tags is not None:
2189 gateway.tags = gateway_update.tags
2191 # Update metadata fields
2192 gateway.updated_at = datetime.now(timezone.utc)
2193 if modified_by:
2194 gateway.modified_by = modified_by
2195 if modified_from_ip:
2196 gateway.modified_from_ip = modified_from_ip
2197 if modified_via:
2198 gateway.modified_via = modified_via
2199 if modified_user_agent:
2200 gateway.modified_user_agent = modified_user_agent
2201 if hasattr(gateway, "version") and gateway.version is not None:
2202 gateway.version = gateway.version + 1
2203 else:
2204 gateway.version = 1
2206 db.commit()
2207 db.refresh(gateway)
2209 # Invalidate cache after successful update
2210 cache = _get_registry_cache()
2211 await cache.invalidate_gateways()
2212 tool_lookup_cache = _get_tool_lookup_cache()
2213 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
2214 # Also invalidate tags cache since gateway tags may have changed
2215 # First-Party
2216 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2218 await admin_stats_cache.invalidate_tags()
2220 # Notify subscribers
2221 await self._notify_gateway_updated(gateway)
2223 logger.info(f"Updated gateway: {gateway.name}")
2225 # Structured logging: Audit trail for gateway update
2226 audit_trail.log_action(
2227 user_id=user_email or modified_by or "system",
2228 action="update_gateway",
2229 resource_type="gateway",
2230 resource_id=str(gateway.id),
2231 resource_name=gateway.name,
2232 user_email=user_email,
2233 team_id=gateway.team_id,
2234 client_ip=modified_from_ip,
2235 user_agent=modified_user_agent,
2236 new_values={
2237 "name": gateway.name,
2238 "url": gateway.url,
2239 "version": gateway.version,
2240 },
2241 context={
2242 "modified_via": modified_via,
2243 },
2244 db=db,
2245 )
2247 # Structured logging: Log successful gateway update
2248 structured_logger.log(
2249 level="INFO",
2250 message="Gateway updated successfully",
2251 event_type="gateway_updated",
2252 component="gateway_service",
2253 user_id=modified_by,
2254 user_email=user_email,
2255 team_id=gateway.team_id,
2256 resource_type="gateway",
2257 resource_id=str(gateway.id),
2258 custom_fields={
2259 "gateway_name": gateway.name,
2260 "version": gateway.version,
2261 },
2262 db=db,
2263 )
2265 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked()
2266 # Gateway is inactive and include_inactive is False → skip update, return None
2267 return None
2268 except GatewayNameConflictError as ge:
2269 logger.error(f"GatewayNameConflictError in group: {ge}")
2271 structured_logger.log(
2272 level="WARNING",
2273 message="Gateway update failed due to name conflict",
2274 event_type="gateway_name_conflict",
2275 component="gateway_service",
2276 user_email=user_email,
2277 resource_type="gateway",
2278 resource_id=gateway_id,
2279 error=ge,
2280 db=db,
2281 )
2282 raise ge
2283 except GatewayNotFoundError as gnfe:
2284 logger.error(f"GatewayNotFoundError: {gnfe}")
2286 structured_logger.log(
2287 level="ERROR",
2288 message="Gateway update failed - gateway not found",
2289 event_type="gateway_not_found",
2290 component="gateway_service",
2291 user_email=user_email,
2292 resource_type="gateway",
2293 resource_id=gateway_id,
2294 error=gnfe,
2295 db=db,
2296 )
2297 raise gnfe
2298 except IntegrityError as ie:
2299 logger.error(f"IntegrityErrors in group: {ie}")
2301 structured_logger.log(
2302 level="ERROR",
2303 message="Gateway update failed due to database integrity error",
2304 event_type="gateway_update_failed",
2305 component="gateway_service",
2306 user_email=user_email,
2307 resource_type="gateway",
2308 resource_id=gateway_id,
2309 error=ie,
2310 db=db,
2311 )
2312 raise ie
2313 except PermissionError as pe:
2314 db.rollback()
2316 structured_logger.log(
2317 level="WARNING",
2318 message="Gateway update failed due to permission error",
2319 event_type="gateway_update_permission_denied",
2320 component="gateway_service",
2321 user_email=user_email,
2322 resource_type="gateway",
2323 resource_id=gateway_id,
2324 error=pe,
2325 db=db,
2326 )
2327 raise
2328 except Exception as e:
2329 db.rollback()
2331 structured_logger.log(
2332 level="ERROR",
2333 message="Gateway update failed",
2334 event_type="gateway_update_failed",
2335 component="gateway_service",
2336 user_email=user_email,
2337 resource_type="gateway",
2338 resource_id=gateway_id,
2339 error=e,
2340 db=db,
2341 )
2342 raise GatewayError(f"Failed to update gateway: {str(e)}")
2344 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead:
2345 """Get a gateway by its ID.
2347 Args:
2348 db: Database session
2349 gateway_id: Gateway ID
2350 include_inactive: Whether to include inactive gateways
2352 Returns:
2353 GatewayRead object
2355 Raises:
2356 GatewayNotFoundError: If the gateway is not found
2358 Examples:
2359 >>> from unittest.mock import MagicMock
2360 >>> from mcpgateway.schemas import GatewayRead
2361 >>> service = GatewayService()
2362 >>> db = MagicMock()
2363 >>> gateway_mock = MagicMock()
2364 >>> gateway_mock.enabled = True
2365 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock
2366 >>> mocked_gateway_read = MagicMock()
2367 >>> mocked_gateway_read.masked.return_value = 'gateway_read'
2368 >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read)
2369 >>> import asyncio
2370 >>> result = asyncio.run(service.get_gateway(db, 'gateway_id'))
2371 >>> result == 'gateway_read'
2372 True
2374 >>> # Test with inactive gateway but include_inactive=True
2375 >>> gateway_mock.enabled = False
2376 >>> result_inactive = asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=True))
2377 >>> result_inactive == 'gateway_read'
2378 True
2380 >>> # Test gateway not found
2381 >>> db.execute.return_value.scalar_one_or_none.return_value = None
2382 >>> try:
2383 ... asyncio.run(service.get_gateway(db, 'missing_id'))
2384 ... except GatewayNotFoundError as e:
2385 ... 'Gateway not found: missing_id' in str(e)
2386 True
2388 >>> # Test inactive gateway with include_inactive=False
2389 >>> gateway_mock.enabled = False
2390 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock
2391 >>> try:
2392 ... asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=False))
2393 ... except GatewayNotFoundError as e:
2394 ... 'Gateway not found: gateway_id' in str(e)
2395 True
2396 >>>
2397 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
2398 >>> asyncio.run(service._http_client.aclose())
2399 """
2400 # Use eager loading to avoid N+1 queries for relationships and team name
2401 gateway = db.execute(
2402 select(DbGateway)
2403 .options(
2404 selectinload(DbGateway.tools),
2405 selectinload(DbGateway.resources),
2406 selectinload(DbGateway.prompts),
2407 joinedload(DbGateway.email_team),
2408 )
2409 .where(DbGateway.id == gateway_id)
2410 ).scalar_one_or_none()
2412 if not gateway:
2413 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2415 if gateway.enabled or include_inactive:
2416 # Structured logging: Log gateway view
2417 structured_logger.log(
2418 level="INFO",
2419 message="Gateway retrieved successfully",
2420 event_type="gateway_viewed",
2421 component="gateway_service",
2422 team_id=getattr(gateway, "team_id", None),
2423 resource_type="gateway",
2424 resource_id=str(gateway.id),
2425 custom_fields={
2426 "gateway_name": gateway.name,
2427 "gateway_url": gateway.url,
2428 "include_inactive": include_inactive,
2429 },
2430 db=db,
2431 )
2433 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked()
2435 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2437 async def set_gateway_state(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False, user_email: Optional[str] = None) -> GatewayRead:
2438 """
2439 Set the activation status of a gateway.
2441 Args:
2442 db: Database session
2443 gateway_id: Gateway ID
2444 activate: True to activate, False to deactivate
2445 reachable: Whether the gateway is reachable
2446 only_update_reachable: Only update reachable status
2447 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2449 Returns:
2450 The updated GatewayRead object
2452 Raises:
2453 GatewayNotFoundError: If the gateway is not found
2454 GatewayError: For other errors
2455 PermissionError: If user doesn't own the agent.
2456 """
2457 try:
2458 # Eager-load collections for the gateway. Note: we don't use FOR UPDATE
2459 # here because _initialize_gateway does network I/O, and holding a row
2460 # lock during network calls would block other operations and risk timeouts.
2461 gateway = db.execute(
2462 select(DbGateway)
2463 .options(
2464 selectinload(DbGateway.tools),
2465 selectinload(DbGateway.resources),
2466 selectinload(DbGateway.prompts),
2467 joinedload(DbGateway.email_team),
2468 )
2469 .where(DbGateway.id == gateway_id)
2470 ).scalar_one_or_none()
2471 if not gateway:
2472 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2474 if user_email:
2475 # First-Party
2476 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2478 permission_service = PermissionService(db)
2479 if not await permission_service.check_resource_ownership(user_email, gateway): 2479 ↛ 2483line 2479 didn't jump to line 2483 because the condition on line 2479 was always true
2480 raise PermissionError("Only the owner can activate the gateway" if activate else "Only the owner can deactivate the gateway")
2482 # Update status if it's different
2483 if (gateway.enabled != activate) or (gateway.reachable != reachable):
2484 gateway.enabled = activate
2485 gateway.reachable = reachable
2486 gateway.updated_at = datetime.now(timezone.utc)
2487 # Update tracking
2488 if activate and reachable:
2489 self._active_gateways.add(gateway.url)
2491 # Initialize empty lists in case initialization fails
2492 tools_to_add = []
2493 resources_to_add = []
2494 prompts_to_add = []
2496 # Try to initialize if activating
2497 try:
2498 # Handle query_param auth - decrypt and apply to URL
2499 init_url = gateway.url
2500 auth_query_params_decrypted: Optional[Dict[str, str]] = None
2501 if gateway.auth_type == "query_param" and gateway.auth_query_params:
2502 auth_query_params_decrypted = {}
2503 for param_key, encrypted_value in gateway.auth_query_params.items():
2504 if encrypted_value: 2504 ↛ 2503line 2504 didn't jump to line 2503 because the condition on line 2504 was always true
2505 try:
2506 decrypted = decode_auth(encrypted_value)
2507 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
2508 except Exception:
2509 logger.debug(f"Failed to decrypt query param '{param_key}' for gateway activation")
2510 if auth_query_params_decrypted:
2511 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2513 capabilities, tools, resources, prompts = await self._initialize_gateway(
2514 init_url, gateway.auth_value, gateway.transport, gateway.auth_type, gateway.oauth_config, auth_query_params=auth_query_params_decrypted, oauth_auto_fetch_tool_flag=True
2515 )
2516 new_tool_names = [tool.name for tool in tools]
2517 new_resource_uris = [resource.uri for resource in resources]
2518 new_prompt_names = [prompt.name for prompt in prompts]
2520 # Update tools, resources, and prompts using helper methods
2521 tools_to_add = self._update_or_create_tools(db, tools, gateway, "rediscovery")
2522 resources_to_add = self._update_or_create_resources(db, resources, gateway, "rediscovery")
2523 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "rediscovery")
2525 # Log newly added items
2526 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add)
2527 if items_added > 0:
2528 if tools_to_add: 2528 ↛ 2530line 2528 didn't jump to line 2530 because the condition on line 2528 was always true
2529 logger.info(f"Added {len(tools_to_add)} new tools during gateway reactivation")
2530 if resources_to_add: 2530 ↛ 2531line 2530 didn't jump to line 2531 because the condition on line 2530 was never true
2531 logger.info(f"Added {len(resources_to_add)} new resources during gateway reactivation")
2532 if prompts_to_add: 2532 ↛ 2533line 2532 didn't jump to line 2533 because the condition on line 2532 was never true
2533 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway reactivation")
2534 logger.info(f"Total {items_added} new items added during gateway reactivation")
2536 # Count items before cleanup for logging
2538 # Bulk delete tools that are no longer available from the gateway
2539 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2540 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
2541 if stale_tool_ids:
2542 # Delete child records first to avoid FK constraint violations
2543 for i in range(0, len(stale_tool_ids), 500):
2544 chunk = stale_tool_ids[i : i + 500]
2545 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2546 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2547 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2549 # Bulk delete resources that are no longer available from the gateway
2550 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
2551 if stale_resource_ids: 2551 ↛ 2553line 2551 didn't jump to line 2553 because the condition on line 2551 was never true
2552 # Delete child records first to avoid FK constraint violations
2553 for i in range(0, len(stale_resource_ids), 500):
2554 chunk = stale_resource_ids[i : i + 500]
2555 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2556 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2557 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2558 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2560 # Bulk delete prompts that are no longer available from the gateway
2561 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
2562 if stale_prompt_ids: 2562 ↛ 2564line 2562 didn't jump to line 2564 because the condition on line 2562 was never true
2563 # Delete child records first to avoid FK constraint violations
2564 for i in range(0, len(stale_prompt_ids), 500):
2565 chunk = stale_prompt_ids[i : i + 500]
2566 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2567 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2568 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2570 # Expire gateway to clear cached relationships after bulk deletes
2571 # This prevents SQLAlchemy from trying to re-delete already-deleted items
2572 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
2573 db.expire(gateway)
2575 gateway.capabilities = capabilities
2577 # Register capabilities for notification-driven actions
2578 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
2580 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
2581 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows
2582 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows
2584 # Log cleanup results
2585 tools_removed = len(stale_tool_ids)
2586 resources_removed = len(stale_resource_ids)
2587 prompts_removed = len(stale_prompt_ids)
2589 if tools_removed > 0:
2590 logger.info(f"Removed {tools_removed} tools no longer available during gateway reactivation")
2591 if resources_removed > 0: 2591 ↛ 2592line 2591 didn't jump to line 2592 because the condition on line 2591 was never true
2592 logger.info(f"Removed {resources_removed} resources no longer available during gateway reactivation")
2593 if prompts_removed > 0: 2593 ↛ 2594line 2593 didn't jump to line 2594 because the condition on line 2593 was never true
2594 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway reactivation")
2596 gateway.last_seen = datetime.now(timezone.utc)
2598 # Add new items to database session in chunks to prevent lock escalation
2599 chunk_size = 50
2601 if tools_to_add:
2602 for i in range(0, len(tools_to_add), chunk_size):
2603 chunk = tools_to_add[i : i + chunk_size]
2604 db.add_all(chunk)
2605 db.flush()
2606 if resources_to_add: 2606 ↛ 2607line 2606 didn't jump to line 2607 because the condition on line 2606 was never true
2607 for i in range(0, len(resources_to_add), chunk_size):
2608 chunk = resources_to_add[i : i + chunk_size]
2609 db.add_all(chunk)
2610 db.flush()
2611 if prompts_to_add: 2611 ↛ 2612line 2611 didn't jump to line 2612 because the condition on line 2611 was never true
2612 for i in range(0, len(prompts_to_add), chunk_size):
2613 chunk = prompts_to_add[i : i + chunk_size]
2614 db.add_all(chunk)
2615 db.flush()
2616 except Exception as e:
2617 logger.warning(f"Failed to initialize reactivated gateway: {e}")
2618 else:
2619 self._active_gateways.discard(gateway.url)
2621 db.commit()
2622 db.refresh(gateway)
2624 # Invalidate cache after status change
2625 cache = _get_registry_cache()
2626 await cache.invalidate_gateways()
2628 # Notify Subscribers
2629 if not gateway.enabled:
2630 # Inactive
2631 await self._notify_gateway_deactivated(gateway)
2632 elif gateway.enabled and not gateway.reachable:
2633 # Offline (Enabled but Unreachable)
2634 await self._notify_gateway_offline(gateway)
2635 else:
2636 # Active (Enabled and Reachable)
2637 await self._notify_gateway_activated(gateway)
2639 # Bulk update tools - single UPDATE statement instead of N FOR UPDATE locks
2640 # This prevents lock contention under high concurrent load
2641 now = datetime.now(timezone.utc)
2642 if only_update_reachable:
2643 # Only update reachable status, keep enabled as-is
2644 tools_result = db.execute(update(DbTool).where(DbTool.gateway_id == gateway_id).where(DbTool.reachable != reachable).values(reachable=reachable, updated_at=now))
2645 else:
2646 # Update both enabled and reachable
2647 tools_result = db.execute(
2648 update(DbTool)
2649 .where(DbTool.gateway_id == gateway_id)
2650 .where(or_(DbTool.enabled != activate, DbTool.reachable != reachable))
2651 .values(enabled=activate, reachable=reachable, updated_at=now)
2652 )
2653 tools_updated = tools_result.rowcount
2655 # Commit tool updates
2656 if tools_updated > 0:
2657 db.commit()
2659 # Invalidate tools cache once after bulk update
2660 if tools_updated > 0:
2661 await cache.invalidate_tools()
2662 tool_lookup_cache = _get_tool_lookup_cache()
2663 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
2665 # Bulk update prompts when gateway is deactivated/activated (skip for reachability-only updates)
2666 prompts_updated = 0
2667 if not only_update_reachable:
2668 prompts_result = db.execute(update(DbPrompt).where(DbPrompt.gateway_id == gateway_id).where(DbPrompt.enabled != activate).values(enabled=activate, updated_at=now))
2669 prompts_updated = prompts_result.rowcount
2670 if prompts_updated > 0:
2671 db.commit()
2672 await cache.invalidate_prompts()
2674 # Bulk update resources when gateway is deactivated/activated (skip for reachability-only updates)
2675 resources_updated = 0
2676 if not only_update_reachable:
2677 resources_result = db.execute(update(DbResource).where(DbResource.gateway_id == gateway_id).where(DbResource.enabled != activate).values(enabled=activate, updated_at=now))
2678 resources_updated = resources_result.rowcount
2679 if resources_updated > 0:
2680 db.commit()
2681 await cache.invalidate_resources()
2683 logger.debug(f"Gateway {gateway.name} bulk state update: {tools_updated} tools, {prompts_updated} prompts, {resources_updated} resources")
2685 logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}")
2687 # Structured logging: Audit trail for gateway state change
2688 audit_trail.log_action(
2689 user_id=user_email or "system",
2690 action="set_gateway_state",
2691 resource_type="gateway",
2692 resource_id=str(gateway.id),
2693 resource_name=gateway.name,
2694 user_email=user_email,
2695 team_id=gateway.team_id,
2696 new_values={
2697 "enabled": gateway.enabled,
2698 "reachable": gateway.reachable,
2699 },
2700 context={
2701 "action": "activate" if activate else "deactivate",
2702 "only_update_reachable": only_update_reachable,
2703 },
2704 db=db,
2705 )
2707 # Structured logging: Log successful gateway state change
2708 structured_logger.log(
2709 level="INFO",
2710 message=f"Gateway {'activated' if activate else 'deactivated'} successfully",
2711 event_type="gateway_state_changed",
2712 component="gateway_service",
2713 user_email=user_email,
2714 team_id=gateway.team_id,
2715 resource_type="gateway",
2716 resource_id=str(gateway.id),
2717 custom_fields={
2718 "gateway_name": gateway.name,
2719 "enabled": gateway.enabled,
2720 "reachable": gateway.reachable,
2721 },
2722 db=db,
2723 )
2725 return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked()
2727 except PermissionError as e:
2728 # Structured logging: Log permission error
2729 structured_logger.log(
2730 level="WARNING",
2731 message="Gateway state change failed due to permission error",
2732 event_type="gateway_state_change_permission_denied",
2733 component="gateway_service",
2734 user_email=user_email,
2735 resource_type="gateway",
2736 resource_id=gateway_id,
2737 error=e,
2738 db=db,
2739 )
2740 raise e
2741 except Exception as e:
2742 db.rollback()
2744 # Structured logging: Log generic gateway state change failure
2745 structured_logger.log(
2746 level="ERROR",
2747 message="Gateway state change failed",
2748 event_type="gateway_state_change_failed",
2749 component="gateway_service",
2750 user_email=user_email,
2751 resource_type="gateway",
2752 resource_id=gateway_id,
2753 error=e,
2754 db=db,
2755 )
2756 raise GatewayError(f"Failed to set gateway state: {str(e)}")
2758 async def _notify_gateway_updated(self, gateway: DbGateway) -> None:
2759 """
2760 Notify subscribers of gateway update.
2762 Args:
2763 gateway: Gateway to update
2764 """
2765 event = {
2766 "type": "gateway_updated",
2767 "data": {
2768 "id": gateway.id,
2769 "name": gateway.name,
2770 "url": gateway.url,
2771 "description": gateway.description,
2772 "enabled": gateway.enabled,
2773 },
2774 "timestamp": datetime.now(timezone.utc).isoformat(),
2775 }
2776 await self._publish_event(event)
2778 async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optional[str] = None) -> None:
2779 """
2780 Delete a gateway by its ID.
2782 Args:
2783 db: Database session
2784 gateway_id: Gateway ID
2785 user_email: Email of user performing deletion (for ownership check)
2787 Raises:
2788 GatewayNotFoundError: If the gateway is not found
2789 PermissionError: If user doesn't own the gateway
2790 GatewayError: For other deletion errors
2792 Examples:
2793 >>> from mcpgateway.services.gateway_service import GatewayService
2794 >>> from unittest.mock import MagicMock
2795 >>> service = GatewayService()
2796 >>> db = MagicMock()
2797 >>> gateway = MagicMock()
2798 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway
2799 >>> db.delete = MagicMock()
2800 >>> db.commit = MagicMock()
2801 >>> service._notify_gateway_deleted = MagicMock()
2802 >>> import asyncio
2803 >>> try:
2804 ... asyncio.run(service.delete_gateway(db, 'gateway_id', 'user@example.com'))
2805 ... except Exception:
2806 ... pass
2807 >>>
2808 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
2809 >>> asyncio.run(service._http_client.aclose())
2810 """
2811 try:
2812 # Find gateway with eager loading for deletion to avoid N+1 queries
2813 gateway = db.execute(
2814 select(DbGateway)
2815 .options(
2816 selectinload(DbGateway.tools),
2817 selectinload(DbGateway.resources),
2818 selectinload(DbGateway.prompts),
2819 joinedload(DbGateway.email_team),
2820 )
2821 .where(DbGateway.id == gateway_id)
2822 ).scalar_one_or_none()
2824 if not gateway:
2825 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2827 # Check ownership if user_email provided
2828 if user_email:
2829 # First-Party
2830 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2832 permission_service = PermissionService(db)
2833 if not await permission_service.check_resource_ownership(user_email, gateway):
2834 raise PermissionError("Only the owner can delete this gateway")
2836 # Store gateway info for notification before deletion
2837 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url}
2838 gateway_name = gateway.name
2839 gateway_team_id = gateway.team_id
2840 gateway_url = gateway.url # Store URL before expiring the object
2842 # Manually delete children first to avoid FK constraint violations
2843 # (passive_deletes=True means ORM won't auto-cascade, we must do it explicitly)
2844 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2845 tool_ids = [t.id for t in gateway.tools]
2846 resource_ids = [r.id for r in gateway.resources]
2847 prompt_ids = [p.id for p in gateway.prompts]
2849 # Delete tool children and tools
2850 if tool_ids:
2851 for i in range(0, len(tool_ids), 500):
2852 chunk = tool_ids[i : i + 500]
2853 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2854 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2855 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2857 # Delete resource children and resources
2858 if resource_ids:
2859 for i in range(0, len(resource_ids), 500):
2860 chunk = resource_ids[i : i + 500]
2861 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2862 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2863 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2864 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2866 # Delete prompt children and prompts
2867 if prompt_ids:
2868 for i in range(0, len(prompt_ids), 500):
2869 chunk = prompt_ids[i : i + 500]
2870 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2871 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2872 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2874 # Expire gateway to clear cached relationships after bulk deletes
2875 db.expire(gateway)
2877 # Use DELETE with rowcount check for database-agnostic atomic delete
2878 # (RETURNING is not supported on MySQL/MariaDB)
2879 stmt = delete(DbGateway).where(DbGateway.id == gateway_id)
2880 result = db.execute(stmt)
2881 if result.rowcount == 0: 2881 ↛ 2883line 2881 didn't jump to line 2883 because the condition on line 2881 was never true
2882 # Gateway was already deleted by another concurrent request
2883 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2885 db.commit()
2887 # Invalidate cache after successful deletion
2888 cache = _get_registry_cache()
2889 await cache.invalidate_gateways()
2890 tool_lookup_cache = _get_tool_lookup_cache()
2891 await tool_lookup_cache.invalidate_gateway(str(gateway_id))
2892 # Also invalidate tags cache since gateway tags may have changed
2893 # First-Party
2894 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2896 await admin_stats_cache.invalidate_tags()
2898 # Update tracking
2899 self._active_gateways.discard(gateway_url)
2901 # Notify subscribers
2902 await self._notify_gateway_deleted(gateway_info)
2904 logger.info(f"Permanently deleted gateway: {gateway_name}")
2906 # Structured logging: Audit trail for gateway deletion
2907 audit_trail.log_action(
2908 user_id=user_email or "system",
2909 action="delete_gateway",
2910 resource_type="gateway",
2911 resource_id=str(gateway_info["id"]),
2912 resource_name=gateway_name,
2913 user_email=user_email,
2914 team_id=gateway_team_id,
2915 old_values={
2916 "name": gateway_name,
2917 "url": gateway_info["url"],
2918 },
2919 db=db,
2920 )
2922 # Structured logging: Log successful gateway deletion
2923 structured_logger.log(
2924 level="INFO",
2925 message="Gateway deleted successfully",
2926 event_type="gateway_deleted",
2927 component="gateway_service",
2928 user_email=user_email,
2929 team_id=gateway_team_id,
2930 resource_type="gateway",
2931 resource_id=str(gateway_info["id"]),
2932 custom_fields={
2933 "gateway_name": gateway_name,
2934 "gateway_url": gateway_info["url"],
2935 },
2936 db=db,
2937 )
2939 except PermissionError as pe:
2940 db.rollback()
2942 # Structured logging: Log permission error
2943 structured_logger.log(
2944 level="WARNING",
2945 message="Gateway deletion failed due to permission error",
2946 event_type="gateway_delete_permission_denied",
2947 component="gateway_service",
2948 user_email=user_email,
2949 resource_type="gateway",
2950 resource_id=gateway_id,
2951 error=pe,
2952 db=db,
2953 )
2954 raise
2955 except Exception as e:
2956 db.rollback()
2958 # Structured logging: Log generic gateway deletion failure
2959 structured_logger.log(
2960 level="ERROR",
2961 message="Gateway deletion failed",
2962 event_type="gateway_deletion_failed",
2963 component="gateway_service",
2964 user_email=user_email,
2965 resource_type="gateway",
2966 resource_id=gateway_id,
2967 error=e,
2968 db=db,
2969 )
2970 raise GatewayError(f"Failed to delete gateway: {str(e)}")
2972 async def forward_request(
2973 self,
2974 gateway_or_db,
2975 method: str,
2976 params: Optional[Dict[str, Any]] = None,
2977 app_user_email: Optional[str] = None,
2978 user_email: Optional[str] = None,
2979 token_teams: Optional[List[str]] = None,
2980 ) -> Any: # noqa: F811 # pylint: disable=function-redefined
2981 """
2982 Forward a request to a gateway or multiple gateways.
2984 This method handles two calling patterns:
2985 1. forward_request(gateway, method, params) - Forward to a specific gateway
2986 2. forward_request(db, method, params) - Forward to active gateways in the database
2988 Args:
2989 gateway_or_db: Either a DbGateway object or database Session
2990 method: RPC method name
2991 params: Optional method parameters
2992 app_user_email: Optional app user email for OAuth token selection
2993 user_email: Optional user email for team-based access control
2994 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only)
2996 Returns:
2997 Gateway response
2999 Raises:
3000 GatewayConnectionError: If forwarding fails
3001 GatewayError: If gateway gave an error
3002 """
3003 # Dispatch based on first parameter type
3004 if hasattr(gateway_or_db, "execute"):
3005 # This is a database session - forward to all active gateways
3006 return await self._forward_request_to_all(gateway_or_db, method, params, app_user_email, user_email, token_teams)
3007 # This is a gateway object - forward to specific gateway
3008 return await self._forward_request_to_gateway(gateway_or_db, method, params, app_user_email)
3010 async def _forward_request_to_gateway(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None, app_user_email: Optional[str] = None) -> Any:
3011 """
3012 Forward a request to a specific gateway.
3014 Args:
3015 gateway: Gateway to forward to
3016 method: RPC method name
3017 params: Optional method parameters
3018 app_user_email: Optional app user email for OAuth token selection
3020 Returns:
3021 Gateway response
3023 Raises:
3024 GatewayConnectionError: If forwarding fails
3025 GatewayError: If gateway gave an error
3026 """
3027 start_time = time.monotonic()
3029 # Create trace span for gateway federation
3030 with create_span(
3031 "gateway.forward_request",
3032 {
3033 "gateway.name": gateway.name,
3034 "gateway.id": str(gateway.id),
3035 "gateway.url": gateway.url,
3036 "rpc.method": method,
3037 "rpc.service": "mcp-gateway",
3038 "http.method": "POST",
3039 "http.url": urljoin(gateway.url, "/rpc"),
3040 "peer.service": gateway.name,
3041 },
3042 ) as span:
3043 if not gateway.enabled:
3044 raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}")
3046 response = None # Initialize response to avoid UnboundLocalError
3047 try:
3048 # Build RPC request
3049 request: Dict[str, Any] = {"jsonrpc": "2.0", "id": 1, "method": method}
3050 if params:
3051 request["params"] = params
3052 if span: 3052 ↛ 3053line 3052 didn't jump to line 3053 because the condition on line 3052 was never true
3053 span.set_attribute("rpc.params_count", len(params))
3055 # Handle OAuth authentication for the specific gateway
3056 headers: Dict[str, str] = {}
3058 if getattr(gateway, "auth_type", None) == "oauth" and gateway.oauth_config:
3059 try:
3060 grant_type = gateway.oauth_config.get("grant_type", "client_credentials")
3062 if grant_type == "client_credentials":
3063 # Use OAuth manager to get access token for Client Credentials flow
3064 access_token = await self.oauth_manager.get_access_token(gateway.oauth_config)
3065 headers = {"Authorization": f"Bearer {access_token}"}
3066 elif grant_type == "authorization_code": 3066 ↛ 3108line 3066 didn't jump to line 3108 because the condition on line 3066 was always true
3067 # For Authorization Code flow, try to get a stored token
3068 if not app_user_email:
3069 logger.warning(f"Skipping OAuth authorization code gateway {gateway.name} - user-specific tokens required but no user email provided")
3070 raise GatewayConnectionError(f"OAuth authorization code gateway {gateway.name} requires user context")
3072 # First-Party
3073 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3075 # Get database session (this is a bit hacky but necessary for now)
3076 db = next(get_db())
3077 try:
3078 token_storage = TokenStorageService(db)
3079 access_token = await token_storage.get_user_token(str(gateway.id), app_user_email)
3080 if access_token:
3081 headers = {"Authorization": f"Bearer {access_token}"}
3082 else:
3083 raise GatewayConnectionError(f"No valid OAuth token for user {app_user_email} and gateway {gateway.name}")
3084 finally:
3085 # Ensure close() always runs even if commit() fails
3086 # Without this nested try/finally, a commit() failure (e.g., PgBouncer timeout)
3087 # would skip close(), leaving the connection in "idle in transaction" state
3088 try:
3089 db.commit() # End read-only transaction cleanly before returning to pool
3090 finally:
3091 db.close()
3092 except Exception as oauth_error:
3093 raise GatewayConnectionError(f"Failed to obtain OAuth token for gateway {gateway.name}: {oauth_error}")
3094 else:
3095 # Handle non-OAuth authentication
3096 auth_data = gateway.auth_value or {}
3097 if isinstance(auth_data, str) and auth_data:
3098 headers = decode_auth(auth_data)
3099 elif isinstance(auth_data, dict) and auth_data:
3100 headers = {str(k): str(v) for k, v in auth_data.items()}
3101 else:
3102 # No auth configured - send request without authentication
3103 # SECURITY: Never send gateway admin credentials to remote servers
3104 logger.warning(f"Gateway {gateway.name} has no authentication configured - sending unauthenticated request")
3105 headers = {"Content-Type": "application/json"}
3107 # Directly use the persistent HTTP client (no async with)
3108 response = await self._http_client.post(urljoin(gateway.url, "/rpc"), json=request, headers=headers)
3109 response.raise_for_status()
3110 result = response.json()
3112 # Update last seen timestamp using fresh DB session
3113 # (gateway object may be detached from original session)
3114 try:
3115 with fresh_db_session() as update_db:
3116 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway.id)).scalar_one_or_none()
3117 if db_gateway:
3118 db_gateway.last_seen = datetime.now(timezone.utc)
3119 update_db.commit()
3120 except Exception as update_error:
3121 logger.warning(f"Failed to update last_seen for gateway {gateway.name}: {update_error}")
3123 # Record success metrics
3124 if span: 3124 ↛ 3125line 3124 didn't jump to line 3125 because the condition on line 3124 was never true
3125 span.set_attribute("http.status_code", response.status_code)
3126 span.set_attribute("success", True)
3127 span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000)
3129 except Exception:
3130 if span: 3130 ↛ 3131line 3130 didn't jump to line 3131 because the condition on line 3130 was never true
3131 span.set_attribute("http.status_code", getattr(response, "status_code", 0))
3132 raise GatewayConnectionError(f"Failed to forward request to {gateway.name}")
3134 if "error" in result:
3135 if span: 3135 ↛ 3136line 3135 didn't jump to line 3136 because the condition on line 3135 was never true
3136 span.set_attribute("rpc.error", True)
3137 span.set_attribute("rpc.error.message", result["error"].get("message", "Unknown error"))
3138 raise GatewayError(f"Gateway error: {result['error'].get('message')}")
3140 return result.get("result")
3142 async def _forward_request_to_all(
3143 self,
3144 db: Session,
3145 method: str,
3146 params: Optional[Dict[str, Any]] = None,
3147 app_user_email: Optional[str] = None,
3148 user_email: Optional[str] = None,
3149 token_teams: Optional[List[str]] = None,
3150 ) -> Any:
3151 """
3152 Forward a request to all active gateways that can handle the method.
3154 Args:
3155 db: Database session
3156 method: RPC method name
3157 params: Optional method parameters
3158 app_user_email: Optional app user email for OAuth token selection
3159 user_email: Optional user email for team-based access control
3160 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only)
3162 Returns:
3163 Gateway response from the first successful gateway
3165 Raises:
3166 GatewayConnectionError: If no gateways can handle the request
3167 """
3168 # ═══════════════════════════════════════════════════════════════════════════
3169 # PHASE 1: Fetch all required data before HTTP calls
3170 # ═══════════════════════════════════════════════════════════════════════════
3172 # SECURITY: Apply team-based access control to gateway selection
3173 # - token_teams is None: admin bypass (is_admin=true with explicit null teams) - sees all
3174 # - token_teams is []: public-only access (missing teams or explicit empty)
3175 # - token_teams is [...]: access to public + specified teams
3176 query = select(DbGateway).where(DbGateway.enabled.is_(True))
3178 if token_teams is not None:
3179 if len(token_teams) == 0:
3180 # Public-only token: only access public gateways
3181 query = query.where(DbGateway.visibility == "public")
3182 else:
3183 # Team-scoped token: public gateways + gateways in allowed teams
3184 access_conditions = [
3185 DbGateway.visibility == "public",
3186 and_(DbGateway.team_id.in_(token_teams), DbGateway.visibility.in_(["team", "public"])),
3187 ]
3188 # Also include private gateways owned by this user
3189 if user_email: 3189 ↛ 3191line 3189 didn't jump to line 3191 because the condition on line 3189 was always true
3190 access_conditions.append(and_(DbGateway.owner_email == user_email, DbGateway.visibility == "private"))
3191 query = query.where(or_(*access_conditions))
3193 active_gateways = db.execute(query).scalars().all()
3195 if not active_gateways:
3196 raise GatewayConnectionError("No active gateways available to forward request")
3198 # Extract all gateway data to local variables before releasing DB connection
3199 gateway_data_list: List[Dict[str, Any]] = []
3200 for gateway in active_gateways:
3201 gw_data = {
3202 "id": gateway.id,
3203 "name": gateway.name,
3204 "url": gateway.url,
3205 "auth_type": getattr(gateway, "auth_type", None),
3206 "auth_value": gateway.auth_value,
3207 "oauth_config": gateway.oauth_config if hasattr(gateway, "oauth_config") else None,
3208 }
3209 gateway_data_list.append(gw_data)
3211 # For OAuth authorization_code flow, we need to fetch tokens while session is open
3212 # First-Party
3213 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3215 for gw_data in gateway_data_list:
3216 if gw_data["auth_type"] == "oauth" and gw_data["oauth_config"]:
3217 grant_type = gw_data["oauth_config"].get("grant_type", "client_credentials")
3218 if grant_type == "authorization_code" and app_user_email:
3219 try:
3220 token_storage = TokenStorageService(db)
3221 access_token = await token_storage.get_user_token(str(gw_data["id"]), app_user_email)
3222 gw_data["_oauth_token"] = access_token
3223 except Exception as e:
3224 logger.warning(f"Failed to get OAuth token for gateway {gw_data['name']}: {e}")
3225 gw_data["_oauth_token"] = None
3227 # ═══════════════════════════════════════════════════════════════════════════
3228 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
3229 # This prevents connection pool exhaustion during slow upstream requests.
3230 # ═══════════════════════════════════════════════════════════════════════════
3231 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats)
3232 db.close()
3234 errors: List[str] = []
3236 # ═══════════════════════════════════════════════════════════════════════════
3237 # PHASE 2: Make HTTP calls (no DB connection held)
3238 # ═══════════════════════════════════════════════════════════════════════════
3239 for gw_data in gateway_data_list:
3240 try:
3241 # Handle OAuth authentication for the specific gateway
3242 headers: Dict[str, str] = {}
3244 if gw_data["auth_type"] == "oauth" and gw_data["oauth_config"]:
3245 try:
3246 grant_type = gw_data["oauth_config"].get("grant_type", "client_credentials")
3248 if grant_type == "client_credentials":
3249 # Use OAuth manager to get access token for Client Credentials flow
3250 access_token = await self.oauth_manager.get_access_token(gw_data["oauth_config"])
3251 headers = {"Authorization": f"Bearer {access_token}"}
3252 elif grant_type == "authorization_code": 3252 ↛ 3279line 3252 didn't jump to line 3279 because the condition on line 3252 was always true
3253 # For Authorization Code flow, use pre-fetched token
3254 if not app_user_email:
3255 logger.warning(f"Skipping OAuth authorization code gateway {gw_data['name']} - user-specific tokens required but no user email provided")
3256 continue
3258 access_token = gw_data.get("_oauth_token")
3259 if access_token: 3259 ↛ 3262line 3259 didn't jump to line 3262 because the condition on line 3259 was always true
3260 headers = {"Authorization": f"Bearer {access_token}"}
3261 else:
3262 logger.warning(f"No valid OAuth token for user {app_user_email} and gateway {gw_data['name']}")
3263 continue
3264 except Exception as oauth_error:
3265 logger.warning(f"Failed to obtain OAuth token for gateway {gw_data['name']}: {oauth_error}")
3266 errors.append(f"Gateway {gw_data['name']}: OAuth error - {str(oauth_error)}")
3267 continue
3268 else:
3269 # Handle non-OAuth authentication
3270 auth_data = gw_data["auth_value"] or {}
3271 if isinstance(auth_data, str):
3272 headers = decode_auth(auth_data)
3273 elif isinstance(auth_data, dict): 3273 ↛ 3276line 3273 didn't jump to line 3276 because the condition on line 3273 was always true
3274 headers = {str(k): str(v) for k, v in auth_data.items()}
3275 else:
3276 headers = {}
3278 # Build RPC request
3279 request: Dict[str, Any] = {"jsonrpc": "2.0", "id": 1, "method": method}
3280 if params: 3280 ↛ 3281line 3280 didn't jump to line 3281 because the condition on line 3280 was never true
3281 request["params"] = params
3283 # Forward request with proper authentication headers
3284 response = await self._http_client.post(urljoin(gw_data["url"], "/rpc"), json=request, headers=headers)
3285 response.raise_for_status()
3286 result = response.json()
3288 # Check for RPC errors
3289 if "error" in result:
3290 errors.append(f"Gateway {gw_data['name']}: {result['error'].get('message', 'Unknown RPC error')}")
3291 continue
3293 # ═══════════════════════════════════════════════════════════════════════════
3294 # PHASE 3: Update last_seen using fresh DB session
3295 # ═══════════════════════════════════════════════════════════════════════════
3296 try:
3297 with fresh_db_session() as update_db:
3298 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gw_data["id"])).scalar_one_or_none()
3299 if db_gateway: 3299 ↛ 3306line 3299 didn't jump to line 3306
3300 db_gateway.last_seen = datetime.now(timezone.utc)
3301 update_db.commit()
3302 except Exception as update_error:
3303 logger.warning(f"Failed to update last_seen for gateway {gw_data['name']}: {update_error}")
3305 # Success - return the result
3306 logger.info(f"Successfully forwarded request to gateway {gw_data['name']}")
3307 return result.get("result")
3309 except Exception as e:
3310 error_msg = f"Gateway {gw_data['name']}: {str(e)}"
3311 errors.append(error_msg)
3312 logger.warning(f"Failed to forward request to gateway {gw_data['name']}: {e}")
3313 continue
3315 # If we get here, all gateways failed
3316 error_summary = "; ".join(errors)
3317 raise GatewayConnectionError(f"All gateways failed to handle request '{method}': {error_summary}")
3319 async def _handle_gateway_failure(self, gateway: DbGateway) -> None:
3320 """Tracks and handles gateway failures during health checks.
3321 If the failure count exceeds the threshold, the gateway is deactivated.
3323 Args:
3324 gateway: The gateway object that failed its health check.
3326 Returns:
3327 None
3329 Examples:
3330 >>> from mcpgateway.services.gateway_service import GatewayService
3331 >>> service = GatewayService()
3332 >>> gateway = type('Gateway', (), {
3333 ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True
3334 ... })()
3335 >>> service._gateway_failure_counts = {}
3336 >>> import asyncio
3337 >>> # Test failure counting
3338 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS
3339 >>> service._gateway_failure_counts['gw1'] >= 1
3340 True
3342 >>> # Test disabled gateway (no action)
3343 >>> gateway.enabled = False
3344 >>> old_count = service._gateway_failure_counts.get('gw1', 0)
3345 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS
3346 >>> service._gateway_failure_counts.get('gw1', 0) == old_count
3347 True
3348 """
3349 if GW_FAILURE_THRESHOLD == -1:
3350 return # Gateway failure action disabled
3352 if not gateway.enabled:
3353 return # No action needed for inactive gateways
3355 if not gateway.reachable:
3356 return # No action needed for unreachable gateways
3358 count = self._gateway_failure_counts.get(gateway.id, 0) + 1
3359 self._gateway_failure_counts[gateway.id] = count
3361 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).")
3363 if count >= GW_FAILURE_THRESHOLD: 3363 ↛ 3364line 3363 didn't jump to line 3364 because the condition on line 3363 was never true
3364 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...")
3365 with cast(Any, SessionLocal)() as db:
3366 await self.set_gateway_state(db, gateway.id, activate=True, reachable=False, only_update_reachable=True)
3367 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation
3369 async def check_health_of_gateways(self, gateways: List[DbGateway], user_email: Optional[str] = None) -> bool:
3370 """Check health of a batch of gateways.
3372 Performs an asynchronous health-check for each gateway in `gateways` using
3373 an Async HTTP client. The function handles different authentication
3374 modes (OAuth client_credentials and authorization_code, and non-OAuth
3375 auth headers). When a gateway uses the authorization_code flow, the
3376 optional `user_email` is used to look up stored user tokens with
3377 fresh_db_session(). On individual failures the service will record the
3378 failure and call internal failure handling which may mark a gateway
3379 unreachable or deactivate it after repeated failures. If a previously
3380 unreachable gateway becomes healthy again the service will attempt to
3381 update its reachable status.
3383 NOTE: This method intentionally does NOT take a db parameter.
3384 DB access uses fresh_db_session() only when needed, avoiding holding
3385 connections during HTTP calls to MCP servers.
3387 Args:
3388 gateways: List of DbGateway objects to check.
3389 user_email: Optional MCP gateway user email used to retrieve
3390 stored OAuth tokens for gateways using the
3391 "authorization_code" grant type. If not provided, authorization
3392 code flows that require a user token will be treated as failed.
3394 Returns:
3395 bool: True when the health-check batch completes. This return
3396 value indicates completion of the checks, not that every gateway
3397 was healthy. Individual gateway failures are handled internally
3398 (via _handle_gateway_failure and status updates).
3400 Examples:
3401 >>> from mcpgateway.services.gateway_service import GatewayService
3402 >>> from unittest.mock import MagicMock
3403 >>> service = GatewayService()
3404 >>> gateways = [MagicMock()]
3405 >>> gateways[0].ca_certificate = None
3406 >>> import asyncio
3407 >>> result = asyncio.run(service.check_health_of_gateways(gateways))
3408 >>> isinstance(result, bool)
3409 True
3411 >>> # Test empty gateway list
3412 >>> empty_result = asyncio.run(service.check_health_of_gateways([]))
3413 >>> empty_result
3414 True
3416 >>> # Test multiple gateways (basic smoke)
3417 >>> multiple_gateways = [MagicMock(), MagicMock(), MagicMock()]
3418 >>> for i, gw in enumerate(multiple_gateways):
3419 ... gw.name = f"gateway_{i}"
3420 ... gw.url = f"http://gateway{i}.example.com"
3421 ... gw.transport = "SSE"
3422 ... gw.enabled = True
3423 ... gw.reachable = True
3424 ... gw.auth_value = {}
3425 ... gw.ca_certificate = None
3426 >>> multi_result = asyncio.run(service.check_health_of_gateways(multiple_gateways))
3427 >>> isinstance(multi_result, bool)
3428 True
3429 """
3430 start_time = time.monotonic()
3431 concurrency_limit = min(settings.max_concurrent_health_checks, max(10, os.cpu_count() * 5)) # adaptive concurrency
3432 semaphore = asyncio.Semaphore(concurrency_limit)
3434 async def limited_check(gateway: DbGateway):
3435 """
3436 Checks the health of a single gateway while respecting a concurrency limit.
3438 This function checks the health of the given database gateway, ensuring that
3439 the number of concurrent checks does not exceed a predefined limit. The check
3440 is performed asynchronously and uses a semaphore to manage concurrency.
3442 Args:
3443 gateway (DbGateway): The database gateway whose health is to be checked.
3445 Raises:
3446 Any exceptions raised during the health check will be propagated to the caller.
3447 """
3448 async with semaphore:
3449 try:
3450 await asyncio.wait_for(
3451 self._check_single_gateway_health(gateway, user_email),
3452 timeout=settings.gateway_health_check_timeout,
3453 )
3454 except asyncio.TimeoutError:
3455 logger.warning(f"Gateway {getattr(gateway, 'name', 'unknown')} health check timed out after {settings.gateway_health_check_timeout}s")
3456 # Treat timeout as a failed health check
3457 await self._handle_gateway_failure(gateway)
3459 # Create trace span for health check batch
3460 with create_span("gateway.health_check_batch", {"gateway.count": len(gateways), "check.type": "health"}) as batch_span:
3461 # Chunk processing to avoid overload
3462 if not gateways:
3463 return True
3464 chunk_size = concurrency_limit
3465 for i in range(0, len(gateways), chunk_size):
3466 # batch will be a sublist of gateways from index i to i + chunk_size
3467 batch = gateways[i : i + chunk_size]
3469 # Each task is a health check for a gateway in the batch, excluding those with auth_type == "one_time_auth"
3470 tasks = [limited_check(gw) for gw in batch if gw.auth_type != "one_time_auth"]
3472 # Execute all health checks concurrently
3473 await asyncio.gather(*tasks, return_exceptions=True)
3474 await asyncio.sleep(0.05) # small pause prevents network saturation
3476 elapsed = time.monotonic() - start_time
3478 if batch_span:
3479 batch_span.set_attribute("check.duration_ms", int(elapsed * 1000))
3480 batch_span.set_attribute("check.completed", True)
3482 logger.debug(f"Health check batch completed for {len(gateways)} gateways in {elapsed:.2f}s")
3484 return True
3486 async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Optional[str] = None) -> None:
3487 """Check health of a single gateway.
3489 NOTE: This method intentionally does NOT take a db parameter.
3490 DB access uses fresh_db_session() only when needed, avoiding holding
3491 connections during HTTP calls to MCP servers.
3493 Args:
3494 gateway: Gateway to check (may be detached from session)
3495 user_email: Optional user email for OAuth token lookup
3496 """
3497 # Extract gateway data upfront (gateway may be detached from session)
3498 gateway_id = gateway.id
3499 gateway_name = gateway.name
3500 gateway_url = gateway.url
3501 gateway_transport = gateway.transport
3502 gateway_enabled = gateway.enabled
3503 gateway_reachable = gateway.reachable
3504 gateway_ca_certificate = gateway.ca_certificate
3505 gateway_ca_certificate_sig = gateway.ca_certificate_sig
3506 gateway_auth_type = gateway.auth_type
3507 gateway_oauth_config = gateway.oauth_config
3508 gateway_auth_value = gateway.auth_value
3509 gateway_auth_query_params = gateway.auth_query_params
3511 # Handle query_param auth - decrypt and apply to URL for health check
3512 auth_query_params_decrypted: Optional[Dict[str, str]] = None
3513 if gateway_auth_type == "query_param" and gateway_auth_query_params:
3514 auth_query_params_decrypted = {}
3515 for param_key, encrypted_value in gateway_auth_query_params.items():
3516 if encrypted_value: 3516 ↛ 3515line 3516 didn't jump to line 3515 because the condition on line 3516 was always true
3517 try:
3518 decrypted = decode_auth(encrypted_value)
3519 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3520 except Exception:
3521 logger.debug(f"Failed to decrypt query param '{param_key}' for health check")
3522 if auth_query_params_decrypted: 3522 ↛ 3526line 3522 didn't jump to line 3526 because the condition on line 3522 was always true
3523 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
3525 # Sanitize URL for logging/telemetry (redacts sensitive query params)
3526 gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted)
3528 # Create span for individual gateway health check
3529 with create_span(
3530 "gateway.health_check",
3531 {
3532 "gateway.name": gateway_name,
3533 "gateway.id": str(gateway_id),
3534 "gateway.url": gateway_url_sanitized,
3535 "gateway.transport": gateway_transport,
3536 "gateway.enabled": gateway_enabled,
3537 "http.method": "GET",
3538 "http.url": gateway_url_sanitized,
3539 },
3540 ) as span:
3541 valid = False
3542 if gateway_ca_certificate:
3543 if settings.enable_ed25519_signing: 3543 ↛ 3547line 3543 didn't jump to line 3547 because the condition on line 3543 was always true
3544 public_key_pem = settings.ed25519_public_key
3545 valid = validate_signature(gateway_ca_certificate.encode(), gateway_ca_certificate_sig, public_key_pem)
3546 else:
3547 valid = True
3548 if valid:
3549 ssl_context = self.create_ssl_context(gateway_ca_certificate)
3550 else:
3551 ssl_context = None
3553 def get_httpx_client_factory(
3554 headers: dict[str, str] | None = None,
3555 timeout: httpx.Timeout | None = None,
3556 auth: httpx.Auth | None = None,
3557 ) -> httpx.AsyncClient:
3558 """Factory function to create httpx.AsyncClient with optional CA certificate.
3560 Args:
3561 headers: Optional headers for the client
3562 timeout: Optional timeout for the client
3563 auth: Optional auth for the client
3565 Returns:
3566 httpx.AsyncClient: Configured HTTPX async client
3567 """
3568 return httpx.AsyncClient(
3569 verify=ssl_context if ssl_context else get_default_verify(),
3570 follow_redirects=True,
3571 headers=headers,
3572 timeout=timeout if timeout else get_http_timeout(),
3573 auth=auth,
3574 limits=httpx.Limits(
3575 max_connections=settings.httpx_max_connections,
3576 max_keepalive_connections=settings.httpx_max_keepalive_connections,
3577 keepalive_expiry=settings.httpx_keepalive_expiry,
3578 ),
3579 )
3581 # Use isolated client for gateway health checks (each gateway may have custom CA cert)
3582 # Use admin timeout for health checks (fail fast, don't wait 120s for slow upstreams)
3583 # Pass ssl_context if present, otherwise let get_isolated_http_client use skip_ssl_verify setting
3584 async with get_isolated_http_client(timeout=settings.httpx_admin_read_timeout, verify=ssl_context) as client:
3585 logger.debug(f"Checking health of gateway: {gateway_name} ({gateway_url_sanitized})")
3586 try:
3587 # Handle different authentication types
3588 headers = {}
3590 if gateway_auth_type == "oauth" and gateway_oauth_config:
3591 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
3593 if grant_type == "authorization_code":
3594 # For Authorization Code flow, try to get stored tokens
3595 try:
3596 # First-Party
3597 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3599 # Use fresh session for OAuth token lookup
3600 with fresh_db_session() as token_db:
3601 token_storage = TokenStorageService(token_db)
3603 # Get user-specific OAuth token
3604 if not user_email:
3605 if span: 3605 ↛ 3608line 3605 didn't jump to line 3608 because the condition on line 3605 was always true
3606 span.set_attribute("health.status", "unhealthy")
3607 span.set_attribute("error.message", "User email required for OAuth token")
3608 await self._handle_gateway_failure(gateway)
3609 return
3611 access_token = await token_storage.get_user_token(gateway_id, user_email)
3613 if access_token:
3614 headers["Authorization"] = f"Bearer {access_token}"
3615 else:
3616 if span: 3616 ↛ 3619line 3616 didn't jump to line 3619 because the condition on line 3616 was always true
3617 span.set_attribute("health.status", "unhealthy")
3618 span.set_attribute("error.message", "No valid OAuth token for user")
3619 await self._handle_gateway_failure(gateway)
3620 return
3621 except Exception as e:
3622 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
3623 if span: 3623 ↛ 3626line 3623 didn't jump to line 3626 because the condition on line 3623 was always true
3624 span.set_attribute("health.status", "unhealthy")
3625 span.set_attribute("error.message", "Failed to obtain stored OAuth token")
3626 await self._handle_gateway_failure(gateway)
3627 return
3628 else:
3629 # For Client Credentials flow, get token directly
3630 try:
3631 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
3632 headers["Authorization"] = f"Bearer {access_token}"
3633 except Exception as e:
3634 if span: 3634 ↛ 3637line 3634 didn't jump to line 3637 because the condition on line 3634 was always true
3635 span.set_attribute("health.status", "unhealthy")
3636 span.set_attribute("error.message", str(e))
3637 await self._handle_gateway_failure(gateway)
3638 return
3639 else:
3640 # Handle non-OAuth authentication (existing logic)
3641 auth_data = gateway_auth_value or {}
3642 if isinstance(auth_data, str):
3643 headers = decode_auth(auth_data)
3644 elif isinstance(auth_data, dict):
3645 headers = {str(k): str(v) for k, v in auth_data.items()}
3646 else:
3647 headers = {}
3649 # Perform the GET and raise on 4xx/5xx
3650 if (gateway_transport).lower() == "sse":
3651 timeout = httpx.Timeout(settings.health_check_timeout)
3652 async with client.stream("GET", gateway_url, headers=headers, timeout=timeout) as response:
3653 # This will raise immediately if status is 4xx/5xx
3654 response.raise_for_status()
3655 if span: 3655 ↛ 3698line 3655 didn't jump to line 3698
3656 span.set_attribute("http.status_code", response.status_code)
3657 elif (gateway_transport).lower() == "streamablehttp":
3658 # Use session pool if enabled for faster health checks
3659 use_pool = False
3660 pool = None
3661 if settings.mcp_session_pool_enabled: 3661 ↛ 3669line 3661 didn't jump to line 3669 because the condition on line 3661 was always true
3662 try:
3663 pool = get_mcp_session_pool()
3664 use_pool = True
3665 except RuntimeError:
3666 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3667 pass
3669 if use_pool and pool is not None:
3670 # Health checks are system operations, not user-driven.
3671 # Use system identity to isolate from user sessions.
3672 async with pool.session(
3673 url=gateway_url,
3674 headers=headers,
3675 transport_type=TransportType.STREAMABLE_HTTP,
3676 httpx_client_factory=get_httpx_client_factory,
3677 user_identity="_system_health_check",
3678 gateway_id=gateway_id,
3679 ) as pooled:
3680 # Optional explicit RPC verification (off by default for performance).
3681 # Pool's internal staleness check handles health via _validate_session.
3682 if settings.mcp_session_pool_explicit_health_rpc: 3682 ↛ 3698line 3682 didn't jump to line 3698
3683 await asyncio.wait_for(
3684 pooled.session.list_tools(),
3685 timeout=settings.health_check_timeout,
3686 )
3687 else:
3688 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout, httpx_client_factory=get_httpx_client_factory) as (
3689 read_stream,
3690 write_stream,
3691 _get_session_id,
3692 ):
3693 async with ClientSession(read_stream, write_stream) as session:
3694 # Initialize the session
3695 response = await session.initialize()
3697 # Reactivate gateway if it was previously inactive and health check passed now
3698 if gateway_enabled and not gateway_reachable:
3699 logger.info(f"Reactivating gateway: {gateway_name}, as it is healthy now")
3700 with cast(Any, SessionLocal)() as status_db:
3701 await self.set_gateway_state(status_db, gateway_id, activate=True, reachable=True, only_update_reachable=True)
3703 # Update last_seen with fresh session (gateway object is detached)
3704 try:
3705 with fresh_db_session() as update_db:
3706 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
3707 if db_gateway: 3707 ↛ 3714line 3707 didn't jump to line 3714
3708 db_gateway.last_seen = datetime.now(timezone.utc)
3709 update_db.commit()
3710 except Exception as update_error:
3711 logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}")
3713 # Auto-refresh tools/resources/prompts if enabled
3714 if settings.auto_refresh_servers:
3715 try:
3716 # Throttling: Check if refresh is needed based on last_refresh_at
3717 refresh_needed = True
3718 if gateway.last_refresh_at: 3718 ↛ 3736line 3718 didn't jump to line 3736 because the condition on line 3718 was always true
3719 # Default to config value if configured interval is missing
3721 last_refresh = gateway.last_refresh_at
3722 if last_refresh.tzinfo is None: 3722 ↛ 3723line 3722 didn't jump to line 3723 because the condition on line 3722 was never true
3723 last_refresh = last_refresh.replace(tzinfo=timezone.utc)
3725 # Use per-gateway interval if set, otherwise fall back to global default
3726 refresh_interval = getattr(settings, "gateway_auto_refresh_interval", 300)
3727 if gateway.refresh_interval_seconds is not None: 3727 ↛ 3730line 3727 didn't jump to line 3730 because the condition on line 3727 was always true
3728 refresh_interval = gateway.refresh_interval_seconds
3730 time_since_refresh = (datetime.now(timezone.utc) - last_refresh).total_seconds()
3732 if time_since_refresh < refresh_interval:
3733 refresh_needed = False
3734 logger.debug(f"Skipping auto-refresh for {gateway_name}: last refreshed {int(time_since_refresh)}s ago")
3736 if refresh_needed:
3737 # Locking: Try to acquire lock to avoid conflict with manual refresh
3738 lock = self._get_refresh_lock(gateway_id)
3739 if not lock.locked():
3740 # Acquire lock to prevent concurrent manual refresh
3741 async with lock:
3742 await self._refresh_gateway_tools_resources_prompts(
3743 gateway_id=gateway_id,
3744 _user_email=user_email,
3745 created_via="health_check",
3746 pre_auth_headers=headers if headers else None,
3747 gateway=gateway,
3748 )
3749 else:
3750 logger.debug(f"Skipping auto-refresh for {gateway_name}: lock held (likely manual refresh in progress)")
3751 except Exception as refresh_error:
3752 logger.warning(f"Failed to refresh tools for gateway {gateway_name}: {refresh_error}")
3754 if span:
3755 span.set_attribute("health.status", "healthy")
3756 span.set_attribute("success", True)
3758 except Exception as e:
3759 if span:
3760 span.set_attribute("health.status", "unhealthy")
3761 span.set_attribute("error.message", str(e))
3763 # Set the logger as debug as this check happens for each interval
3764 logger.debug(f"Health check failed for gateway {gateway_name}: {e}")
3765 await self._handle_gateway_failure(gateway)
3767 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
3768 """
3769 Aggregate capabilities across all gateways.
3771 Args:
3772 db: Database session
3774 Returns:
3775 Dictionary of aggregated capabilities
3777 Examples:
3778 >>> from mcpgateway.services.gateway_service import GatewayService
3779 >>> from unittest.mock import MagicMock
3780 >>> service = GatewayService()
3781 >>> db = MagicMock()
3782 >>> gateway_mock = MagicMock()
3783 >>> gateway_mock.capabilities = {"tools": {"listChanged": True}, "custom": {"feature": True}}
3784 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_mock]
3785 >>> import asyncio
3786 >>> result = asyncio.run(service.aggregate_capabilities(db))
3787 >>> isinstance(result, dict)
3788 True
3789 >>> 'prompts' in result
3790 True
3791 >>> 'resources' in result
3792 True
3793 >>> 'tools' in result
3794 True
3795 >>> 'logging' in result
3796 True
3797 >>> result['prompts']['listChanged']
3798 True
3799 >>> result['resources']['subscribe']
3800 True
3801 >>> result['resources']['listChanged']
3802 True
3803 >>> result['tools']['listChanged']
3804 True
3805 >>> isinstance(result['logging'], dict)
3806 True
3808 >>> # Test with no gateways
3809 >>> db.execute.return_value.scalars.return_value.all.return_value = []
3810 >>> empty_result = asyncio.run(service.aggregate_capabilities(db))
3811 >>> isinstance(empty_result, dict)
3812 True
3813 >>> 'tools' in empty_result
3814 True
3816 >>> # Test capability merging
3817 >>> gateway1 = MagicMock()
3818 >>> gateway1.capabilities = {"tools": {"feature1": True}}
3819 >>> gateway2 = MagicMock()
3820 >>> gateway2.capabilities = {"tools": {"feature2": True}}
3821 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway1, gateway2]
3822 >>> merged_result = asyncio.run(service.aggregate_capabilities(db))
3823 >>> merged_result['tools']['listChanged'] # Default capability
3824 True
3825 """
3826 capabilities = {
3827 "prompts": {"listChanged": True},
3828 "resources": {"subscribe": True, "listChanged": True},
3829 "tools": {"listChanged": True},
3830 "logging": {},
3831 }
3833 # Get all active gateways
3834 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
3836 # Combine capabilities
3837 for gateway in gateways:
3838 if gateway.capabilities:
3839 for key, value in gateway.capabilities.items():
3840 if key not in capabilities:
3841 capabilities[key] = value
3842 elif isinstance(value, dict): 3842 ↛ 3839line 3842 didn't jump to line 3839 because the condition on line 3842 was always true
3843 capabilities[key].update(value)
3845 return capabilities
3847 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
3848 """Subscribe to gateway events.
3850 Creates a new event queue and subscribes to gateway events. Events are
3851 yielded as they are published. The subscription is automatically cleaned
3852 up when the generator is closed or goes out of scope.
3854 Yields:
3855 Dict[str, Any]: Gateway event messages with 'type', 'data', and 'timestamp' fields
3857 Examples:
3858 >>> service = GatewayService()
3859 >>> import asyncio
3860 >>> from unittest.mock import MagicMock
3861 >>> # Create a mock async generator for the event service
3862 >>> async def mock_event_gen():
3863 ... yield {"type": "test_event", "data": "payload"}
3864 >>>
3865 >>> # Mock the event service to return our generator
3866 >>> service._event_service = MagicMock()
3867 >>> service._event_service.subscribe_events.return_value = mock_event_gen()
3868 >>>
3869 >>> # Test the subscription
3870 >>> async def test_sub():
3871 ... async for event in service.subscribe_events():
3872 ... return event
3873 >>>
3874 >>> result = asyncio.run(test_sub())
3875 >>> result
3876 {'type': 'test_event', 'data': 'payload'}
3877 """
3878 async for event in self._event_service.subscribe_events():
3879 yield event
3881 async def _initialize_gateway(
3882 self,
3883 url: str,
3884 authentication: Optional[Dict[str, str]] = None,
3885 transport: str = "SSE",
3886 auth_type: Optional[str] = None,
3887 oauth_config: Optional[Dict[str, Any]] = None,
3888 ca_certificate: Optional[bytes] = None,
3889 pre_auth_headers: Optional[Dict[str, str]] = None,
3890 include_resources: bool = True,
3891 include_prompts: bool = True,
3892 auth_query_params: Optional[Dict[str, str]] = None,
3893 oauth_auto_fetch_tool_flag: Optional[bool] = False,
3894 ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
3895 """Initialize connection to a gateway and retrieve its capabilities.
3897 Connects to an MCP gateway using the specified transport protocol,
3898 performs the MCP handshake, and retrieves capabilities, tools,
3899 resources, and prompts from the gateway.
3901 Args:
3902 url: Gateway URL to connect to
3903 authentication: Optional authentication headers for the connection
3904 transport: Transport protocol - "SSE" or "StreamableHTTP"
3905 auth_type: Authentication type - "basic", "bearer", "headers", "oauth", "query_param" or None
3906 oauth_config: OAuth configuration if auth_type is "oauth"
3907 ca_certificate: CA certificate for SSL verification
3908 pre_auth_headers: Pre-authenticated headers to skip OAuth token fetch (for reuse)
3909 include_resources: Whether to include resources in the fetch
3910 include_prompts: Whether to include prompts in the fetch
3911 auth_query_params: Query param names for URL sanitization in error logs (decrypted values)
3912 oauth_auto_fetch_tool_flag: Whether to skip the early return for OAuth Authorization Code flow.
3913 When False (default), auth_code gateways return empty lists immediately (for health checks).
3914 When True, attempts to connect even for auth_code gateways (for activation after user authorization).
3916 Returns:
3917 tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
3918 Capabilities dictionary, list of ToolCreate objects, list of ResourceCreate objects, and list of PromptCreate objects
3920 Raises:
3921 GatewayConnectionError: If connection or initialization fails
3923 Examples:
3924 >>> service = GatewayService()
3925 >>> # Test parameter validation
3926 >>> import asyncio
3927 >>> from unittest.mock import AsyncMock
3928 >>> # Avoid opening a real SSE connection in doctests (it can leak anyio streams on failure paths)
3929 >>> service.connect_to_sse_server = AsyncMock(side_effect=GatewayConnectionError("boom"))
3930 >>> async def test_params():
3931 ... try:
3932 ... await service._initialize_gateway("hello//")
3933 ... except Exception as e:
3934 ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e)
3936 >>> asyncio.run(test_params())
3937 True
3939 >>> # Test default parameters
3940 >>> hasattr(service, '_initialize_gateway')
3941 True
3942 >>> import inspect
3943 >>> sig = inspect.signature(service._initialize_gateway)
3944 >>> sig.parameters['transport'].default
3945 'SSE'
3946 >>> sig.parameters['authentication'].default is None
3947 True
3948 >>>
3949 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
3950 >>> asyncio.run(service._http_client.aclose())
3951 """
3952 try:
3953 if authentication is None:
3954 authentication = {}
3956 # Use pre-authenticated headers if provided (avoids duplicate OAuth token fetch)
3957 if pre_auth_headers:
3958 authentication = pre_auth_headers
3959 # Handle OAuth authentication
3960 elif auth_type == "oauth" and oauth_config:
3961 grant_type = oauth_config.get("grant_type", "client_credentials")
3963 if grant_type == "authorization_code":
3964 if not oauth_auto_fetch_tool_flag:
3965 # For Authorization Code flow during health checks, we can't initialize immediately
3966 # because we need user consent. Just store the configuration
3967 # and let the user complete the OAuth flow later.
3968 logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""")
3969 # Don't try to get access token here - it will be obtained during tool invocation
3970 authentication = {}
3972 # Skip MCP server connection for Authorization Code flow
3973 # Tools will be fetched after OAuth completion
3974 return {}, [], [], []
3975 # When flag is True (activation), skip token fetch but try to connect
3976 # This allows activation to proceed - actual auth happens during tool invocation
3977 logger.debug("OAuth Authorization Code gateway activation - skipping token fetch")
3978 elif grant_type == "client_credentials": 3978 ↛ 3988line 3978 didn't jump to line 3988 because the condition on line 3978 was always true
3979 # For Client Credentials flow, we can get the token immediately
3980 try:
3981 logger.debug("Obtaining OAuth access token for Client Credentials flow")
3982 access_token = await self.oauth_manager.get_access_token(oauth_config)
3983 authentication = {"Authorization": f"Bearer {access_token}"}
3984 except Exception as e:
3985 logger.error(f"Failed to obtain OAuth access token: {e}")
3986 raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}")
3988 capabilities = {}
3989 tools = []
3990 resources = []
3991 prompts = []
3992 if auth_type in ("basic", "bearer", "headers") and isinstance(authentication, str):
3993 authentication = decode_auth(authentication)
3994 if transport.lower() == "sse":
3995 capabilities, tools, resources, prompts = await self.connect_to_sse_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params)
3996 elif transport.lower() == "streamablehttp":
3997 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params)
3999 return capabilities, tools, resources, prompts
4000 except Exception as e:
4002 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
4003 root_cause = e
4004 if isinstance(e, BaseExceptionGroup):
4005 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
4006 root_cause = root_cause.exceptions[0]
4007 sanitized_url = sanitize_url_for_logging(url, auth_query_params)
4008 sanitized_error = sanitize_exception_message(str(root_cause), auth_query_params)
4009 logger.error(f"Gateway initialization failed for {sanitized_url}: {sanitized_error}", exc_info=True)
4010 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: {sanitized_error}")
4012 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:
4013 """Sync function for database operations (runs in thread).
4015 Args:
4016 include_inactive: Whether to include inactive gateways
4018 Returns:
4019 List[DbGateway]: List of active gateways
4021 Examples:
4022 >>> from unittest.mock import patch, MagicMock
4023 >>> service = GatewayService()
4024 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session:
4025 ... mock_db = MagicMock()
4026 ... mock_session.return_value.__enter__.return_value = mock_db
4027 ... mock_db.execute.return_value.scalars.return_value.all.return_value = []
4028 ... result = service._get_gateways()
4029 ... isinstance(result, list)
4030 True
4032 >>> # Test include_inactive parameter handling
4033 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session:
4034 ... mock_db = MagicMock()
4035 ... mock_session.return_value.__enter__.return_value = mock_db
4036 ... mock_db.execute.return_value.scalars.return_value.all.return_value = []
4037 ... result_active_only = service._get_gateways(include_inactive=False)
4038 ... isinstance(result_active_only, list)
4039 True
4040 """
4041 with cast(Any, SessionLocal)() as db:
4042 if include_inactive:
4043 return db.execute(select(DbGateway)).scalars().all()
4044 # Only return active gateways
4045 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
4047 def get_first_gateway_by_url(self, db: Session, url: str, team_id: Optional[str] = None, include_inactive: bool = False) -> Optional[GatewayRead]:
4048 """Return the first DbGateway matching the given URL and optional team_id.
4050 This is a synchronous helper intended for use from request handlers where
4051 a simple DB lookup is needed. It normalizes the provided URL similar to
4052 how gateways are stored and matches by the `url` column. If team_id is
4053 provided, it restricts the search to that team.
4055 Args:
4056 db: Database session to use for the query
4057 url: Gateway base URL to match (will be normalized)
4058 team_id: Optional team id to restrict search
4059 include_inactive: Whether to include inactive gateways
4061 Returns:
4062 Optional[DbGateway]: First matching gateway or None
4063 """
4064 query = select(DbGateway).where(DbGateway.url == url)
4065 if not include_inactive: 4065 ↛ 4067line 4065 didn't jump to line 4067 because the condition on line 4065 was always true
4066 query = query.where(DbGateway.enabled)
4067 if team_id: 4067 ↛ 4068line 4067 didn't jump to line 4068 because the condition on line 4067 was never true
4068 query = query.where(DbGateway.team_id == team_id)
4069 result = db.execute(query).scalars().first()
4070 # Wrap the DB object in the GatewayRead schema for consistency with
4071 # other service methods. Return None if no match found.
4072 if result is None:
4073 return None
4074 return GatewayRead.model_validate(self._prepare_gateway_for_read(result)).masked()
4076 async def _run_leader_heartbeat(self) -> None:
4077 """Run leader heartbeat loop to keep leader key alive.
4079 This runs independently from health checks to ensure the leader key
4080 is refreshed frequently enough (every redis_leader_heartbeat_interval seconds)
4081 to prevent expiration during long-running health check operations.
4083 The loop exits if this instance loses leadership.
4084 """
4085 while True:
4086 try:
4087 await asyncio.sleep(self._leader_heartbeat_interval)
4089 if not self._redis_client:
4090 return
4092 # Check if we're still the leader
4093 current_leader = await self._redis_client.get(self._leader_key)
4094 if current_leader != self._instance_id:
4095 logger.info("Lost Redis leadership, stopping heartbeat")
4096 return
4098 # Refresh the leader key TTL
4099 await self._redis_client.expire(self._leader_key, self._leader_ttl)
4100 logger.debug(f"Leader heartbeat: refreshed TTL to {self._leader_ttl}s")
4102 except Exception as e:
4103 logger.warning(f"Leader heartbeat error: {e}")
4104 # Continue trying - the main health check loop will handle leadership loss
4106 async def _run_health_checks(self, user_email: str) -> None:
4107 """Run health checks periodically,
4108 Uses Redis or FileLock - for multiple workers.
4109 Uses simple health check for single worker mode.
4111 NOTE: This method intentionally does NOT take a db parameter.
4112 Health checks use fresh_db_session() only when DB access is needed,
4113 avoiding holding connections during HTTP calls to MCP servers.
4115 Args:
4116 user_email: Email of the user for OAuth token lookup
4118 Examples:
4119 >>> service = GatewayService()
4120 >>> service._health_check_interval = 0.1 # Short interval for testing
4121 >>> service._redis_client = None
4122 >>> import asyncio
4123 >>> # Test that method exists and is callable
4124 >>> callable(service._run_health_checks)
4125 True
4126 >>> # Test setup without actual execution (would run forever)
4127 >>> hasattr(service, '_health_check_interval')
4128 True
4129 >>> service._health_check_interval == 0.1
4130 True
4131 """
4133 while True:
4134 try:
4135 if self._redis_client and settings.cache_type == "redis": 4135 ↛ 4138line 4135 didn't jump to line 4138 because the condition on line 4135 was never true
4136 # Redis-based leader check (async, decode_responses=True returns strings)
4137 # Note: Leader key TTL refresh is handled by _run_leader_heartbeat task
4138 current_leader = await self._redis_client.get(self._leader_key)
4139 if current_leader != self._instance_id:
4140 return
4142 # Run health checks
4143 gateways = await asyncio.to_thread(self._get_gateways)
4144 if gateways:
4145 await self.check_health_of_gateways(gateways, user_email)
4147 await asyncio.sleep(self._health_check_interval)
4149 elif settings.cache_type == "none":
4150 try:
4151 # For single worker mode, run health checks directly
4152 gateways = await asyncio.to_thread(self._get_gateways)
4153 if gateways: 4153 ↛ 4158line 4153 didn't jump to line 4158 because the condition on line 4153 was always true
4154 await self.check_health_of_gateways(gateways, user_email)
4155 except Exception as e:
4156 logger.error(f"Health check run failed: {str(e)}")
4158 await asyncio.sleep(self._health_check_interval)
4160 else:
4161 # FileLock-based leader fallback
4162 try:
4163 self._file_lock.acquire(timeout=0)
4164 logger.info("File lock acquired. Running health checks.")
4166 while True:
4167 gateways = await asyncio.to_thread(self._get_gateways)
4168 if gateways: 4168 ↛ 4169line 4168 didn't jump to line 4169 because the condition on line 4168 was never true
4169 await self.check_health_of_gateways(gateways, user_email)
4170 await asyncio.sleep(self._health_check_interval)
4172 except Timeout:
4173 logger.debug("File lock already held. Retrying later.")
4174 await asyncio.sleep(self._health_check_interval)
4176 except Exception as e:
4177 logger.error(f"FileLock health check failed: {str(e)}")
4179 finally:
4180 if self._file_lock.is_locked: 4180 ↛ 4133line 4180 didn't jump to line 4133 because the condition on line 4180 was always true
4181 try:
4182 self._file_lock.release()
4183 logger.info("Released file lock.")
4184 except Exception as e:
4185 logger.warning(f"Failed to release file lock: {str(e)}")
4187 except Exception as e:
4188 logger.error(f"Unexpected error in health check loop: {str(e)}")
4189 await asyncio.sleep(self._health_check_interval)
4191 def _get_auth_headers(self) -> Dict[str, str]:
4192 """Get default headers for gateway requests (no authentication).
4194 SECURITY: This method intentionally does NOT include authentication credentials.
4195 Each gateway should have its own auth_value configured. Never send this gateway's
4196 admin credentials to remote servers.
4198 Returns:
4199 dict: Default headers without authentication
4201 Examples:
4202 >>> service = GatewayService()
4203 >>> headers = service._get_auth_headers()
4204 >>> isinstance(headers, dict)
4205 True
4206 >>> 'Content-Type' in headers
4207 True
4208 >>> headers['Content-Type']
4209 'application/json'
4210 >>> 'Authorization' not in headers # No credentials leaked
4211 True
4212 """
4213 return {"Content-Type": "application/json"}
4215 async def _notify_gateway_added(self, gateway: DbGateway) -> None:
4216 """Notify subscribers of gateway addition.
4218 Args:
4219 gateway: Gateway to add
4220 """
4221 event = {
4222 "type": "gateway_added",
4223 "data": {
4224 "id": gateway.id,
4225 "name": gateway.name,
4226 "url": gateway.url,
4227 "description": gateway.description,
4228 "enabled": gateway.enabled,
4229 },
4230 "timestamp": datetime.now(timezone.utc).isoformat(),
4231 }
4232 await self._publish_event(event)
4234 async def _notify_gateway_activated(self, gateway: DbGateway) -> None:
4235 """Notify subscribers of gateway activation.
4237 Args:
4238 gateway: Gateway to activate
4239 """
4240 event = {
4241 "type": "gateway_activated",
4242 "data": {
4243 "id": gateway.id,
4244 "name": gateway.name,
4245 "url": gateway.url,
4246 "enabled": gateway.enabled,
4247 "reachable": gateway.reachable,
4248 },
4249 "timestamp": datetime.now(timezone.utc).isoformat(),
4250 }
4251 await self._publish_event(event)
4253 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None:
4254 """Notify subscribers of gateway deactivation.
4256 Args:
4257 gateway: Gateway database object
4258 """
4259 event = {
4260 "type": "gateway_deactivated",
4261 "data": {
4262 "id": gateway.id,
4263 "name": gateway.name,
4264 "url": gateway.url,
4265 "enabled": gateway.enabled,
4266 "reachable": gateway.reachable,
4267 },
4268 "timestamp": datetime.now(timezone.utc).isoformat(),
4269 }
4270 await self._publish_event(event)
4272 async def _notify_gateway_offline(self, gateway: DbGateway) -> None:
4273 """
4274 Notify subscribers that gateway is offline (Enabled but Unreachable).
4276 Args:
4277 gateway: Gateway database object
4278 """
4279 event = {
4280 "type": "gateway_offline",
4281 "data": {
4282 "id": gateway.id,
4283 "name": gateway.name,
4284 "url": gateway.url,
4285 "enabled": True,
4286 "reachable": False,
4287 },
4288 "timestamp": datetime.now(timezone.utc).isoformat(),
4289 }
4290 await self._publish_event(event)
4292 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None:
4293 """Notify subscribers of gateway deletion.
4295 Args:
4296 gateway_info: Dict containing information about gateway to delete
4297 """
4298 event = {
4299 "type": "gateway_deleted",
4300 "data": gateway_info,
4301 "timestamp": datetime.now(timezone.utc).isoformat(),
4302 }
4303 await self._publish_event(event)
4305 async def _notify_gateway_removed(self, gateway: DbGateway) -> None:
4306 """Notify subscribers of gateway removal (deactivation).
4308 Args:
4309 gateway: Gateway to remove
4310 """
4311 event = {
4312 "type": "gateway_removed",
4313 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled},
4314 "timestamp": datetime.now(timezone.utc).isoformat(),
4315 }
4316 await self._publish_event(event)
4318 def convert_gateway_to_read(self, gateway: DbGateway) -> GatewayRead:
4319 """Convert a DbGateway instance to a GatewayRead Pydantic model.
4321 Args:
4322 gateway: Gateway database object
4324 Returns:
4325 GatewayRead: Pydantic model instance
4326 """
4327 gateway_dict = gateway.__dict__.copy()
4328 gateway_dict.pop("_sa_instance_state", None)
4330 # Ensure auth_value is properly encoded
4331 if isinstance(gateway.auth_value, dict):
4332 gateway_dict["auth_value"] = encode_auth(gateway.auth_value)
4334 if gateway.tags:
4335 # Check tags are list of strings or list of Dict[str, str]
4336 if isinstance(gateway.tags[0], str):
4337 # Convert tags from List[str] to List[Dict[str, str]] for GatewayRead
4338 gateway_dict["tags"] = validate_tags_field(gateway.tags)
4339 else:
4340 gateway_dict["tags"] = gateway.tags
4341 else:
4342 gateway_dict["tags"] = []
4344 # Include metadata fields
4345 gateway_dict["created_by"] = getattr(gateway, "created_by", None)
4346 gateway_dict["modified_by"] = getattr(gateway, "modified_by", None)
4347 gateway_dict["created_at"] = getattr(gateway, "created_at", None)
4348 gateway_dict["updated_at"] = getattr(gateway, "updated_at", None)
4349 gateway_dict["version"] = getattr(gateway, "version", None)
4350 gateway_dict["team"] = getattr(gateway, "team", None)
4352 return GatewayRead.model_validate(gateway_dict).masked()
4354 def _prepare_gateway_for_read(self, gateway: DbGateway) -> DbGateway:
4355 """DEPRECATED: Use convert_gateway_to_read instead.
4357 Prepare a gateway object for GatewayRead validation.
4359 Ensures auth_value is in the correct format (encoded string) for the schema.
4360 Converts legacy List[str] tags to List[Dict[str, str]] format for GatewayRead schema.
4362 Args:
4363 gateway: Gateway database object
4365 Returns:
4366 Gateway object with properly formatted auth_value and tags
4367 """
4368 # If auth_value is a dict, encode it to string for GatewayRead schema
4369 if isinstance(gateway.auth_value, dict):
4370 gateway.auth_value = encode_auth(gateway.auth_value)
4372 # Handle legacy List[str] tags - convert to List[Dict[str, str]] for GatewayRead schema
4373 if gateway.tags:
4374 if isinstance(gateway.tags[0], str):
4375 # Legacy format: convert to dict format
4376 gateway.tags = validate_tags_field(gateway.tags)
4378 return gateway
4380 def _create_db_tool(
4381 self,
4382 tool: ToolCreate,
4383 gateway: DbGateway,
4384 created_by: Optional[str] = None,
4385 created_from_ip: Optional[str] = None,
4386 created_via: Optional[str] = None,
4387 created_user_agent: Optional[str] = None,
4388 ) -> DbTool:
4389 """Create a DbTool with consistent federation metadata across all scenarios.
4391 Args:
4392 tool: Tool creation schema
4393 gateway: Gateway database object
4394 created_by: Username who created/updated this tool
4395 created_from_ip: IP address of creator
4396 created_via: Creation method (ui, api, federation, rediscovery)
4397 created_user_agent: User agent of creation request
4399 Returns:
4400 DbTool: Consistently configured database tool object
4401 """
4402 return DbTool(
4403 original_name=tool.name,
4404 custom_name=tool.name,
4405 custom_name_slug=slugify(tool.name),
4406 display_name=generate_display_name(tool.name),
4407 url=gateway.url,
4408 description=tool.description,
4409 integration_type="MCP", # Gateway-discovered tools are MCP type
4410 request_type=tool.request_type,
4411 headers=tool.headers,
4412 input_schema=tool.input_schema,
4413 annotations=tool.annotations,
4414 jsonpath_filter=tool.jsonpath_filter,
4415 auth_type=gateway.auth_type,
4416 auth_value=encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
4417 # Federation metadata - consistent across all scenarios
4418 created_by=created_by or "system",
4419 created_from_ip=created_from_ip,
4420 created_via=created_via or "federation",
4421 created_user_agent=created_user_agent,
4422 federation_source=gateway.name,
4423 version=1,
4424 # Inherit team assignment and visibility from gateway
4425 team_id=gateway.team_id,
4426 owner_email=gateway.owner_email,
4427 visibility="public", # Federated tools should be public for discovery
4428 )
4430 def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGateway, created_via: str) -> List[DbTool]:
4431 """Helper to handle update-or-create logic for tools from MCP server.
4433 Args:
4434 db: Database session
4435 tools: List of tools from MCP server
4436 gateway: Gateway object
4437 created_via: String indicating creation source ("oauth", "update", etc.)
4439 Returns:
4440 List of new tools to be added to the database
4441 """
4442 if not tools:
4443 return []
4445 tools_to_add = []
4447 # Batch fetch all existing tools for this gateway
4448 tool_names = [tool.name for tool in tools if tool is not None]
4449 if not tool_names: 4449 ↛ 4450line 4449 didn't jump to line 4450 because the condition on line 4449 was never true
4450 return []
4452 existing_tools_query = select(DbTool).where(DbTool.gateway_id == gateway.id, DbTool.original_name.in_(tool_names))
4453 existing_tools = db.execute(existing_tools_query).scalars().all()
4454 existing_tools_map = {tool.original_name: tool for tool in existing_tools}
4456 for tool in tools:
4457 if tool is None:
4458 logger.warning("Skipping None tool in tools list")
4459 continue
4461 try:
4462 # Check if tool already exists for this gateway from the tools_map
4463 existing_tool = existing_tools_map.get(tool.name)
4464 if existing_tool:
4465 # Update existing tool if there are changes
4466 fields_to_update = False
4468 # Check basic field changes
4469 basic_fields_changed = (
4470 existing_tool.url != gateway.url or existing_tool.description != tool.description or existing_tool.integration_type != "MCP" or existing_tool.request_type != tool.request_type
4471 )
4473 # Check schema and configuration changes
4474 schema_fields_changed = (
4475 existing_tool.headers != tool.headers
4476 or existing_tool.input_schema != tool.input_schema
4477 or existing_tool.output_schema != tool.output_schema
4478 or existing_tool.jsonpath_filter != tool.jsonpath_filter
4479 )
4481 # Check authentication and visibility changes
4482 auth_fields_changed = existing_tool.auth_type != gateway.auth_type or existing_tool.auth_value != gateway.auth_value or existing_tool.visibility != gateway.visibility
4484 if basic_fields_changed or schema_fields_changed or auth_fields_changed: 4484 ↛ 4486line 4484 didn't jump to line 4486 because the condition on line 4484 was always true
4485 fields_to_update = True
4486 if fields_to_update: 4486 ↛ 4456line 4486 didn't jump to line 4456 because the condition on line 4486 was always true
4487 existing_tool.url = gateway.url
4488 existing_tool.description = tool.description
4489 existing_tool.integration_type = "MCP"
4490 existing_tool.request_type = tool.request_type
4491 existing_tool.headers = tool.headers
4492 existing_tool.input_schema = tool.input_schema
4493 existing_tool.output_schema = tool.output_schema
4494 existing_tool.jsonpath_filter = tool.jsonpath_filter
4495 existing_tool.auth_type = gateway.auth_type
4496 existing_tool.auth_value = gateway.auth_value
4497 existing_tool.visibility = gateway.visibility
4498 logger.debug(f"Updated existing tool: {tool.name}")
4499 else:
4500 # Create new tool if it doesn't exist
4501 db_tool = self._create_db_tool(
4502 tool=tool,
4503 gateway=gateway,
4504 created_by="system",
4505 created_via=created_via,
4506 )
4507 # Attach relationship to avoid NoneType during flush
4508 db_tool.gateway = gateway
4509 tools_to_add.append(db_tool)
4510 logger.debug(f"Created new tool: {tool.name}")
4511 except Exception as e:
4512 logger.warning(f"Failed to process tool {getattr(tool, 'name', 'unknown')}: {e}")
4513 continue
4515 return tools_to_add
4517 def _update_or_create_resources(self, db: Session, resources: List[Any], gateway: DbGateway, created_via: str) -> List[DbResource]:
4518 """Helper to handle update-or-create logic for resources from MCP server.
4520 Args:
4521 db: Database session
4522 resources: List of resources from MCP server
4523 gateway: Gateway object
4524 created_via: String indicating creation source ("oauth", "update", etc.)
4526 Returns:
4527 List of new resources to be added to the database
4528 """
4529 if not resources:
4530 return []
4532 resources_to_add = []
4534 # Batch fetch all existing resources for this gateway
4535 resource_uris = [resource.uri for resource in resources if resource is not None]
4536 if not resource_uris: 4536 ↛ 4537line 4536 didn't jump to line 4537 because the condition on line 4536 was never true
4537 return []
4539 existing_resources_query = select(DbResource).where(DbResource.gateway_id == gateway.id, DbResource.uri.in_(resource_uris))
4540 existing_resources = db.execute(existing_resources_query).scalars().all()
4541 existing_resources_map = {resource.uri: resource for resource in existing_resources}
4543 for resource in resources:
4544 if resource is None:
4545 logger.warning("Skipping None resource in resources list")
4546 continue
4548 try:
4549 # Check if resource already exists for this gateway from the resources_map
4550 existing_resource = existing_resources_map.get(resource.uri)
4552 if existing_resource:
4553 # Update existing resource if there are changes
4554 fields_to_update = False
4556 if ( 4556 ↛ 4565line 4556 didn't jump to line 4565 because the condition on line 4556 was always true
4557 existing_resource.name != resource.name
4558 or existing_resource.description != resource.description
4559 or existing_resource.mime_type != resource.mime_type
4560 or existing_resource.uri_template != resource.uri_template
4561 or existing_resource.visibility != gateway.visibility
4562 ):
4563 fields_to_update = True
4565 if fields_to_update: 4565 ↛ 4543line 4565 didn't jump to line 4543 because the condition on line 4565 was always true
4566 existing_resource.name = resource.name
4567 existing_resource.description = resource.description
4568 existing_resource.mime_type = resource.mime_type
4569 existing_resource.uri_template = resource.uri_template
4570 existing_resource.visibility = gateway.visibility
4571 logger.debug(f"Updated existing resource: {resource.uri}")
4572 else:
4573 # Create new resource if it doesn't exist
4574 db_resource = DbResource(
4575 uri=resource.uri,
4576 name=resource.name,
4577 description=resource.description,
4578 mime_type=resource.mime_type,
4579 uri_template=resource.uri_template,
4580 gateway_id=gateway.id,
4581 created_by="system",
4582 created_via=created_via,
4583 visibility=gateway.visibility,
4584 )
4585 resources_to_add.append(db_resource)
4586 logger.debug(f"Created new resource: {resource.uri}")
4587 except Exception as e:
4588 logger.warning(f"Failed to process resource {getattr(resource, 'uri', 'unknown')}: {e}")
4589 continue
4591 return resources_to_add
4593 def _update_or_create_prompts(self, db: Session, prompts: List[Any], gateway: DbGateway, created_via: str) -> List[DbPrompt]:
4594 """Helper to handle update-or-create logic for prompts from MCP server.
4596 Args:
4597 db: Database session
4598 prompts: List of prompts from MCP server
4599 gateway: Gateway object
4600 created_via: String indicating creation source ("oauth", "update", etc.)
4602 Returns:
4603 List of new prompts to be added to the database
4604 """
4605 if not prompts:
4606 return []
4608 prompts_to_add = []
4610 # Batch fetch all existing prompts for this gateway
4611 prompt_names = [prompt.name for prompt in prompts if prompt is not None]
4612 if not prompt_names: 4612 ↛ 4613line 4612 didn't jump to line 4613 because the condition on line 4612 was never true
4613 return []
4615 existing_prompts_query = select(DbPrompt).where(DbPrompt.gateway_id == gateway.id, DbPrompt.original_name.in_(prompt_names))
4616 existing_prompts = db.execute(existing_prompts_query).scalars().all()
4617 existing_prompts_map = {prompt.original_name: prompt for prompt in existing_prompts}
4619 for prompt in prompts:
4620 if prompt is None:
4621 logger.warning("Skipping None prompt in prompts list")
4622 continue
4624 try:
4625 # Check if resource already exists for this gateway from the prompts_map
4626 existing_prompt = existing_prompts_map.get(prompt.name)
4628 if existing_prompt:
4629 # Update existing prompt if there are changes
4630 fields_to_update = False
4632 if ( 4632 ↛ 4639line 4632 didn't jump to line 4639 because the condition on line 4632 was always true
4633 existing_prompt.description != prompt.description
4634 or existing_prompt.template != (prompt.template if hasattr(prompt, "template") else "")
4635 or existing_prompt.visibility != gateway.visibility
4636 ):
4637 fields_to_update = True
4639 if fields_to_update: 4639 ↛ 4619line 4639 didn't jump to line 4619 because the condition on line 4639 was always true
4640 existing_prompt.description = prompt.description
4641 existing_prompt.template = prompt.template if hasattr(prompt, "template") else ""
4642 existing_prompt.visibility = gateway.visibility
4643 logger.debug(f"Updated existing prompt: {prompt.name}")
4644 else:
4645 # Create new prompt if it doesn't exist
4646 db_prompt = DbPrompt(
4647 name=prompt.name,
4648 original_name=prompt.name,
4649 custom_name=prompt.name,
4650 display_name=prompt.name,
4651 description=prompt.description,
4652 template=prompt.template if hasattr(prompt, "template") else "",
4653 argument_schema={}, # Use argument_schema instead of arguments
4654 gateway_id=gateway.id,
4655 created_by="system",
4656 created_via=created_via,
4657 visibility=gateway.visibility,
4658 )
4659 db_prompt.gateway = gateway
4660 prompts_to_add.append(db_prompt)
4661 logger.debug(f"Created new prompt: {prompt.name}")
4662 except Exception as e:
4663 logger.warning(f"Failed to process prompt {getattr(prompt, 'name', 'unknown')}: {e}")
4664 continue
4666 return prompts_to_add
4668 async def _refresh_gateway_tools_resources_prompts(
4669 self,
4670 gateway_id: str,
4671 _user_email: Optional[str] = None,
4672 created_via: str = "health_check",
4673 pre_auth_headers: Optional[Dict[str, str]] = None,
4674 gateway: Optional[DbGateway] = None,
4675 include_resources: bool = True,
4676 include_prompts: bool = True,
4677 ) -> Dict[str, int]:
4678 """Refresh tools, resources, and prompts for a gateway during health checks.
4680 Fetches the latest tools/resources/prompts from the MCP server and syncs
4681 with the database (add new, update changed, remove stale). Only performs
4682 DB operations if actual changes are detected.
4684 This method uses fresh_db_session() internally to avoid holding
4685 connections during HTTP calls to MCP servers.
4687 Args:
4688 gateway_id: ID of the gateway to refresh
4689 _user_email: Optional user email for OAuth token lookup (unused currently)
4690 created_via: String indicating creation source (default: "health_check")
4691 pre_auth_headers: Pre-authenticated headers from health check to avoid duplicate OAuth token fetch
4692 gateway: Optional DbGateway object to avoid redundant DB lookup
4693 include_resources: Whether to include resources in the refresh
4694 include_prompts: Whether to include prompts in the refresh
4696 Returns:
4697 Dict with counts: {tools_added, tools_removed, resources_added,
4698 resources_removed, prompts_added, prompts_removed}
4700 Examples:
4701 >>> from mcpgateway.services.gateway_service import GatewayService
4702 >>> from unittest.mock import patch, MagicMock, AsyncMock
4703 >>> import asyncio
4705 >>> # Test gateway not found returns empty result
4706 >>> service = GatewayService()
4707 >>> mock_session = MagicMock()
4708 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = None
4709 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4710 ... mock_fresh.return_value.__enter__.return_value = mock_session
4711 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4712 >>> result['tools_added'] == 0 and result['tools_removed'] == 0
4713 True
4714 >>> result['resources_added'] == 0 and result['resources_removed'] == 0
4715 True
4716 >>> result['success'] is True and result['error'] is None
4717 True
4719 >>> # Test disabled gateway returns empty result
4720 >>> mock_gw = MagicMock()
4721 >>> mock_gw.enabled = False
4722 >>> mock_gw.reachable = True
4723 >>> mock_gw.name = 'test_gw'
4724 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = mock_gw
4725 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4726 ... mock_fresh.return_value.__enter__.return_value = mock_session
4727 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4728 >>> result['tools_added']
4729 0
4731 >>> # Test unreachable gateway returns empty result
4732 >>> mock_gw.enabled = True
4733 >>> mock_gw.reachable = False
4734 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4735 ... mock_fresh.return_value.__enter__.return_value = mock_session
4736 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4737 >>> result['tools_added']
4738 0
4740 >>> # Test method is async and callable
4741 >>> import inspect
4742 >>> inspect.iscoroutinefunction(service._refresh_gateway_tools_resources_prompts)
4743 True
4744 >>>
4745 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
4746 >>> asyncio.run(service._http_client.aclose())
4747 """
4748 result = {
4749 "tools_added": 0,
4750 "tools_removed": 0,
4751 "resources_added": 0,
4752 "resources_removed": 0,
4753 "prompts_added": 0,
4754 "prompts_removed": 0,
4755 "tools_updated": 0,
4756 "resources_updated": 0,
4757 "prompts_updated": 0,
4758 "success": True,
4759 "error": None,
4760 "validation_errors": [],
4761 }
4763 # Fetch gateway metadata only (no relationships needed for MCP call)
4764 # Use provided gateway object if available to save a DB call
4765 gateway_name = None
4766 gateway_url = None
4767 gateway_transport = None
4768 gateway_auth_type = None
4769 gateway_auth_value = None
4770 gateway_oauth_config = None
4771 gateway_ca_certificate = None
4772 gateway_auth_query_params = None
4774 if gateway:
4775 if not gateway.enabled or not gateway.reachable:
4776 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway.name}")
4777 return result
4779 gateway_name = gateway.name
4780 gateway_url = gateway.url
4781 gateway_transport = gateway.transport
4782 gateway_auth_type = gateway.auth_type
4783 gateway_auth_value = gateway.auth_value
4784 gateway_oauth_config = gateway.oauth_config
4785 gateway_ca_certificate = gateway.ca_certificate
4786 gateway_auth_query_params = gateway.auth_query_params
4787 else:
4788 with fresh_db_session() as db:
4789 gateway_obj = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
4791 if not gateway_obj:
4792 logger.warning(f"Gateway {gateway_id} not found for tool refresh")
4793 return result
4795 if not gateway_obj.enabled or not gateway_obj.reachable:
4796 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway_obj.name}")
4797 return result
4799 # Extract metadata before session closes
4800 gateway_name = gateway_obj.name
4801 gateway_url = gateway_obj.url
4802 gateway_transport = gateway_obj.transport
4803 gateway_auth_type = gateway_obj.auth_type
4804 gateway_auth_value = gateway_obj.auth_value
4805 gateway_oauth_config = gateway_obj.oauth_config
4806 gateway_ca_certificate = gateway_obj.ca_certificate
4807 gateway_auth_query_params = gateway_obj.auth_query_params
4809 # Handle query_param auth - decrypt and apply to URL for refresh
4810 auth_query_params_decrypted: Optional[Dict[str, str]] = None
4811 if gateway_auth_type == "query_param" and gateway_auth_query_params:
4812 auth_query_params_decrypted = {}
4813 for param_key, encrypted_value in gateway_auth_query_params.items():
4814 if encrypted_value: 4814 ↛ 4813line 4814 didn't jump to line 4813 because the condition on line 4814 was always true
4815 try:
4816 decrypted = decode_auth(encrypted_value)
4817 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
4818 except Exception:
4819 logger.debug(f"Failed to decrypt query param '{param_key}' for tool refresh")
4820 if auth_query_params_decrypted: 4820 ↛ 4824line 4820 didn't jump to line 4824 because the condition on line 4820 was always true
4821 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
4823 # Fetch tools/resources/prompts from MCP server (no DB connection held)
4824 try:
4825 _capabilities, tools, resources, prompts = await self._initialize_gateway(
4826 url=gateway_url,
4827 authentication=gateway_auth_value,
4828 transport=gateway_transport,
4829 auth_type=gateway_auth_type,
4830 oauth_config=gateway_oauth_config,
4831 ca_certificate=gateway_ca_certificate.encode() if gateway_ca_certificate else None,
4832 pre_auth_headers=pre_auth_headers,
4833 include_resources=include_resources,
4834 include_prompts=include_prompts,
4835 auth_query_params=auth_query_params_decrypted,
4836 )
4837 except Exception as e:
4838 logger.warning(f"Failed to fetch tools from gateway {gateway_name}: {e}")
4839 result["success"] = False
4840 result["error"] = str(e)
4841 return result
4843 # For authorization_code OAuth gateways, empty responses may indicate incomplete auth flow
4844 # Skip only if it's an auth_code gateway with no data (user may not have completed authorization)
4845 is_auth_code_gateway = gateway_oauth_config and isinstance(gateway_oauth_config, dict) and gateway_oauth_config.get("grant_type") == "authorization_code"
4846 if not tools and not resources and not prompts and is_auth_code_gateway:
4847 logger.debug(f"No tools/resources/prompts returned from auth_code gateway {gateway_name} (user may not have authorized)")
4848 return result
4850 # For non-auth_code gateways, empty responses are legitimate and will clear stale items
4852 # Update database with fresh session
4853 with fresh_db_session() as db:
4854 # Fetch gateway with relationships for update/comparison
4855 gateway = db.execute(
4856 select(DbGateway)
4857 .options(
4858 selectinload(DbGateway.tools),
4859 selectinload(DbGateway.resources),
4860 selectinload(DbGateway.prompts),
4861 )
4862 .where(DbGateway.id == gateway_id)
4863 ).scalar_one_or_none()
4865 if not gateway: 4865 ↛ 4866line 4865 didn't jump to line 4866 because the condition on line 4865 was never true
4866 result["success"] = False
4867 result["error"] = f"Gateway {gateway_id} not found during refresh"
4868 return result
4870 new_tool_names = [tool.name for tool in tools]
4871 new_resource_uris = [resource.uri for resource in resources] if include_resources else None
4872 new_prompt_names = [prompt.name for prompt in prompts] if include_prompts else None
4874 # Track dirty objects before update operations to count per-type updates
4875 pending_tools_before = {obj for obj in db.dirty if isinstance(obj, DbTool)}
4876 pending_resources_before = {obj for obj in db.dirty if isinstance(obj, DbResource)}
4877 pending_prompts_before = {obj for obj in db.dirty if isinstance(obj, DbPrompt)}
4879 # Update/create tools, resources, and prompts
4880 tools_to_add = self._update_or_create_tools(db, tools, gateway, created_via)
4881 resources_to_add = self._update_or_create_resources(db, resources, gateway, created_via) if include_resources else []
4882 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, created_via) if include_prompts else []
4884 # Count per-type updates
4885 result["tools_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbTool)} - pending_tools_before)
4886 result["resources_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbResource)} - pending_resources_before)
4887 result["prompts_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbPrompt)} - pending_prompts_before)
4889 # Only delete MCP-discovered items (not user-created entries)
4890 # Excludes "api", "ui", None (legacy/user-created) to preserve user entries
4891 mcp_created_via_values = {"MCP", "federation", "health_check", "manual_refresh", "oauth", "update"}
4893 # Find and remove stale tools (only MCP-discovered ones)
4894 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names and tool.created_via in mcp_created_via_values]
4895 if stale_tool_ids:
4896 for i in range(0, len(stale_tool_ids), 500):
4897 chunk = stale_tool_ids[i : i + 500]
4898 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
4899 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
4900 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
4901 result["tools_removed"] = len(stale_tool_ids)
4903 # Find and remove stale resources (only MCP-discovered ones, only if resources were fetched)
4904 stale_resource_ids = []
4905 if new_resource_uris is not None: 4905 ↛ 4917line 4905 didn't jump to line 4917 because the condition on line 4905 was always true
4906 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris and resource.created_via in mcp_created_via_values]
4907 if stale_resource_ids: 4907 ↛ 4908line 4907 didn't jump to line 4908 because the condition on line 4907 was never true
4908 for i in range(0, len(stale_resource_ids), 500):
4909 chunk = stale_resource_ids[i : i + 500]
4910 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
4911 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
4912 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
4913 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
4914 result["resources_removed"] = len(stale_resource_ids)
4916 # Find and remove stale prompts (only MCP-discovered ones, only if prompts were fetched)
4917 stale_prompt_ids = []
4918 if new_prompt_names is not None: 4918 ↛ 4929line 4918 didn't jump to line 4929 because the condition on line 4918 was always true
4919 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names and prompt.created_via in mcp_created_via_values]
4920 if stale_prompt_ids: 4920 ↛ 4921line 4920 didn't jump to line 4921 because the condition on line 4920 was never true
4921 for i in range(0, len(stale_prompt_ids), 500):
4922 chunk = stale_prompt_ids[i : i + 500]
4923 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
4924 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
4925 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
4926 result["prompts_removed"] = len(stale_prompt_ids)
4928 # Expire gateway if stale items were deleted
4929 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
4930 db.expire(gateway)
4932 # Add new items in chunks
4933 chunk_size = 50
4934 if tools_to_add:
4935 for i in range(0, len(tools_to_add), chunk_size):
4936 chunk = tools_to_add[i : i + chunk_size]
4937 db.add_all(chunk)
4938 db.flush()
4939 result["tools_added"] = len(tools_to_add)
4941 if resources_to_add:
4942 for i in range(0, len(resources_to_add), chunk_size):
4943 chunk = resources_to_add[i : i + chunk_size]
4944 db.add_all(chunk)
4945 db.flush()
4946 result["resources_added"] = len(resources_to_add)
4948 if prompts_to_add:
4949 for i in range(0, len(prompts_to_add), chunk_size):
4950 chunk = prompts_to_add[i : i + chunk_size]
4951 db.add_all(chunk)
4952 db.flush()
4953 result["prompts_added"] = len(prompts_to_add)
4955 gateway.last_refresh_at = datetime.now(timezone.utc)
4957 total_changes = (
4958 result["tools_added"]
4959 + result["tools_removed"]
4960 + result["tools_updated"]
4961 + result["resources_added"]
4962 + result["resources_removed"]
4963 + result["resources_updated"]
4964 + result["prompts_added"]
4965 + result["prompts_removed"]
4966 + result["prompts_updated"]
4967 )
4969 has_changes = total_changes > 0
4971 if has_changes:
4972 db.commit()
4973 logger.info(
4974 f"Refreshed gateway {gateway_name}: "
4975 f"tools(+{result['tools_added']}/-{result['tools_removed']}/~{result['tools_updated']}), "
4976 f"resources(+{result['resources_added']}/-{result['resources_removed']}/~{result['resources_updated']}), "
4977 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}/~{result['prompts_updated']})"
4978 )
4980 # Invalidate caches per-type based on actual changes
4981 cache = _get_registry_cache()
4982 if result["tools_added"] > 0 or result["tools_removed"] > 0 or result["tools_updated"] > 0: 4982 ↛ 4984line 4982 didn't jump to line 4984 because the condition on line 4982 was always true
4983 await cache.invalidate_tools()
4984 if result["resources_added"] > 0 or result["resources_removed"] > 0 or result["resources_updated"] > 0:
4985 await cache.invalidate_resources()
4986 if result["prompts_added"] > 0 or result["prompts_removed"] > 0 or result["prompts_updated"] > 0:
4987 await cache.invalidate_prompts()
4989 # Invalidate tool lookup cache for this gateway
4990 tool_lookup_cache = _get_tool_lookup_cache()
4991 await tool_lookup_cache.invalidate_gateway(str(gateway_id))
4992 else:
4993 db.commit()
4994 logger.debug(f"No changes detected during refresh of gateway {gateway_name}")
4996 return result
4998 def _get_refresh_lock(self, gateway_id: str) -> asyncio.Lock:
4999 """Get or create a per-gateway refresh lock.
5001 This ensures only one refresh operation can run for a given gateway at a time.
5003 Args:
5004 gateway_id: ID of the gateway to get the lock for
5006 Returns:
5007 asyncio.Lock: The lock for the specified gateway
5009 Examples:
5010 >>> from mcpgateway.services.gateway_service import GatewayService
5011 >>> service = GatewayService()
5012 >>> lock1 = service._get_refresh_lock('gw-123')
5013 >>> lock2 = service._get_refresh_lock('gw-123')
5014 >>> lock1 is lock2
5015 True
5016 >>> lock3 = service._get_refresh_lock('gw-456')
5017 >>> lock1 is lock3
5018 False
5019 """
5020 if gateway_id not in self._refresh_locks:
5021 self._refresh_locks[gateway_id] = asyncio.Lock()
5022 return self._refresh_locks[gateway_id]
5024 async def refresh_gateway_manually(
5025 self,
5026 gateway_id: str,
5027 include_resources: bool = True,
5028 include_prompts: bool = True,
5029 user_email: Optional[str] = None,
5030 request_headers: Optional[Dict[str, str]] = None,
5031 ) -> Dict[str, Any]:
5032 """Manually trigger a refresh of tools/resources/prompts for a gateway.
5034 This method provides a public API for triggering an immediate refresh
5035 of a gateway's tools, resources, and prompts from its MCP server.
5036 It includes concurrency control via per-gateway locking.
5038 Args:
5039 gateway_id: Gateway ID to refresh
5040 include_resources: Whether to include resources in the refresh
5041 include_prompts: Whether to include prompts in the refresh
5042 user_email: Email of the user triggering the refresh
5043 request_headers: Optional request headers for passthrough authentication
5045 Returns:
5046 Dict with counts: {tools_added, tools_updated, tools_removed,
5047 resources_added, resources_updated, resources_removed,
5048 prompts_added, prompts_updated, prompts_removed,
5049 validation_errors, duration_ms, refreshed_at}
5051 Raises:
5052 GatewayNotFoundError: If the gateway does not exist
5053 GatewayError: If another refresh is already in progress for this gateway
5055 Examples:
5056 >>> from mcpgateway.services.gateway_service import GatewayService
5057 >>> from unittest.mock import patch, MagicMock, AsyncMock
5058 >>> import asyncio
5060 >>> # Test method is async
5061 >>> service = GatewayService()
5062 >>> import inspect
5063 >>> inspect.iscoroutinefunction(service.refresh_gateway_manually)
5064 True
5065 """
5066 start_time = time.monotonic()
5068 pre_auth_headers = {}
5070 # Check if gateway exists before acquiring lock
5071 with fresh_db_session() as db:
5072 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
5073 if not gateway:
5074 raise GatewayNotFoundError(f"Gateway with ID '{gateway_id}' not found")
5075 gateway_name = gateway.name
5077 # Get passthrough headers if request headers provided
5078 if request_headers:
5079 pre_auth_headers = get_passthrough_headers(request_headers, {}, db, gateway)
5081 lock = self._get_refresh_lock(gateway_id)
5083 # Check if lock is already held (concurrent refresh in progress)
5084 if lock.locked():
5085 raise GatewayError(f"Refresh already in progress for gateway {gateway_name}")
5087 async with lock:
5088 logger.info(f"Starting manual refresh for gateway {gateway_name} (ID: {gateway_id})")
5090 result = await self._refresh_gateway_tools_resources_prompts(
5091 gateway_id=gateway_id,
5092 _user_email=user_email,
5093 created_via="manual_refresh",
5094 pre_auth_headers=pre_auth_headers,
5095 gateway=gateway,
5096 include_resources=include_resources,
5097 include_prompts=include_prompts,
5098 )
5099 # Note: last_refresh_at is updated inside _refresh_gateway_tools_resources_prompts on success
5101 result["duration_ms"] = (time.monotonic() - start_time) * 1000
5102 result["refreshed_at"] = datetime.now(timezone.utc)
5104 log_level = logging.INFO if result.get("success", True) else logging.WARNING
5105 status_msg = "succeeded" if result.get("success", True) else f"failed: {result.get('error')}"
5107 logger.log(
5108 log_level,
5109 f"Manual refresh for gateway {gateway_id} {status_msg}. Stats: "
5110 f"tools(+{result['tools_added']}/-{result['tools_removed']}), "
5111 f"resources(+{result['resources_added']}/-{result['resources_removed']}), "
5112 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}) "
5113 f"in {result['duration_ms']:.2f}ms",
5114 )
5116 return result
5118 async def _publish_event(self, event: Dict[str, Any]) -> None:
5119 """Publish event to all subscribers.
5121 Args:
5122 event: event dictionary
5124 Examples:
5125 >>> import asyncio
5126 >>> from unittest.mock import AsyncMock
5127 >>> service = GatewayService()
5128 >>> # Mock the underlying event service
5129 >>> service._event_service = AsyncMock()
5130 >>> test_event = {"type": "test", "data": {}}
5131 >>>
5132 >>> asyncio.run(service._publish_event(test_event))
5133 >>>
5134 >>> # Verify the event was passed to the event service
5135 >>> service._event_service.publish_event.assert_awaited_with(test_event)
5136 """
5137 await self._event_service.publish_event(event)
5139 def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> tuple[list[ToolCreate], list[str]]:
5140 """Validate tools individually with richer logging and error aggregation.
5142 Args:
5143 tools: list of tool dicts
5144 context: caller context, e.g. "oauth" to tailor errors/messages
5146 Returns:
5147 tuple[list[ToolCreate], list[str]]: Tuple of (valid tools, validation errors)
5149 Raises:
5150 OAuthToolValidationError: If all tools fail validation in OAuth context
5151 GatewayConnectionError: If all tools fail validation in default context
5152 """
5153 valid_tools: list[ToolCreate] = []
5154 validation_errors: list[str] = []
5156 for i, tool_dict in enumerate(tools):
5157 tool_name = tool_dict.get("name", f"unknown_tool_{i}")
5158 try:
5159 logger.debug(f"Validating tool: {tool_name}")
5160 validated_tool = ToolCreate.model_validate(tool_dict)
5161 valid_tools.append(validated_tool)
5162 logger.debug(f"Tool '{tool_name}' validated successfully")
5163 except ValidationError as e:
5164 error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}"
5165 logger.error(error_msg)
5166 logger.debug(f"Failed tool schema: {tool_dict}")
5167 validation_errors.append(error_msg)
5168 except ValueError as e:
5169 if "JSON structure exceeds maximum depth" in str(e):
5170 error_msg = f"Tool '{tool_name}' schema too deeply nested. " f"Current depth limit: {settings.validation_max_json_depth}"
5171 logger.error(error_msg)
5172 logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable")
5173 else:
5174 error_msg = f"ValueError for tool '{tool_name}': {str(e)}"
5175 logger.error(error_msg)
5176 validation_errors.append(error_msg)
5177 except Exception as e: # pragma: no cover - defensive
5178 error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}"
5179 logger.error(error_msg, exc_info=True)
5180 validation_errors.append(error_msg)
5182 if validation_errors:
5183 logger.warning(f"Tool validation completed with {len(validation_errors)} error(s). " f"Successfully validated {len(valid_tools)} tool(s).")
5184 for err in validation_errors[:3]:
5185 logger.debug(f"Validation error: {err}")
5187 if not valid_tools and validation_errors:
5188 if context == "oauth":
5189 raise OAuthToolValidationError(f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}")
5190 raise GatewayConnectionError(f"Failed to fetch tools: All {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}")
5192 return valid_tools, validation_errors
5194 async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None):
5195 """Connect to an MCP server running with SSE transport, skipping URL validation.
5197 This is used for OAuth-protected servers where we've already validated the token works.
5199 Args:
5200 server_url: The URL of the SSE MCP server to connect to.
5201 authentication: Optional dictionary containing authentication headers.
5203 Returns:
5204 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5205 """
5206 if authentication is None:
5207 authentication = {}
5209 # Skip validation for OAuth servers - we already validated via OAuth flow
5210 # Use async with for both sse_client and ClientSession
5211 try:
5212 async with sse_client(url=server_url, headers=authentication) as streams:
5213 async with ClientSession(*streams) as session:
5214 # Initialize the session
5215 response = await session.initialize()
5216 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5217 logger.debug(f"Server capabilities: {capabilities}")
5219 response = await session.list_tools()
5220 tools = response.tools
5221 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5223 tools, _ = self._validate_tools(tools, context="oauth")
5224 if tools:
5225 logger.info(f"Fetched {len(tools)} tools from gateway")
5226 # Fetch resources if supported
5228 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5229 resources = []
5230 if capabilities.get("resources"): 5230 ↛ 5282line 5230 didn't jump to line 5282 because the condition on line 5230 was always true
5231 try:
5232 response = await session.list_resources()
5233 raw_resources = response.resources
5234 for resource in raw_resources:
5235 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5236 # Convert AnyUrl to string if present
5237 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5238 resource_data["uri"] = str(resource_data["uri"])
5239 # Add default content if not present (will be fetched on demand)
5240 if "content" not in resource_data: 5240 ↛ 5242line 5240 didn't jump to line 5242 because the condition on line 5240 was always true
5241 resource_data["content"] = ""
5242 try:
5243 resources.append(ResourceCreate.model_validate(resource_data))
5244 except Exception:
5245 # If validation fails, create minimal resource
5246 resources.append(
5247 ResourceCreate(
5248 uri=str(resource_data.get("uri", "")),
5249 name=resource_data.get("name", ""),
5250 description=resource_data.get("description"),
5251 mime_type=resource_data.get("mimeType"),
5252 uri_template=resource_data.get("uriTemplate") or None,
5253 content="",
5254 )
5255 )
5256 logger.info(f"Fetched {len(resources)} resources from gateway")
5257 except Exception as e:
5258 logger.warning(f"Failed to fetch resources: {e}")
5260 # resource template URI
5261 try:
5262 response_templates = await session.list_resource_templates()
5263 raw_resources_templates = response_templates.resourceTemplates
5264 resource_templates = []
5265 for resource_template in raw_resources_templates:
5266 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5268 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 5268 ↛ 5272line 5268 didn't jump to line 5272 because the condition on line 5268 was always true
5269 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5270 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5272 if "content" not in resource_template_data: 5272 ↛ 5275line 5272 didn't jump to line 5275 because the condition on line 5272 was always true
5273 resource_template_data["content"] = ""
5275 resources.append(ResourceCreate.model_validate(resource_template_data))
5276 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5277 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway")
5278 except Exception as e:
5279 logger.warning(f"Failed to fetch resource templates: {e}")
5281 # Fetch prompts if supported
5282 prompts = []
5283 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5284 if capabilities.get("prompts"): 5284 ↛ 5308line 5284 didn't jump to line 5308 because the condition on line 5284 was always true
5285 try:
5286 response = await session.list_prompts()
5287 raw_prompts = response.prompts
5288 for prompt in raw_prompts:
5289 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5290 # Add default template if not present
5291 if "template" not in prompt_data: 5291 ↛ 5293line 5291 didn't jump to line 5293 because the condition on line 5291 was always true
5292 prompt_data["template"] = ""
5293 try:
5294 prompts.append(PromptCreate.model_validate(prompt_data))
5295 except Exception:
5296 # If validation fails, create minimal prompt
5297 prompts.append(
5298 PromptCreate(
5299 name=prompt_data.get("name", ""),
5300 description=prompt_data.get("description"),
5301 template=prompt_data.get("template", ""),
5302 )
5303 )
5304 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5305 except Exception as e:
5306 logger.warning(f"Failed to fetch prompts: {e}")
5308 return capabilities, tools, resources, prompts
5309 except Exception as e:
5310 # Note: This function is for OAuth servers only, which don't use query param auth
5311 # Still sanitize in case exception contains URL with static sensitive params
5312 sanitized_url = sanitize_url_for_logging(server_url)
5313 sanitized_error = sanitize_exception_message(str(e))
5314 logger.error(f"SSE connection error details: {type(e).__name__}: {sanitized_error}", exc_info=True)
5315 raise GatewayConnectionError(f"Failed to connect to SSE server at {sanitized_url}: {sanitized_error}")
5317 async def connect_to_sse_server(
5318 self,
5319 server_url: str,
5320 authentication: Optional[Dict[str, str]] = None,
5321 ca_certificate: Optional[bytes] = None,
5322 include_prompts: bool = True,
5323 include_resources: bool = True,
5324 auth_query_params: Optional[Dict[str, str]] = None,
5325 ):
5326 """Connect to an MCP server running with SSE transport.
5328 Args:
5329 server_url: The URL of the SSE MCP server to connect to.
5330 authentication: Optional dictionary containing authentication headers.
5331 ca_certificate: Optional CA certificate for SSL verification.
5332 include_prompts: Whether to fetch prompts from the server.
5333 include_resources: Whether to fetch resources from the server.
5334 auth_query_params: Query param names for URL sanitization in error logs.
5336 Returns:
5337 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5338 """
5339 if authentication is None:
5340 authentication = {}
5342 def get_httpx_client_factory(
5343 headers: dict[str, str] | None = None,
5344 timeout: httpx.Timeout | None = None,
5345 auth: httpx.Auth | None = None,
5346 ) -> httpx.AsyncClient:
5347 """Factory function to create httpx.AsyncClient with optional CA certificate.
5349 Args:
5350 headers: Optional headers for the client
5351 timeout: Optional timeout for the client
5352 auth: Optional auth for the client
5354 Returns:
5355 httpx.AsyncClient: Configured HTTPX async client
5356 """
5357 if ca_certificate:
5358 ctx = self.create_ssl_context(ca_certificate)
5359 else:
5360 ctx = None
5361 return httpx.AsyncClient(
5362 verify=ctx if ctx else get_default_verify(),
5363 follow_redirects=True,
5364 headers=headers,
5365 timeout=timeout if timeout else get_http_timeout(),
5366 auth=auth,
5367 limits=httpx.Limits(
5368 max_connections=settings.httpx_max_connections,
5369 max_keepalive_connections=settings.httpx_max_keepalive_connections,
5370 keepalive_expiry=settings.httpx_keepalive_expiry,
5371 ),
5372 )
5374 # Use async with for both sse_client and ClientSession
5375 async with sse_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as streams:
5376 async with ClientSession(*streams) as session:
5377 # Initialize the session
5378 response = await session.initialize()
5380 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5381 logger.debug(f"Server capabilities: {capabilities}")
5383 response = await session.list_tools()
5384 tools = response.tools
5385 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5387 tools, _ = self._validate_tools(tools)
5388 if tools:
5389 logger.info(f"Fetched {len(tools)} tools from gateway")
5390 # Fetch resources if supported
5391 resources = []
5392 if include_resources: 5392 ↛ 5446line 5392 didn't jump to line 5446 because the condition on line 5392 was always true
5393 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5394 if capabilities.get("resources"):
5395 try:
5396 response = await session.list_resources()
5397 raw_resources = response.resources
5398 for resource in raw_resources:
5399 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5400 # Convert AnyUrl to string if present
5401 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5402 resource_data["uri"] = str(resource_data["uri"])
5403 # Add default content if not present (will be fetched on demand)
5404 if "content" not in resource_data: 5404 ↛ 5406line 5404 didn't jump to line 5406 because the condition on line 5404 was always true
5405 resource_data["content"] = ""
5406 try:
5407 resources.append(ResourceCreate.model_validate(resource_data))
5408 except Exception:
5409 # If validation fails, create minimal resource
5410 resources.append(
5411 ResourceCreate(
5412 uri=str(resource_data.get("uri", "")),
5413 name=resource_data.get("name", ""),
5414 description=resource_data.get("description"),
5415 mime_type=resource_data.get("mimeType"),
5416 uri_template=resource_data.get("uriTemplate") or None,
5417 content="",
5418 )
5419 )
5420 logger.info(f"Fetched {len(resources)} resources from gateway")
5421 except Exception as e:
5422 logger.warning(f"Failed to fetch resources: {e}")
5424 # resource template URI
5425 try:
5426 response_templates = await session.list_resource_templates()
5427 raw_resources_templates = response_templates.resourceTemplates
5428 resource_templates = []
5429 for resource_template in raw_resources_templates:
5430 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5432 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 5432 ↛ 5436line 5432 didn't jump to line 5436 because the condition on line 5432 was always true
5433 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5434 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5436 if "content" not in resource_template_data: 5436 ↛ 5439line 5436 didn't jump to line 5439 because the condition on line 5436 was always true
5437 resource_template_data["content"] = ""
5439 resources.append(ResourceCreate.model_validate(resource_template_data))
5440 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5441 logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway")
5442 except Exception as ei:
5443 logger.warning(f"Failed to fetch resource templates: {ei}")
5445 # Fetch prompts if supported
5446 prompts = []
5447 if include_prompts: 5447 ↛ 5473line 5447 didn't jump to line 5473 because the condition on line 5447 was always true
5448 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5449 if capabilities.get("prompts"):
5450 try:
5451 response = await session.list_prompts()
5452 raw_prompts = response.prompts
5453 for prompt in raw_prompts:
5454 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5455 # Add default template if not present
5456 if "template" not in prompt_data:
5457 prompt_data["template"] = ""
5458 try:
5459 prompts.append(PromptCreate.model_validate(prompt_data))
5460 except Exception:
5461 # If validation fails, create minimal prompt
5462 prompts.append(
5463 PromptCreate(
5464 name=prompt_data.get("name", ""),
5465 description=prompt_data.get("description"),
5466 template=prompt_data.get("template", ""),
5467 )
5468 )
5469 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5470 except Exception as e:
5471 logger.warning(f"Failed to fetch prompts: {e}")
5473 return capabilities, tools, resources, prompts
5474 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params)
5475 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established")
5477 async def connect_to_streamablehttp_server(
5478 self,
5479 server_url: str,
5480 authentication: Optional[Dict[str, str]] = None,
5481 ca_certificate: Optional[bytes] = None,
5482 include_prompts: bool = True,
5483 include_resources: bool = True,
5484 auth_query_params: Optional[Dict[str, str]] = None,
5485 ):
5486 """Connect to an MCP server running with Streamable HTTP transport.
5488 Args:
5489 server_url: The URL of the Streamable HTTP MCP server to connect to.
5490 authentication: Optional dictionary containing authentication headers.
5491 ca_certificate: Optional CA certificate for SSL verification.
5492 include_prompts: Whether to fetch prompts from the server.
5493 include_resources: Whether to fetch resources from the server.
5494 auth_query_params: Query param names for URL sanitization in error logs.
5496 Returns:
5497 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5498 """
5499 if authentication is None:
5500 authentication = {}
5502 # Use authentication directly instead
5503 def get_httpx_client_factory(
5504 headers: dict[str, str] | None = None,
5505 timeout: httpx.Timeout | None = None,
5506 auth: httpx.Auth | None = None,
5507 ) -> httpx.AsyncClient:
5508 """Factory function to create httpx.AsyncClient with optional CA certificate.
5510 Args:
5511 headers: Optional headers for the client
5512 timeout: Optional timeout for the client
5513 auth: Optional auth for the client
5515 Returns:
5516 httpx.AsyncClient: Configured HTTPX async client
5517 """
5518 if ca_certificate: 5518 ↛ 5521line 5518 didn't jump to line 5521 because the condition on line 5518 was always true
5519 ctx = self.create_ssl_context(ca_certificate)
5520 else:
5521 ctx = None
5522 return httpx.AsyncClient(
5523 verify=ctx if ctx else get_default_verify(),
5524 follow_redirects=True,
5525 headers=headers,
5526 timeout=timeout if timeout else get_http_timeout(),
5527 auth=auth,
5528 limits=httpx.Limits(
5529 max_connections=settings.httpx_max_connections,
5530 max_keepalive_connections=settings.httpx_max_keepalive_connections,
5531 keepalive_expiry=settings.httpx_keepalive_expiry,
5532 ),
5533 )
5535 async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id):
5536 async with ClientSession(read_stream, write_stream) as session:
5537 # Initialize the session
5538 response = await session.initialize()
5539 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5540 logger.debug(f"Server capabilities: {capabilities}")
5542 response = await session.list_tools()
5543 tools = response.tools
5544 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5546 tools, _ = self._validate_tools(tools)
5547 for tool in tools:
5548 tool.request_type = "STREAMABLEHTTP"
5549 if tools: 5549 ↛ 5553line 5549 didn't jump to line 5553 because the condition on line 5549 was always true
5550 logger.info(f"Fetched {len(tools)} tools from gateway")
5552 # Fetch resources if supported
5553 resources = []
5554 if include_resources: 5554 ↛ 5608line 5554 didn't jump to line 5608 because the condition on line 5554 was always true
5555 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5556 if capabilities.get("resources"):
5557 try:
5558 response = await session.list_resources()
5559 raw_resources = response.resources
5560 for resource in raw_resources:
5561 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5562 # Convert AnyUrl to string if present
5563 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): 5563 ↛ 5566line 5563 didn't jump to line 5566 because the condition on line 5563 was always true
5564 resource_data["uri"] = str(resource_data["uri"])
5565 # Add default content if not present
5566 if "content" not in resource_data: 5566 ↛ 5568line 5566 didn't jump to line 5568 because the condition on line 5566 was always true
5567 resource_data["content"] = ""
5568 try:
5569 resources.append(ResourceCreate.model_validate(resource_data))
5570 except Exception:
5571 # If validation fails, create minimal resource
5572 resources.append(
5573 ResourceCreate(
5574 uri=str(resource_data.get("uri", "")),
5575 name=resource_data.get("name", ""),
5576 description=resource_data.get("description"),
5577 mime_type=resource_data.get("mimeType"),
5578 uri_template=resource_data.get("uriTemplate") or None,
5579 content="",
5580 )
5581 )
5582 logger.info(f"Fetched {len(resources)} resources from gateway")
5583 except Exception as e:
5584 logger.warning(f"Failed to fetch resources: {e}")
5586 # resource template URI
5587 try:
5588 response_templates = await session.list_resource_templates()
5589 raw_resources_templates = response_templates.resourceTemplates
5590 resource_templates = []
5591 for resource_template in raw_resources_templates:
5592 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5594 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): 5594 ↛ 5598line 5594 didn't jump to line 5598 because the condition on line 5594 was always true
5595 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5596 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5598 if "content" not in resource_template_data: 5598 ↛ 5601line 5598 didn't jump to line 5601 because the condition on line 5598 was always true
5599 resource_template_data["content"] = ""
5601 resources.append(ResourceCreate.model_validate(resource_template_data))
5602 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5603 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway")
5604 except Exception as e:
5605 logger.warning(f"Failed to fetch resource templates: {e}")
5607 # Fetch prompts if supported
5608 prompts = []
5609 if include_prompts: 5609 ↛ 5625line 5609 didn't jump to line 5625 because the condition on line 5609 was always true
5610 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5611 if capabilities.get("prompts"):
5612 try:
5613 response = await session.list_prompts()
5614 raw_prompts = response.prompts
5615 for prompt in raw_prompts:
5616 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5617 # Add default template if not present
5618 if "template" not in prompt_data: 5618 ↛ 5619line 5618 didn't jump to line 5619 because the condition on line 5618 was never true
5619 prompt_data["template"] = ""
5620 prompts.append(PromptCreate.model_validate(prompt_data))
5621 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5622 except Exception as e:
5623 logger.warning(f"Failed to fetch prompts: {e}")
5625 return capabilities, tools, resources, prompts
5626 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params)
5627 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established")
5630# Lazy singleton - created on first access, not at module import time.
5631# This avoids instantiation when only exception classes are imported.
5632_gateway_service_instance = None # pylint: disable=invalid-name
5635def __getattr__(name: str):
5636 """Module-level __getattr__ for lazy singleton creation.
5638 Args:
5639 name: The attribute name being accessed.
5641 Returns:
5642 The gateway_service singleton instance if name is "gateway_service".
5644 Raises:
5645 AttributeError: If the attribute name is not "gateway_service".
5646 """
5647 global _gateway_service_instance # pylint: disable=global-statement
5648 if name == "gateway_service":
5649 if _gateway_service_instance is None:
5650 _gateway_service_instance = GatewayService()
5651 return _gateway_service_instance
5652 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")