Coverage for mcpgateway / services / gateway_service.py: 93%
2306 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2# pylint: disable=import-outside-toplevel,no-name-in-module
3"""Location: ./mcpgateway/services/gateway_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Gateway Service Implementation.
9This module implements gateway federation according to the MCP specification.
10It handles:
11- Gateway discovery and registration
12- Capability aggregation
13- Health monitoring
14- Active/inactive gateway management
16Examples:
17 >>> from mcpgateway.services.gateway_service import GatewayService, GatewayError
18 >>> service = GatewayService()
19 >>> isinstance(service, GatewayService)
20 True
21 >>> hasattr(service, '_active_gateways')
22 True
23 >>> isinstance(service._active_gateways, set)
24 True
26 Test error classes:
27 >>> error = GatewayError("Test error")
28 >>> str(error)
29 'Test error'
30 >>> isinstance(error, Exception)
31 True
33 >>> conflict_error = GatewayNameConflictError("test_gw")
34 >>> "test_gw" in str(conflict_error)
35 True
36 >>> conflict_error.enabled
37 True
38 >>>
39 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
40 >>> import asyncio
41 >>> asyncio.run(service._http_client.aclose())
42"""
44# Standard
45import asyncio
46import binascii
47from datetime import datetime, timezone
48import logging
49import mimetypes
50import os
51import ssl
52import tempfile
53import time
54from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union
55from urllib.parse import urlparse, urlunparse
56import uuid
58# Third-Party
59import anyio
60from filelock import FileLock, Timeout
61import httpx
62from mcp import ClientSession
63from mcp.client.sse import sse_client
64from mcp.client.streamable_http import streamablehttp_client
65from pydantic import ValidationError
66from sqlalchemy import and_, delete, desc, or_, select, update
67from sqlalchemy.exc import IntegrityError
68from sqlalchemy.orm import joinedload, selectinload, Session
70try:
71 # Third-Party - check if redis is available
72 # Third-Party
73 import redis.asyncio as _aioredis # noqa: F401 # pylint: disable=unused-import
75 REDIS_AVAILABLE = True
76 del _aioredis # Only needed for availability check
77except ImportError:
78 REDIS_AVAILABLE = False
79 logging.info("Redis is not utilized in this environment.")
81# First-Party
82from mcpgateway.common.validators import SecurityValidator
83from mcpgateway.config import settings
84from mcpgateway.db import EmailTeam as DbEmailTeam
85from mcpgateway.db import EmailTeamMember as DbEmailTeamMember
86from mcpgateway.db import fresh_db_session
87from mcpgateway.db import Gateway as DbGateway
88from mcpgateway.db import get_for_update
89from mcpgateway.db import Prompt as DbPrompt
90from mcpgateway.db import PromptMetric
91from mcpgateway.db import Resource as DbResource
92from mcpgateway.db import ResourceMetric, ResourceSubscription, server_prompt_association, server_resource_association, server_tool_association, SessionLocal
93from mcpgateway.db import Tool as DbTool
94from mcpgateway.db import ToolMetric
95from mcpgateway.observability import create_span, set_span_attribute, set_span_error
96from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate
98# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks
99from mcpgateway.services.audit_trail_service import get_audit_trail_service
100from mcpgateway.services.base_service import BaseService
101from mcpgateway.services.encryption_service import get_encryption_service, protect_oauth_config_for_storage
102from mcpgateway.services.event_service import EventService
103from mcpgateway.services.http_client_service import get_default_verify, get_http_timeout, get_isolated_http_client
104from mcpgateway.services.logging_service import LoggingService
105from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, register_gateway_capabilities_for_notifications, TransportType
106from mcpgateway.services.oauth_manager import OAuthManager
107from mcpgateway.services.structured_logger import get_structured_logger
108from mcpgateway.services.team_management_service import TeamManagementService
109from mcpgateway.utils.create_slug import slugify
110from mcpgateway.utils.display_name import generate_display_name
111from mcpgateway.utils.pagination import unified_paginate
112from mcpgateway.utils.passthrough_headers import get_passthrough_headers
113from mcpgateway.utils.redis_client import get_redis_client
114from mcpgateway.utils.retry_manager import ResilientHttpClient
115from mcpgateway.utils.services_auth import decode_auth, encode_auth
116from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
117from mcpgateway.utils.ssl_context_cache import get_cached_ssl_context
118from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message, sanitize_url_for_logging
119from mcpgateway.utils.validate_signature import validate_signature
120from mcpgateway.validation.tags import validate_tags_field
123def _resolve_tool_title(tool) -> Optional[str]:
124 """Resolve the display title for a tool per MCP spec precedence.
126 MCP 2025-11-25: "Display name precedence order is: title,
127 annotations.title, then name."
129 1. ``tool.title`` — top-level ``BaseMetadata`` field (canonical).
130 2. ``tool.annotations.title`` — ``ToolAnnotations`` (legacy fallback).
131 3. ``None`` if neither is available (caller may fall back to ``name``).
133 All return paths are guarded with ``isinstance(str)`` so the function
134 never leaks non-string values from mock objects or malformed payloads.
136 Args:
137 tool: An object representing a tool. It may define a top-level
138 ``title`` attribute and/or an ``annotations`` attribute
139 (``ToolAnnotations`` model or ``dict``).
141 Returns:
142 Optional[str]: The resolved title string if found, otherwise None.
144 Examples:
145 >>> class Tool:
146 ... def __init__(self, title=None, annotations=None):
147 ... self.title = title
148 ... self.annotations = annotations
149 ...
150 >>> # 1. top-level title takes precedence
151 >>> tool = Tool(title="Top Level", annotations={"title": "Annotated"})
152 >>> _resolve_tool_title(tool)
153 'Top Level'
155 >>> # 2. Fallback to annotations.title
156 >>> tool = Tool(annotations={"title": "Annotated"})
157 >>> _resolve_tool_title(tool)
158 'Annotated'
160 >>> # 3. No title available
161 >>> tool = Tool()
162 >>> _resolve_tool_title(tool) is None
163 True
165 >>> # 4. annotations is not a dict
166 >>> tool = Tool(title="Top Level", annotations="invalid")
167 >>> _resolve_tool_title(tool)
168 'Top Level'
169 """
170 # MCP spec: "Display name precedence order is: title, annotations.title, then name."
171 title = getattr(tool, "title", None)
172 if isinstance(title, str):
173 return title
174 annotations = getattr(tool, "annotations", None)
175 if annotations is not None:
176 if isinstance(annotations, dict):
177 ann_title = annotations.get("title")
178 else:
179 ann_title = getattr(annotations, "title", None)
180 if isinstance(ann_title, str):
181 return ann_title
182 return None
185# Cache import (lazy to avoid circular dependencies)
186_REGISTRY_CACHE = None
187_TOOL_LOOKUP_CACHE = None
190def _get_registry_cache():
191 """Get registry cache singleton lazily.
193 Returns:
194 RegistryCache instance.
195 """
196 global _REGISTRY_CACHE # pylint: disable=global-statement
197 if _REGISTRY_CACHE is None:
198 # First-Party
199 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
201 _REGISTRY_CACHE = registry_cache
202 return _REGISTRY_CACHE
205def _get_tool_lookup_cache():
206 """Get tool lookup cache singleton lazily.
208 Returns:
209 ToolLookupCache instance.
210 """
211 global _TOOL_LOOKUP_CACHE # pylint: disable=global-statement
212 if _TOOL_LOOKUP_CACHE is None:
213 # First-Party
214 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
216 _TOOL_LOOKUP_CACHE = tool_lookup_cache
217 return _TOOL_LOOKUP_CACHE
220# Initialize logging service first
221logging_service = LoggingService()
222logger = logging_service.get_logger(__name__)
224# Initialize structured logger and audit trail for gateway operations
225structured_logger = get_structured_logger("gateway_service")
226audit_trail = get_audit_trail_service()
229GW_FAILURE_THRESHOLD = settings.unhealthy_threshold
230GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval
233class GatewayError(Exception):
234 """Base class for gateway-related errors.
236 Examples:
237 >>> error = GatewayError("Test error")
238 >>> str(error)
239 'Test error'
240 >>> isinstance(error, Exception)
241 True
242 """
245class GatewayNotFoundError(GatewayError):
246 """Raised when a requested gateway is not found.
248 Examples:
249 >>> error = GatewayNotFoundError("Gateway not found")
250 >>> str(error)
251 'Gateway not found'
252 >>> isinstance(error, GatewayError)
253 True
254 """
257class GatewayNameConflictError(GatewayError):
258 """Raised when a gateway name conflicts with existing (active or inactive) gateway.
260 Args:
261 name: The conflicting gateway name
262 enabled: Whether the existing gateway is enabled
263 gateway_id: ID of the existing gateway if available
264 visibility: The visibility of the gateway ("public" or "team").
266 Examples:
267 >>> error = GatewayNameConflictError("test_gateway")
268 >>> str(error)
269 'Public Gateway already exists with name: test_gateway'
270 >>> error.name
271 'test_gateway'
272 >>> error.enabled
273 True
274 >>> error.gateway_id is None
275 True
277 >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123)
278 >>> str(error_inactive)
279 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)'
280 >>> error_inactive.enabled
281 False
282 >>> error_inactive.gateway_id
283 123
284 """
286 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"):
287 """Initialize the error with gateway information.
289 Args:
290 name: The conflicting gateway name
291 enabled: Whether the existing gateway is enabled
292 gateway_id: ID of the existing gateway if available
293 visibility: The visibility of the gateway ("public" or "team").
294 """
295 self.name = name
296 self.enabled = enabled
297 self.gateway_id = gateway_id
298 if visibility == "team":
299 vis_label = "Team-level"
300 else:
301 vis_label = "Public"
302 message = f"{vis_label} Gateway already exists with name: {name}"
303 if not enabled:
304 message += f" (currently inactive, ID: {gateway_id})"
305 super().__init__(message)
308class GatewayDuplicateConflictError(GatewayError):
309 """Raised when a gateway conflicts with an existing gateway (same URL + credentials).
311 This error is raised when attempting to register a gateway with a URL and
312 authentication credentials that already exist within the same scope:
313 - Public: Global uniqueness required across all public gateways.
314 - Team: Uniqueness required within the same team.
315 - Private: Uniqueness required for the same user, a user cannot have two private gateways with the same URL and credentials.
317 Args:
318 duplicate_gateway: The existing conflicting gateway (DbGateway instance).
320 Examples:
321 >>> # Public gateway conflict with the same URL and basic auth
322 >>> existing_gw = DbGateway(url="https://api.example.com", id="abc-123", enabled=True, visibility="public", team_id=None, name="API Gateway", owner_email="alice@example.com")
323 >>> error = GatewayDuplicateConflictError(
324 ... duplicate_gateway=existing_gw
325 ... )
326 >>> str(error)
327 'The Server already exists in Public scope (Name: API Gateway, Status: active)'
329 >>> # Team gateway conflict with the same URL and OAuth credentials
330 >>> team_gw = DbGateway(url="https://api.example.com", id="def-456", enabled=False, visibility="team", team_id="engineering-team", name="API Gateway", owner_email="bob@example.com")
331 >>> error = GatewayDuplicateConflictError(
332 ... duplicate_gateway=team_gw
333 ... )
334 >>> str(error)
335 'The Server already exists in your Team (Name: API Gateway, Status: inactive). You may want to re-enable the existing gateway instead.'
337 >>> # Private gateway conflict (same user cannot have two gateways with the same URL)
338 >>> private_gw = DbGateway(url="https://api.example.com", id="ghi-789", enabled=True, visibility="private", team_id="none", name="API Gateway", owner_email="charlie@example.com")
339 >>> error = GatewayDuplicateConflictError(
340 ... duplicate_gateway=private_gw
341 ... )
342 >>> str(error)
343 'The Server already exists in "private" scope (Name: API Gateway, Status: active)'
344 """
346 def __init__(
347 self,
348 duplicate_gateway: "DbGateway",
349 ):
350 """Initialize the error with gateway information.
352 Args:
353 duplicate_gateway: The existing conflicting gateway (DbGateway instance)
354 """
355 self.duplicate_gateway = duplicate_gateway
356 self.url = duplicate_gateway.url
357 self.gateway_id = duplicate_gateway.id
358 self.enabled = duplicate_gateway.enabled
359 self.visibility = duplicate_gateway.visibility
360 self.team_id = duplicate_gateway.team_id
361 self.name = duplicate_gateway.name
363 # Build scope description
364 if self.visibility == "public":
365 scope_desc = "Public scope"
366 elif self.visibility == "team" and self.team_id:
367 scope_desc = "your Team"
368 else:
369 scope_desc = f'"{self.visibility}" scope'
371 # Build status description
372 status = "active" if self.enabled else "inactive"
374 # Construct error message
375 message = f"The Server already exists in {scope_desc} " f"(Name: {self.name}, Status: {status})"
377 # Add helpful hint for inactive gateways
378 if not self.enabled:
379 message += ". You may want to re-enable the existing gateway instead."
381 super().__init__(message)
384class GatewayConnectionError(GatewayError):
385 """Raised when gateway connection fails.
387 Examples:
388 >>> error = GatewayConnectionError("Connection failed")
389 >>> str(error)
390 'Connection failed'
391 >>> isinstance(error, GatewayError)
392 True
393 """
396class OAuthToolValidationError(GatewayConnectionError):
397 """Raised when tool validation fails during OAuth-driven fetch."""
400def _validate_gateway_team_assignment(db: Session, user_email: Optional[str], target_team_id: Optional[str]) -> None:
401 """Validate team assignment for gateway updates.
403 Args:
404 db: Database session used for membership checks.
405 user_email: Requesting user email. When omitted, ownership checks are skipped.
406 target_team_id: Team identifier to validate.
408 Raises:
409 ValueError: If team does not exist or caller lacks ownership.
410 """
411 if not target_team_id:
412 raise ValueError("Cannot set visibility to 'team' without a team_id")
414 team = db.query(DbEmailTeam).filter(DbEmailTeam.id == target_team_id).first()
415 if not team:
416 raise ValueError(f"Team {target_team_id} not found")
418 if not user_email:
419 return
421 membership = (
422 db.query(DbEmailTeamMember)
423 .filter(DbEmailTeamMember.team_id == target_team_id, DbEmailTeamMember.user_email == user_email, DbEmailTeamMember.is_active, DbEmailTeamMember.role == "owner")
424 .first()
425 )
426 if not membership:
427 raise ValueError("User membership in team not sufficient for this update.")
430class GatewayService(BaseService): # pylint: disable=too-many-instance-attributes
431 """Service for managing federated gateways.
433 Handles:
434 - Gateway registration and health checks
435 - Capability negotiation
436 - Federation events
437 - Active/inactive status management
438 """
440 _visibility_model_cls = DbGateway
442 def __init__(self) -> None:
443 """Initialize the gateway service.
445 Examples:
446 >>> from mcpgateway.services.gateway_service import GatewayService
447 >>> from mcpgateway.services.event_service import EventService
448 >>> from mcpgateway.utils.retry_manager import ResilientHttpClient
449 >>> from mcpgateway.services.tool_service import ToolService
450 >>> service = GatewayService()
451 >>> isinstance(service._event_service, EventService)
452 True
453 >>> isinstance(service._http_client, ResilientHttpClient)
454 True
455 >>> service._health_check_interval == GW_HEALTH_CHECK_INTERVAL
456 True
457 >>> service._health_check_task is None
458 True
459 >>> isinstance(service._active_gateways, set)
460 True
461 >>> len(service._active_gateways)
462 0
463 >>> service._stream_response is None
464 True
465 >>> isinstance(service._pending_responses, dict)
466 True
467 >>> len(service._pending_responses)
468 0
469 >>> isinstance(service.tool_service, ToolService)
470 True
471 >>> isinstance(service._gateway_failure_counts, dict)
472 True
473 >>> len(service._gateway_failure_counts)
474 0
475 >>> hasattr(service, 'redis_url')
476 True
477 >>>
478 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
479 >>> import asyncio
480 >>> asyncio.run(service._http_client.aclose())
481 """
482 self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify})
483 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL
484 self._health_check_task: Optional[asyncio.Task] = None
485 self._active_gateways: Set[str] = set() # Track active gateway URLs
486 self._stream_response = None
487 self._pending_responses = {}
488 # Hot/cold server classification service (initialized in initialize())
489 self._classification_service: Optional[Any] = None
490 # Prefer using the globally-initialized singletons from the service modules
491 # so events propagate via their initialized EventService/Redis clients.
492 # Import lazily and fall back to creating local instances when the module-level
493 # __getattr__ singletons are not yet available (e.g. circular import during
494 # Gunicorn --preload).
495 # First-Party
496 try:
497 # First-Party
498 from mcpgateway.services.prompt_service import prompt_service
499 except ImportError:
500 # First-Party
501 from mcpgateway.services.prompt_service import PromptService
503 prompt_service = PromptService()
504 try:
505 # First-Party
506 from mcpgateway.services.resource_service import resource_service
507 except ImportError:
508 # First-Party
509 from mcpgateway.services.resource_service import ResourceService
511 resource_service = ResourceService()
512 try:
513 # First-Party
514 from mcpgateway.services.tool_service import tool_service
515 except ImportError:
516 # First-Party
517 from mcpgateway.services.tool_service import ToolService
519 tool_service = ToolService()
521 self.tool_service = tool_service
522 self.prompt_service = prompt_service
523 self.resource_service = resource_service
524 self._gateway_failure_counts: dict[str, int] = {}
525 self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3")))
526 self._event_service = EventService(channel_name="mcpgateway:gateway_events")
528 # Per-gateway refresh locks to prevent concurrent refreshes for the same gateway
529 self._refresh_locks: Dict[str, asyncio.Lock] = {}
531 # For health checks, we determine the leader instance.
532 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None
534 # Initialize optional Redis client holder (set in initialize())
535 self._redis_client: Optional[Any] = None
537 # Leader election settings from config
538 if self.redis_url and REDIS_AVAILABLE:
539 self._instance_id = str(uuid.uuid4()) # Unique ID for this process
540 self._leader_key = settings.redis_leader_key
541 self._leader_ttl = settings.redis_leader_ttl
542 self._leader_heartbeat_interval = settings.redis_leader_heartbeat_interval
543 self._leader_heartbeat_task: Optional[asyncio.Task] = None
544 self._follower_election_task: Optional[asyncio.Task] = None
546 # Log instance mapping for debugging
547 logger.info(f"Instance started: instance_id={self._instance_id}, port={settings.port}, pid={os.getpid()}")
549 # Always initialize file lock as fallback (used if Redis connection fails at runtime)
550 if settings.cache_type != "none":
551 temp_dir = tempfile.gettempdir()
552 user_path = os.path.normpath(settings.filelock_name)
553 if os.path.isabs(user_path):
554 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep)
555 full_path = os.path.join(temp_dir, user_path)
556 self._lock_path = full_path.replace("\\", "/")
557 self._file_lock = FileLock(self._lock_path)
559 @staticmethod
560 def normalize_url(url: str) -> str:
561 """
562 Normalize a URL by ensuring it's properly formatted.
564 Special handling for localhost to prevent duplicates:
565 - Converts 127.0.0.1 to localhost for consistency
566 - Preserves all other domain names as-is for CDN/load balancer support
568 Args:
569 url (str): The URL to normalize.
571 Returns:
572 str: The normalized URL.
574 Examples:
575 >>> GatewayService.normalize_url('http://localhost:8080/path')
576 'http://localhost:8080/path'
577 >>> GatewayService.normalize_url('http://127.0.0.1:8080/path')
578 'http://localhost:8080/path'
579 >>> GatewayService.normalize_url('https://example.com/api')
580 'https://example.com/api'
581 """
582 parsed = urlparse(url)
583 hostname = parsed.hostname
585 # Special case: normalize 127.0.0.1 to localhost to prevent duplicates
586 # but preserve all other domains as-is for CDN/load balancer support
587 if hostname == "127.0.0.1":
588 netloc = "localhost"
589 if parsed.port:
590 netloc += f":{parsed.port}"
591 normalized = parsed._replace(netloc=netloc)
592 return str(urlunparse(normalized))
594 # For all other URLs, preserve the domain name
595 return url
597 @staticmethod
598 async def _encrypt_client_key(client_key: Optional[str]) -> Optional[str]:
599 """Encrypt a client private key for storage.
601 Args:
602 client_key: Plaintext client private key or None.
604 Returns:
605 Encrypted client key or None if input is None/empty.
606 """
607 if not client_key:
608 return None
609 encryption = get_encryption_service(settings.auth_encryption_secret)
610 if encryption.is_encrypted(client_key):
611 return client_key
612 return await encryption.encrypt_secret_async(client_key)
614 def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext:
615 """Create an SSL context with the provided CA certificate.
617 Uses caching to avoid repeated SSL context creation for the same certificate.
619 Args:
620 ca_certificate: CA certificate in PEM format
622 Returns:
623 ssl.SSLContext: Configured SSL context
624 """
625 return get_cached_ssl_context(ca_certificate)
627 async def initialize(self) -> None:
628 """Initialize the service and start health check if this instance is the leader.
630 Raises:
631 ConnectionError: When redis ping fails
632 """
633 logger.info("Initializing gateway service")
635 # Initialize event service with shared Redis client
636 await self._event_service.initialize()
638 # NOTE: We intentionally do NOT create a long-lived DB session here.
639 # Health checks use fresh_db_session() only when DB access is actually needed,
640 # avoiding holding connections during HTTP calls to MCP servers.
642 user_email = settings.platform_admin_email
644 # Get shared Redis client from factory
645 if self.redis_url and REDIS_AVAILABLE:
646 self._redis_client = await get_redis_client()
648 if self._redis_client:
649 # Check if Redis is available (ping already done by factory, but verify)
650 try:
651 await self._redis_client.ping()
652 except Exception as e:
653 raise ConnectionError(f"Redis ping failed: {e}") from e
655 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
656 if is_leader:
657 logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.")
658 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
659 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat())
660 else:
661 # Did not acquire leadership - start follower election loop
662 logger.info("Did not acquire leadership. Starting follower election loop.")
663 self._follower_election_task = asyncio.create_task(self._run_follower_election(user_email))
664 else:
665 # No Redis available - always create the health check task in filelock mode
666 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
668 # Initialize hot/cold classification service (if enabled)
669 if settings.hot_cold_classification_enabled:
670 # First-Party
671 from mcpgateway.services.server_classification_service import ServerClassificationService
673 self._classification_service = ServerClassificationService(redis_client=self._redis_client)
674 await self._classification_service.start()
675 logger.info("Hot/cold classification service initialized")
677 async def shutdown(self) -> None:
678 """Shutdown the service.
680 Examples:
681 >>> service = GatewayService()
682 >>> # Mock internal components
683 >>> from unittest.mock import AsyncMock
684 >>> service._event_service = AsyncMock()
685 >>> service._active_gateways = {'test_gw'}
686 >>> import asyncio
687 >>> asyncio.run(service.shutdown())
688 >>> # Verify event service shutdown was called
689 >>> service._event_service.shutdown.assert_awaited_once()
690 >>> len(service._active_gateways)
691 0
692 """
693 # Cancel follower election FIRST to prevent it from spawning new
694 # health-check / heartbeat tasks while we are tearing down.
695 if getattr(self, "_follower_election_task", None):
696 self._follower_election_task.cancel()
697 try:
698 await self._follower_election_task
699 except asyncio.CancelledError:
700 pass
702 # Now safe to cancel health-check and heartbeat (handles may have been
703 # overwritten by follower election just before cancellation — that is fine,
704 # we always cancel whichever task the attribute currently points to).
705 if self._health_check_task:
706 self._health_check_task.cancel()
707 try:
708 await self._health_check_task
709 except asyncio.CancelledError:
710 pass
712 # Stop classification service
713 if self._classification_service:
714 await self._classification_service.stop()
715 logger.info("Classification service stopped")
717 # Cancel leader heartbeat task if running
718 if getattr(self, "_leader_heartbeat_task", None):
719 self._leader_heartbeat_task.cancel()
720 try:
721 await self._leader_heartbeat_task
722 except asyncio.CancelledError:
723 pass
725 # Release Redis leadership atomically if we hold it
726 if self._redis_client:
727 try:
728 # Lua script for atomic check-and-delete (only delete if we own the key)
729 release_script = """
730 if redis.call("get", KEYS[1]) == ARGV[1] then
731 return redis.call("del", KEYS[1])
732 else
733 return 0
734 end
735 """
736 result = await self._redis_client.eval(release_script, 1, self._leader_key, self._instance_id)
737 if result:
738 logger.info("Released Redis leadership on shutdown")
739 except Exception as e:
740 logger.warning(f"Failed to release Redis leader key on shutdown: {e}")
742 await self._http_client.aclose()
743 await self._event_service.shutdown()
744 self._active_gateways.clear()
745 logger.info("Gateway service shutdown complete")
747 def _check_gateway_uniqueness(
748 self,
749 db: Session,
750 url: str,
751 auth_value: Optional[Dict[str, str]],
752 oauth_config: Optional[Dict[str, Any]],
753 team_id: Optional[str],
754 owner_email: str,
755 visibility: str,
756 gateway_id: Optional[str] = None,
757 ) -> Optional[DbGateway]:
758 """
759 Check if a gateway with the same URL and credentials already exists.
761 Args:
762 db: Database session
763 url: Gateway URL (normalized)
764 auth_value: Decoded auth_value dict (not encrypted)
765 oauth_config: OAuth configuration dict
766 team_id: Team ID for team-scoped gateways
767 owner_email: Email of the gateway owner
768 visibility: Gateway visibility (public/team/private)
769 gateway_id: Optional gateway ID to exclude from check (for updates)
771 Returns:
772 DbGateway if duplicate found, None otherwise
773 """
774 # Build base query based on visibility
775 if visibility == "public":
776 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "public")
777 elif visibility == "team" and team_id:
778 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "team", DbGateway.team_id == team_id)
779 elif visibility == "private":
780 # Check for duplicates within the same user's private gateways
781 query = db.query(DbGateway).filter(DbGateway.url == url, DbGateway.visibility == "private", DbGateway.owner_email == owner_email) # Scoped to same user
782 else:
783 return None
785 # Exclude current gateway if updating
786 if gateway_id:
787 query = query.filter(DbGateway.id != gateway_id)
789 existing_gateways = query.all()
791 # Check each existing gateway
792 for existing in existing_gateways:
793 # Case 1: Both have OAuth config
794 if oauth_config and existing.oauth_config:
795 # Compare OAuth configs (exclude dynamic fields like tokens)
796 existing_oauth = existing.oauth_config or {}
797 new_oauth = oauth_config or {}
799 # Compare key OAuth fields
800 oauth_keys = ["grant_type", "client_id", "authorization_url", "token_url", "scope"]
801 if all(existing_oauth.get(k) == new_oauth.get(k) for k in oauth_keys):
802 return existing # Duplicate OAuth config found
804 # Case 2: Both have auth_value (need to decrypt and compare)
805 elif auth_value and existing.auth_value:
807 try:
808 # Decrypt existing auth_value
809 if isinstance(existing.auth_value, str):
810 existing_decoded = decode_auth(existing.auth_value)
812 elif isinstance(existing.auth_value, dict):
813 existing_decoded = existing.auth_value
815 else:
816 continue
818 # Compare decoded auth values
819 if auth_value == existing_decoded:
820 return existing # Duplicate credentials found
821 except Exception as e:
822 logger.warning(f"Failed to decode auth_value for comparison: {e}")
823 continue
825 # Case 3: Both have no auth (URL only, not allowed)
826 elif not auth_value and not oauth_config and not existing.auth_value and not existing.oauth_config:
827 return existing # Duplicate URL without credentials
829 return None # No duplicate found
831 async def register_gateway(
832 self,
833 db: Session,
834 gateway: GatewayCreate,
835 created_by: Optional[str] = None,
836 created_from_ip: Optional[str] = None,
837 created_via: Optional[str] = None,
838 created_user_agent: Optional[str] = None,
839 team_id: Optional[str] = None,
840 owner_email: Optional[str] = None,
841 visibility: Optional[str] = None,
842 initialize_timeout: Optional[float] = None,
843 ) -> GatewayRead:
844 """Register a new gateway.
846 Args:
847 db: Database session
848 gateway: Gateway creation schema
849 created_by: Username who created this gateway
850 created_from_ip: IP address of creator
851 created_via: Creation method (ui, api, federation)
852 created_user_agent: User agent of creation request
853 team_id (Optional[str]): Team ID to assign the gateway to.
854 owner_email (Optional[str]): Email of the user who owns this gateway.
855 visibility (Optional[str]): Gateway visibility level (private, team, public).
856 initialize_timeout (Optional[float]): Timeout in seconds for gateway initialization.
858 Returns:
859 Created gateway information
861 Raises:
862 GatewayNameConflictError: If gateway name already exists
863 GatewayConnectionError: If there was an error connecting to the gateway
864 ValueError: If required values are missing
865 RuntimeError: If there is an error during processing that is not covered by other exceptions
866 IntegrityError: If there is a database integrity error
867 BaseException: If an unexpected error occurs
869 Examples:
870 >>> from mcpgateway.services.gateway_service import GatewayService
871 >>> from unittest.mock import MagicMock
872 >>> service = GatewayService()
873 >>> db = MagicMock()
874 >>> gateway = MagicMock()
875 >>> db.execute.return_value.scalar_one_or_none.return_value = None
876 >>> db.add = MagicMock()
877 >>> db.commit = MagicMock()
878 >>> db.refresh = MagicMock()
879 >>> service._notify_gateway_added = MagicMock()
880 >>> import asyncio
881 >>> try:
882 ... asyncio.run(service.register_gateway(db, gateway))
883 ... except Exception:
884 ... pass
885 >>>
886 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
887 >>> asyncio.run(service._http_client.aclose())
888 """
889 visibility = "public" if visibility not in ("private", "team", "public") else visibility
890 try:
891 # # Check for name conflicts (both active and inactive)
892 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none()
894 # if existing_gateway:
895 # raise GatewayNameConflictError(
896 # gateway.name,
897 # enabled=existing_gateway.enabled,
898 # gateway_id=existing_gateway.id,
899 # )
900 # Check for existing gateway with the same slug and visibility
901 slug_name = slugify(gateway.name)
902 if visibility.lower() == "public":
903 # Check for existing public gateway with the same slug (row-locked)
904 existing_gateway = get_for_update(
905 db,
906 DbGateway,
907 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "public"),
908 )
909 if existing_gateway:
910 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility)
911 elif visibility.lower() == "team" and team_id:
912 # Check for existing team gateway with the same slug (row-locked)
913 existing_gateway = get_for_update(
914 db,
915 DbGateway,
916 where=and_(DbGateway.slug == slug_name, DbGateway.visibility == "team", DbGateway.team_id == team_id),
917 )
918 if existing_gateway:
919 raise GatewayNameConflictError(existing_gateway.slug, enabled=existing_gateway.enabled, gateway_id=existing_gateway.id, visibility=existing_gateway.visibility)
921 # Normalize the gateway URL
922 normalized_url = self.normalize_url(str(gateway.url))
924 decoded_auth_value = None
925 if gateway.auth_value:
926 if isinstance(gateway.auth_value, str):
927 try:
928 decoded_auth_value = decode_auth(gateway.auth_value)
929 except Exception as e:
930 logger.warning(f"Failed to decode provided auth_value: {e}")
931 decoded_auth_value = None
932 elif isinstance(gateway.auth_value, dict):
933 decoded_auth_value = gateway.auth_value
935 # Check for duplicate gateway
936 if not gateway.one_time_auth:
937 duplicate_gateway = self._check_gateway_uniqueness(
938 db=db, url=normalized_url, auth_value=decoded_auth_value, oauth_config=gateway.oauth_config, team_id=team_id, owner_email=owner_email, visibility=visibility
939 )
941 if duplicate_gateway:
942 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway)
944 # Prevent URL-only gateways (no auth at all)
945 # if not decoded_auth_value and not gateway.oauth_config:
946 # raise ValueError(
947 # f"Gateway with URL '{normalized_url}' must have either auth_value or oauth_config. "
948 # "URL-only gateways are not allowed."
949 # )
951 auth_type = getattr(gateway, "auth_type", None)
952 # Support multiple custom headers
953 auth_value = getattr(gateway, "auth_value", {})
954 authentication_headers: Optional[Dict[str, str]] = None
956 # Handle query_param auth - encrypt and prepare for storage
957 auth_query_params_encrypted: Optional[Dict[str, str]] = None
958 auth_query_params_decrypted: Optional[Dict[str, str]] = None
959 init_url = normalized_url # URL to use for initialization
961 if auth_type == "query_param":
962 # Extract and encrypt query param auth
963 param_key = getattr(gateway, "auth_query_param_key", None)
964 param_value = getattr(gateway, "auth_query_param_value", None)
965 if param_key and param_value:
966 # Get the actual secret value
967 if hasattr(param_value, "get_secret_value"):
968 raw_value = param_value.get_secret_value()
969 else:
970 raw_value = str(param_value)
971 # Encrypt for storage
972 encrypted_value = encode_auth({param_key: raw_value})
973 auth_query_params_encrypted = {param_key: encrypted_value}
974 auth_query_params_decrypted = {param_key: raw_value}
975 # Append query params to URL for initialization
976 init_url = apply_query_param_auth(normalized_url, auth_query_params_decrypted)
977 # Query param auth doesn't use auth_value
978 auth_value = None
979 authentication_headers = None
981 elif hasattr(gateway, "auth_headers") and gateway.auth_headers:
982 # Convert list of {key, value} to dict
983 header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
984 auth_value = header_dict # store plain dict, consistent with update path and DB column type
985 authentication_headers = {str(k): str(v) for k, v in header_dict.items()}
987 elif isinstance(auth_value, str) and auth_value:
988 # Decode persisted auth for initialization
989 decoded = decode_auth(auth_value)
990 authentication_headers = {str(k): str(v) for k, v in decoded.items()}
991 else:
992 authentication_headers = None
994 oauth_config = await protect_oauth_config_for_storage(getattr(gateway, "oauth_config", None))
995 ca_certificate = getattr(gateway, "ca_certificate", None)
996 init_client_cert = getattr(gateway, "client_cert", None)
997 init_client_key = getattr(gateway, "client_key", None)
999 # Check if gateway is in direct_proxy mode
1000 gateway_mode = getattr(gateway, "gateway_mode", "cache")
1002 if gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled:
1003 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.")
1005 if initialize_timeout is not None:
1006 try:
1007 capabilities, tools, resources, prompts = await asyncio.wait_for(
1008 self._initialize_gateway(
1009 init_url, # URL with query params if applicable
1010 authentication_headers,
1011 gateway.transport,
1012 auth_type,
1013 oauth_config,
1014 ca_certificate,
1015 auth_query_params=auth_query_params_decrypted,
1016 client_cert=init_client_cert,
1017 client_key=init_client_key,
1018 ),
1019 timeout=initialize_timeout,
1020 )
1021 except asyncio.TimeoutError as exc:
1022 sanitized = sanitize_url_for_logging(init_url, auth_query_params_decrypted)
1023 raise GatewayConnectionError(f"Gateway initialization timed out after {initialize_timeout}s for {sanitized}") from exc
1024 else:
1025 capabilities, tools, resources, prompts = await self._initialize_gateway(
1026 init_url, # URL with query params if applicable
1027 authentication_headers,
1028 gateway.transport,
1029 auth_type,
1030 oauth_config,
1031 ca_certificate,
1032 auth_query_params=auth_query_params_decrypted,
1033 client_cert=init_client_cert,
1034 client_key=init_client_key,
1035 )
1037 if gateway.one_time_auth:
1038 # For one-time auth, clear auth_type and auth_value after initialization
1039 auth_type = "one_time_auth"
1040 auth_value = None
1041 oauth_config = None
1043 # DbTool.auth_value is Mapped[Optional[str]] (Text), so encode the dict before
1044 # storing it there. DbGateway.auth_value is Mapped[Optional[Dict]] (JSON) and
1045 # receives the plain dict directly (see assignment above).
1046 tool_auth_value = encode_auth(auth_value) if isinstance(auth_value, dict) else auth_value
1048 tools = [
1049 DbTool(
1050 original_name=tool.name,
1051 custom_name=tool.name,
1052 custom_name_slug=slugify(tool.name),
1053 display_name=generate_display_name(tool.name),
1054 title=_resolve_tool_title(tool),
1055 url=normalized_url,
1056 original_description=tool.description,
1057 description=tool.description,
1058 integration_type="MCP", # Gateway-discovered tools are MCP type
1059 request_type=tool.request_type,
1060 headers=tool.headers,
1061 input_schema=tool.input_schema,
1062 output_schema=tool.output_schema,
1063 annotations=tool.annotations,
1064 jsonpath_filter=tool.jsonpath_filter,
1065 auth_type=auth_type,
1066 auth_value=tool_auth_value,
1067 # Federation metadata
1068 created_by=created_by or "system",
1069 created_from_ip=created_from_ip,
1070 created_via="federation", # These are federated tools
1071 created_user_agent=created_user_agent,
1072 federation_source=gateway.name,
1073 version=1,
1074 # Inherit team assignment from gateway
1075 team_id=team_id,
1076 owner_email=owner_email,
1077 visibility=visibility,
1078 )
1079 for tool in tools
1080 ]
1082 # Create resource DB models with upsert logic for ORPHANED resources only
1083 # Query for existing ORPHANED resources (gateway_id IS NULL or points to non-existent gateway)
1084 # with same (team_id, owner_email, uri) to handle resources left behind from incomplete
1085 # gateway deletions (e.g., issue #2341 crash scenarios).
1086 # We only update orphaned resources - resources belonging to active gateways are not touched.
1087 resource_uris = [r.uri for r in resources]
1088 effective_owner = owner_email or created_by
1090 # Build lookup map: (team_id, owner_email, uri) -> orphaned DbResource
1091 # We query all resources matching our URIs, then filter to orphaned ones in Python
1092 # to handle per-resource team/owner overrides correctly
1093 orphaned_resources_map: Dict[tuple, DbResource] = {}
1094 if resource_uris:
1095 try:
1096 # Get valid gateway IDs to identify orphaned resources
1097 valid_gateway_ids = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all())
1098 candidate_resources = db.execute(select(DbResource).where(DbResource.uri.in_(resource_uris))).scalars().all()
1099 for res in candidate_resources:
1100 # Only consider orphaned resources (no gateway or gateway doesn't exist)
1101 is_orphaned = res.gateway_id is None or res.gateway_id not in valid_gateway_ids
1102 if is_orphaned:
1103 key = (res.team_id, res.owner_email, res.uri)
1104 orphaned_resources_map[key] = res
1105 if orphaned_resources_map:
1106 logger.info(f"Found {len(orphaned_resources_map)} orphaned resources to reassign for gateway {SecurityValidator.sanitize_log_message(gateway.name)}")
1107 except Exception as e:
1108 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new resources
1109 # This is conservative - we won't accidentally reassign resources from active gateways
1110 logger.debug(f"Orphan resource detection skipped: {e}")
1112 db_resources = []
1113 for r in resources:
1114 mime_type = mimetypes.guess_type(r.uri)[0] or ("text/plain" if isinstance(r.content, str) else "application/octet-stream")
1115 r_team_id = getattr(r, "team_id", None) or team_id
1116 r_owner_email = getattr(r, "owner_email", None) or effective_owner
1117 r_visibility = getattr(r, "visibility", None) or visibility
1119 # Check if there's an orphaned resource with matching unique key
1120 lookup_key = (r_team_id, r_owner_email, r.uri)
1121 if lookup_key in orphaned_resources_map:
1122 # Update orphaned resource - reassign to new gateway
1123 existing = orphaned_resources_map[lookup_key]
1124 existing.name = r.name
1125 existing.description = r.description
1126 existing.mime_type = mime_type
1127 existing.uri_template = r.uri_template or None
1128 existing.text_content = r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None
1129 existing.binary_content = (
1130 r.content.encode() if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else r.content if isinstance(r.content, bytes) else None
1131 )
1132 existing.size = len(r.content) if r.content else 0
1133 existing.title = getattr(r, "title", None)
1134 existing.tags = getattr(r, "tags", []) or []
1135 existing.federation_source = gateway.name
1136 existing.modified_by = created_by
1137 existing.modified_from_ip = created_from_ip
1138 existing.modified_via = "federation"
1139 existing.modified_user_agent = created_user_agent
1140 existing.updated_at = datetime.now(timezone.utc)
1141 existing.visibility = r_visibility
1142 # Note: gateway_id will be set when gateway is created (relationship)
1143 db_resources.append(existing)
1144 else:
1145 # Create new resource
1146 db_resources.append(
1147 DbResource(
1148 uri=r.uri,
1149 name=r.name,
1150 title=getattr(r, "title", None),
1151 description=r.description,
1152 mime_type=mime_type,
1153 uri_template=r.uri_template or None,
1154 text_content=r.content if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str) else None,
1155 binary_content=(
1156 r.content.encode()
1157 if (mime_type.startswith("text/") or isinstance(r.content, str)) and isinstance(r.content, str)
1158 else r.content if isinstance(r.content, bytes) else None
1159 ),
1160 size=len(r.content) if r.content else 0,
1161 tags=getattr(r, "tags", []) or [],
1162 created_by=created_by or "system",
1163 created_from_ip=created_from_ip,
1164 created_via="federation",
1165 created_user_agent=created_user_agent,
1166 import_batch_id=None,
1167 federation_source=gateway.name,
1168 version=1,
1169 team_id=r_team_id,
1170 owner_email=r_owner_email,
1171 visibility=r_visibility,
1172 )
1173 )
1175 # Create prompt DB models with upsert logic for ORPHANED prompts only
1176 # Query for existing ORPHANED prompts (gateway_id IS NULL or points to non-existent gateway)
1177 # with same (team_id, owner_email, name) to handle prompts left behind from incomplete
1178 # gateway deletions. We only update orphaned prompts - prompts belonging to active gateways are not touched.
1179 prompt_names = [p.name for p in prompts]
1181 # Build lookup map: (team_id, owner_email, name) -> orphaned DbPrompt
1182 orphaned_prompts_map: Dict[tuple, DbPrompt] = {}
1183 if prompt_names:
1184 try:
1185 # Get valid gateway IDs to identify orphaned prompts
1186 valid_gateway_ids_for_prompts = set(gw_id for (gw_id,) in db.execute(select(DbGateway.id)).all())
1187 candidate_prompts = db.execute(select(DbPrompt).where(DbPrompt.name.in_(prompt_names))).scalars().all()
1188 for pmt in candidate_prompts:
1189 # Only consider orphaned prompts (no gateway or gateway doesn't exist)
1190 is_orphaned = pmt.gateway_id is None or pmt.gateway_id not in valid_gateway_ids_for_prompts
1191 if is_orphaned:
1192 key = (pmt.team_id, pmt.owner_email, pmt.name)
1193 orphaned_prompts_map[key] = pmt
1194 if orphaned_prompts_map:
1195 logger.info(f"Found {len(orphaned_prompts_map)} orphaned prompts to reassign for gateway {SecurityValidator.sanitize_log_message(gateway.name)}")
1196 except Exception as e:
1197 # If orphan detection fails (e.g., in mocked tests), skip upsert and create new prompts
1198 logger.debug(f"Orphan prompt detection skipped: {e}")
1200 db_prompts = []
1201 for prompt in prompts:
1202 # Prompts inherit team/owner from gateway (no per-prompt overrides)
1203 p_team_id = team_id
1204 p_owner_email = owner_email or effective_owner
1206 # Check if there's an orphaned prompt with matching unique key
1207 lookup_key = (p_team_id, p_owner_email, prompt.name)
1208 if lookup_key in orphaned_prompts_map:
1209 # Update orphaned prompt - reassign to new gateway
1210 existing = orphaned_prompts_map[lookup_key]
1211 existing.original_name = prompt.name
1212 existing.custom_name = prompt.name
1213 existing.display_name = prompt.name
1214 existing.title = getattr(prompt, "title", None)
1215 existing.description = prompt.description
1216 existing.template = prompt.template if hasattr(prompt, "template") else ""
1217 existing.argument_schema = self._build_prompt_argument_schema(prompt)
1218 existing.federation_source = gateway.name
1219 existing.modified_by = created_by
1220 existing.modified_from_ip = created_from_ip
1221 existing.modified_via = "federation"
1222 existing.modified_user_agent = created_user_agent
1223 existing.updated_at = datetime.now(timezone.utc)
1224 existing.visibility = visibility
1225 # Note: gateway_id will be set when gateway is created (relationship)
1226 db_prompts.append(existing)
1227 else:
1228 # Create new prompt
1229 db_prompts.append(
1230 DbPrompt(
1231 name=prompt.name,
1232 original_name=prompt.name,
1233 custom_name=prompt.name,
1234 display_name=prompt.name,
1235 title=getattr(prompt, "title", None),
1236 description=prompt.description,
1237 template=prompt.template if hasattr(prompt, "template") else "",
1238 argument_schema=self._build_prompt_argument_schema(prompt),
1239 # Federation metadata
1240 created_by=created_by or "system",
1241 created_from_ip=created_from_ip,
1242 created_via="federation", # These are federated prompts
1243 created_user_agent=created_user_agent,
1244 federation_source=gateway.name,
1245 version=1,
1246 # Inherit team assignment from gateway
1247 team_id=team_id,
1248 owner_email=owner_email,
1249 visibility=visibility,
1250 )
1251 )
1253 # Create DB model
1254 db_gateway = DbGateway(
1255 name=gateway.name,
1256 slug=slug_name,
1257 url=normalized_url,
1258 description=gateway.description,
1259 tags=gateway.tags or [],
1260 transport=gateway.transport,
1261 capabilities=capabilities,
1262 last_seen=datetime.now(timezone.utc),
1263 auth_type=auth_type,
1264 auth_value=auth_value,
1265 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth
1266 oauth_config=oauth_config,
1267 passthrough_headers=gateway.passthrough_headers,
1268 tools=tools,
1269 resources=db_resources,
1270 prompts=db_prompts,
1271 # Gateway metadata
1272 created_by=created_by,
1273 created_from_ip=created_from_ip,
1274 created_via=created_via or "api",
1275 created_user_agent=created_user_agent,
1276 version=1,
1277 # Team scoping fields
1278 team_id=team_id,
1279 owner_email=owner_email,
1280 visibility=visibility,
1281 ca_certificate=gateway.ca_certificate,
1282 ca_certificate_sig=gateway.ca_certificate_sig,
1283 signing_algorithm=gateway.signing_algorithm,
1284 # mTLS client certificate/key
1285 client_cert=getattr(gateway, "client_cert", None),
1286 client_key=await self._encrypt_client_key(getattr(gateway, "client_key", None)),
1287 # Gateway mode configuration
1288 gateway_mode=gateway_mode,
1289 )
1291 # Add to DB and commit immediately so tools/resources/prompts are visible
1292 # to other workers before the HTTP response reaches the client.
1293 # Without this, clients issuing follow-up requests (e.g., manual refresh)
1294 # can hit a different worker that hasn't seen the uncommitted data yet.
1295 db.add(db_gateway)
1296 db.commit()
1297 db.refresh(db_gateway)
1299 # Update tracking
1300 self._active_gateways.add(db_gateway.url)
1302 # Notify subscribers
1303 await self._notify_gateway_added(db_gateway)
1305 # Invalidate caches so other workers see the new gateway and its tools/resources/prompts
1306 cache = _get_registry_cache()
1307 await cache.invalidate_gateways()
1308 await cache.invalidate_tools()
1309 await cache.invalidate_resources()
1310 await cache.invalidate_prompts()
1311 tool_lookup_cache = _get_tool_lookup_cache()
1312 await tool_lookup_cache.invalidate_gateway(str(db_gateway.id))
1313 # First-Party
1314 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1316 await admin_stats_cache.invalidate_tags()
1318 # Invalidate loopback passthrough cache when a new gateway has passthrough headers (#3640)
1319 if gateway.passthrough_headers:
1320 # First-Party
1321 from mcpgateway.utils.passthrough_headers import invalidate_passthrough_header_caches # pylint: disable=import-outside-toplevel
1323 invalidate_passthrough_header_caches()
1325 logger.info(f"Registered gateway: {SecurityValidator.sanitize_log_message(gateway.name)}")
1327 # Structured logging: Audit trail for gateway creation
1328 audit_trail.log_action(
1329 user_id=created_by or "system",
1330 action="create_gateway",
1331 resource_type="gateway",
1332 resource_id=str(db_gateway.id),
1333 resource_name=db_gateway.name,
1334 user_email=owner_email,
1335 team_id=team_id,
1336 client_ip=created_from_ip,
1337 user_agent=created_user_agent,
1338 new_values={
1339 "name": db_gateway.name,
1340 "url": db_gateway.url,
1341 "visibility": visibility,
1342 "transport": db_gateway.transport,
1343 "tools_count": len(tools),
1344 "resources_count": len(db_resources),
1345 "prompts_count": len(db_prompts),
1346 },
1347 context={
1348 "created_via": created_via,
1349 },
1350 db=db,
1351 )
1353 # Structured logging: Log successful gateway creation
1354 structured_logger.log(
1355 level="INFO",
1356 message="Gateway created successfully",
1357 event_type="gateway_created",
1358 component="gateway_service",
1359 user_id=created_by,
1360 user_email=owner_email,
1361 team_id=team_id,
1362 resource_type="gateway",
1363 resource_id=str(db_gateway.id),
1364 custom_fields={
1365 "gateway_name": db_gateway.name,
1366 "gateway_url": normalized_url,
1367 "visibility": visibility,
1368 "transport": db_gateway.transport,
1369 },
1370 )
1372 return self.convert_gateway_to_read(db_gateway)
1373 except* GatewayConnectionError as ge: # pragma: no mutate
1374 if TYPE_CHECKING:
1375 ge: ExceptionGroup[GatewayConnectionError]
1376 logger.error(f"GatewayConnectionError in group: {ge.exceptions}")
1377 db.rollback()
1379 structured_logger.log(
1380 level="ERROR",
1381 message="Gateway creation failed due to connection error",
1382 event_type="gateway_creation_failed",
1383 component="gateway_service",
1384 user_id=created_by,
1385 user_email=owner_email,
1386 error=ge.exceptions[0],
1387 custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)},
1388 )
1389 raise ge.exceptions[0]
1390 except* GatewayNameConflictError as gnce: # pragma: no mutate
1391 if TYPE_CHECKING:
1392 gnce: ExceptionGroup[GatewayNameConflictError]
1393 logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}")
1394 db.rollback()
1396 structured_logger.log(
1397 level="WARNING",
1398 message="Gateway creation failed due to name conflict",
1399 event_type="gateway_name_conflict",
1400 component="gateway_service",
1401 user_id=created_by,
1402 user_email=owner_email,
1403 custom_fields={"gateway_name": gateway.name, "visibility": visibility},
1404 )
1405 raise gnce.exceptions[0]
1406 except* GatewayDuplicateConflictError as guce: # pragma: no mutate
1407 if TYPE_CHECKING:
1408 guce: ExceptionGroup[GatewayDuplicateConflictError]
1409 logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}")
1410 db.rollback()
1412 structured_logger.log(
1413 level="WARNING",
1414 message="Gateway creation failed due to duplicate",
1415 event_type="gateway_duplicate_conflict",
1416 component="gateway_service",
1417 user_id=created_by,
1418 user_email=owner_email,
1419 custom_fields={"gateway_name": gateway.name},
1420 )
1421 raise guce.exceptions[0]
1422 except* ValueError as ve: # pragma: no mutate
1423 if TYPE_CHECKING:
1424 ve: ExceptionGroup[ValueError]
1425 logger.error(f"ValueErrors in group: {ve.exceptions}")
1426 db.rollback()
1428 structured_logger.log(
1429 level="ERROR",
1430 message="Gateway creation failed due to validation error",
1431 event_type="gateway_creation_failed",
1432 component="gateway_service",
1433 user_id=created_by,
1434 user_email=owner_email,
1435 error=ve.exceptions[0],
1436 custom_fields={"gateway_name": gateway.name},
1437 )
1438 raise ve.exceptions[0]
1439 except* RuntimeError as re: # pragma: no mutate
1440 if TYPE_CHECKING:
1441 re: ExceptionGroup[RuntimeError]
1442 logger.error(f"RuntimeErrors in group: {re.exceptions}")
1443 db.rollback()
1445 structured_logger.log(
1446 level="ERROR",
1447 message="Gateway creation failed due to runtime error",
1448 event_type="gateway_creation_failed",
1449 component="gateway_service",
1450 user_id=created_by,
1451 user_email=owner_email,
1452 error=re.exceptions[0],
1453 custom_fields={"gateway_name": gateway.name},
1454 )
1455 raise re.exceptions[0]
1456 except* IntegrityError as ie: # pragma: no mutate
1457 if TYPE_CHECKING:
1458 ie: ExceptionGroup[IntegrityError]
1459 logger.error(f"IntegrityErrors in group: {ie.exceptions}")
1460 db.rollback()
1462 structured_logger.log(
1463 level="ERROR",
1464 message="Gateway creation failed due to database integrity error",
1465 event_type="gateway_creation_failed",
1466 component="gateway_service",
1467 user_id=created_by,
1468 user_email=owner_email,
1469 error=ie.exceptions[0],
1470 custom_fields={"gateway_name": gateway.name},
1471 )
1472 raise ie.exceptions[0]
1473 except* BaseException as other: # catches every other sub-exception # pragma: no mutate
1474 if TYPE_CHECKING:
1475 other: ExceptionGroup[Exception]
1476 logger.error(f"Other grouped errors: {other.exceptions}")
1477 db.rollback()
1478 raise other.exceptions[0]
1480 async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_email: str) -> Dict[str, Any]:
1481 """Fetch tools from MCP server after OAuth completion for Authorization Code flow.
1483 Args:
1484 db: Database session
1485 gateway_id: ID of the gateway to fetch tools for
1486 app_user_email: ContextForge user email for token retrieval
1488 Returns:
1489 Dict containing capabilities, tools, resources, and prompts
1491 Raises:
1492 GatewayConnectionError: If connection or OAuth fails
1493 """
1494 try:
1495 # Get the gateway with eager loading for sync operations to avoid N+1 queries
1496 gateway = db.execute(
1497 select(DbGateway)
1498 .options(
1499 selectinload(DbGateway.tools),
1500 selectinload(DbGateway.resources),
1501 selectinload(DbGateway.prompts),
1502 joinedload(DbGateway.email_team),
1503 )
1504 .where(DbGateway.id == gateway_id)
1505 ).scalar_one_or_none()
1507 if not gateway:
1508 raise ValueError(f"Gateway {gateway_id} not found")
1510 if not gateway.oauth_config:
1511 raise ValueError(f"Gateway {gateway_id} has no OAuth configuration")
1513 grant_type = gateway.oauth_config.get("grant_type")
1514 if grant_type != "authorization_code":
1515 raise ValueError(f"Gateway {gateway_id} is not using Authorization Code flow")
1517 # Get OAuth tokens for this gateway
1518 # First-Party
1519 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
1521 token_storage = TokenStorageService(db)
1523 # Get user-specific OAuth token
1524 if not app_user_email:
1525 raise GatewayConnectionError(f"User authentication required for OAuth gateway {gateway.name}")
1527 access_token = await token_storage.get_user_token(gateway.id, app_user_email)
1529 if not access_token:
1530 raise GatewayConnectionError(
1531 f"No OAuth tokens found for user {app_user_email} on gateway {gateway.name}. Please complete the OAuth authorization flow first at /oauth/authorize/{gateway.id}"
1532 )
1534 # Debug: Check if token was decrypted
1535 if access_token.startswith("Z0FBQUFBQm"): # Encrypted tokens start with this
1536 logger.error("OAuth token decryption may have failed before gateway initialization")
1537 else:
1538 logger.info("Using decrypted OAuth token for gateway %s", gateway.name)
1540 # Now connect to MCP server with the access token
1541 authentication = {"Authorization": f"Bearer {access_token}"}
1543 # Use the existing connection logic
1544 # Note: For OAuth servers, skip validation since we already validated via OAuth flow
1545 if gateway.transport.upper() == "SSE":
1546 capabilities, tools, resources, prompts = await self._connect_to_sse_server_without_validation(gateway.url, authentication)
1547 elif gateway.transport.upper() == "STREAMABLEHTTP":
1548 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(gateway.url, authentication)
1549 else:
1550 raise ValueError(f"Unsupported transport type: {gateway.transport}")
1552 # Handle tools, resources, and prompts using helper methods
1553 tools_to_add = self._update_or_create_tools(db, tools, gateway, "oauth")
1554 resources_to_add = self._update_or_create_resources(db, resources, gateway, "oauth")
1555 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "oauth")
1557 # Clean up items that are no longer available from the gateway
1558 new_tool_names = [tool.name for tool in tools]
1559 new_resource_uris = [resource.uri for resource in resources]
1560 new_prompt_names = [prompt.name for prompt in prompts]
1562 # Count items before cleanup for logging
1564 # Bulk delete tools that are no longer available from the gateway
1565 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
1566 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
1567 if stale_tool_ids:
1568 # Delete child records first to avoid FK constraint violations
1569 for i in range(0, len(stale_tool_ids), 500):
1570 chunk = stale_tool_ids[i : i + 500]
1571 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
1572 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
1573 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
1575 # Bulk delete resources that are no longer available from the gateway
1576 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
1577 if stale_resource_ids:
1578 # Delete child records first to avoid FK constraint violations
1579 for i in range(0, len(stale_resource_ids), 500):
1580 chunk = stale_resource_ids[i : i + 500]
1581 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
1582 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
1583 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
1584 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
1586 # Bulk delete prompts that are no longer available from the gateway
1587 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
1588 if stale_prompt_ids:
1589 # Delete child records first to avoid FK constraint violations
1590 for i in range(0, len(stale_prompt_ids), 500):
1591 chunk = stale_prompt_ids[i : i + 500]
1592 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
1593 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
1594 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
1596 # Expire gateway to clear cached relationships after bulk deletes
1597 # This prevents SQLAlchemy from trying to re-delete already-deleted items
1598 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
1599 db.expire(gateway)
1601 # Update gateway relationships to reflect deletions
1602 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names]
1603 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris]
1604 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names]
1606 # Log cleanup results
1607 tools_removed = len(stale_tool_ids)
1608 resources_removed = len(stale_resource_ids)
1609 prompts_removed = len(stale_prompt_ids)
1611 if tools_removed > 0:
1612 logger.info(f"Removed {tools_removed} tools no longer available from gateway")
1613 if resources_removed > 0:
1614 logger.info(f"Removed {resources_removed} resources no longer available from gateway")
1615 if prompts_removed > 0:
1616 logger.info(f"Removed {prompts_removed} prompts no longer available from gateway")
1618 # Update gateway capabilities and last_seen
1619 gateway.capabilities = capabilities
1620 gateway.last_seen = datetime.now(timezone.utc)
1622 # Register capabilities for notification-driven actions
1623 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
1625 # Add new items to DB in chunks to prevent lock escalation
1626 items_added = 0
1627 chunk_size = 50
1629 if tools_to_add:
1630 for i in range(0, len(tools_to_add), chunk_size):
1631 chunk = tools_to_add[i : i + chunk_size]
1632 db.add_all(chunk)
1633 db.flush() # Flush each chunk to avoid excessive memory usage
1634 items_added += len(tools_to_add)
1635 logger.info(f"Added {len(tools_to_add)} new tools to database")
1637 if resources_to_add:
1638 for i in range(0, len(resources_to_add), chunk_size):
1639 chunk = resources_to_add[i : i + chunk_size]
1640 db.add_all(chunk)
1641 db.flush()
1642 items_added += len(resources_to_add)
1643 logger.info(f"Added {len(resources_to_add)} new resources to database")
1645 if prompts_to_add:
1646 for i in range(0, len(prompts_to_add), chunk_size):
1647 chunk = prompts_to_add[i : i + chunk_size]
1648 db.add_all(chunk)
1649 db.flush()
1650 items_added += len(prompts_to_add)
1651 logger.info(f"Added {len(prompts_to_add)} new prompts to database")
1653 if items_added > 0:
1654 db.commit()
1655 logger.info(f"Total {items_added} new items added to database")
1656 else:
1657 logger.info("No new items to add to database")
1658 # Still commit to save any updates to existing items
1659 db.commit()
1661 cache = _get_registry_cache()
1662 await cache.invalidate_tools()
1663 await cache.invalidate_resources()
1664 await cache.invalidate_prompts()
1665 tool_lookup_cache = _get_tool_lookup_cache()
1666 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
1667 # Also invalidate tags cache since tool/resource tags may have changed
1668 # First-Party
1669 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1671 await admin_stats_cache.invalidate_tags()
1673 return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts}
1675 except GatewayConnectionError as gce:
1676 db.rollback()
1677 # Surface validation or depth-related failures directly to the user
1678 logger.error(f"GatewayConnectionError during OAuth fetch for {SecurityValidator.sanitize_log_message(gateway_id)}: {gce}")
1679 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}")
1680 except Exception as e:
1681 db.rollback()
1682 logger.error(f"Failed to fetch tools after OAuth for gateway {SecurityValidator.sanitize_log_message(gateway_id)}: {e}")
1683 raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}")
1685 async def list_gateways(
1686 self,
1687 db: Session,
1688 include_inactive: bool = False,
1689 tags: Optional[List[str]] = None,
1690 cursor: Optional[str] = None,
1691 limit: Optional[int] = None,
1692 page: Optional[int] = None,
1693 per_page: Optional[int] = None,
1694 user_email: Optional[str] = None,
1695 team_id: Optional[str] = None,
1696 visibility: Optional[str] = None,
1697 token_teams: Optional[List[str]] = None,
1698 ) -> Union[tuple[List[GatewayRead], Optional[str]], Dict[str, Any]]:
1699 """List all registered gateways with cursor pagination and optional team filtering.
1701 Args:
1702 db: Database session
1703 include_inactive: Whether to include inactive gateways
1704 tags (Optional[List[str]]): Filter resources by tags. If provided, only resources with at least one matching tag will be returned.
1705 cursor: Cursor for pagination (encoded last created_at and id).
1706 limit: Maximum number of gateways to return. None for default, 0 for unlimited.
1707 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1708 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1709 user_email: Email of user for team-based access control. None for no access control.
1710 team_id: Optional team ID to filter by specific team (requires user_email).
1711 visibility: Optional visibility filter (private, team, public) (requires user_email).
1712 token_teams: Optional list of team IDs from the token (None=unrestricted, []=public-only).
1714 Returns:
1715 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}}
1716 If cursor is provided or neither: tuple of (list of GatewayRead objects, next_cursor).
1718 Examples:
1719 >>> from mcpgateway.services.gateway_service import GatewayService
1720 >>> from unittest.mock import MagicMock, AsyncMock, patch
1721 >>> from mcpgateway.schemas import GatewayRead
1722 >>> import asyncio
1723 >>> service = GatewayService()
1724 >>> db = MagicMock()
1725 >>> gateway_obj = MagicMock()
1726 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj]
1727 >>> gateway_read_obj = MagicMock(spec=GatewayRead)
1728 >>> service.convert_gateway_to_read = MagicMock(return_value=gateway_read_obj)
1729 >>> # Mock the cache to bypass caching logic
1730 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory:
1731 ... mock_cache = MagicMock()
1732 ... mock_cache.get = AsyncMock(return_value=None)
1733 ... mock_cache.set = AsyncMock(return_value=None)
1734 ... mock_cache.hash_filters = MagicMock(return_value="hash")
1735 ... mock_cache_factory.return_value = mock_cache
1736 ... gateways, cursor = asyncio.run(service.list_gateways(db))
1737 ... gateways == [gateway_read_obj] and cursor is None
1738 True
1740 >>> # Test empty result
1741 >>> db.execute.return_value.scalars.return_value.all.return_value = []
1742 >>> with patch('mcpgateway.services.gateway_service._get_registry_cache') as mock_cache_factory:
1743 ... mock_cache = MagicMock()
1744 ... mock_cache.get = AsyncMock(return_value=None)
1745 ... mock_cache.set = AsyncMock(return_value=None)
1746 ... mock_cache.hash_filters = MagicMock(return_value="hash")
1747 ... mock_cache_factory.return_value = mock_cache
1748 ... empty_result, cursor = asyncio.run(service.list_gateways(db))
1749 ... empty_result == [] and cursor is None
1750 True
1751 >>>
1752 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
1753 >>> asyncio.run(service._http_client.aclose())
1754 """
1755 # Check cache for first page only - only for public-only queries (no user/team filtering)
1756 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped
1757 cache = _get_registry_cache()
1758 is_public_only = token_teams is not None and len(token_teams) == 0
1759 use_cache = cursor is None and user_email is None and page is None and is_public_only
1760 if use_cache:
1761 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None, visibility=visibility)
1762 cached = await cache.get("gateways", filters_hash)
1763 if cached is not None:
1764 # Reconstruct GatewayRead objects from cached dicts
1765 # SECURITY: Always apply .masked() to ensure stale cache entries don't leak credentials
1766 cached_gateways = [GatewayRead.model_validate(g).masked() for g in cached["gateways"]]
1767 return (cached_gateways, cached.get("next_cursor"))
1769 # Build base query with ordering
1770 query = select(DbGateway).options(joinedload(DbGateway.email_team)).order_by(desc(DbGateway.created_at), desc(DbGateway.id))
1772 # Apply active/inactive filter
1773 if not include_inactive:
1774 query = query.where(DbGateway.enabled)
1776 query = await self._apply_access_control(query, db, user_email, token_teams, team_id)
1778 if visibility:
1779 query = query.where(DbGateway.visibility == visibility)
1781 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
1782 if tags:
1783 query = query.where(json_contains_tag_expr(db, DbGateway.tags, tags, match_any=True))
1784 # Use unified pagination helper - handles both page and cursor pagination
1785 pag_result = await unified_paginate(
1786 db=db,
1787 query=query,
1788 page=page,
1789 per_page=per_page,
1790 cursor=cursor,
1791 limit=limit,
1792 base_url="/admin/gateways", # Used for page-based links
1793 query_params={"include_inactive": include_inactive} if include_inactive else {},
1794 )
1796 next_cursor = None
1797 # Extract gateways based on pagination type
1798 if page is not None:
1799 # Page-based: pag_result is a dict
1800 gateways_db = pag_result["data"]
1801 else:
1802 # Cursor-based: pag_result is a tuple
1803 gateways_db, next_cursor = pag_result
1805 db.commit() # Release transaction to avoid idle-in-transaction
1807 # Convert to GatewayRead (common for both pagination types)
1808 result = []
1809 for s in gateways_db:
1810 try:
1811 result.append(self.convert_gateway_to_read(s))
1812 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1813 logger.exception(f"Failed to convert gateway {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
1814 # Continue with remaining gateways instead of failing completely
1816 # Return appropriate format based on pagination type
1817 if page is not None:
1818 # Page-based format
1819 return {
1820 "data": result,
1821 "pagination": pag_result["pagination"],
1822 "links": pag_result["links"],
1823 }
1825 # Cursor-based format
1827 # Cache first page results - only for public-only queries (no user/team filtering)
1828 # SECURITY: Only cache public-only results (token_teams=[]), never admin bypass or team-scoped
1829 if cursor is None and user_email is None and is_public_only:
1830 try:
1831 cache_data = {"gateways": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
1832 await cache.set("gateways", cache_data, filters_hash)
1833 except AttributeError:
1834 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
1836 return (result, next_cursor)
1838 async def list_gateways_for_user(
1839 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100
1840 ) -> List[GatewayRead]:
1841 """
1842 DEPRECATED: Use list_gateways() with user_email parameter instead.
1844 This method is maintained for backward compatibility but is no longer used.
1845 New code should call list_gateways() with user_email, team_id, and visibility parameters.
1847 List gateways user has access to with team filtering.
1849 Args:
1850 db: Database session
1851 user_email: Email of the user requesting gateways
1852 team_id: Optional team ID to filter by specific team
1853 visibility: Optional visibility filter (private, team, public)
1854 include_inactive: Whether to include inactive gateways
1855 skip: Number of gateways to skip for pagination
1856 limit: Maximum number of gateways to return
1858 Returns:
1859 List[GatewayRead]: Gateways the user has access to
1860 """
1861 # Build query following existing patterns from list_gateways()
1862 team_service = TeamManagementService(db)
1863 user_teams = await team_service.get_user_teams(user_email)
1864 team_ids = [team.id for team in user_teams]
1866 # Use joinedload to eager load email_team relationship (avoids N+1 queries)
1867 query = select(DbGateway).options(joinedload(DbGateway.email_team))
1869 # Apply active/inactive filter
1870 if not include_inactive:
1871 query = query.where(DbGateway.enabled.is_(True))
1873 if team_id:
1874 if team_id not in team_ids:
1875 return [] # No access to team
1877 access_conditions = []
1878 # Filter by specific team
1880 # Team-owned gateways (team-scoped gateways)
1881 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.visibility.in_(["team", "public"])))
1883 access_conditions.append(and_(DbGateway.team_id == team_id, DbGateway.owner_email == user_email))
1885 # Also include global public gateways (no team_id) so public gateways are visible regardless of selected team
1886 access_conditions.append(DbGateway.visibility == "public")
1888 query = query.where(or_(*access_conditions))
1889 else:
1890 # Get user's accessible teams
1891 # Build access conditions following existing patterns
1892 access_conditions = []
1893 # 1. User's personal resources (owner_email matches)
1894 access_conditions.append(DbGateway.owner_email == user_email)
1895 # 2. Team resources where user is member
1896 if team_ids:
1897 access_conditions.append(and_(DbGateway.team_id.in_(team_ids), DbGateway.visibility.in_(["team", "public"])))
1898 # 3. Public resources (if visibility allows)
1899 access_conditions.append(DbGateway.visibility == "public")
1901 query = query.where(or_(*access_conditions))
1903 # Apply visibility filter if specified
1904 if visibility:
1905 query = query.where(DbGateway.visibility == visibility)
1907 # Apply pagination following existing patterns
1908 query = query.offset(skip).limit(limit)
1910 gateways = db.execute(query).scalars().all()
1912 db.commit() # Release transaction to avoid idle-in-transaction
1914 # Team names are loaded via joinedload(DbGateway.email_team)
1915 result = []
1916 for g in gateways:
1917 logger.info(f"Gateway: {SecurityValidator.sanitize_log_message(g.team_id)}, Team: {g.team}")
1918 result.append(self.convert_gateway_to_read(g))
1919 return result
1921 async def update_gateway(
1922 self,
1923 db: Session,
1924 gateway_id: str,
1925 gateway_update: GatewayUpdate,
1926 modified_by: Optional[str] = None,
1927 modified_from_ip: Optional[str] = None,
1928 modified_via: Optional[str] = None,
1929 modified_user_agent: Optional[str] = None,
1930 include_inactive: bool = True,
1931 user_email: Optional[str] = None,
1932 ) -> Optional[GatewayRead]:
1933 """Update a gateway.
1935 Args:
1936 db: Database session
1937 gateway_id: Gateway ID to update
1938 gateway_update: Updated gateway data
1939 modified_by: Username of the person modifying the gateway
1940 modified_from_ip: IP address where the modification request originated
1941 modified_via: Source of modification (ui/api/import)
1942 modified_user_agent: User agent string from the modification request
1943 include_inactive: Whether to include inactive gateways
1944 user_email: Email of user performing update (for ownership check)
1946 Returns:
1947 Updated gateway information
1949 Raises:
1950 GatewayNotFoundError: If gateway not found
1951 PermissionError: If user doesn't own the gateway
1952 GatewayError: For other update errors
1953 GatewayNameConflictError: If gateway name conflict occurs
1954 IntegrityError: If there is a database integrity error
1955 ValidationError: If validation fails
1956 """
1957 try: # pylint: disable=too-many-nested-blocks
1958 # Acquire row lock and eager-load relationships while locked so
1959 # concurrent updates are serialized on Postgres.
1960 gateway = get_for_update(
1961 db,
1962 DbGateway,
1963 gateway_id,
1964 options=[
1965 selectinload(DbGateway.tools),
1966 selectinload(DbGateway.resources),
1967 selectinload(DbGateway.prompts),
1968 selectinload(DbGateway.email_team), # Use selectinload to avoid locking email_teams
1969 ],
1970 )
1971 if not gateway:
1972 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
1974 # Check ownership if user_email provided
1975 if user_email:
1976 # First-Party
1977 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1979 permission_service = PermissionService(db)
1980 if not await permission_service.check_resource_ownership(user_email, gateway):
1981 raise PermissionError("Only the owner can update this gateway")
1983 if gateway.enabled or include_inactive:
1984 # Check for name conflicts if name is being changed
1985 if gateway_update.name is not None and gateway_update.name != gateway.name:
1986 # existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none()
1988 # if existing_gateway:
1989 # raise GatewayNameConflictError(
1990 # gateway_update.name,
1991 # enabled=existing_gateway.enabled,
1992 # gateway_id=existing_gateway.id,
1993 # )
1994 # Check for existing gateway with the same slug and visibility
1995 new_slug = slugify(gateway_update.name)
1996 if gateway_update.visibility is not None:
1997 vis = gateway_update.visibility
1998 else:
1999 vis = gateway.visibility
2000 if vis == "public":
2001 # Check for existing public gateway with the same slug (row-locked)
2002 existing_gateway = get_for_update(
2003 db,
2004 DbGateway,
2005 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "public", DbGateway.id != gateway_id),
2006 )
2007 if existing_gateway:
2008 raise GatewayNameConflictError(
2009 new_slug,
2010 enabled=existing_gateway.enabled,
2011 gateway_id=existing_gateway.id,
2012 visibility=existing_gateway.visibility,
2013 )
2014 elif vis == "team" and gateway.team_id:
2015 # Check for existing team gateway with the same slug (row-locked)
2016 existing_gateway = get_for_update(
2017 db,
2018 DbGateway,
2019 where=and_(DbGateway.slug == new_slug, DbGateway.visibility == "team", DbGateway.team_id == gateway.team_id, DbGateway.id != gateway_id),
2020 )
2021 if existing_gateway:
2022 raise GatewayNameConflictError(
2023 new_slug,
2024 enabled=existing_gateway.enabled,
2025 gateway_id=existing_gateway.id,
2026 visibility=existing_gateway.visibility,
2027 )
2028 # Check for existing gateway with the same URL and visibility
2029 normalized_url = ""
2030 if gateway_update.url is not None:
2031 normalized_url = self.normalize_url(str(gateway_update.url))
2032 else:
2033 normalized_url = None
2035 # Prepare decoded auth_value for uniqueness check
2036 decoded_auth_value = None
2037 if gateway_update.auth_value:
2038 if isinstance(gateway_update.auth_value, str):
2039 try:
2040 decoded_auth_value = decode_auth(gateway_update.auth_value)
2041 except Exception as e:
2042 logger.warning(f"Failed to decode provided auth_value: {e}")
2043 elif isinstance(gateway_update.auth_value, dict):
2044 decoded_auth_value = gateway_update.auth_value
2046 # Determine final values for uniqueness check
2047 final_auth_value = decoded_auth_value if gateway_update.auth_value is not None else (decode_auth(gateway.auth_value) if isinstance(gateway.auth_value, str) else gateway.auth_value)
2048 final_oauth_config = gateway_update.oauth_config if gateway_update.oauth_config is not None else gateway.oauth_config
2049 final_visibility = gateway_update.visibility if gateway_update.visibility is not None else gateway.visibility
2051 # Check for duplicates with updated credentials
2052 if not gateway_update.one_time_auth:
2053 duplicate_gateway = self._check_gateway_uniqueness(
2054 db=db,
2055 url=normalized_url,
2056 auth_value=final_auth_value,
2057 oauth_config=final_oauth_config,
2058 team_id=gateway.team_id,
2059 visibility=final_visibility,
2060 gateway_id=gateway_id, # Exclude current gateway from check
2061 owner_email=user_email,
2062 )
2064 if duplicate_gateway:
2065 raise GatewayDuplicateConflictError(duplicate_gateway=duplicate_gateway)
2067 # FIX for Issue #1025: Determine if URL actually changed before we update it
2068 # We need this early because we update gateway.url below, and need to know
2069 # if it actually changed to decide whether to re-fetch tools
2070 # tools/resoures/prompts are need to be re-fetched not only if URL changed , in case any update like authentication and visibility changed
2071 # url_changed = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != gateway.url
2073 # Save original values BEFORE updating for change detection checks later
2074 original_url = gateway.url
2075 original_auth_type = gateway.auth_type
2077 # Update fields if provided
2078 if gateway_update.name is not None:
2079 gateway.name = gateway_update.name
2080 gateway.slug = slugify(gateway_update.name)
2081 if gateway_update.url is not None:
2082 # Normalize the updated URL
2083 gateway.url = self.normalize_url(str(gateway_update.url))
2084 if gateway_update.description is not None:
2085 gateway.description = gateway_update.description
2086 if gateway_update.transport is not None:
2087 gateway.transport = gateway_update.transport
2088 if gateway_update.tags is not None:
2089 gateway.tags = gateway_update.tags
2090 if gateway_update.visibility is not None:
2091 old_visibility = gateway.visibility
2092 # Validate visibility transitions
2093 if gateway_update.visibility == "team":
2094 target_team_id = gateway_update.team_id if gateway_update.team_id is not None else gateway.team_id
2095 _validate_gateway_team_assignment(db, user_email, target_team_id)
2096 gateway.visibility = gateway_update.visibility
2097 # Propagate visibility to all linked items immediately so it
2098 # takes effect even when the upstream server is unreachable
2099 # and _initialize_gateway fails.
2100 # Only update items that inherited the old gateway visibility;
2101 # preserve per-item overrides (e.g. a resource set to "team"
2102 # while the gateway was "public").
2103 for tool in gateway.tools:
2104 if tool.visibility == old_visibility:
2105 tool.visibility = gateway.visibility
2106 for resource in gateway.resources:
2107 if resource.visibility == old_visibility:
2108 resource.visibility = gateway.visibility
2109 for prompt in gateway.prompts:
2110 if prompt.visibility == old_visibility:
2111 prompt.visibility = gateway.visibility
2112 if gateway_update.passthrough_headers is not None:
2113 if isinstance(gateway_update.passthrough_headers, list):
2114 gateway.passthrough_headers = gateway_update.passthrough_headers
2115 else:
2116 if isinstance(gateway_update.passthrough_headers, str):
2117 parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()]
2118 gateway.passthrough_headers = parsed
2119 else:
2120 raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string")
2122 logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}")
2124 # Update team assignment if provided, validating ownership
2125 if gateway_update.team_id is not None:
2126 if gateway_update.team_id != gateway.team_id:
2127 _validate_gateway_team_assignment(db, user_email, gateway_update.team_id)
2128 gateway.team_id = gateway_update.team_id
2130 # Update CA certificate fields if provided
2131 if getattr(gateway_update, "ca_certificate", None) is not None:
2132 gateway.ca_certificate = gateway_update.ca_certificate
2133 if getattr(gateway_update, "ca_certificate_sig", None) is not None:
2134 gateway.ca_certificate_sig = gateway_update.ca_certificate_sig
2135 if getattr(gateway_update, "signing_algorithm", None) is not None:
2136 gateway.signing_algorithm = gateway_update.signing_algorithm
2138 # Update mTLS client certificate/key if provided
2139 if getattr(gateway_update, "client_cert", None) is not None:
2140 gateway.client_cert = gateway_update.client_cert
2141 if getattr(gateway_update, "client_key", None) is not None:
2142 if gateway_update.client_key == settings.masked_auth_value:
2143 pass # Preserve existing encrypted value
2144 else:
2145 gateway.client_key = await self._encrypt_client_key(gateway_update.client_key)
2147 # Only update auth_type if explicitly provided in the update
2148 if gateway_update.auth_type is not None:
2149 gateway.auth_type = gateway_update.auth_type
2151 # If auth_type is empty, update the auth_value too
2152 if gateway_update.auth_type == "":
2153 gateway.auth_value = cast(Any, "")
2155 # Clear auth_query_params when switching away from query_param auth
2156 if original_auth_type == "query_param" and gateway_update.auth_type != "query_param":
2157 gateway.auth_query_params = None
2158 logger.debug(f"Cleared auth_query_params for gateway {SecurityValidator.sanitize_log_message(gateway.id)} (switched from query_param to {gateway_update.auth_type})")
2160 # if auth_type is not None and only then check auth_value
2161 # Handle OAuth configuration updates
2162 if gateway_update.oauth_config is not None:
2163 gateway.oauth_config = await protect_oauth_config_for_storage(gateway_update.oauth_config, existing_oauth_config=gateway.oauth_config)
2165 # Handle auth_value updates (both existing and new auth values)
2166 token = gateway_update.auth_token
2167 password = gateway_update.auth_password
2168 header_value = gateway_update.auth_header_value
2170 # Support multiple custom headers on update
2171 if hasattr(gateway_update, "auth_headers") and gateway_update.auth_headers:
2172 existing_auth_raw = getattr(gateway, "auth_value", {}) or {}
2173 if isinstance(existing_auth_raw, str):
2174 try:
2175 existing_auth = decode_auth(existing_auth_raw)
2176 except Exception:
2177 existing_auth = {}
2178 elif isinstance(existing_auth_raw, dict):
2179 existing_auth = existing_auth_raw
2180 else:
2181 existing_auth = {}
2183 header_dict: Dict[str, str] = {}
2184 for header in gateway_update.auth_headers:
2185 key = header.get("key")
2186 if not key:
2187 continue
2188 value = header.get("value", "")
2189 if value == settings.masked_auth_value and key in existing_auth:
2190 header_dict[key] = existing_auth[key]
2191 else:
2192 header_dict[key] = value
2193 gateway.auth_value = header_dict # Store as dict for DB JSON field
2194 elif settings.masked_auth_value not in (token, password, header_value):
2195 # Check if values differ from existing ones or if setting for first time
2196 decoded_auth = decode_auth(gateway_update.auth_value) if gateway_update.auth_value else {}
2197 current_auth = getattr(gateway, "auth_value", {}) or {}
2198 if current_auth != decoded_auth:
2199 gateway.auth_value = decoded_auth
2201 # Handle query_param auth updates with service-layer enforcement
2202 auth_query_params_decrypted: Optional[Dict[str, str]] = None
2203 init_url = gateway.url
2205 # Check if updating to query_param auth or updating existing query_param credentials
2206 # Use original_auth_type since gateway.auth_type may have been updated already
2207 is_switching_to_queryparam = gateway_update.auth_type == "query_param" and original_auth_type != "query_param"
2208 is_updating_queryparam_creds = original_auth_type == "query_param" and (gateway_update.auth_query_param_key is not None or gateway_update.auth_query_param_value is not None)
2209 is_url_changing = gateway_update.url is not None and self.normalize_url(str(gateway_update.url)) != original_url
2211 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"):
2212 # Service-layer enforcement: Check feature flag
2213 if not settings.insecure_allow_queryparam_auth:
2214 # Grandfather clause: Allow updates to existing query_param gateways
2215 # unless they're trying to change credentials
2216 if is_switching_to_queryparam or is_updating_queryparam_creds:
2217 raise ValueError("Query parameter authentication is disabled. " + "Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
2219 # Service-layer enforcement: Check host allowlist
2220 if settings.insecure_queryparam_auth_allowed_hosts:
2221 check_url = str(gateway_update.url) if gateway_update.url else gateway.url
2222 parsed = urlparse(check_url)
2223 hostname = (parsed.hostname or "").lower()
2224 if hostname not in settings.insecure_queryparam_auth_allowed_hosts:
2225 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
2226 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. Allowed: {allowed}")
2228 # Process query_param auth credentials
2229 param_key = getattr(gateway_update, "auth_query_param_key", None) or (next(iter(gateway.auth_query_params.keys()), None) if gateway.auth_query_params else None)
2230 param_value = getattr(gateway_update, "auth_query_param_value", None)
2232 # Get raw value from SecretStr if applicable
2233 raw_value: Optional[str] = None
2234 if param_value:
2235 if hasattr(param_value, "get_secret_value"):
2236 raw_value = param_value.get_secret_value()
2237 else:
2238 raw_value = str(param_value)
2240 # Check if the value is the masked placeholder - if so, keep existing value
2241 is_masked_placeholder = raw_value == settings.masked_auth_value
2243 if param_key:
2244 if raw_value and not is_masked_placeholder:
2245 # New value provided - encrypt for storage
2246 encrypted_value = encode_auth({param_key: raw_value})
2247 gateway.auth_query_params = {param_key: encrypted_value}
2248 auth_query_params_decrypted = {param_key: raw_value}
2249 elif gateway.auth_query_params:
2250 # Use existing encrypted value
2251 existing_encrypted = gateway.auth_query_params.get(param_key, "")
2252 if existing_encrypted:
2253 decrypted = decode_auth(existing_encrypted)
2254 auth_query_params_decrypted = {param_key: decrypted.get(param_key, "")}
2256 # Append query params to URL for initialization
2257 if auth_query_params_decrypted:
2258 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2260 # Update auth_type if switching
2261 if is_switching_to_queryparam:
2262 gateway.auth_type = "query_param"
2263 gateway.auth_value = None # Query param auth doesn't use auth_value
2265 elif gateway.auth_type == "query_param" and gateway.auth_query_params:
2266 # Existing query_param gateway without credential changes - decrypt for init
2267 first_key = next(iter(gateway.auth_query_params.keys()), None)
2268 if first_key:
2269 encrypted_value = gateway.auth_query_params.get(first_key, "")
2270 if encrypted_value:
2271 decrypted = decode_auth(encrypted_value)
2272 auth_query_params_decrypted = {first_key: decrypted.get(first_key, "")}
2273 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2275 # Try to reinitialize connection if URL actually changed
2276 # if url_changed:
2277 # Initialize empty lists in case initialization fails
2278 tools_to_add = []
2279 resources_to_add = []
2280 prompts_to_add = []
2281 reinit_succeeded = False
2283 try:
2284 ca_certificate = getattr(gateway, "ca_certificate", None)
2285 update_client_cert = getattr(gateway, "client_cert", None)
2286 update_client_key = getattr(gateway, "client_key", None)
2287 # Decrypt client_key for initialization (stored encrypted)
2288 if update_client_key:
2289 try:
2290 _enc = get_encryption_service(settings.auth_encryption_secret)
2291 update_client_key = _enc.decrypt_secret_or_plaintext(update_client_key)
2292 except Exception:
2293 logger.debug("client_key decryption skipped during gateway re-init")
2294 capabilities, tools, resources, prompts = await self._initialize_gateway(
2295 init_url,
2296 gateway.auth_value,
2297 gateway.transport,
2298 gateway.auth_type,
2299 gateway.oauth_config,
2300 ca_certificate,
2301 auth_query_params=auth_query_params_decrypted,
2302 client_cert=update_client_cert,
2303 client_key=update_client_key,
2304 )
2305 new_tool_names = [tool.name for tool in tools]
2306 new_resource_uris = [resource.uri for resource in resources]
2307 new_prompt_names = [prompt.name for prompt in prompts]
2309 if gateway_update.one_time_auth:
2310 # For one-time auth, clear auth_type and auth_value after initialization
2311 gateway.auth_type = "one_time_auth"
2312 gateway.auth_value = None
2313 gateway.oauth_config = None
2315 # Update tools using helper method — only propagate visibility
2316 # when the user explicitly changed it in this request
2317 _vis_changed = gateway_update.visibility is not None
2318 tools_to_add = self._update_or_create_tools(db, tools, gateway, "update", update_visibility=_vis_changed)
2320 # Update resources using helper method
2321 resources_to_add = self._update_or_create_resources(db, resources, gateway, "update", update_visibility=_vis_changed)
2323 # Update prompts using helper method
2324 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "update", update_visibility=_vis_changed)
2326 # Log newly added items
2327 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add)
2328 if items_added > 0:
2329 if tools_to_add:
2330 logger.info(f"Added {len(tools_to_add)} new tools during gateway update")
2331 if resources_to_add:
2332 logger.info(f"Added {len(resources_to_add)} new resources during gateway update")
2333 if prompts_to_add:
2334 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway update")
2335 logger.info(f"Total {items_added} new items added during gateway update")
2337 # Count items before cleanup for logging
2339 # Bulk delete tools that are no longer available from the gateway
2340 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2341 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names]
2342 if stale_tool_ids:
2343 # Delete child records first to avoid FK constraint violations
2344 for i in range(0, len(stale_tool_ids), 500):
2345 chunk = stale_tool_ids[i : i + 500]
2346 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2347 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2348 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2350 # Bulk delete resources that are no longer available from the gateway
2351 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris]
2352 if stale_resource_ids:
2353 # Delete child records first to avoid FK constraint violations
2354 for i in range(0, len(stale_resource_ids), 500):
2355 chunk = stale_resource_ids[i : i + 500]
2356 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2357 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2358 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2359 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2361 # Bulk delete prompts that are no longer available from the gateway
2362 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names]
2363 if stale_prompt_ids:
2364 # Delete child records first to avoid FK constraint violations
2365 for i in range(0, len(stale_prompt_ids), 500):
2366 chunk = stale_prompt_ids[i : i + 500]
2367 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2368 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2369 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2371 # Expire gateway to clear cached relationships after bulk deletes
2372 # This prevents SQLAlchemy from trying to re-delete already-deleted items
2373 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
2374 db.expire(gateway)
2376 gateway.capabilities = capabilities
2378 # Register capabilities for notification-driven actions
2379 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
2381 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
2382 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows
2383 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows
2385 # Log cleanup results
2386 tools_removed = len(stale_tool_ids)
2387 resources_removed = len(stale_resource_ids)
2388 prompts_removed = len(stale_prompt_ids)
2390 if tools_removed > 0:
2391 logger.info(f"Removed {tools_removed} tools no longer available during gateway update")
2392 if resources_removed > 0:
2393 logger.info(f"Removed {resources_removed} resources no longer available during gateway update")
2394 if prompts_removed > 0:
2395 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway update")
2397 gateway.last_seen = datetime.now(timezone.utc)
2399 # Add new items to database session in chunks to prevent lock escalation
2400 chunk_size = 50
2402 if tools_to_add:
2403 for i in range(0, len(tools_to_add), chunk_size):
2404 chunk = tools_to_add[i : i + chunk_size]
2405 db.add_all(chunk)
2406 db.flush()
2407 if resources_to_add:
2408 for i in range(0, len(resources_to_add), chunk_size):
2409 chunk = resources_to_add[i : i + chunk_size]
2410 db.add_all(chunk)
2411 db.flush()
2412 if prompts_to_add:
2413 for i in range(0, len(prompts_to_add), chunk_size):
2414 chunk = prompts_to_add[i : i + chunk_size]
2415 db.add_all(chunk)
2416 db.flush()
2418 # Update tracking with new URL
2419 self._active_gateways.discard(gateway.url)
2420 self._active_gateways.add(gateway.url)
2421 reinit_succeeded = True
2422 except Exception as e:
2423 logger.warning(f"Failed to initialize updated gateway: {e}")
2424 reinit_succeeded = False
2426 # Update tags if provided
2427 if gateway_update.tags is not None:
2428 gateway.tags = gateway_update.tags
2430 # Update gateway_mode if provided
2431 if hasattr(gateway_update, "gateway_mode") and gateway_update.gateway_mode is not None:
2432 if gateway_update.gateway_mode == "direct_proxy" and not settings.mcpgateway_direct_proxy_enabled:
2433 raise GatewayError("direct_proxy gateway mode is disabled. Set MCPGATEWAY_DIRECT_PROXY_ENABLED=true to enable.")
2434 gateway.gateway_mode = gateway_update.gateway_mode
2436 # Update metadata fields
2437 gateway.updated_at = datetime.now(timezone.utc)
2438 if modified_by:
2439 gateway.modified_by = modified_by
2440 if modified_from_ip:
2441 gateway.modified_from_ip = modified_from_ip
2442 if modified_via:
2443 gateway.modified_via = modified_via
2444 if modified_user_agent:
2445 gateway.modified_user_agent = modified_user_agent
2446 if hasattr(gateway, "version") and gateway.version is not None:
2447 gateway.version = gateway.version + 1
2448 else:
2449 gateway.version = 1
2451 db.commit()
2452 db.refresh(gateway)
2454 # Invalidate cache after successful update
2455 cache = _get_registry_cache()
2456 await cache.invalidate_gateways()
2457 tool_lookup_cache = _get_tool_lookup_cache()
2458 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
2459 # Also invalidate tags cache since gateway tags may have changed
2460 # First-Party
2461 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2463 await admin_stats_cache.invalidate_tags()
2465 # Advance hot/cold poll schedule only after successful tool re-init
2466 if reinit_succeeded and self._classification_service and gateway.url:
2467 try:
2468 await self._classification_service.mark_poll_completed(gateway.url, "tool_discovery", gateway_id=str(gateway.id))
2469 except Exception as poll_ts_err:
2470 logger.debug(f"Best-effort tool_discovery poll timestamp update failed: {poll_ts_err}")
2472 # Invalidate loopback passthrough cache when gateway headers change (#3640)
2473 if gateway_update.passthrough_headers is not None:
2474 # First-Party
2475 from mcpgateway.utils.passthrough_headers import invalidate_passthrough_header_caches # pylint: disable=import-outside-toplevel
2477 invalidate_passthrough_header_caches()
2479 # Notify subscribers
2480 await self._notify_gateway_updated(gateway)
2482 logger.info(f"Updated gateway: {SecurityValidator.sanitize_log_message(gateway.name)}")
2484 # Structured logging: Audit trail for gateway update
2485 audit_trail.log_action(
2486 user_id=user_email or modified_by or "system",
2487 action="update_gateway",
2488 resource_type="gateway",
2489 resource_id=str(gateway.id),
2490 resource_name=gateway.name,
2491 user_email=user_email,
2492 team_id=gateway.team_id,
2493 client_ip=modified_from_ip,
2494 user_agent=modified_user_agent,
2495 new_values={
2496 "name": gateway.name,
2497 "url": gateway.url,
2498 "version": gateway.version,
2499 },
2500 context={
2501 "modified_via": modified_via,
2502 },
2503 db=db,
2504 )
2506 # Structured logging: Log successful gateway update
2507 structured_logger.log(
2508 level="INFO",
2509 message="Gateway updated successfully",
2510 event_type="gateway_updated",
2511 component="gateway_service",
2512 user_id=modified_by,
2513 user_email=user_email,
2514 team_id=gateway.team_id,
2515 resource_type="gateway",
2516 resource_id=str(gateway.id),
2517 custom_fields={
2518 "gateway_name": gateway.name,
2519 "version": gateway.version,
2520 },
2521 )
2523 return self.convert_gateway_to_read(gateway)
2524 # Gateway is inactive and include_inactive is False → skip update, return None
2525 return None
2526 except GatewayNameConflictError as ge:
2527 logger.error(f"GatewayNameConflictError in group: {ge}")
2528 db.rollback()
2530 structured_logger.log(
2531 level="WARNING",
2532 message="Gateway update failed due to name conflict",
2533 event_type="gateway_name_conflict",
2534 component="gateway_service",
2535 user_email=user_email,
2536 resource_type="gateway",
2537 resource_id=gateway_id,
2538 error=ge,
2539 )
2540 raise ge
2541 except GatewayNotFoundError as gnfe:
2542 logger.error(f"GatewayNotFoundError: {gnfe}")
2543 db.rollback()
2545 structured_logger.log(
2546 level="ERROR",
2547 message="Gateway update failed - gateway not found",
2548 event_type="gateway_not_found",
2549 component="gateway_service",
2550 user_email=user_email,
2551 resource_type="gateway",
2552 resource_id=gateway_id,
2553 error=gnfe,
2554 )
2555 raise gnfe
2556 except IntegrityError as ie:
2557 logger.error(f"IntegrityErrors in group: {ie}")
2558 db.rollback()
2560 structured_logger.log(
2561 level="ERROR",
2562 message="Gateway update failed due to database integrity error",
2563 event_type="gateway_update_failed",
2564 component="gateway_service",
2565 user_email=user_email,
2566 resource_type="gateway",
2567 resource_id=gateway_id,
2568 error=ie,
2569 )
2570 raise ie
2571 except PermissionError as pe:
2572 db.rollback()
2574 structured_logger.log(
2575 level="WARNING",
2576 message="Gateway update failed due to permission error",
2577 event_type="gateway_update_permission_denied",
2578 component="gateway_service",
2579 user_email=user_email,
2580 resource_type="gateway",
2581 resource_id=gateway_id,
2582 error=pe,
2583 )
2584 raise
2585 except Exception as e:
2586 db.rollback()
2588 structured_logger.log(
2589 level="ERROR",
2590 message="Gateway update failed",
2591 event_type="gateway_update_failed",
2592 component="gateway_service",
2593 user_email=user_email,
2594 resource_type="gateway",
2595 resource_id=gateway_id,
2596 error=e,
2597 )
2598 raise GatewayError(f"Failed to update gateway: {str(e)}")
2600 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead:
2601 """Get a gateway by its ID.
2603 Args:
2604 db: Database session
2605 gateway_id: Gateway ID
2606 include_inactive: Whether to include inactive gateways
2608 Returns:
2609 GatewayRead object
2611 Raises:
2612 GatewayNotFoundError: If the gateway is not found
2614 Examples:
2615 >>> from unittest.mock import MagicMock
2616 >>> from mcpgateway.schemas import GatewayRead
2617 >>> service = GatewayService()
2618 >>> db = MagicMock()
2619 >>> gateway_mock = MagicMock()
2620 >>> gateway_mock.enabled = True
2621 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock
2622 >>> mocked_gateway_read = MagicMock()
2623 >>> mocked_gateway_read.masked.return_value = 'gateway_read'
2624 >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read)
2625 >>> import asyncio
2626 >>> result = asyncio.run(service.get_gateway(db, 'gateway_id'))
2627 >>> result == 'gateway_read'
2628 True
2630 >>> # Test with inactive gateway but include_inactive=True
2631 >>> gateway_mock.enabled = False
2632 >>> result_inactive = asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=True))
2633 >>> result_inactive == 'gateway_read'
2634 True
2636 >>> # Test gateway not found
2637 >>> db.execute.return_value.scalar_one_or_none.return_value = None
2638 >>> try:
2639 ... asyncio.run(service.get_gateway(db, 'missing_id'))
2640 ... except GatewayNotFoundError as e:
2641 ... 'Gateway not found: missing_id' in str(e)
2642 True
2644 >>> # Test inactive gateway with include_inactive=False
2645 >>> gateway_mock.enabled = False
2646 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway_mock
2647 >>> try:
2648 ... asyncio.run(service.get_gateway(db, 'gateway_id', include_inactive=False))
2649 ... except GatewayNotFoundError as e:
2650 ... 'Gateway not found: gateway_id' in str(e)
2651 True
2652 >>>
2653 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
2654 >>> asyncio.run(service._http_client.aclose())
2655 """
2656 # Use eager loading to avoid N+1 queries for relationships and team name
2657 gateway = db.execute(
2658 select(DbGateway)
2659 .options(
2660 selectinload(DbGateway.tools),
2661 selectinload(DbGateway.resources),
2662 selectinload(DbGateway.prompts),
2663 joinedload(DbGateway.email_team),
2664 )
2665 .where(DbGateway.id == gateway_id)
2666 ).scalar_one_or_none()
2668 if not gateway:
2669 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2671 if gateway.enabled or include_inactive:
2672 # Structured logging: Log gateway view
2673 structured_logger.log(
2674 level="INFO",
2675 message="Gateway retrieved successfully",
2676 event_type="gateway_viewed",
2677 component="gateway_service",
2678 team_id=getattr(gateway, "team_id", None),
2679 resource_type="gateway",
2680 resource_id=str(gateway.id),
2681 custom_fields={
2682 "gateway_name": gateway.name,
2683 "gateway_url": gateway.url,
2684 "include_inactive": include_inactive,
2685 },
2686 )
2688 return self.convert_gateway_to_read(gateway)
2690 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2692 async def set_gateway_state(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False, user_email: Optional[str] = None) -> GatewayRead:
2693 """
2694 Set the activation status of a gateway.
2696 Args:
2697 db: Database session
2698 gateway_id: Gateway ID
2699 activate: True to activate, False to deactivate
2700 reachable: Whether the gateway is reachable
2701 only_update_reachable: Only update reachable status
2702 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2704 Returns:
2705 The updated GatewayRead object
2707 Raises:
2708 GatewayNotFoundError: If the gateway is not found
2709 GatewayError: For other errors
2710 PermissionError: If user doesn't own the agent.
2711 """
2712 try:
2713 # Eager-load collections for the gateway. Note: we don't use FOR UPDATE
2714 # here because _initialize_gateway does network I/O, and holding a row
2715 # lock during network calls would block other operations and risk timeouts.
2716 gateway = db.execute(
2717 select(DbGateway)
2718 .options(
2719 selectinload(DbGateway.tools),
2720 selectinload(DbGateway.resources),
2721 selectinload(DbGateway.prompts),
2722 joinedload(DbGateway.email_team),
2723 )
2724 .where(DbGateway.id == gateway_id)
2725 ).scalar_one_or_none()
2726 if not gateway:
2727 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
2729 if user_email:
2730 # First-Party
2731 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2733 permission_service = PermissionService(db)
2734 if not await permission_service.check_resource_ownership(user_email, gateway):
2735 raise PermissionError("Only the owner can activate the gateway" if activate else "Only the owner can deactivate the gateway")
2737 # Update status if it's different
2738 if (gateway.enabled != activate) or (gateway.reachable != reachable):
2739 gateway.enabled = activate
2740 gateway.reachable = reachable
2741 gateway.updated_at = datetime.now(timezone.utc)
2742 # Update tracking
2743 if activate and reachable:
2744 self._active_gateways.add(gateway.url)
2746 # Initialize empty lists in case initialization fails
2747 tools_to_add = []
2748 resources_to_add = []
2749 prompts_to_add = []
2751 # Try to initialize if activating
2752 try:
2753 # Handle query_param auth - decrypt and apply to URL
2754 init_url = gateway.url
2755 auth_query_params_decrypted: Optional[Dict[str, str]] = None
2756 if gateway.auth_type == "query_param" and gateway.auth_query_params:
2757 auth_query_params_decrypted = {}
2758 for param_key, encrypted_value in gateway.auth_query_params.items():
2759 if encrypted_value:
2760 try:
2761 decrypted = decode_auth(encrypted_value)
2762 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
2763 except Exception:
2764 logger.debug(f"Failed to decrypt query param '{param_key}' for gateway activation")
2765 if auth_query_params_decrypted:
2766 init_url = apply_query_param_auth(gateway.url, auth_query_params_decrypted)
2768 act_client_cert = getattr(gateway, "client_cert", None)
2769 act_client_key = getattr(gateway, "client_key", None)
2770 if act_client_key:
2771 try:
2772 _enc = get_encryption_service(settings.auth_encryption_secret)
2773 act_client_key = _enc.decrypt_secret_or_plaintext(act_client_key)
2774 except Exception:
2775 logger.debug("client_key decryption skipped during gateway activation")
2776 capabilities, tools, resources, prompts = await self._initialize_gateway(
2777 init_url,
2778 gateway.auth_value,
2779 gateway.transport,
2780 gateway.auth_type,
2781 gateway.oauth_config,
2782 auth_query_params=auth_query_params_decrypted,
2783 oauth_auto_fetch_tool_flag=True,
2784 client_cert=act_client_cert,
2785 client_key=act_client_key,
2786 )
2787 new_tool_names = [tool.name for tool in tools]
2788 new_resource_uris = [resource.uri for resource in resources]
2789 new_prompt_names = [prompt.name for prompt in prompts]
2791 # Update tools, resources, and prompts using helper methods
2792 tools_to_add = self._update_or_create_tools(db, tools, gateway, "rediscovery")
2793 resources_to_add = self._update_or_create_resources(db, resources, gateway, "rediscovery")
2794 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, "rediscovery")
2796 # Log newly added items
2797 items_added = len(tools_to_add) + len(resources_to_add) + len(prompts_to_add)
2798 if items_added > 0:
2799 if tools_to_add:
2800 logger.info(f"Added {len(tools_to_add)} new tools during gateway reactivation")
2801 if resources_to_add:
2802 logger.info(f"Added {len(resources_to_add)} new resources during gateway reactivation")
2803 if prompts_to_add:
2804 logger.info(f"Added {len(prompts_to_add)} new prompts during gateway reactivation")
2805 logger.info(f"Total {items_added} new items added during gateway reactivation")
2807 # Count items before cleanup for logging
2809 # For authorization_code OAuth gateways, empty responses may indicate
2810 # a missing auth token rather than genuine removal of all items.
2811 # Skip stale cleanup to prevent destructive deletion of tools,
2812 # resources, prompts, and their virtual server associations.
2813 # Mirrors the guard in _refresh_gateway_tools_resources_prompts.
2814 is_auth_code_gateway = gateway.oauth_config and isinstance(gateway.oauth_config, dict) and gateway.oauth_config.get("grant_type") == "authorization_code"
2815 skip_stale_cleanup = not tools and not resources and not prompts and is_auth_code_gateway
2816 if skip_stale_cleanup:
2817 logger.debug(f"Empty response from auth_code gateway {gateway.name} during reactivation, preserving existing items")
2819 # Bulk delete tools that are no longer available from the gateway
2820 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
2821 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names] if not skip_stale_cleanup else []
2822 if stale_tool_ids:
2823 # Delete child records first to avoid FK constraint violations
2824 for i in range(0, len(stale_tool_ids), 500):
2825 chunk = stale_tool_ids[i : i + 500]
2826 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
2827 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
2828 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
2830 # Bulk delete resources that are no longer available from the gateway
2831 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris] if not skip_stale_cleanup else []
2832 if stale_resource_ids:
2833 # Delete child records first to avoid FK constraint violations
2834 for i in range(0, len(stale_resource_ids), 500):
2835 chunk = stale_resource_ids[i : i + 500]
2836 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
2837 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
2838 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
2839 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
2841 # Bulk delete prompts that are no longer available from the gateway
2842 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names] if not skip_stale_cleanup else []
2843 if stale_prompt_ids:
2844 # Delete child records first to avoid FK constraint violations
2845 for i in range(0, len(stale_prompt_ids), 500):
2846 chunk = stale_prompt_ids[i : i + 500]
2847 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
2848 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
2849 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
2851 # Expire gateway to clear cached relationships after bulk deletes
2852 # This prevents SQLAlchemy from trying to re-delete already-deleted items
2853 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
2854 db.expire(gateway)
2856 gateway.capabilities = capabilities
2858 # Register capabilities for notification-driven actions
2859 register_gateway_capabilities_for_notifications(gateway.id, capabilities)
2861 if not skip_stale_cleanup:
2862 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
2863 gateway.resources = [resource for resource in gateway.resources if resource.uri in new_resource_uris] # keep only still-valid rows
2864 gateway.prompts = [prompt for prompt in gateway.prompts if prompt.original_name in new_prompt_names] # keep only still-valid rows
2866 # Log cleanup results
2867 tools_removed = len(stale_tool_ids)
2868 resources_removed = len(stale_resource_ids)
2869 prompts_removed = len(stale_prompt_ids)
2871 if tools_removed > 0:
2872 logger.info(f"Removed {tools_removed} tools no longer available during gateway reactivation")
2873 if resources_removed > 0:
2874 logger.info(f"Removed {resources_removed} resources no longer available during gateway reactivation")
2875 if prompts_removed > 0:
2876 logger.info(f"Removed {prompts_removed} prompts no longer available during gateway reactivation")
2878 gateway.last_seen = datetime.now(timezone.utc)
2880 # Add new items to database session in chunks to prevent lock escalation
2881 chunk_size = 50
2883 if tools_to_add:
2884 for i in range(0, len(tools_to_add), chunk_size):
2885 chunk = tools_to_add[i : i + chunk_size]
2886 db.add_all(chunk)
2887 db.flush()
2888 if resources_to_add:
2889 for i in range(0, len(resources_to_add), chunk_size):
2890 chunk = resources_to_add[i : i + chunk_size]
2891 db.add_all(chunk)
2892 db.flush()
2893 if prompts_to_add:
2894 for i in range(0, len(prompts_to_add), chunk_size):
2895 chunk = prompts_to_add[i : i + chunk_size]
2896 db.add_all(chunk)
2897 db.flush()
2898 except Exception as e:
2899 logger.warning(f"Failed to initialize reactivated gateway: {e}")
2900 else:
2901 self._active_gateways.discard(gateway.url)
2903 db.commit()
2904 db.refresh(gateway)
2906 # Invalidate cache after status change
2907 cache = _get_registry_cache()
2908 await cache.invalidate_gateways()
2910 # Notify Subscribers
2911 if not gateway.enabled:
2912 # Inactive
2913 await self._notify_gateway_deactivated(gateway)
2914 elif gateway.enabled and not gateway.reachable:
2915 # Offline (Enabled but Unreachable)
2916 await self._notify_gateway_offline(gateway)
2917 else:
2918 # Active (Enabled and Reachable)
2919 await self._notify_gateway_activated(gateway)
2921 # Bulk update tools - single UPDATE statement instead of N FOR UPDATE locks
2922 # This prevents lock contention under high concurrent load
2923 now = datetime.now(timezone.utc)
2924 if only_update_reachable:
2925 # Only update reachable status, keep enabled as-is
2926 tools_result = db.execute(update(DbTool).where(DbTool.gateway_id == gateway_id).where(DbTool.reachable != reachable).values(reachable=reachable, updated_at=now))
2927 else:
2928 # Update both enabled and reachable
2929 tools_result = db.execute(
2930 update(DbTool)
2931 .where(DbTool.gateway_id == gateway_id)
2932 .where(or_(DbTool.enabled != activate, DbTool.reachable != reachable))
2933 .values(enabled=activate, reachable=reachable, updated_at=now)
2934 )
2935 tools_updated = tools_result.rowcount
2937 # Commit tool updates
2938 if tools_updated > 0:
2939 db.commit()
2941 # Invalidate tools cache once after bulk update
2942 if tools_updated > 0:
2943 await cache.invalidate_tools()
2944 tool_lookup_cache = _get_tool_lookup_cache()
2945 await tool_lookup_cache.invalidate_gateway(str(gateway.id))
2947 # Bulk update prompts when gateway is deactivated/activated (skip for reachability-only updates)
2948 prompts_updated = 0
2949 if not only_update_reachable:
2950 prompts_result = db.execute(update(DbPrompt).where(DbPrompt.gateway_id == gateway_id).where(DbPrompt.enabled != activate).values(enabled=activate, updated_at=now))
2951 prompts_updated = prompts_result.rowcount
2952 if prompts_updated > 0:
2953 db.commit()
2954 await cache.invalidate_prompts()
2956 # Bulk update resources when gateway is deactivated/activated (skip for reachability-only updates)
2957 resources_updated = 0
2958 if not only_update_reachable:
2959 resources_result = db.execute(update(DbResource).where(DbResource.gateway_id == gateway_id).where(DbResource.enabled != activate).values(enabled=activate, updated_at=now))
2960 resources_updated = resources_result.rowcount
2961 if resources_updated > 0:
2962 db.commit()
2963 await cache.invalidate_resources()
2965 logger.debug(f"Gateway {SecurityValidator.sanitize_log_message(gateway.name)} bulk state update: {tools_updated} tools, {prompts_updated} prompts, {resources_updated} resources")
2967 logger.info(f"Gateway status: {SecurityValidator.sanitize_log_message(gateway.name)} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}")
2969 # Structured logging: Audit trail for gateway state change
2970 audit_trail.log_action(
2971 user_id=user_email or "system",
2972 action="set_gateway_state",
2973 resource_type="gateway",
2974 resource_id=str(gateway.id),
2975 resource_name=gateway.name,
2976 user_email=user_email,
2977 team_id=gateway.team_id,
2978 new_values={
2979 "enabled": gateway.enabled,
2980 "reachable": gateway.reachable,
2981 },
2982 context={
2983 "action": "activate" if activate else "deactivate",
2984 "only_update_reachable": only_update_reachable,
2985 },
2986 db=db,
2987 )
2989 # Structured logging: Log successful gateway state change
2990 structured_logger.log(
2991 level="INFO",
2992 message=f"Gateway {'activated' if activate else 'deactivated'} successfully",
2993 event_type="gateway_state_changed",
2994 component="gateway_service",
2995 user_email=user_email,
2996 team_id=gateway.team_id,
2997 resource_type="gateway",
2998 resource_id=str(gateway.id),
2999 custom_fields={
3000 "gateway_name": gateway.name,
3001 "enabled": gateway.enabled,
3002 "reachable": gateway.reachable,
3003 },
3004 )
3006 return self.convert_gateway_to_read(gateway)
3008 except PermissionError as e:
3009 db.rollback()
3011 # Structured logging: Log permission error
3012 structured_logger.log(
3013 level="WARNING",
3014 message="Gateway state change failed due to permission error",
3015 event_type="gateway_state_change_permission_denied",
3016 component="gateway_service",
3017 user_email=user_email,
3018 resource_type="gateway",
3019 resource_id=gateway_id,
3020 error=e,
3021 )
3022 raise e
3023 except Exception as e:
3024 db.rollback()
3026 # Structured logging: Log generic gateway state change failure
3027 structured_logger.log(
3028 level="ERROR",
3029 message="Gateway state change failed",
3030 event_type="gateway_state_change_failed",
3031 component="gateway_service",
3032 user_email=user_email,
3033 resource_type="gateway",
3034 resource_id=gateway_id,
3035 error=e,
3036 )
3037 raise GatewayError(f"Failed to set gateway state: {str(e)}")
3039 async def _notify_gateway_updated(self, gateway: DbGateway) -> None:
3040 """
3041 Notify subscribers of gateway update.
3043 Args:
3044 gateway: Gateway to update
3045 """
3046 event = {
3047 "type": "gateway_updated",
3048 "data": {
3049 "id": gateway.id,
3050 "name": gateway.name,
3051 "url": gateway.url,
3052 "description": gateway.description,
3053 "enabled": gateway.enabled,
3054 },
3055 "timestamp": datetime.now(timezone.utc).isoformat(),
3056 }
3057 await self._publish_event(event)
3059 async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optional[str] = None) -> None:
3060 """
3061 Delete a gateway by its ID.
3063 Args:
3064 db: Database session
3065 gateway_id: Gateway ID
3066 user_email: Email of user performing deletion (for ownership check)
3068 Raises:
3069 GatewayNotFoundError: If the gateway is not found
3070 PermissionError: If user doesn't own the gateway
3071 GatewayError: For other deletion errors
3073 Examples:
3074 >>> from mcpgateway.services.gateway_service import GatewayService
3075 >>> from unittest.mock import MagicMock
3076 >>> service = GatewayService()
3077 >>> db = MagicMock()
3078 >>> gateway = MagicMock()
3079 >>> db.execute.return_value.scalar_one_or_none.return_value = gateway
3080 >>> db.delete = MagicMock()
3081 >>> db.commit = MagicMock()
3082 >>> service._notify_gateway_deleted = MagicMock()
3083 >>> import asyncio
3084 >>> try:
3085 ... asyncio.run(service.delete_gateway(db, 'gateway_id', 'user@example.com'))
3086 ... except Exception:
3087 ... pass
3088 >>>
3089 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
3090 >>> asyncio.run(service._http_client.aclose())
3091 """
3092 try:
3093 # Find gateway with eager loading for deletion to avoid N+1 queries
3094 gateway = db.execute(
3095 select(DbGateway)
3096 .options(
3097 selectinload(DbGateway.tools),
3098 selectinload(DbGateway.resources),
3099 selectinload(DbGateway.prompts),
3100 joinedload(DbGateway.email_team),
3101 )
3102 .where(DbGateway.id == gateway_id)
3103 ).scalar_one_or_none()
3105 if not gateway:
3106 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
3108 # Check ownership if user_email provided
3109 if user_email:
3110 # First-Party
3111 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
3113 permission_service = PermissionService(db)
3114 if not await permission_service.check_resource_ownership(user_email, gateway):
3115 raise PermissionError("Only the owner can delete this gateway")
3117 # Store gateway info for notification before deletion
3118 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url}
3119 gateway_name = gateway.name
3120 gateway_team_id = gateway.team_id
3121 gateway_url = gateway.url # Store URL before expiring the object
3123 # Manually delete children first to avoid FK constraint violations
3124 # (passive_deletes=True means ORM won't auto-cascade, we must do it explicitly)
3125 # Use chunking to avoid SQLite's 999 parameter limit for IN clauses
3126 tool_ids = [t.id for t in gateway.tools]
3127 resource_ids = [r.id for r in gateway.resources]
3128 prompt_ids = [p.id for p in gateway.prompts]
3130 # Delete tool children and tools
3131 if tool_ids:
3132 for i in range(0, len(tool_ids), 500):
3133 chunk = tool_ids[i : i + 500]
3134 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
3135 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
3136 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
3138 # Delete resource children and resources
3139 if resource_ids:
3140 for i in range(0, len(resource_ids), 500):
3141 chunk = resource_ids[i : i + 500]
3142 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
3143 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
3144 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
3145 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
3147 # Delete prompt children and prompts
3148 if prompt_ids:
3149 for i in range(0, len(prompt_ids), 500):
3150 chunk = prompt_ids[i : i + 500]
3151 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
3152 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
3153 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
3155 # Expire gateway to clear cached relationships after bulk deletes
3156 db.expire(gateway)
3158 # Use DELETE with rowcount check for database-agnostic atomic delete
3159 stmt = delete(DbGateway).where(DbGateway.id == gateway_id)
3160 result = db.execute(stmt)
3161 if result.rowcount == 0:
3162 # Gateway was already deleted by another concurrent request
3163 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
3165 db.commit()
3167 # Invalidate cache after successful deletion
3168 cache = _get_registry_cache()
3169 await cache.invalidate_gateways()
3170 tool_lookup_cache = _get_tool_lookup_cache()
3171 await tool_lookup_cache.invalidate_gateway(str(gateway_id))
3172 # Also invalidate tags cache since gateway tags may have changed
3173 # First-Party
3174 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
3176 await admin_stats_cache.invalidate_tags()
3178 # Invalidate loopback passthrough cache when a gateway is deleted (#3640)
3179 # First-Party
3180 from mcpgateway.utils.passthrough_headers import invalidate_passthrough_header_caches # pylint: disable=import-outside-toplevel
3182 invalidate_passthrough_header_caches()
3184 # Update tracking
3185 self._active_gateways.discard(gateway_url)
3187 # Notify subscribers
3188 await self._notify_gateway_deleted(gateway_info)
3190 logger.info(f"Permanently deleted gateway: {gateway_name}")
3192 # Structured logging: Audit trail for gateway deletion
3193 audit_trail.log_action(
3194 user_id=user_email or "system",
3195 action="delete_gateway",
3196 resource_type="gateway",
3197 resource_id=str(gateway_info["id"]),
3198 resource_name=gateway_name,
3199 user_email=user_email,
3200 team_id=gateway_team_id,
3201 old_values={
3202 "name": gateway_name,
3203 "url": gateway_info["url"],
3204 },
3205 db=db,
3206 )
3208 # Structured logging: Log successful gateway deletion
3209 structured_logger.log(
3210 level="INFO",
3211 message="Gateway deleted successfully",
3212 event_type="gateway_deleted",
3213 component="gateway_service",
3214 user_email=user_email,
3215 team_id=gateway_team_id,
3216 resource_type="gateway",
3217 resource_id=str(gateway_info["id"]),
3218 custom_fields={
3219 "gateway_name": gateway_name,
3220 "gateway_url": gateway_info["url"],
3221 },
3222 )
3224 except PermissionError as pe:
3225 db.rollback()
3227 # Structured logging: Log permission error
3228 structured_logger.log(
3229 level="WARNING",
3230 message="Gateway deletion failed due to permission error",
3231 event_type="gateway_delete_permission_denied",
3232 component="gateway_service",
3233 user_email=user_email,
3234 resource_type="gateway",
3235 resource_id=gateway_id,
3236 error=pe,
3237 )
3238 raise
3239 except Exception as e:
3240 db.rollback()
3242 # Structured logging: Log generic gateway deletion failure
3243 structured_logger.log(
3244 level="ERROR",
3245 message="Gateway deletion failed",
3246 event_type="gateway_deletion_failed",
3247 component="gateway_service",
3248 user_email=user_email,
3249 resource_type="gateway",
3250 resource_id=gateway_id,
3251 error=e,
3252 )
3253 raise GatewayError(f"Failed to delete gateway: {str(e)}")
3255 async def _handle_gateway_failure(self, gateway: DbGateway) -> None:
3256 """Tracks and handles gateway failures during health checks.
3257 If the failure count exceeds the threshold, the gateway is deactivated.
3259 Args:
3260 gateway: The gateway object that failed its health check.
3262 Returns:
3263 None
3265 Examples:
3266 >>> from mcpgateway.services.gateway_service import GatewayService
3267 >>> service = GatewayService()
3268 >>> gateway = type('Gateway', (), {
3269 ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True
3270 ... })()
3271 >>> service._gateway_failure_counts = {}
3272 >>> import asyncio
3273 >>> # Test failure counting
3274 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS
3275 >>> service._gateway_failure_counts['gw1'] >= 1
3276 True
3278 >>> # Test disabled gateway (no action)
3279 >>> gateway.enabled = False
3280 >>> old_count = service._gateway_failure_counts.get('gw1', 0)
3281 >>> asyncio.run(service._handle_gateway_failure(gateway)) # doctest: +ELLIPSIS
3282 >>> service._gateway_failure_counts.get('gw1', 0) == old_count
3283 True
3284 """
3285 if GW_FAILURE_THRESHOLD == -1:
3286 return # Gateway failure action disabled
3288 if not gateway.enabled:
3289 return # No action needed for inactive gateways
3291 if not gateway.reachable:
3292 return # No action needed for unreachable gateways
3294 count = self._gateway_failure_counts.get(gateway.id, 0) + 1
3295 self._gateway_failure_counts[gateway.id] = count
3297 logger.warning(f"Gateway {SecurityValidator.sanitize_log_message(gateway.name)} failed health check {count} time(s).")
3299 if count >= GW_FAILURE_THRESHOLD:
3300 logger.error(f"Gateway {SecurityValidator.sanitize_log_message(gateway.name)} failed {GW_FAILURE_THRESHOLD} times. Deactivating...")
3301 with cast(Any, SessionLocal)() as db:
3302 await self.set_gateway_state(db, gateway.id, activate=True, reachable=False, only_update_reachable=True)
3303 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation
3305 async def check_health_of_gateways(self, gateways: List[DbGateway], user_email: Optional[str] = None) -> bool:
3306 """Check health of a batch of gateways.
3308 Performs an asynchronous health-check for each gateway in `gateways` using
3309 an Async HTTP client. The function handles different authentication
3310 modes (OAuth client_credentials and authorization_code, and non-OAuth
3311 auth headers). When a gateway uses the authorization_code flow, the
3312 optional `user_email` is used to look up stored user tokens with
3313 fresh_db_session(). On individual failures the service will record the
3314 failure and call internal failure handling which may mark a gateway
3315 unreachable or deactivate it after repeated failures. If a previously
3316 unreachable gateway becomes healthy again the service will attempt to
3317 update its reachable status.
3319 NOTE: This method intentionally does NOT take a db parameter.
3320 DB access uses fresh_db_session() only when needed, avoiding holding
3321 connections during HTTP calls to MCP servers.
3323 Args:
3324 gateways: List of DbGateway objects to check.
3325 user_email: Optional MCP gateway user email used to retrieve
3326 stored OAuth tokens for gateways using the
3327 "authorization_code" grant type. If not provided, authorization
3328 code flows that require a user token will be treated as failed.
3330 Returns:
3331 bool: True when the health-check batch completes. This return
3332 value indicates completion of the checks, not that every gateway
3333 was healthy. Individual gateway failures are handled internally
3334 (via _handle_gateway_failure and status updates).
3336 Examples:
3337 >>> from mcpgateway.services.gateway_service import GatewayService
3338 >>> from unittest.mock import MagicMock
3339 >>> service = GatewayService()
3340 >>> gateways = [MagicMock()]
3341 >>> gateways[0].ca_certificate = None
3342 >>> import asyncio
3343 >>> result = asyncio.run(service.check_health_of_gateways(gateways))
3344 >>> isinstance(result, bool)
3345 True
3347 >>> # Test empty gateway list
3348 >>> empty_result = asyncio.run(service.check_health_of_gateways([]))
3349 >>> empty_result
3350 True
3352 >>> # Test multiple gateways (basic smoke)
3353 >>> multiple_gateways = [MagicMock(), MagicMock(), MagicMock()]
3354 >>> for i, gw in enumerate(multiple_gateways):
3355 ... gw.name = f"gateway_{i}"
3356 ... gw.url = f"http://gateway{i}.example.com"
3357 ... gw.transport = "SSE"
3358 ... gw.enabled = True
3359 ... gw.reachable = True
3360 ... gw.auth_value = {}
3361 ... gw.ca_certificate = None
3362 >>> multi_result = asyncio.run(service.check_health_of_gateways(multiple_gateways))
3363 >>> isinstance(multi_result, bool)
3364 True
3365 """
3366 start_time = time.monotonic()
3367 concurrency_limit = min(settings.max_concurrent_health_checks, max(10, os.cpu_count() * 5)) # adaptive concurrency
3368 semaphore = asyncio.Semaphore(concurrency_limit)
3370 async def limited_check(gateway: DbGateway):
3371 """
3372 Checks the health of a single gateway while respecting a concurrency limit.
3374 This function checks the health of the given database gateway, ensuring that
3375 the number of concurrent checks does not exceed a predefined limit. The check
3376 is performed asynchronously and uses a semaphore to manage concurrency.
3378 Args:
3379 gateway (DbGateway): The database gateway whose health is to be checked.
3381 Raises:
3382 Any exceptions raised during the health check will be propagated to the caller.
3383 """
3384 async with semaphore:
3385 try:
3386 await asyncio.wait_for(
3387 self._check_single_gateway_health(gateway, user_email),
3388 timeout=settings.gateway_health_check_timeout,
3389 )
3390 except asyncio.TimeoutError:
3391 logger.warning(f"Gateway {getattr(gateway, 'name', 'unknown')} health check timed out after {settings.gateway_health_check_timeout}s")
3392 # Treat timeout as a failed health check
3393 await self._handle_gateway_failure(gateway)
3395 # Create trace span for health check batch
3396 with create_span("gateway.health_check_batch", {"gateway.count": len(gateways), "check.type": "health"}) as batch_span:
3397 # Chunk processing to avoid overload
3398 if not gateways:
3399 return True
3400 chunk_size = concurrency_limit
3401 for i in range(0, len(gateways), chunk_size):
3402 # batch will be a sublist of gateways from index i to i + chunk_size
3403 batch = gateways[i : i + chunk_size]
3405 # Each task is a health check for a gateway in the batch, excluding those with auth_type == "one_time_auth"
3406 tasks = [limited_check(gw) for gw in batch if gw.auth_type != "one_time_auth"]
3408 # Execute all health checks concurrently
3409 await asyncio.gather(*tasks, return_exceptions=True)
3410 await asyncio.sleep(0.05) # small pause prevents network saturation
3412 elapsed = time.monotonic() - start_time
3414 if batch_span:
3415 set_span_attribute(batch_span, "check.duration_ms", int(elapsed * 1000))
3416 set_span_attribute(batch_span, "check.completed", True)
3418 logger.debug(f"Health check batch completed for {len(gateways)} gateways in {elapsed:.2f}s")
3420 return True
3422 async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Optional[str] = None) -> None:
3423 """Check health of a single gateway.
3425 NOTE: This method intentionally does NOT take a db parameter.
3426 DB access uses fresh_db_session() only when needed, avoiding holding
3427 connections during HTTP calls to MCP servers.
3429 Args:
3430 gateway: Gateway to check (may be detached from session)
3431 user_email: Optional user email for OAuth token lookup
3432 """
3433 # Extract gateway data upfront (gateway may be detached from session)
3434 gateway_id = gateway.id
3435 gateway_name = gateway.name
3436 gateway_url = gateway.url
3437 gateway_transport = gateway.transport
3438 gateway_enabled = gateway.enabled
3439 gateway_reachable = gateway.reachable
3440 gateway_ca_certificate = gateway.ca_certificate
3441 gateway_ca_certificate_sig = gateway.ca_certificate_sig
3442 gateway_auth_type = gateway.auth_type
3443 gateway_oauth_config = gateway.oauth_config
3444 gateway_auth_value = gateway.auth_value
3445 gateway_auth_query_params = gateway.auth_query_params
3446 health_client_cert = getattr(gateway, "client_cert", None)
3447 health_client_key = getattr(gateway, "client_key", None)
3449 # Handle query_param auth - decrypt and apply to URL for health check
3450 auth_query_params_decrypted: Optional[Dict[str, str]] = None
3451 # Preserve the base URL (without auth query params) for classification lookups.
3452 # Classification uses Gateway.url from the DB, so poll-state keys must match.
3453 gateway_base_url = gateway_url
3454 if gateway_auth_type == "query_param" and gateway_auth_query_params:
3455 auth_query_params_decrypted = {}
3456 for param_key, encrypted_value in gateway_auth_query_params.items():
3457 if encrypted_value:
3458 try:
3459 decrypted = decode_auth(encrypted_value)
3460 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
3461 except Exception:
3462 logger.debug(f"Failed to decrypt query param '{param_key}' for health check")
3463 if auth_query_params_decrypted:
3464 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
3466 # Sanitize URL for logging/telemetry (redacts sensitive query params)
3467 gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted)
3469 # NOTE: Health checks always run regardless of hot/cold classification.
3470 # Classification only gates auto-refresh (tool discovery), not health monitoring.
3471 # Skipping health checks would blind the gateway to outages on cold servers.
3473 # Create span for individual gateway health check
3474 with create_span(
3475 "gateway.health_check",
3476 {
3477 "gateway.name": gateway_name,
3478 "gateway.id": str(gateway_id),
3479 "gateway.url": gateway_url_sanitized,
3480 "gateway.transport": gateway_transport,
3481 "gateway.enabled": gateway_enabled,
3482 "http.method": "GET",
3483 "http.url": gateway_url_sanitized,
3484 },
3485 ) as span:
3486 valid = False
3487 if gateway_ca_certificate:
3488 if settings.enable_ed25519_signing:
3489 public_key_pem = settings.ed25519_public_key
3490 valid = validate_signature(gateway_ca_certificate.encode(), gateway_ca_certificate_sig, public_key_pem)
3491 else:
3492 valid = True
3494 # Decrypt client_key for health check mTLS
3495 _hc_client_key = health_client_key
3496 if _hc_client_key:
3497 try:
3498 _enc = get_encryption_service(settings.auth_encryption_secret)
3499 _hc_client_key = _enc.decrypt_secret_or_plaintext(_hc_client_key)
3500 except Exception:
3501 logger.debug("client_key decryption skipped during health check")
3503 if gateway_url and gateway_url.lower().startswith("http://"):
3504 ssl_context = None
3505 elif valid and gateway_ca_certificate:
3506 ssl_context = get_cached_ssl_context(gateway_ca_certificate, client_cert=health_client_cert, client_key=_hc_client_key)
3507 else:
3508 ssl_context = None
3510 def get_httpx_client_factory(
3511 headers: dict[str, str] | None = None,
3512 timeout: httpx.Timeout | None = None,
3513 auth: httpx.Auth | None = None,
3514 ) -> httpx.AsyncClient:
3515 """Factory function to create httpx.AsyncClient with optional CA certificate.
3517 Args:
3518 headers: Optional headers for the client
3519 timeout: Optional timeout for the client
3520 auth: Optional auth for the client
3522 Returns:
3523 httpx.AsyncClient: Configured HTTPX async client
3524 """
3525 return httpx.AsyncClient(
3526 verify=ssl_context if ssl_context else get_default_verify(),
3527 follow_redirects=True,
3528 headers=headers,
3529 timeout=timeout if timeout else get_http_timeout(),
3530 auth=auth,
3531 limits=httpx.Limits(
3532 max_connections=settings.httpx_max_connections,
3533 max_keepalive_connections=settings.httpx_max_keepalive_connections,
3534 keepalive_expiry=settings.httpx_keepalive_expiry,
3535 ),
3536 )
3538 # Use isolated client for gateway health checks (each gateway may have custom CA cert)
3539 # Use admin timeout for health checks (fail fast, don't wait 120s for slow upstreams)
3540 # Pass ssl_context if present, otherwise let get_isolated_http_client use skip_ssl_verify setting
3541 async with get_isolated_http_client(timeout=settings.httpx_admin_read_timeout, verify=ssl_context) as client:
3542 logger.debug(f"Checking health of gateway: {gateway_name} ({gateway_url_sanitized})")
3543 try:
3544 # Handle different authentication types
3545 headers = {}
3547 if gateway_auth_type == "oauth" and gateway_oauth_config:
3548 grant_type = gateway_oauth_config.get("grant_type", "client_credentials")
3550 if grant_type == "authorization_code":
3551 # For Authorization Code flow, try to get stored tokens
3552 try:
3553 # First-Party
3554 from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel
3556 # Use fresh session for OAuth token lookup
3557 with fresh_db_session() as token_db:
3558 token_storage = TokenStorageService(token_db)
3560 # Get user-specific OAuth token
3561 if not user_email:
3562 if span:
3563 set_span_attribute(span, "health.status", "unhealthy")
3564 set_span_error(span, "User email required for OAuth token")
3565 await self._handle_gateway_failure(gateway)
3566 return
3568 access_token = await token_storage.get_user_token(gateway_id, user_email)
3570 if access_token:
3571 headers["Authorization"] = f"Bearer {access_token}"
3572 else:
3573 if span:
3574 set_span_attribute(span, "health.status", "unhealthy")
3575 set_span_error(span, "No valid OAuth token for user")
3576 await self._handle_gateway_failure(gateway)
3577 return
3578 except Exception as e:
3579 logger.error(f"Failed to obtain stored OAuth token for gateway {gateway_name}: {e}")
3580 if span:
3581 set_span_attribute(span, "health.status", "unhealthy")
3582 set_span_error(span, "Failed to obtain stored OAuth token")
3583 await self._handle_gateway_failure(gateway)
3584 return
3585 else:
3586 # For Client Credentials flow, get token directly
3587 try:
3588 access_token = await self.oauth_manager.get_access_token(gateway_oauth_config)
3589 headers["Authorization"] = f"Bearer {access_token}"
3590 except Exception as e:
3591 if span:
3592 set_span_attribute(span, "health.status", "unhealthy")
3593 set_span_error(span, e)
3594 await self._handle_gateway_failure(gateway)
3595 return
3596 else:
3597 # Handle non-OAuth authentication (existing logic)
3598 auth_data = gateway_auth_value or {}
3599 if isinstance(auth_data, str):
3600 headers = decode_auth(auth_data)
3601 elif isinstance(auth_data, dict):
3602 headers = {str(k): str(v) for k, v in auth_data.items()}
3603 else:
3604 headers = {}
3606 # Perform the GET and raise on 4xx/5xx
3607 if (gateway_transport).lower() == "sse":
3608 timeout = httpx.Timeout(settings.health_check_timeout)
3609 async with client.stream("GET", gateway_url, headers=headers, timeout=timeout) as response:
3610 # This will raise immediately if status is 4xx/5xx
3611 response.raise_for_status()
3612 if span:
3613 set_span_attribute(span, "http.status_code", response.status_code)
3614 elif (gateway_transport).lower() == "streamablehttp":
3615 # Use session pool if enabled for faster health checks
3616 use_pool = False
3617 pool = None
3618 if settings.mcp_session_pool_enabled:
3619 try:
3620 pool = get_mcp_session_pool()
3621 use_pool = True
3622 except RuntimeError:
3623 # Pool not initialized (e.g., in tests), fall back to per-call sessions
3624 pass
3626 if use_pool and pool is not None:
3627 # Health checks are system operations, not user-driven.
3628 # Use system identity to isolate from user sessions.
3629 async with pool.session(
3630 url=gateway_url,
3631 headers=headers,
3632 transport_type=TransportType.STREAMABLE_HTTP,
3633 httpx_client_factory=get_httpx_client_factory,
3634 user_identity="_system_health_check",
3635 gateway_id=gateway_id,
3636 ) as pooled:
3637 # Optional explicit RPC verification (off by default for performance).
3638 # Pool's internal staleness check handles health via _validate_session.
3639 if settings.mcp_session_pool_explicit_health_rpc:
3640 with anyio.fail_after(settings.health_check_timeout):
3641 await pooled.session.list_tools()
3642 else:
3643 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout, httpx_client_factory=get_httpx_client_factory) as (
3644 read_stream,
3645 write_stream,
3646 _get_session_id,
3647 ):
3648 async with ClientSession(read_stream, write_stream) as session:
3649 # Initialize the session
3650 response = await session.initialize()
3652 # Reactivate gateway if it was previously inactive and health check passed now
3653 if gateway_enabled and not gateway_reachable:
3654 logger.info(f"Reactivating gateway: {gateway_name}, as it is healthy now")
3655 with cast(Any, SessionLocal)() as status_db:
3656 await self.set_gateway_state(status_db, gateway_id, activate=True, reachable=True, only_update_reachable=True)
3658 # Update last_seen with fresh session (gateway object is detached)
3659 try:
3660 with fresh_db_session() as update_db:
3661 db_gateway = update_db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
3662 if db_gateway:
3663 db_gateway.last_seen = datetime.now(timezone.utc)
3664 update_db.commit()
3665 except Exception as update_error:
3666 logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}")
3668 # Auto-refresh tools/resources/prompts if enabled
3669 should_auto_refresh = False
3670 if settings.auto_refresh_servers:
3671 # Hot/cold classification: Check if this server should have tools refreshed now
3672 if self._classification_service:
3673 try:
3674 should_auto_refresh = await self._classification_service.should_poll_server(gateway_base_url, "tool_discovery", gateway_id=str(gateway_id))
3675 if not should_auto_refresh:
3676 logger.debug(f"Skipping auto-refresh for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
3677 except Exception as e:
3678 # Fail open: proceed with auto-refresh if classification check fails
3679 logger.warning(f"Classification check failed for {SecurityValidator.sanitize_log_message(gateway_name)}, proceeding with auto-refresh (fail-open): {e}")
3680 should_auto_refresh = True
3681 else:
3682 should_auto_refresh = True
3684 if should_auto_refresh:
3685 try:
3686 # Throttling: Check if refresh is needed based on last_refresh_at
3687 refresh_needed = True
3688 if gateway.last_refresh_at:
3689 # Default to config value if configured interval is missing
3691 last_refresh = gateway.last_refresh_at
3692 if last_refresh.tzinfo is None:
3693 last_refresh = last_refresh.replace(tzinfo=timezone.utc)
3695 # Use per-gateway interval if set, otherwise fall back to global default
3696 refresh_interval = getattr(settings, "gateway_auto_refresh_interval", 300)
3697 if gateway.refresh_interval_seconds is not None:
3698 refresh_interval = gateway.refresh_interval_seconds
3700 time_since_refresh = (datetime.now(timezone.utc) - last_refresh).total_seconds()
3702 if time_since_refresh < refresh_interval:
3703 refresh_needed = False
3704 logger.debug(f"Skipping auto-refresh for {gateway_name}: last refreshed {int(time_since_refresh)}s ago")
3706 if refresh_needed:
3707 # Locking: Try to acquire lock to avoid conflict with manual refresh
3708 lock = self._get_refresh_lock(gateway_id)
3709 if not lock.locked():
3710 # Acquire lock to prevent concurrent manual refresh
3711 async with lock:
3712 await self._refresh_gateway_tools_resources_prompts(
3713 gateway_id=gateway_id,
3714 _user_email=user_email,
3715 created_via="health_check",
3716 pre_auth_headers=headers if headers else None,
3717 gateway=gateway,
3718 )
3719 # mark_poll_completed is called inside _refresh_gateway_tools_resources_prompts
3720 else:
3721 logger.debug(f"Skipping auto-refresh for {gateway_name}: lock held (likely manual refresh in progress)")
3722 except Exception as refresh_error:
3723 logger.warning(f"Failed to refresh tools for gateway {gateway_name}: {refresh_error}")
3725 if span:
3726 set_span_attribute(span, "health.status", "healthy")
3727 set_span_attribute(span, "success", True)
3729 except Exception as e:
3730 if span:
3731 set_span_attribute(span, "health.status", "unhealthy")
3732 set_span_error(span, e)
3734 # Set the logger as debug as this check happens for each interval
3735 logger.debug(f"Health check failed for gateway {gateway_name}: {e}")
3736 await self._handle_gateway_failure(gateway)
3738 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
3739 """
3740 Aggregate capabilities across all gateways.
3742 Args:
3743 db: Database session
3745 Returns:
3746 Dictionary of aggregated capabilities
3748 Examples:
3749 >>> from mcpgateway.services.gateway_service import GatewayService
3750 >>> from unittest.mock import MagicMock
3751 >>> service = GatewayService()
3752 >>> db = MagicMock()
3753 >>> gateway_mock = MagicMock()
3754 >>> gateway_mock.capabilities = {"tools": {"listChanged": True}, "custom": {"feature": True}}
3755 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_mock]
3756 >>> import asyncio
3757 >>> result = asyncio.run(service.aggregate_capabilities(db))
3758 >>> isinstance(result, dict)
3759 True
3760 >>> 'prompts' in result
3761 True
3762 >>> 'resources' in result
3763 True
3764 >>> 'tools' in result
3765 True
3766 >>> 'logging' in result
3767 True
3768 >>> result['prompts']['listChanged']
3769 True
3770 >>> result['resources']['subscribe']
3771 True
3772 >>> result['resources']['listChanged']
3773 True
3774 >>> result['tools']['listChanged']
3775 True
3776 >>> isinstance(result['logging'], dict)
3777 True
3779 >>> # Test with no gateways
3780 >>> db.execute.return_value.scalars.return_value.all.return_value = []
3781 >>> empty_result = asyncio.run(service.aggregate_capabilities(db))
3782 >>> isinstance(empty_result, dict)
3783 True
3784 >>> 'tools' in empty_result
3785 True
3787 >>> # Test capability merging
3788 >>> gateway1 = MagicMock()
3789 >>> gateway1.capabilities = {"tools": {"feature1": True}}
3790 >>> gateway2 = MagicMock()
3791 >>> gateway2.capabilities = {"tools": {"feature2": True}}
3792 >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway1, gateway2]
3793 >>> merged_result = asyncio.run(service.aggregate_capabilities(db))
3794 >>> merged_result['tools']['listChanged'] # Default capability
3795 True
3796 """
3797 capabilities = {
3798 "prompts": {"listChanged": True},
3799 "resources": {"subscribe": True, "listChanged": True},
3800 "tools": {"listChanged": True},
3801 "logging": {},
3802 }
3804 # Get all active gateways
3805 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
3807 # Combine capabilities
3808 for gateway in gateways:
3809 if gateway.capabilities:
3810 for key, value in gateway.capabilities.items():
3811 if key not in capabilities:
3812 capabilities[key] = value
3813 elif isinstance(value, dict):
3814 capabilities[key].update(value)
3816 return capabilities
3818 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
3819 """Subscribe to gateway events.
3821 Creates a new event queue and subscribes to gateway events. Events are
3822 yielded as they are published. The subscription is automatically cleaned
3823 up when the generator is closed or goes out of scope.
3825 Yields:
3826 Dict[str, Any]: Gateway event messages with 'type', 'data', and 'timestamp' fields
3828 Examples:
3829 >>> service = GatewayService()
3830 >>> import asyncio
3831 >>> from unittest.mock import MagicMock
3832 >>> # Create a mock async generator for the event service
3833 >>> async def mock_event_gen():
3834 ... yield {"type": "test_event", "data": "payload"}
3835 >>>
3836 >>> # Mock the event service to return our generator
3837 >>> service._event_service = MagicMock()
3838 >>> service._event_service.subscribe_events.return_value = mock_event_gen()
3839 >>>
3840 >>> # Test the subscription
3841 >>> async def test_sub():
3842 ... async for event in service.subscribe_events():
3843 ... return event
3844 >>>
3845 >>> result = asyncio.run(test_sub())
3846 >>> result
3847 {'type': 'test_event', 'data': 'payload'}
3848 """
3849 async for event in self._event_service.subscribe_events():
3850 yield event
3852 async def _initialize_gateway(
3853 self,
3854 url: str,
3855 authentication: Optional[Dict[str, str]] = None,
3856 transport: str = "SSE",
3857 auth_type: Optional[str] = None,
3858 oauth_config: Optional[Dict[str, Any]] = None,
3859 ca_certificate: Optional[bytes] = None,
3860 pre_auth_headers: Optional[Dict[str, str]] = None,
3861 include_resources: bool = True,
3862 include_prompts: bool = True,
3863 auth_query_params: Optional[Dict[str, str]] = None,
3864 oauth_auto_fetch_tool_flag: Optional[bool] = False,
3865 client_cert: Optional[str] = None,
3866 client_key: Optional[str] = None,
3867 ) -> tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
3868 """Initialize connection to a gateway and retrieve its capabilities.
3870 Connects to an MCP gateway using the specified transport protocol,
3871 performs the MCP handshake, and retrieves capabilities, tools,
3872 resources, and prompts from the gateway.
3874 Args:
3875 url: Gateway URL to connect to
3876 authentication: Optional authentication headers for the connection
3877 transport: Transport protocol - "SSE" or "StreamableHTTP"
3878 auth_type: Authentication type - "basic", "bearer", "authheaders", "oauth", "query_param" or None
3879 oauth_config: OAuth configuration if auth_type is "oauth"
3880 ca_certificate: CA certificate for SSL verification
3881 pre_auth_headers: Pre-authenticated headers to skip OAuth token fetch (for reuse)
3882 include_resources: Whether to include resources in the fetch
3883 include_prompts: Whether to include prompts in the fetch
3884 auth_query_params: Query param names for URL sanitization in error logs (decrypted values)
3885 oauth_auto_fetch_tool_flag: Whether to skip the early return for OAuth Authorization Code flow.
3886 When False (default), auth_code gateways return empty lists immediately (for health checks).
3887 When True, attempts to connect even for auth_code gateways (for activation after user authorization).
3888 client_cert: Optional client certificate path or PEM for mTLS
3889 client_key: Optional client private key path or PEM for mTLS
3891 Returns:
3892 tuple[Dict[str, Any], List[ToolCreate], List[ResourceCreate], List[PromptCreate]]:
3893 Capabilities dictionary, list of ToolCreate objects, list of ResourceCreate objects, and list of PromptCreate objects
3895 Raises:
3896 GatewayConnectionError: If connection or initialization fails
3898 Examples:
3899 >>> service = GatewayService()
3900 >>> # Test parameter validation
3901 >>> import asyncio
3902 >>> from unittest.mock import AsyncMock
3903 >>> # Avoid opening a real SSE connection in doctests (it can leak anyio streams on failure paths)
3904 >>> service.connect_to_sse_server = AsyncMock(side_effect=GatewayConnectionError("boom"))
3905 >>> async def test_params():
3906 ... try:
3907 ... await service._initialize_gateway("hello//")
3908 ... except Exception as e:
3909 ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e)
3911 >>> asyncio.run(test_params())
3912 True
3914 >>> # Test default parameters
3915 >>> hasattr(service, '_initialize_gateway')
3916 True
3917 >>> import inspect
3918 >>> sig = inspect.signature(service._initialize_gateway)
3919 >>> sig.parameters['transport'].default
3920 'SSE'
3921 >>> sig.parameters['authentication'].default is None
3922 True
3923 >>>
3924 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
3925 >>> asyncio.run(service._http_client.aclose())
3926 """
3927 try:
3928 if authentication is None:
3929 authentication = {}
3931 # Use pre-authenticated headers if provided (avoids duplicate OAuth token fetch)
3932 if pre_auth_headers:
3933 authentication = pre_auth_headers
3934 # Handle OAuth authentication
3935 elif auth_type == "oauth" and oauth_config:
3936 grant_type = oauth_config.get("grant_type", "client_credentials")
3938 if grant_type == "authorization_code":
3939 if not oauth_auto_fetch_tool_flag:
3940 # For Authorization Code flow during health checks, we can't initialize immediately
3941 # because we need user consent. Just store the configuration
3942 # and let the user complete the OAuth flow later.
3943 logger.info("""OAuth Authorization Code flow configured for gateway. User must complete authorization before gateway can be used.""")
3944 # Don't try to get access token here - it will be obtained during tool invocation
3945 authentication = {}
3947 # Skip MCP server connection for Authorization Code flow
3948 # Tools will be fetched after OAuth completion
3949 return {}, [], [], []
3950 # When flag is True (activation), skip token fetch but try to connect
3951 # This allows activation to proceed - actual auth happens during tool invocation
3952 logger.debug("OAuth Authorization Code gateway activation - skipping token fetch")
3953 elif grant_type == "client_credentials":
3954 # For Client Credentials flow, we can get the token immediately
3955 try:
3956 logger.debug("Obtaining OAuth access token for Client Credentials flow")
3957 access_token = await self.oauth_manager.get_access_token(oauth_config)
3958 authentication = {"Authorization": f"Bearer {access_token}"}
3959 except Exception as e:
3960 logger.error(f"Failed to obtain OAuth access token: {e}")
3961 raise GatewayConnectionError(f"OAuth authentication failed: {str(e)}")
3963 capabilities = {}
3964 tools = []
3965 resources = []
3966 prompts = []
3967 if auth_type in ("basic", "bearer", "authheaders") and isinstance(authentication, str):
3968 authentication = decode_auth(authentication)
3969 if transport.lower() == "sse":
3970 capabilities, tools, resources, prompts = await self.connect_to_sse_server(
3971 url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params, client_cert=client_cert, client_key=client_key
3972 )
3973 elif transport.lower() == "streamablehttp":
3974 capabilities, tools, resources, prompts = await self.connect_to_streamablehttp_server(
3975 url, authentication, ca_certificate, include_prompts, include_resources, auth_query_params, client_cert=client_cert, client_key=client_key
3976 )
3978 return capabilities, tools, resources, prompts
3979 except Exception as e:
3981 # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup
3982 root_cause = e
3983 if isinstance(e, BaseExceptionGroup):
3984 while isinstance(root_cause, BaseExceptionGroup) and root_cause.exceptions:
3985 root_cause = root_cause.exceptions[0]
3986 sanitized_url = sanitize_url_for_logging(url, auth_query_params)
3987 raw_error = str(root_cause) or type(root_cause).__name__
3988 sanitized_error = sanitize_exception_message(raw_error, auth_query_params)
3989 logger.error(f"Gateway initialization failed for {sanitized_url}: {sanitized_error}", exc_info=True)
3990 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: {sanitized_error}")
3992 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:
3993 """Sync function for database operations (runs in thread).
3995 Args:
3996 include_inactive: Whether to include inactive gateways
3998 Returns:
3999 List[DbGateway]: List of active gateways
4001 Examples:
4002 >>> from unittest.mock import patch, MagicMock
4003 >>> service = GatewayService()
4004 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session:
4005 ... mock_db = MagicMock()
4006 ... mock_session.return_value.__enter__.return_value = mock_db
4007 ... mock_db.execute.return_value.scalars.return_value.all.return_value = []
4008 ... result = service._get_gateways()
4009 ... isinstance(result, list)
4010 True
4012 >>> # Test include_inactive parameter handling
4013 >>> with patch('mcpgateway.services.gateway_service.SessionLocal') as mock_session:
4014 ... mock_db = MagicMock()
4015 ... mock_session.return_value.__enter__.return_value = mock_db
4016 ... mock_db.execute.return_value.scalars.return_value.all.return_value = []
4017 ... result_active_only = service._get_gateways(include_inactive=False)
4018 ... isinstance(result_active_only, list)
4019 True
4020 """
4021 with cast(Any, SessionLocal)() as db:
4022 if include_inactive:
4023 return db.execute(select(DbGateway)).scalars().all()
4024 # Only return active gateways
4025 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
4027 def get_first_gateway_by_url(self, db: Session, url: str, team_id: Optional[str] = None, include_inactive: bool = False) -> Optional[GatewayRead]:
4028 """Return the first DbGateway matching the given URL and optional team_id.
4030 This is a synchronous helper intended for use from request handlers where
4031 a simple DB lookup is needed. It normalizes the provided URL similar to
4032 how gateways are stored and matches by the `url` column. If team_id is
4033 provided, it restricts the search to that team.
4035 Args:
4036 db: Database session to use for the query
4037 url: Gateway base URL to match (will be normalized)
4038 team_id: Optional team id to restrict search
4039 include_inactive: Whether to include inactive gateways
4041 Returns:
4042 Optional[DbGateway]: First matching gateway or None
4043 """
4044 query = select(DbGateway).where(DbGateway.url == url)
4045 if not include_inactive:
4046 query = query.where(DbGateway.enabled)
4047 if team_id:
4048 query = query.where(DbGateway.team_id == team_id)
4049 result = db.execute(query).scalars().first()
4050 # Wrap the DB object in the GatewayRead schema for consistency with
4051 # other service methods. Return None if no match found.
4052 if result is None:
4053 return None
4054 return self.convert_gateway_to_read(result)
4056 async def _run_leader_heartbeat(self) -> None:
4057 """Run leader heartbeat loop with Redis reconnection support.
4059 Refreshes the leader key TTL every heartbeat interval. Exits and starts
4060 follower election if leadership is lost or after consecutive failures.
4061 """
4062 consecutive_failures = 0
4063 max_failures = 3
4065 while True:
4066 try:
4067 await asyncio.sleep(self._leader_heartbeat_interval)
4069 if not self._redis_client:
4070 logger.warning("Redis client unavailable in heartbeat")
4071 consecutive_failures += 1
4072 if consecutive_failures >= max_failures:
4073 logger.error("Lost Redis connection, stopping heartbeat")
4074 return
4075 continue
4077 # Check if we're still the leader
4078 current_leader = await self._redis_client.get(self._leader_key)
4079 if current_leader != self._instance_id:
4080 logger.info("Lost Redis leadership, stopping heartbeat")
4081 self._start_follower_election()
4082 return
4084 # Refresh the leader key TTL
4085 await self._redis_client.expire(self._leader_key, self._leader_ttl)
4086 logger.debug(f"Leader heartbeat: refreshed TTL to {self._leader_ttl}s")
4087 consecutive_failures = 0
4089 except Exception as e:
4090 consecutive_failures += 1
4091 logger.warning(f"Leader heartbeat error (failure {consecutive_failures}/{max_failures}): {e}")
4092 if consecutive_failures >= max_failures:
4093 logger.error("Too many consecutive heartbeat failures, starting follower election")
4094 self._start_follower_election()
4095 return
4097 def _start_follower_election(self) -> None:
4098 """Start a follower election task if one is not already running."""
4099 if self._follower_election_task is None or self._follower_election_task.done():
4100 self._follower_election_task = asyncio.create_task(self._run_follower_election(settings.platform_admin_email))
4102 async def _run_follower_election(self, user_email: str) -> None:
4103 """Continuously attempt to acquire leadership when not the leader.
4105 This runs on follower instances and polls Redis to claim leadership
4106 when the current leader key expires or becomes available.
4108 Args:
4109 user_email: Email of the user for OAuth token lookup
4110 """
4111 retry_interval = max(1, self._leader_ttl // 3) # Poll at 1/3 of TTL
4113 while True:
4114 try:
4115 await asyncio.sleep(retry_interval)
4117 if not self._redis_client:
4118 logger.warning("Redis client unavailable, cannot attempt election.")
4119 continue
4121 # Attempt to acquire leadership
4122 is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
4124 if is_leader:
4125 logger.info("Acquired Redis leadership via follower election. Starting health check and heartbeat.")
4126 # Cancel stale tasks from a previous leadership period to prevent
4127 # orphaned loops running alongside the new ones.
4128 if self._health_check_task and not self._health_check_task.done():
4129 self._health_check_task.cancel()
4130 if getattr(self, "_leader_heartbeat_task", None) and not self._leader_heartbeat_task.done():
4131 self._leader_heartbeat_task.cancel()
4132 self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
4133 self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat())
4134 return # Exit follower loop, now running as leader
4136 except Exception as e:
4137 logger.warning(f"Follower election error: {e}", exc_info=True)
4139 async def _run_health_checks(self, user_email: str) -> None:
4140 """Run health checks periodically,
4141 Uses Redis or FileLock - for multiple workers.
4142 Uses simple health check for single worker mode.
4144 NOTE: This method intentionally does NOT take a db parameter.
4145 Health checks use fresh_db_session() only when DB access is needed,
4146 avoiding holding connections during HTTP calls to MCP servers.
4148 Args:
4149 user_email: Email of the user for OAuth token lookup
4151 Examples:
4152 >>> service = GatewayService()
4153 >>> service._health_check_interval = 0.1 # Short interval for testing
4154 >>> service._redis_client = None
4155 >>> import asyncio
4156 >>> # Test that method exists and is callable
4157 >>> callable(service._run_health_checks)
4158 True
4159 >>> # Test setup without actual execution (would run forever)
4160 >>> hasattr(service, '_health_check_interval')
4161 True
4162 >>> service._health_check_interval == 0.1
4163 True
4164 """
4166 while True:
4167 try:
4168 if self._redis_client and settings.cache_type == "redis":
4169 # Redis-based leader check (async, decode_responses=True returns strings)
4170 # Note: Leader key TTL refresh is handled by _run_leader_heartbeat task
4171 current_leader = await self._redis_client.get(self._leader_key)
4172 if current_leader != self._instance_id:
4173 return
4175 # Run health checks
4176 gateways = await asyncio.to_thread(self._get_gateways)
4177 if gateways:
4178 await self.check_health_of_gateways(gateways, user_email)
4180 await asyncio.sleep(self._health_check_interval)
4182 elif settings.cache_type == "none":
4183 try:
4184 # For single worker mode, run health checks directly
4185 gateways = await asyncio.to_thread(self._get_gateways)
4186 if gateways:
4187 await self.check_health_of_gateways(gateways, user_email)
4188 except Exception as e:
4189 logger.error(f"Health check run failed: {str(e)}")
4191 await asyncio.sleep(self._health_check_interval)
4193 else:
4194 # FileLock-based leader fallback
4195 try:
4196 self._file_lock.acquire(timeout=0)
4197 logger.info("File lock acquired. Running health checks.")
4199 while True:
4200 gateways = await asyncio.to_thread(self._get_gateways)
4201 if gateways:
4202 await self.check_health_of_gateways(gateways, user_email)
4203 await asyncio.sleep(self._health_check_interval)
4205 except Timeout:
4206 logger.debug("File lock already held. Retrying later.")
4207 await asyncio.sleep(self._health_check_interval)
4209 except Exception as e:
4210 logger.error(f"FileLock health check failed: {str(e)}")
4212 finally:
4213 if self._file_lock.is_locked:
4214 try:
4215 self._file_lock.release()
4216 logger.info("Released file lock.")
4217 except Exception as e:
4218 logger.warning(f"Failed to release file lock: {str(e)}")
4220 except Exception as e:
4221 logger.error(f"Unexpected error in health check loop: {str(e)}")
4222 await asyncio.sleep(self._health_check_interval)
4224 def _get_auth_headers(self) -> Dict[str, str]:
4225 """Get default headers for gateway requests (no authentication).
4227 SECURITY: This method intentionally does NOT include authentication credentials.
4228 Each gateway should have its own auth_value configured. Never send this gateway's
4229 admin credentials to remote servers.
4231 Returns:
4232 dict: Default headers without authentication
4234 Examples:
4235 >>> service = GatewayService()
4236 >>> headers = service._get_auth_headers()
4237 >>> isinstance(headers, dict)
4238 True
4239 >>> 'Content-Type' in headers
4240 True
4241 >>> headers['Content-Type']
4242 'application/json'
4243 >>> 'Authorization' not in headers # No credentials leaked
4244 True
4245 """
4246 return {"Content-Type": "application/json"}
4248 async def _notify_gateway_added(self, gateway: DbGateway) -> None:
4249 """Notify subscribers of gateway addition.
4251 Args:
4252 gateway: Gateway to add
4253 """
4254 event = {
4255 "type": "gateway_added",
4256 "data": {
4257 "id": gateway.id,
4258 "name": gateway.name,
4259 "url": gateway.url,
4260 "description": gateway.description,
4261 "enabled": gateway.enabled,
4262 },
4263 "timestamp": datetime.now(timezone.utc).isoformat(),
4264 }
4265 await self._publish_event(event)
4267 async def _notify_gateway_activated(self, gateway: DbGateway) -> None:
4268 """Notify subscribers of gateway activation.
4270 Args:
4271 gateway: Gateway to activate
4272 """
4273 event = {
4274 "type": "gateway_activated",
4275 "data": {
4276 "id": gateway.id,
4277 "name": gateway.name,
4278 "url": gateway.url,
4279 "enabled": gateway.enabled,
4280 "reachable": gateway.reachable,
4281 },
4282 "timestamp": datetime.now(timezone.utc).isoformat(),
4283 }
4284 await self._publish_event(event)
4286 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None:
4287 """Notify subscribers of gateway deactivation.
4289 Args:
4290 gateway: Gateway database object
4291 """
4292 event = {
4293 "type": "gateway_deactivated",
4294 "data": {
4295 "id": gateway.id,
4296 "name": gateway.name,
4297 "url": gateway.url,
4298 "enabled": gateway.enabled,
4299 "reachable": gateway.reachable,
4300 },
4301 "timestamp": datetime.now(timezone.utc).isoformat(),
4302 }
4303 await self._publish_event(event)
4305 async def _notify_gateway_offline(self, gateway: DbGateway) -> None:
4306 """
4307 Notify subscribers that gateway is offline (Enabled but Unreachable).
4309 Args:
4310 gateway: Gateway database object
4311 """
4312 event = {
4313 "type": "gateway_offline",
4314 "data": {
4315 "id": gateway.id,
4316 "name": gateway.name,
4317 "url": gateway.url,
4318 "enabled": True,
4319 "reachable": False,
4320 },
4321 "timestamp": datetime.now(timezone.utc).isoformat(),
4322 }
4323 await self._publish_event(event)
4325 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None:
4326 """Notify subscribers of gateway deletion.
4328 Args:
4329 gateway_info: Dict containing information about gateway to delete
4330 """
4331 event = {
4332 "type": "gateway_deleted",
4333 "data": gateway_info,
4334 "timestamp": datetime.now(timezone.utc).isoformat(),
4335 }
4336 await self._publish_event(event)
4338 async def _notify_gateway_removed(self, gateway: DbGateway) -> None:
4339 """Notify subscribers of gateway removal (deactivation).
4341 Args:
4342 gateway: Gateway to remove
4343 """
4344 event = {
4345 "type": "gateway_removed",
4346 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled},
4347 "timestamp": datetime.now(timezone.utc).isoformat(),
4348 }
4349 await self._publish_event(event)
4351 def convert_gateway_to_read(self, gateway: DbGateway) -> GatewayRead:
4352 """Convert a DbGateway instance to a GatewayRead Pydantic model.
4354 Args:
4355 gateway: Gateway database object
4357 Returns:
4358 GatewayRead: Pydantic model instance
4359 """
4360 gateway_dict = gateway.__dict__.copy()
4361 gateway_dict.pop("_sa_instance_state", None)
4363 # Ensure auth_value is properly encoded
4364 if isinstance(gateway.auth_value, dict):
4365 gateway_dict["auth_value"] = encode_auth(gateway.auth_value)
4367 if gateway.tags:
4368 # Check tags are list of strings or list of Dict[str, str]
4369 if isinstance(gateway.tags[0], str):
4370 # Convert tags from List[str] to List[Dict[str, str]] for GatewayRead
4371 gateway_dict["tags"] = validate_tags_field(gateway.tags)
4372 else:
4373 gateway_dict["tags"] = gateway.tags
4374 else:
4375 gateway_dict["tags"] = []
4377 # Include metadata fields
4378 gateway_dict["created_by"] = getattr(gateway, "created_by", None)
4379 gateway_dict["modified_by"] = getattr(gateway, "modified_by", None)
4380 gateway_dict["created_at"] = getattr(gateway, "created_at", None)
4381 gateway_dict["updated_at"] = getattr(gateway, "updated_at", None)
4382 gateway_dict["version"] = getattr(gateway, "version", None)
4383 gateway_dict["team"] = getattr(gateway, "team", None)
4385 # Populate tool count from the eagerly-loaded tools relationship when available
4386 tools_rel = gateway.__dict__.get("tools")
4387 gateway_dict["tool_count"] = len(tools_rel) if tools_rel is not None else 0
4389 return GatewayRead.model_validate(gateway_dict).masked()
4391 def _create_db_tool(
4392 self,
4393 tool: ToolCreate,
4394 gateway: DbGateway,
4395 created_by: Optional[str] = None,
4396 created_from_ip: Optional[str] = None,
4397 created_via: Optional[str] = None,
4398 created_user_agent: Optional[str] = None,
4399 ) -> DbTool:
4400 """Create a DbTool with consistent federation metadata across all scenarios.
4402 Args:
4403 tool: Tool creation schema
4404 gateway: Gateway database object
4405 created_by: Username who created/updated this tool
4406 created_from_ip: IP address of creator
4407 created_via: Creation method (ui, api, federation, rediscovery)
4408 created_user_agent: User agent of creation request
4410 Returns:
4411 DbTool: Consistently configured database tool object
4412 """
4413 return DbTool(
4414 original_name=tool.name,
4415 custom_name=tool.name,
4416 custom_name_slug=slugify(tool.name),
4417 display_name=generate_display_name(tool.name),
4418 title=_resolve_tool_title(tool),
4419 url=gateway.url,
4420 original_description=tool.description,
4421 description=tool.description,
4422 integration_type="MCP", # Gateway-discovered tools are MCP type
4423 request_type=tool.request_type,
4424 headers=tool.headers,
4425 input_schema=tool.input_schema,
4426 annotations=tool.annotations,
4427 jsonpath_filter=tool.jsonpath_filter,
4428 auth_type=gateway.auth_type,
4429 auth_value=encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value,
4430 # Federation metadata - consistent across all scenarios
4431 created_by=created_by or "system",
4432 created_from_ip=created_from_ip,
4433 created_via=created_via or "federation",
4434 created_user_agent=created_user_agent,
4435 federation_source=gateway.name,
4436 version=1,
4437 # Inherit team assignment from gateway; respect per-tool visibility if set
4438 team_id=gateway.team_id,
4439 owner_email=gateway.owner_email,
4440 visibility=getattr(tool, "visibility", None) or gateway.visibility,
4441 )
4443 def _update_or_create_tools(self, db: Session, tools: List[Any], gateway: DbGateway, created_via: str, update_visibility: bool = False) -> List[DbTool]:
4444 """Helper to handle update-or-create logic for tools from MCP server.
4446 Args:
4447 db: Database session
4448 tools: List of tools from MCP server
4449 gateway: Gateway object
4450 created_via: String indicating creation source ("oauth", "update", etc.)
4451 update_visibility: Whether to propagate gateway visibility to existing tools
4453 Returns:
4454 List of new tools to be added to the database
4455 """
4456 if not tools:
4457 return []
4459 tools_to_add = []
4461 # Batch fetch all existing tools for this gateway
4462 tool_names = [tool.name for tool in tools if tool is not None]
4463 if not tool_names:
4464 return []
4466 existing_tools_query = select(DbTool).where(DbTool.gateway_id == gateway.id, DbTool.original_name.in_(tool_names))
4467 existing_tools = db.execute(existing_tools_query).scalars().all()
4468 existing_tools_map = {tool.original_name: tool for tool in existing_tools}
4470 for tool in tools:
4471 if tool is None:
4472 logger.warning("Skipping None tool in tools list")
4473 continue
4475 try:
4476 # Check if tool already exists for this gateway from the tools_map
4477 existing_tool = existing_tools_map.get(tool.name)
4478 if existing_tool:
4479 # Update existing tool if there are changes
4480 fields_to_update = False
4482 # Check basic field changes
4483 # Compare against original_description (upstream value) rather than description
4484 # (which may have been customized by the user)
4485 basic_fields_changed = (
4486 existing_tool.url != gateway.url
4487 or existing_tool.original_description != tool.description
4488 or existing_tool.integration_type != "MCP"
4489 or existing_tool.request_type != tool.request_type
4490 )
4492 # Check schema and configuration changes
4493 schema_fields_changed = (
4494 existing_tool.headers != tool.headers
4495 or existing_tool.input_schema != tool.input_schema
4496 or existing_tool.output_schema != tool.output_schema
4497 or existing_tool.jsonpath_filter != tool.jsonpath_filter
4498 )
4500 # Check authentication and visibility changes.
4501 # DbTool.auth_value is Text (encoded str); DbGateway.auth_value is JSON (dict).
4502 # encode_auth() uses a random nonce, so comparing ciphertext would always
4503 # differ even when the plaintext hasn't changed. Compare on decoded
4504 # (plaintext) values instead, and only encode on the write path.
4505 # If decoding fails (legacy/corrupt data), fall back to direct comparison.
4506 try:
4507 gateway_auth_plain = gateway.auth_value if isinstance(gateway.auth_value, dict) else (decode_auth(gateway.auth_value) if gateway.auth_value else {})
4508 existing_tool_auth_plain = decode_auth(existing_tool.auth_value) if existing_tool.auth_value else {}
4509 auth_value_changed = existing_tool_auth_plain != gateway_auth_plain
4510 except Exception:
4511 gateway_tool_auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
4512 auth_value_changed = existing_tool.auth_value != gateway_tool_auth_value
4514 upstream_tool_visibility = getattr(tool, "visibility", None)
4515 auth_fields_changed = (
4516 existing_tool.auth_type != gateway.auth_type
4517 or auth_value_changed
4518 or (update_visibility and upstream_tool_visibility is not None and existing_tool.visibility != upstream_tool_visibility)
4519 )
4521 title_changed = existing_tool.title != _resolve_tool_title(tool)
4523 if basic_fields_changed or schema_fields_changed or auth_fields_changed or title_changed:
4524 fields_to_update = True
4525 if fields_to_update:
4526 existing_tool.url = gateway.url
4527 # Only overwrite user-facing description if it hasn't been customized
4528 # (mirrors original_name/custom_name pattern)
4529 if existing_tool.description == existing_tool.original_description:
4530 existing_tool.description = tool.description
4531 existing_tool.original_description = tool.description
4532 existing_tool.integration_type = "MCP"
4533 existing_tool.request_type = tool.request_type
4534 existing_tool.headers = tool.headers
4535 existing_tool.input_schema = tool.input_schema
4536 existing_tool.output_schema = tool.output_schema
4537 existing_tool.jsonpath_filter = tool.jsonpath_filter
4538 existing_tool.title = _resolve_tool_title(tool)
4539 existing_tool.auth_type = gateway.auth_type
4540 existing_tool.auth_value = encode_auth(gateway.auth_value) if isinstance(gateway.auth_value, dict) else gateway.auth_value
4541 if update_visibility and upstream_tool_visibility is not None:
4542 existing_tool.visibility = upstream_tool_visibility
4543 logger.debug(f"Updated existing tool: {tool.name}")
4544 else:
4545 # Create new tool if it doesn't exist
4546 db_tool = self._create_db_tool(
4547 tool=tool,
4548 gateway=gateway,
4549 created_by="system",
4550 created_via=created_via,
4551 )
4552 # Attach relationship to avoid NoneType during flush
4553 db_tool.gateway = gateway
4554 tools_to_add.append(db_tool)
4555 logger.debug(f"Created new tool: {tool.name}")
4556 except Exception as e:
4557 logger.warning(f"Failed to process tool {getattr(tool, 'name', 'unknown')}: {e}")
4558 continue
4560 return tools_to_add
4562 def _update_or_create_resources(self, db: Session, resources: List[Any], gateway: DbGateway, created_via: str, update_visibility: bool = False) -> List[DbResource]:
4563 """Helper to handle update-or-create logic for resources from MCP server.
4565 Args:
4566 db: Database session
4567 resources: List of resources from MCP server
4568 gateway: Gateway object
4569 created_via: String indicating creation source ("oauth", "update", etc.)
4570 update_visibility: Whether to propagate gateway visibility to existing resources
4572 Returns:
4573 List of new resources to be added to the database
4574 """
4575 if not resources:
4576 return []
4578 resources_to_add = []
4580 # Batch fetch all existing resources for this gateway
4581 resource_uris = [resource.uri for resource in resources if resource is not None]
4582 if not resource_uris:
4583 return []
4585 existing_resources_query = select(DbResource).where(DbResource.gateway_id == gateway.id, DbResource.uri.in_(resource_uris))
4586 existing_resources = db.execute(existing_resources_query).scalars().all()
4587 existing_resources_map = {resource.uri: resource for resource in existing_resources}
4589 for resource in resources:
4590 if resource is None:
4591 logger.warning("Skipping None resource in resources list")
4592 continue
4594 try:
4595 # Check if resource already exists for this gateway from the resources_map
4596 existing_resource = existing_resources_map.get(resource.uri)
4598 if existing_resource:
4599 # Update existing resource if there are changes
4600 fields_to_update = False
4602 upstream_visibility = getattr(resource, "visibility", None)
4603 if (
4604 existing_resource.name != resource.name
4605 or existing_resource.description != resource.description
4606 or existing_resource.mime_type != resource.mime_type
4607 or existing_resource.uri_template != resource.uri_template
4608 or (update_visibility and upstream_visibility is not None and existing_resource.visibility != upstream_visibility)
4609 or existing_resource.title != getattr(resource, "title", None)
4610 ):
4611 fields_to_update = True
4613 if fields_to_update:
4614 existing_resource.name = resource.name
4615 existing_resource.description = resource.description
4616 existing_resource.mime_type = resource.mime_type
4617 existing_resource.uri_template = resource.uri_template
4618 existing_resource.title = getattr(resource, "title", None)
4619 if update_visibility and upstream_visibility is not None:
4620 existing_resource.visibility = upstream_visibility
4621 logger.debug(f"Updated existing resource: {resource.uri}")
4622 else:
4623 # Create new resource if it doesn't exist
4624 db_resource = DbResource(
4625 uri=resource.uri,
4626 name=resource.name,
4627 title=getattr(resource, "title", None),
4628 description=resource.description,
4629 mime_type=resource.mime_type,
4630 uri_template=resource.uri_template,
4631 gateway_id=gateway.id,
4632 created_by="system",
4633 created_via=created_via,
4634 visibility=getattr(resource, "visibility", None) or gateway.visibility,
4635 )
4636 resources_to_add.append(db_resource)
4637 logger.debug(f"Created new resource: {resource.uri}")
4638 except Exception as e:
4639 logger.warning(f"Failed to process resource {getattr(resource, 'uri', 'unknown')}: {e}")
4640 continue
4642 return resources_to_add
4644 @staticmethod
4645 def _build_prompt_argument_schema(prompt: Any) -> Dict[str, Any]:
4646 """Build a JSON-schema-compatible argument_schema dict from a PromptCreate's arguments list.
4648 The MCP protocol's ``prompts/list`` response includes argument metadata
4649 (name, description, required) on each prompt. This helper converts that
4650 list into the internal ``argument_schema`` structure expected by
4651 ``DbPrompt`` so that the UI and API can surface the arguments correctly.
4653 Args:
4654 prompt: A PromptCreate (or any object with an ``arguments`` attribute
4655 whose items have ``name``, optional ``description``, and
4656 optional ``required`` fields).
4658 Returns:
4659 Dict with ``type``, ``properties``, and ``required`` keys.
4660 """
4661 schema: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
4662 for arg in getattr(prompt, "arguments", []) or []:
4663 prop: Dict[str, Any] = {"type": "string"}
4664 if getattr(arg, "description", None):
4665 prop["description"] = arg.description
4666 schema["properties"][arg.name] = prop
4667 if getattr(arg, "required", False):
4668 schema["required"].append(arg.name)
4669 return schema
4671 def _update_or_create_prompts(self, db: Session, prompts: List[Any], gateway: DbGateway, created_via: str, update_visibility: bool = False) -> List[DbPrompt]:
4672 """Helper to handle update-or-create logic for prompts from MCP server.
4674 Args:
4675 db: Database session
4676 prompts: List of prompts from MCP server
4677 gateway: Gateway object
4678 created_via: String indicating creation source ("oauth", "update", etc.)
4679 update_visibility: Whether to propagate gateway visibility to existing prompts
4681 Returns:
4682 List of new prompts to be added to the database
4683 """
4684 if not prompts:
4685 return []
4687 prompts_to_add = []
4689 # Batch fetch all existing prompts for this gateway
4690 prompt_names = [prompt.name for prompt in prompts if prompt is not None]
4691 if not prompt_names:
4692 return []
4694 existing_prompts_query = select(DbPrompt).where(DbPrompt.gateway_id == gateway.id, DbPrompt.original_name.in_(prompt_names))
4695 existing_prompts = db.execute(existing_prompts_query).scalars().all()
4696 existing_prompts_map = {prompt.original_name: prompt for prompt in existing_prompts}
4698 for prompt in prompts:
4699 if prompt is None:
4700 logger.warning("Skipping None prompt in prompts list")
4701 continue
4703 try:
4704 # Check if resource already exists for this gateway from the prompts_map
4705 existing_prompt = existing_prompts_map.get(prompt.name)
4707 if existing_prompt:
4708 # Update existing prompt if there are changes
4709 fields_to_update = False
4711 new_argument_schema = self._build_prompt_argument_schema(prompt)
4712 upstream_prompt_visibility = getattr(prompt, "visibility", None)
4713 if (
4714 existing_prompt.description != prompt.description
4715 or existing_prompt.template != (prompt.template if hasattr(prompt, "template") else "")
4716 or (update_visibility and upstream_prompt_visibility is not None and existing_prompt.visibility != upstream_prompt_visibility)
4717 or (existing_prompt.argument_schema or {}) != new_argument_schema
4718 or existing_prompt.title != getattr(prompt, "title", None)
4719 ):
4720 fields_to_update = True
4722 if fields_to_update:
4723 existing_prompt.description = prompt.description
4724 existing_prompt.template = prompt.template if hasattr(prompt, "template") else ""
4725 existing_prompt.argument_schema = new_argument_schema
4726 existing_prompt.title = getattr(prompt, "title", None)
4727 if update_visibility and upstream_prompt_visibility is not None:
4728 existing_prompt.visibility = upstream_prompt_visibility
4729 logger.debug(f"Updated existing prompt: {prompt.name}")
4730 else:
4731 # Create new prompt if it doesn't exist
4732 db_prompt = DbPrompt(
4733 name=prompt.name,
4734 original_name=prompt.name,
4735 custom_name=prompt.name,
4736 display_name=prompt.name,
4737 title=getattr(prompt, "title", None),
4738 description=prompt.description,
4739 template=prompt.template if hasattr(prompt, "template") else "",
4740 argument_schema=self._build_prompt_argument_schema(prompt),
4741 gateway_id=gateway.id,
4742 created_by="system",
4743 created_via=created_via,
4744 visibility=getattr(prompt, "visibility", None) or gateway.visibility,
4745 )
4746 db_prompt.gateway = gateway
4747 prompts_to_add.append(db_prompt)
4748 logger.debug(f"Created new prompt: {prompt.name}")
4749 except Exception as e:
4750 logger.warning(f"Failed to process prompt {getattr(prompt, 'name', 'unknown')}: {e}")
4751 continue
4753 return prompts_to_add
4755 async def _refresh_gateway_tools_resources_prompts(
4756 self,
4757 gateway_id: str,
4758 _user_email: Optional[str] = None,
4759 created_via: str = "health_check",
4760 pre_auth_headers: Optional[Dict[str, str]] = None,
4761 gateway: Optional[DbGateway] = None,
4762 include_resources: bool = True,
4763 include_prompts: bool = True,
4764 ) -> Dict[str, int]:
4765 """Refresh tools, resources, and prompts for a gateway during health checks.
4767 Fetches the latest tools/resources/prompts from the MCP server and syncs
4768 with the database (add new, update changed, remove stale). Only performs
4769 DB operations if actual changes are detected.
4771 This method uses fresh_db_session() internally to avoid holding
4772 connections during HTTP calls to MCP servers.
4774 Args:
4775 gateway_id: ID of the gateway to refresh
4776 _user_email: Optional user email for OAuth token lookup (unused currently)
4777 created_via: String indicating creation source (default: "health_check")
4778 pre_auth_headers: Pre-authenticated headers from health check to avoid duplicate OAuth token fetch
4779 gateway: Optional DbGateway object to avoid redundant DB lookup
4780 include_resources: Whether to include resources in the refresh
4781 include_prompts: Whether to include prompts in the refresh
4783 Returns:
4784 Dict with counts: {tools_added, tools_removed, resources_added,
4785 resources_removed, prompts_added, prompts_removed}
4787 Examples:
4788 >>> from mcpgateway.services.gateway_service import GatewayService
4789 >>> from unittest.mock import patch, MagicMock, AsyncMock
4790 >>> import asyncio
4792 >>> # Test gateway not found returns empty result
4793 >>> service = GatewayService()
4794 >>> mock_session = MagicMock()
4795 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = None
4796 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4797 ... mock_fresh.return_value.__enter__.return_value = mock_session
4798 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4799 >>> result['tools_added'] == 0 and result['tools_removed'] == 0
4800 True
4801 >>> result['resources_added'] == 0 and result['resources_removed'] == 0
4802 True
4803 >>> result['success'] is True and result['error'] is None
4804 True
4806 >>> # Test disabled gateway returns empty result
4807 >>> mock_gw = MagicMock()
4808 >>> mock_gw.enabled = False
4809 >>> mock_gw.reachable = True
4810 >>> mock_gw.name = 'test_gw'
4811 >>> mock_session.execute.return_value.scalar_one_or_none.return_value = mock_gw
4812 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4813 ... mock_fresh.return_value.__enter__.return_value = mock_session
4814 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4815 >>> result['tools_added']
4816 0
4818 >>> # Test unreachable gateway returns empty result
4819 >>> mock_gw.enabled = True
4820 >>> mock_gw.reachable = False
4821 >>> with patch('mcpgateway.services.gateway_service.fresh_db_session') as mock_fresh:
4822 ... mock_fresh.return_value.__enter__.return_value = mock_session
4823 ... result = asyncio.run(service._refresh_gateway_tools_resources_prompts('gw-123'))
4824 >>> result['tools_added']
4825 0
4827 >>> # Test method is async and callable
4828 >>> import inspect
4829 >>> inspect.iscoroutinefunction(service._refresh_gateway_tools_resources_prompts)
4830 True
4831 >>>
4832 >>> # Cleanup long-lived clients created by the service to avoid ResourceWarnings in doctest runs
4833 >>> asyncio.run(service._http_client.aclose())
4834 """
4835 result = {
4836 "tools_added": 0,
4837 "tools_removed": 0,
4838 "resources_added": 0,
4839 "resources_removed": 0,
4840 "prompts_added": 0,
4841 "prompts_removed": 0,
4842 "tools_updated": 0,
4843 "resources_updated": 0,
4844 "prompts_updated": 0,
4845 "success": True,
4846 "error": None,
4847 "validation_errors": [],
4848 }
4850 # Fetch gateway metadata only (no relationships needed for MCP call)
4851 # Use provided gateway object if available to save a DB call
4852 gateway_name = None
4853 gateway_url = None
4854 gateway_transport = None
4855 gateway_auth_type = None
4856 gateway_auth_value = None
4857 gateway_oauth_config = None
4858 gateway_ca_certificate = None
4859 gateway_auth_query_params = None
4860 refresh_client_cert = None
4861 refresh_client_key = None
4863 if gateway:
4864 if not gateway.enabled or not gateway.reachable:
4865 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {SecurityValidator.sanitize_log_message(gateway.name)}")
4866 return result
4868 gateway_name = gateway.name
4869 gateway_url = gateway.url
4870 gateway_transport = gateway.transport
4871 gateway_auth_type = gateway.auth_type
4872 gateway_auth_value = gateway.auth_value
4873 gateway_oauth_config = gateway.oauth_config
4874 gateway_ca_certificate = gateway.ca_certificate
4875 gateway_auth_query_params = gateway.auth_query_params
4876 refresh_client_cert = getattr(gateway, "client_cert", None)
4877 refresh_client_key = getattr(gateway, "client_key", None)
4878 else:
4879 with fresh_db_session() as db:
4880 gateway_obj = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
4882 if not gateway_obj:
4883 logger.warning(f"Gateway {SecurityValidator.sanitize_log_message(gateway_id)} not found for tool refresh")
4884 return result
4886 if not gateway_obj.enabled or not gateway_obj.reachable:
4887 logger.debug(f"Skipping tool refresh for disabled/unreachable gateway {gateway_obj.name}")
4888 return result
4890 # Extract metadata before session closes
4891 gateway_name = gateway_obj.name
4892 gateway_url = gateway_obj.url
4893 gateway_transport = gateway_obj.transport
4894 gateway_auth_type = gateway_obj.auth_type
4895 gateway_auth_value = gateway_obj.auth_value
4896 gateway_oauth_config = gateway_obj.oauth_config
4897 gateway_ca_certificate = gateway_obj.ca_certificate
4898 gateway_auth_query_params = gateway_obj.auth_query_params
4899 refresh_client_cert = getattr(gateway_obj, "client_cert", None)
4900 refresh_client_key = getattr(gateway_obj, "client_key", None)
4902 # Preserve base URL before auth mutation for classification poll-state keys
4903 gateway_base_url = gateway_url
4905 # Handle query_param auth - decrypt and apply to URL for refresh
4906 auth_query_params_decrypted: Optional[Dict[str, str]] = None
4907 if gateway_auth_type == "query_param" and gateway_auth_query_params:
4908 auth_query_params_decrypted = {}
4909 for param_key, encrypted_value in gateway_auth_query_params.items():
4910 if encrypted_value:
4911 try:
4912 decrypted = decode_auth(encrypted_value)
4913 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
4914 except Exception:
4915 logger.debug(f"Failed to decrypt query param '{param_key}' for tool refresh")
4916 if auth_query_params_decrypted:
4917 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
4919 # Fetch tools/resources/prompts from MCP server (no DB connection held)
4920 try:
4921 # Decrypt client_key for refresh initialization
4922 _refresh_key = refresh_client_key
4923 if _refresh_key:
4924 try:
4925 _enc = get_encryption_service(settings.auth_encryption_secret)
4926 _refresh_key = _enc.decrypt_secret_or_plaintext(_refresh_key)
4927 except Exception:
4928 logger.debug("client_key decryption skipped during gateway refresh")
4929 _capabilities, tools, resources, prompts = await self._initialize_gateway(
4930 url=gateway_url,
4931 authentication=gateway_auth_value,
4932 transport=gateway_transport,
4933 auth_type=gateway_auth_type,
4934 oauth_config=gateway_oauth_config,
4935 ca_certificate=gateway_ca_certificate.encode() if gateway_ca_certificate else None,
4936 pre_auth_headers=pre_auth_headers,
4937 include_resources=include_resources,
4938 include_prompts=include_prompts,
4939 auth_query_params=auth_query_params_decrypted,
4940 client_cert=refresh_client_cert,
4941 client_key=_refresh_key,
4942 )
4943 except Exception as e:
4944 logger.warning(f"Failed to fetch tools from gateway {gateway_name}: {e}")
4945 result["success"] = False
4946 result["error"] = str(e)
4947 return result
4949 # For authorization_code OAuth gateways, empty responses may indicate incomplete auth flow
4950 # Skip only if it's an auth_code gateway with no data (user may not have completed authorization)
4951 is_auth_code_gateway = gateway_oauth_config and isinstance(gateway_oauth_config, dict) and gateway_oauth_config.get("grant_type") == "authorization_code"
4952 if not tools and not resources and not prompts and is_auth_code_gateway:
4953 logger.debug(f"No tools/resources/prompts returned from auth_code gateway {gateway_name} (user may not have authorized)")
4954 return result
4956 # For non-auth_code gateways, empty responses are legitimate and will clear stale items
4958 # Update database with fresh session
4959 with fresh_db_session() as db:
4960 # Fetch gateway with relationships for update/comparison
4961 gateway = db.execute(
4962 select(DbGateway)
4963 .options(
4964 selectinload(DbGateway.tools),
4965 selectinload(DbGateway.resources),
4966 selectinload(DbGateway.prompts),
4967 )
4968 .where(DbGateway.id == gateway_id)
4969 ).scalar_one_or_none()
4971 if not gateway:
4972 result["success"] = False
4973 result["error"] = f"Gateway {gateway_id} not found during refresh"
4974 return result
4976 new_tool_names = [tool.name for tool in tools]
4977 new_resource_uris = [resource.uri for resource in resources] if include_resources else None
4978 new_prompt_names = [prompt.name for prompt in prompts] if include_prompts else None
4980 # Track dirty objects before update operations to count per-type updates
4981 pending_tools_before = {obj for obj in db.dirty if isinstance(obj, DbTool)}
4982 pending_resources_before = {obj for obj in db.dirty if isinstance(obj, DbResource)}
4983 pending_prompts_before = {obj for obj in db.dirty if isinstance(obj, DbPrompt)}
4985 # Update/create tools, resources, and prompts
4986 tools_to_add = self._update_or_create_tools(db, tools, gateway, created_via)
4987 resources_to_add = self._update_or_create_resources(db, resources, gateway, created_via) if include_resources else []
4988 prompts_to_add = self._update_or_create_prompts(db, prompts, gateway, created_via) if include_prompts else []
4990 # Count per-type updates
4991 result["tools_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbTool)} - pending_tools_before)
4992 result["resources_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbResource)} - pending_resources_before)
4993 result["prompts_updated"] = len({obj for obj in db.dirty if isinstance(obj, DbPrompt)} - pending_prompts_before)
4995 # Only delete MCP-discovered items (not user-created entries)
4996 # Excludes "api", "ui", None (legacy/user-created) to preserve user entries
4997 mcp_created_via_values = {"MCP", "federation", "health_check", "manual_refresh", "oauth", "update"}
4999 # Find and remove stale tools (only MCP-discovered ones)
5000 stale_tool_ids = [tool.id for tool in gateway.tools if tool.original_name not in new_tool_names and tool.created_via in mcp_created_via_values]
5001 if stale_tool_ids:
5002 for i in range(0, len(stale_tool_ids), 500):
5003 chunk = stale_tool_ids[i : i + 500]
5004 db.execute(delete(ToolMetric).where(ToolMetric.tool_id.in_(chunk)))
5005 db.execute(delete(server_tool_association).where(server_tool_association.c.tool_id.in_(chunk)))
5006 db.execute(delete(DbTool).where(DbTool.id.in_(chunk)))
5007 result["tools_removed"] = len(stale_tool_ids)
5009 # Find and remove stale resources (only MCP-discovered ones, only if resources were fetched)
5010 stale_resource_ids = []
5011 if new_resource_uris is not None:
5012 stale_resource_ids = [resource.id for resource in gateway.resources if resource.uri not in new_resource_uris and resource.created_via in mcp_created_via_values]
5013 if stale_resource_ids:
5014 for i in range(0, len(stale_resource_ids), 500):
5015 chunk = stale_resource_ids[i : i + 500]
5016 db.execute(delete(ResourceMetric).where(ResourceMetric.resource_id.in_(chunk)))
5017 db.execute(delete(server_resource_association).where(server_resource_association.c.resource_id.in_(chunk)))
5018 db.execute(delete(ResourceSubscription).where(ResourceSubscription.resource_id.in_(chunk)))
5019 db.execute(delete(DbResource).where(DbResource.id.in_(chunk)))
5020 result["resources_removed"] = len(stale_resource_ids)
5022 # Find and remove stale prompts (only MCP-discovered ones, only if prompts were fetched)
5023 stale_prompt_ids = []
5024 if new_prompt_names is not None:
5025 stale_prompt_ids = [prompt.id for prompt in gateway.prompts if prompt.original_name not in new_prompt_names and prompt.created_via in mcp_created_via_values]
5026 if stale_prompt_ids:
5027 for i in range(0, len(stale_prompt_ids), 500):
5028 chunk = stale_prompt_ids[i : i + 500]
5029 db.execute(delete(PromptMetric).where(PromptMetric.prompt_id.in_(chunk)))
5030 db.execute(delete(server_prompt_association).where(server_prompt_association.c.prompt_id.in_(chunk)))
5031 db.execute(delete(DbPrompt).where(DbPrompt.id.in_(chunk)))
5032 result["prompts_removed"] = len(stale_prompt_ids)
5034 # Expire gateway if stale items were deleted
5035 if stale_tool_ids or stale_resource_ids or stale_prompt_ids:
5036 db.expire(gateway)
5038 # Add new items in chunks
5039 chunk_size = 50
5040 if tools_to_add:
5041 for i in range(0, len(tools_to_add), chunk_size):
5042 chunk = tools_to_add[i : i + chunk_size]
5043 db.add_all(chunk)
5044 db.flush()
5045 result["tools_added"] = len(tools_to_add)
5047 if resources_to_add:
5048 for i in range(0, len(resources_to_add), chunk_size):
5049 chunk = resources_to_add[i : i + chunk_size]
5050 db.add_all(chunk)
5051 db.flush()
5052 result["resources_added"] = len(resources_to_add)
5054 if prompts_to_add:
5055 for i in range(0, len(prompts_to_add), chunk_size):
5056 chunk = prompts_to_add[i : i + chunk_size]
5057 db.add_all(chunk)
5058 db.flush()
5059 result["prompts_added"] = len(prompts_to_add)
5061 gateway.last_refresh_at = datetime.now(timezone.utc)
5063 total_changes = (
5064 result["tools_added"]
5065 + result["tools_removed"]
5066 + result["tools_updated"]
5067 + result["resources_added"]
5068 + result["resources_removed"]
5069 + result["resources_updated"]
5070 + result["prompts_added"]
5071 + result["prompts_removed"]
5072 + result["prompts_updated"]
5073 )
5075 has_changes = total_changes > 0
5077 if has_changes:
5078 db.commit()
5079 logger.info(
5080 f"Refreshed gateway {gateway_name}: "
5081 f"tools(+{result['tools_added']}/-{result['tools_removed']}/~{result['tools_updated']}), "
5082 f"resources(+{result['resources_added']}/-{result['resources_removed']}/~{result['resources_updated']}), "
5083 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}/~{result['prompts_updated']})"
5084 )
5086 # Invalidate caches per-type based on actual changes
5087 cache = _get_registry_cache()
5088 if result["tools_added"] > 0 or result["tools_removed"] > 0 or result["tools_updated"] > 0:
5089 await cache.invalidate_tools()
5090 if result["resources_added"] > 0 or result["resources_removed"] > 0 or result["resources_updated"] > 0:
5091 await cache.invalidate_resources()
5092 if result["prompts_added"] > 0 or result["prompts_removed"] > 0 or result["prompts_updated"] > 0:
5093 await cache.invalidate_prompts()
5095 # Invalidate tool lookup cache for this gateway
5096 tool_lookup_cache = _get_tool_lookup_cache()
5097 await tool_lookup_cache.invalidate_gateway(str(gateway_id))
5098 else:
5099 db.commit()
5100 logger.debug(f"No changes detected during refresh of gateway {gateway_name}")
5102 # Advance poll schedule so hot/cold classification tracks the actual last refresh
5103 # regardless of whether the refresh was triggered by health check, manual API, or registration.
5104 # Use gateway_base_url (pre-auth) to match classification keys.
5105 if self._classification_service and gateway_base_url:
5106 try:
5107 await self._classification_service.mark_poll_completed(gateway_base_url, "tool_discovery", gateway_id=str(gateway_id))
5108 except Exception as poll_ts_err:
5109 logger.debug(f"Best-effort tool_discovery poll timestamp update failed: {poll_ts_err}")
5111 return result
5113 def _get_refresh_lock(self, gateway_id: str) -> asyncio.Lock:
5114 """Get or create a per-gateway refresh lock.
5116 This ensures only one refresh operation can run for a given gateway at a time.
5118 Args:
5119 gateway_id: ID of the gateway to get the lock for
5121 Returns:
5122 asyncio.Lock: The lock for the specified gateway
5124 Examples:
5125 >>> from mcpgateway.services.gateway_service import GatewayService
5126 >>> service = GatewayService()
5127 >>> lock1 = service._get_refresh_lock('gw-123')
5128 >>> lock2 = service._get_refresh_lock('gw-123')
5129 >>> lock1 is lock2
5130 True
5131 >>> lock3 = service._get_refresh_lock('gw-456')
5132 >>> lock1 is lock3
5133 False
5134 """
5135 if gateway_id not in self._refresh_locks:
5136 self._refresh_locks[gateway_id] = asyncio.Lock()
5137 return self._refresh_locks[gateway_id]
5139 async def refresh_gateway_manually(
5140 self,
5141 gateway_id: str,
5142 include_resources: bool = True,
5143 include_prompts: bool = True,
5144 user_email: Optional[str] = None,
5145 request_headers: Optional[Dict[str, str]] = None,
5146 ) -> Dict[str, Any]:
5147 """Manually trigger a refresh of tools/resources/prompts for a gateway.
5149 This method provides a public API for triggering an immediate refresh
5150 of a gateway's tools, resources, and prompts from its MCP server.
5151 It includes concurrency control via per-gateway locking.
5153 Args:
5154 gateway_id: Gateway ID to refresh
5155 include_resources: Whether to include resources in the refresh
5156 include_prompts: Whether to include prompts in the refresh
5157 user_email: Email of the user triggering the refresh
5158 request_headers: Optional request headers for passthrough authentication
5160 Returns:
5161 Dict with counts: {tools_added, tools_updated, tools_removed,
5162 resources_added, resources_updated, resources_removed,
5163 prompts_added, prompts_updated, prompts_removed,
5164 validation_errors, duration_ms, refreshed_at}
5166 Raises:
5167 GatewayNotFoundError: If the gateway does not exist
5168 GatewayError: If another refresh is already in progress for this gateway
5170 Examples:
5171 >>> from mcpgateway.services.gateway_service import GatewayService
5172 >>> from unittest.mock import patch, MagicMock, AsyncMock
5173 >>> import asyncio
5175 >>> # Test method is async
5176 >>> service = GatewayService()
5177 >>> import inspect
5178 >>> inspect.iscoroutinefunction(service.refresh_gateway_manually)
5179 True
5180 """
5181 start_time = time.monotonic()
5183 pre_auth_headers = {}
5185 # Check if gateway exists before acquiring lock
5186 with fresh_db_session() as db:
5187 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
5188 if not gateway:
5189 raise GatewayNotFoundError(f"Gateway with ID '{gateway_id}' not found")
5190 gateway_name = gateway.name
5192 # Get passthrough headers if request headers provided
5193 if request_headers:
5194 pre_auth_headers = get_passthrough_headers(request_headers, {}, db, gateway)
5196 lock = self._get_refresh_lock(gateway_id)
5198 # Check if lock is already held (concurrent refresh in progress)
5199 if lock.locked():
5200 raise GatewayError(f"Refresh already in progress for gateway {gateway_name}")
5202 async with lock:
5203 logger.info(f"Starting manual refresh for gateway {gateway_name} (ID: {SecurityValidator.sanitize_log_message(gateway_id)})")
5205 result = await self._refresh_gateway_tools_resources_prompts(
5206 gateway_id=gateway_id,
5207 _user_email=user_email,
5208 created_via="manual_refresh",
5209 pre_auth_headers=pre_auth_headers,
5210 gateway=gateway,
5211 include_resources=include_resources,
5212 include_prompts=include_prompts,
5213 )
5214 # Note: last_refresh_at is updated inside _refresh_gateway_tools_resources_prompts on success
5216 result["duration_ms"] = (time.monotonic() - start_time) * 1000
5217 result["refreshed_at"] = datetime.now(timezone.utc)
5219 log_level = logging.INFO if result.get("success", True) else logging.WARNING
5220 status_msg = "succeeded" if result.get("success", True) else f"failed: {result.get('error')}"
5222 logger.log(
5223 log_level,
5224 f"Manual refresh for gateway {gateway_id} {status_msg}. Stats: "
5225 f"tools(+{result['tools_added']}/-{result['tools_removed']}), "
5226 f"resources(+{result['resources_added']}/-{result['resources_removed']}), "
5227 f"prompts(+{result['prompts_added']}/-{result['prompts_removed']}) "
5228 f"in {result['duration_ms']:.2f}ms",
5229 )
5231 return result
5233 async def _publish_event(self, event: Dict[str, Any]) -> None:
5234 """Publish event to all subscribers.
5236 Args:
5237 event: event dictionary
5239 Examples:
5240 >>> import asyncio
5241 >>> from unittest.mock import AsyncMock
5242 >>> service = GatewayService()
5243 >>> # Mock the underlying event service
5244 >>> service._event_service = AsyncMock()
5245 >>> test_event = {"type": "test", "data": {}}
5246 >>>
5247 >>> asyncio.run(service._publish_event(test_event))
5248 >>>
5249 >>> # Verify the event was passed to the event service
5250 >>> service._event_service.publish_event.assert_awaited_with(test_event)
5251 """
5252 await self._event_service.publish_event(event)
5254 def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> tuple[list[ToolCreate], list[str]]:
5255 """Validate tools individually with richer logging and error aggregation.
5257 Args:
5258 tools: list of tool dicts
5259 context: caller context, e.g. "oauth" to tailor errors/messages
5261 Returns:
5262 tuple[list[ToolCreate], list[str]]: Tuple of (valid tools, validation errors)
5264 Raises:
5265 OAuthToolValidationError: If all tools fail validation in OAuth context
5266 GatewayConnectionError: If all tools fail validation in default context
5267 """
5268 valid_tools: list[ToolCreate] = []
5269 validation_errors: list[str] = []
5271 for i, tool_dict in enumerate(tools):
5272 tool_name = tool_dict.get("name", f"unknown_tool_{i}")
5273 try:
5274 logger.debug(f"Validating tool: {tool_name}")
5275 validated_tool = ToolCreate.model_validate(tool_dict)
5276 valid_tools.append(validated_tool)
5277 logger.debug(f"Tool '{tool_name}' validated successfully")
5278 except ValidationError as e:
5279 error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}"
5280 logger.error(error_msg)
5281 logger.debug(f"Failed tool schema: {tool_dict}")
5282 validation_errors.append(error_msg)
5283 except ValueError as e:
5284 if "JSON structure exceeds maximum depth" in str(e):
5285 error_msg = f"Tool '{tool_name}' schema too deeply nested. " f"Current depth limit: {settings.validation_max_json_depth}"
5286 logger.error(error_msg)
5287 logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable")
5288 else:
5289 error_msg = f"ValueError for tool '{tool_name}': {str(e)}"
5290 logger.error(error_msg)
5291 validation_errors.append(error_msg)
5292 except Exception as e: # pragma: no cover - defensive
5293 error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}"
5294 logger.error(error_msg, exc_info=True)
5295 validation_errors.append(error_msg)
5297 if validation_errors:
5298 logger.warning(f"Tool validation completed with {len(validation_errors)} error(s). " f"Successfully validated {len(valid_tools)} tool(s).")
5299 for err in validation_errors[:3]:
5300 logger.debug(f"Validation error: {err}")
5302 if not valid_tools and validation_errors:
5303 if context == "oauth":
5304 raise OAuthToolValidationError(f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}")
5305 raise GatewayConnectionError(f"Failed to fetch tools: All {len(tools)} tools failed validation. " f"First error: {validation_errors[0][:200]}")
5307 return valid_tools, validation_errors
5309 async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None):
5310 """Connect to an MCP server running with SSE transport, skipping URL validation.
5312 This is used for OAuth-protected servers where we've already validated the token works.
5314 Args:
5315 server_url: The URL of the SSE MCP server to connect to.
5316 authentication: Optional dictionary containing authentication headers.
5318 Returns:
5319 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5320 """
5321 if authentication is None:
5322 authentication = {}
5324 # Skip validation for OAuth servers - we already validated via OAuth flow
5325 # Use async with for both sse_client and ClientSession
5326 try:
5327 async with sse_client(url=server_url, headers=authentication) as streams:
5328 async with ClientSession(*streams) as session:
5329 # Initialize the session
5330 response = await session.initialize()
5331 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5332 logger.debug(f"Server capabilities: {capabilities}")
5334 response = await session.list_tools()
5335 tools = response.tools
5336 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5338 tools, _ = self._validate_tools(tools, context="oauth")
5339 if tools:
5340 logger.info(f"Fetched {len(tools)} tools from gateway")
5341 # Fetch resources if supported
5343 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5344 resources = []
5345 if capabilities.get("resources"):
5346 try:
5347 response = await session.list_resources()
5348 raw_resources = response.resources
5349 for resource in raw_resources:
5350 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5351 # Convert AnyUrl to string if present
5352 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5353 resource_data["uri"] = str(resource_data["uri"])
5354 # Add default content if not present (will be fetched on demand)
5355 if "content" not in resource_data:
5356 resource_data["content"] = ""
5357 try:
5358 resources.append(ResourceCreate.model_validate(resource_data))
5359 except Exception:
5360 # If validation fails, create minimal resource
5361 resources.append(
5362 ResourceCreate(
5363 uri=str(resource_data.get("uri", "")),
5364 name=resource_data.get("name", ""),
5365 description=resource_data.get("description"),
5366 mime_type=resource_data.get("mimeType"),
5367 uri_template=resource_data.get("uriTemplate") or None,
5368 content="",
5369 )
5370 )
5371 logger.info(f"Fetched {len(resources)} resources from gateway")
5372 except Exception as e:
5373 logger.warning(f"Failed to fetch resources: {e}")
5375 # resource template URI
5376 try:
5377 response_templates = await session.list_resource_templates()
5378 raw_resources_templates = response_templates.resourceTemplates
5379 resource_templates = []
5380 for resource_template in raw_resources_templates:
5381 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5383 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"):
5384 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5385 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5387 if "content" not in resource_template_data:
5388 resource_template_data["content"] = ""
5390 resources.append(ResourceCreate.model_validate(resource_template_data))
5391 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5392 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway")
5393 except Exception as e:
5394 logger.warning(f"Failed to fetch resource templates: {e}")
5396 # Fetch prompts if supported
5397 prompts = []
5398 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5399 if capabilities.get("prompts"):
5400 try:
5401 response = await session.list_prompts()
5402 raw_prompts = response.prompts
5403 for prompt in raw_prompts:
5404 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5405 # Add default template if not present
5406 if "template" not in prompt_data:
5407 prompt_data["template"] = ""
5408 try:
5409 prompts.append(PromptCreate.model_validate(prompt_data))
5410 except Exception:
5411 # If validation fails, create minimal prompt
5412 prompts.append(
5413 PromptCreate(
5414 name=prompt_data.get("name", ""),
5415 description=prompt_data.get("description"),
5416 template=prompt_data.get("template", ""),
5417 )
5418 )
5419 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5420 except Exception as e:
5421 logger.warning(f"Failed to fetch prompts: {e}")
5423 return capabilities, tools, resources, prompts
5424 except Exception as e:
5425 # Note: This function is for OAuth servers only, which don't use query param auth
5426 # Still sanitize in case exception contains URL with static sensitive params
5427 sanitized_url = sanitize_url_for_logging(server_url)
5428 sanitized_error = sanitize_exception_message(str(e))
5429 logger.error(f"SSE connection error details: {type(e).__name__}: {sanitized_error}", exc_info=True)
5430 raise GatewayConnectionError(f"Failed to connect to SSE server at {sanitized_url}: {sanitized_error}")
5432 async def connect_to_sse_server(
5433 self,
5434 server_url: str,
5435 authentication: Optional[Dict[str, str]] = None,
5436 ca_certificate: Optional[bytes] = None,
5437 include_prompts: bool = True,
5438 include_resources: bool = True,
5439 auth_query_params: Optional[Dict[str, str]] = None,
5440 client_cert: Optional[str] = None,
5441 client_key: Optional[str] = None,
5442 ):
5443 """Connect to an MCP server running with SSE transport.
5445 Args:
5446 server_url: The URL of the SSE MCP server to connect to.
5447 authentication: Optional dictionary containing authentication headers.
5448 ca_certificate: Optional CA certificate for SSL verification.
5449 include_prompts: Whether to fetch prompts from the server.
5450 include_resources: Whether to fetch resources from the server.
5451 auth_query_params: Query param names for URL sanitization in error logs.
5452 client_cert: Optional client certificate path or PEM for mTLS.
5453 client_key: Optional client private key path or PEM for mTLS.
5455 Returns:
5456 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5457 """
5458 if authentication is None:
5459 authentication = {}
5461 def get_httpx_client_factory(
5462 headers: dict[str, str] | None = None,
5463 timeout: httpx.Timeout | None = None,
5464 auth: httpx.Auth | None = None,
5465 ) -> httpx.AsyncClient:
5466 """Factory function to create httpx.AsyncClient with optional CA certificate.
5468 Args:
5469 headers: Optional headers for the client
5470 timeout: Optional timeout for the client
5471 auth: Optional auth for the client
5473 Returns:
5474 httpx.AsyncClient: Configured HTTPX async client
5475 """
5476 if server_url and server_url.lower().startswith("http://"):
5477 ctx = None
5478 elif ca_certificate:
5479 ctx = get_cached_ssl_context(ca_certificate, client_cert=client_cert, client_key=client_key)
5480 else:
5481 ctx = None
5482 return httpx.AsyncClient(
5483 verify=ctx if ctx else get_default_verify(),
5484 follow_redirects=True,
5485 headers=headers,
5486 timeout=timeout if timeout else get_http_timeout(),
5487 auth=auth,
5488 limits=httpx.Limits(
5489 max_connections=settings.httpx_max_connections,
5490 max_keepalive_connections=settings.httpx_max_keepalive_connections,
5491 keepalive_expiry=settings.httpx_keepalive_expiry,
5492 ),
5493 )
5495 # Use async with for both sse_client and ClientSession
5496 async with sse_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as streams:
5497 async with ClientSession(*streams) as session:
5498 # Initialize the session
5499 response = await session.initialize()
5501 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5502 logger.debug(f"Server capabilities: {capabilities}")
5504 response = await session.list_tools()
5505 tools = response.tools
5506 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5508 tools, _ = self._validate_tools(tools)
5509 if tools:
5510 logger.info(f"Fetched {len(tools)} tools from gateway")
5511 # Fetch resources if supported
5512 resources = []
5513 if include_resources:
5514 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5515 if capabilities.get("resources"):
5516 try:
5517 response = await session.list_resources()
5518 raw_resources = response.resources
5519 for resource in raw_resources:
5520 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5521 # Convert AnyUrl to string if present
5522 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5523 resource_data["uri"] = str(resource_data["uri"])
5524 # Add default content if not present (will be fetched on demand)
5525 if "content" not in resource_data:
5526 resource_data["content"] = ""
5527 try:
5528 resources.append(ResourceCreate.model_validate(resource_data))
5529 except Exception:
5530 # If validation fails, create minimal resource
5531 resources.append(
5532 ResourceCreate(
5533 uri=str(resource_data.get("uri", "")),
5534 name=resource_data.get("name", ""),
5535 description=resource_data.get("description"),
5536 mime_type=resource_data.get("mimeType"),
5537 uri_template=resource_data.get("uriTemplate") or None,
5538 content="",
5539 )
5540 )
5541 logger.info(f"Fetched {len(resources)} resources from gateway")
5542 except Exception as e:
5543 logger.warning(f"Failed to fetch resources: {e}")
5545 # resource template URI
5546 try:
5547 response_templates = await session.list_resource_templates()
5548 raw_resources_templates = response_templates.resourceTemplates
5549 resource_templates = []
5550 for resource_template in raw_resources_templates:
5551 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5553 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"):
5554 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5555 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5557 if "content" not in resource_template_data:
5558 resource_template_data["content"] = ""
5560 resources.append(ResourceCreate.model_validate(resource_template_data))
5561 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5562 logger.info(f"Fetched {len(raw_resources_templates)} resource templates from gateway")
5563 except Exception as ei:
5564 logger.warning(f"Failed to fetch resource templates: {ei}")
5566 # Fetch prompts if supported
5567 prompts = []
5568 if include_prompts:
5569 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5570 if capabilities.get("prompts"):
5571 try:
5572 response = await session.list_prompts()
5573 raw_prompts = response.prompts
5574 for prompt in raw_prompts:
5575 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5576 # Add default template if not present
5577 if "template" not in prompt_data:
5578 prompt_data["template"] = ""
5579 try:
5580 prompts.append(PromptCreate.model_validate(prompt_data))
5581 except Exception:
5582 # If validation fails, create minimal prompt
5583 prompts.append(
5584 PromptCreate(
5585 name=prompt_data.get("name", ""),
5586 description=prompt_data.get("description"),
5587 template=prompt_data.get("template", ""),
5588 )
5589 )
5590 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5591 except Exception as e:
5592 logger.warning(f"Failed to fetch prompts: {e}")
5594 return capabilities, tools, resources, prompts
5595 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params)
5596 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established")
5598 async def connect_to_streamablehttp_server(
5599 self,
5600 server_url: str,
5601 authentication: Optional[Dict[str, str]] = None,
5602 ca_certificate: Optional[bytes] = None,
5603 include_prompts: bool = True,
5604 include_resources: bool = True,
5605 auth_query_params: Optional[Dict[str, str]] = None,
5606 client_cert: Optional[str] = None,
5607 client_key: Optional[str] = None,
5608 ):
5609 """Connect to an MCP server running with Streamable HTTP transport.
5611 Args:
5612 server_url: The URL of the Streamable HTTP MCP server to connect to.
5613 authentication: Optional dictionary containing authentication headers.
5614 ca_certificate: Optional CA certificate for SSL verification.
5615 include_prompts: Whether to fetch prompts from the server.
5616 include_resources: Whether to fetch resources from the server.
5617 auth_query_params: Query param names for URL sanitization in error logs.
5618 client_cert: Optional client certificate path or PEM for mTLS.
5619 client_key: Optional client private key path or PEM for mTLS.
5621 Returns:
5622 Tuple containing (capabilities, tools, resources, prompts) from the MCP server.
5623 """
5624 if authentication is None:
5625 authentication = {}
5627 # Use authentication directly instead
5628 def get_httpx_client_factory(
5629 headers: dict[str, str] | None = None,
5630 timeout: httpx.Timeout | None = None,
5631 auth: httpx.Auth | None = None,
5632 ) -> httpx.AsyncClient:
5633 """Factory function to create httpx.AsyncClient with optional CA certificate.
5635 Args:
5636 headers: Optional headers for the client
5637 timeout: Optional timeout for the client
5638 auth: Optional auth for the client
5640 Returns:
5641 httpx.AsyncClient: Configured HTTPX async client
5642 """
5643 if server_url and server_url.lower().startswith("http://"):
5644 ctx = None
5645 elif ca_certificate:
5646 ctx = get_cached_ssl_context(ca_certificate, client_cert=client_cert, client_key=client_key)
5647 else:
5648 ctx = None
5649 return httpx.AsyncClient(
5650 verify=ctx if ctx else get_default_verify(),
5651 follow_redirects=True,
5652 headers=headers,
5653 timeout=timeout if timeout else get_http_timeout(),
5654 auth=auth,
5655 limits=httpx.Limits(
5656 max_connections=settings.httpx_max_connections,
5657 max_keepalive_connections=settings.httpx_max_keepalive_connections,
5658 keepalive_expiry=settings.httpx_keepalive_expiry,
5659 ),
5660 )
5662 async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id):
5663 async with ClientSession(read_stream, write_stream) as session:
5664 # Initialize the session
5665 response = await session.initialize()
5666 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
5667 logger.debug(f"Server capabilities: {capabilities}")
5669 response = await session.list_tools()
5670 tools = response.tools
5671 tools = [tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools]
5673 tools, _ = self._validate_tools(tools)
5674 for tool in tools:
5675 tool.request_type = "STREAMABLEHTTP"
5676 if tools:
5677 logger.info(f"Fetched {len(tools)} tools from gateway")
5679 # Fetch resources if supported
5680 resources = []
5681 if include_resources:
5682 logger.debug(f"Checking for resources support: {capabilities.get('resources')}")
5683 if capabilities.get("resources"):
5684 try:
5685 response = await session.list_resources()
5686 raw_resources = response.resources
5687 for resource in raw_resources:
5688 resource_data = resource.model_dump(by_alias=True, exclude_none=True)
5689 # Convert AnyUrl to string if present
5690 if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"):
5691 resource_data["uri"] = str(resource_data["uri"])
5692 # Add default content if not present
5693 if "content" not in resource_data:
5694 resource_data["content"] = ""
5695 try:
5696 resources.append(ResourceCreate.model_validate(resource_data))
5697 except Exception:
5698 # If validation fails, create minimal resource
5699 resources.append(
5700 ResourceCreate(
5701 uri=str(resource_data.get("uri", "")),
5702 name=resource_data.get("name", ""),
5703 description=resource_data.get("description"),
5704 mime_type=resource_data.get("mimeType"),
5705 uri_template=resource_data.get("uriTemplate") or None,
5706 content="",
5707 )
5708 )
5709 logger.info(f"Fetched {len(resources)} resources from gateway")
5710 except Exception as e:
5711 logger.warning(f"Failed to fetch resources: {e}")
5713 # resource template URI
5714 try:
5715 response_templates = await session.list_resource_templates()
5716 raw_resources_templates = response_templates.resourceTemplates
5717 resource_templates = []
5718 for resource_template in raw_resources_templates:
5719 resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True)
5721 if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"):
5722 resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"])
5723 resource_template_data["uri"] = str(resource_template_data["uriTemplate"])
5725 if "content" not in resource_template_data:
5726 resource_template_data["content"] = ""
5728 resources.append(ResourceCreate.model_validate(resource_template_data))
5729 resource_templates.append(ResourceCreate.model_validate(resource_template_data))
5730 logger.info(f"Fetched {len(resource_templates)} resource templates from gateway")
5731 except Exception as e:
5732 logger.warning(f"Failed to fetch resource templates: {e}")
5734 # Fetch prompts if supported
5735 prompts = []
5736 if include_prompts:
5737 logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}")
5738 if capabilities.get("prompts"):
5739 try:
5740 response = await session.list_prompts()
5741 raw_prompts = response.prompts
5742 for prompt in raw_prompts:
5743 prompt_data = prompt.model_dump(by_alias=True, exclude_none=True)
5744 # Add default template if not present
5745 if "template" not in prompt_data:
5746 prompt_data["template"] = ""
5747 prompts.append(PromptCreate.model_validate(prompt_data))
5748 logger.info(f"Fetched {len(prompts)} prompts from gateway")
5749 except Exception as e:
5750 logger.warning(f"Failed to fetch prompts: {e}")
5752 return capabilities, tools, resources, prompts
5753 sanitized_url = sanitize_url_for_logging(server_url, auth_query_params)
5754 raise GatewayConnectionError(f"Failed to initialize gateway at {sanitized_url}: Connection could not be established")
5757# Lazy singleton - created on first access, not at module import time.
5758# This avoids instantiation when only exception classes are imported.
5759_gateway_service_instance = None # pylint: disable=invalid-name
5762def __getattr__(name: str):
5763 """Module-level __getattr__ for lazy singleton creation.
5765 Args:
5766 name: The attribute name being accessed.
5768 Returns:
5769 The gateway_service singleton instance if name is "gateway_service".
5771 Raises:
5772 AttributeError: If the attribute name is not "gateway_service".
5773 """
5774 global _gateway_service_instance # pylint: disable=global-statement
5775 if name == "gateway_service":
5776 if _gateway_service_instance is None:
5777 _gateway_service_instance = GatewayService()
5778 return _gateway_service_instance
5779 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")