Coverage for mcpgateway / services / mcp_session_pool.py: 100%
842 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""
3MCP Session Pool Implementation.
5Provides session pooling for MCP ClientSessions to reduce per-request overhead.
6Sessions are isolated per user/tenant via identity hashing to prevent session collision.
8Performance Impact:
9 - Without pooling: 20-23ms per tool call (new session each time)
10 - With pooling: 1-2ms per tool call (10-20x improvement)
12Security:
13 Sessions are isolated by (url, identity_hash, transport_type) to prevent:
14 - Cross-user session sharing
15 - Cross-tenant data leakage
16 - Authentication bypass
18Copyright 2026
19SPDX-License-Identifier: Apache-2.0
20Authors: Mihai Criveti
21"""
23# flake8: noqa: DAR101, DAR201, DAR401
25# Future
26from __future__ import annotations
28# Standard
29import asyncio
30from contextlib import asynccontextmanager
31from dataclasses import dataclass, field
32from enum import Enum
33import hashlib
34import logging
35import os
36import re
37import socket
38import time
39from typing import Any, Callable, Dict, Optional, Set, Tuple, TYPE_CHECKING
40import uuid
42# Third-Party
43import anyio
44import httpx
45from mcp import ClientSession, McpError
46from mcp.client.sse import sse_client
47from mcp.client.streamable_http import streamablehttp_client
48from mcp.shared.session import RequestResponder
49import mcp.types as mcp_types
50import orjson
52# First-Party
53from mcpgateway.config import settings
54from mcpgateway.utils.url_auth import sanitize_url_for_logging
56# JSON-RPC standard error code for method not found
57METHOD_NOT_FOUND = -32601
59# Shared session-id validation (downstream MCP session IDs used for affinity).
60# Intentionally strict: protects Redis key/channel construction and log lines.
61_MCP_SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
63# Worker ID for multi-worker session affinity
64# Uses hostname + PID to be unique across Docker containers (each container has PID 1)
65# and across gunicorn workers within the same container
66WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
69def _get_cleanup_timeout() -> float:
70 """Get session cleanup timeout from config (lazy import to avoid circular deps).
72 This timeout controls how long to wait for session/transport __aexit__ calls
73 when closing sessions. It prevents CPU spin loops when internal tasks don't
74 respond to cancellation (anyio's _deliver_cancellation issue).
76 Returns:
77 Cleanup timeout in seconds (default: 5.0)
78 """
79 try:
80 # Lazy import to avoid circular dependency during startup
81 return settings.mcp_session_pool_cleanup_timeout
82 except Exception:
83 return 5.0 # Fallback default
86if TYPE_CHECKING:
87 # Standard
88 from collections.abc import AsyncIterator # pragma: no cover
90logger = logging.getLogger(__name__)
93class TransportType(Enum):
94 """Supported MCP transport types."""
96 SSE = "sse"
97 STREAMABLE_HTTP = "streamablehttp"
100@dataclass(eq=False) # eq=False makes instances hashable by object identity
101class PooledSession:
102 """A pooled MCP session with metadata for lifecycle management.
104 Note: eq=False is required because we store these in Sets for active session
105 tracking. This makes instances hashable by their object id (identity).
106 """
108 session: ClientSession
109 transport_context: Any # The transport context manager (kept open)
110 url: str
111 transport_type: TransportType
112 headers: Dict[str, str] # Original headers (for reconnection)
113 identity_key: str # Identity hash component for headers
114 user_identity: str = "anonymous" # for user isolation
115 gateway_id: str = "" # Gateway ID for notification attribution
116 created_at: float = field(default_factory=time.time)
117 last_used: float = field(default_factory=time.time)
118 use_count: int = 0
119 _closed: bool = field(default=False, repr=False)
121 @property
122 def age_seconds(self) -> float:
123 """Return session age in seconds.
125 Returns:
126 float: Session age in seconds since creation.
127 """
128 return time.time() - self.created_at
130 @property
131 def idle_seconds(self) -> float:
132 """Return seconds since last use.
134 Returns:
135 float: Seconds since last use of this session.
136 """
137 return time.time() - self.last_used
139 @property
140 def is_closed(self) -> bool:
141 """Return whether this session has been closed.
143 Returns:
144 bool: True if session is closed, False otherwise.
145 """
146 return self._closed
148 def mark_closed(self) -> None:
149 """Mark this session as closed."""
150 self._closed = True
153# Type aliases
154# Pool key includes transport type and gateway_id to prevent returning wrong transport for same URL
155# and to ensure correct notification attribution when notifications are enabled
156PoolKey = Tuple[str, str, str, str, str] # (user_identity_hash, url, identity_hash, transport_type, gateway_id)
158# Session affinity mapping key: (mcp_session_id, url, transport_type, gateway_id)
159SessionMappingKey = Tuple[str, str, str, str]
160HttpxClientFactory = Callable[
161 [Optional[Dict[str, str]], Optional[httpx.Timeout], Optional[httpx.Auth]],
162 httpx.AsyncClient,
163]
166# Type alias for identity extractor callback
167# Extracts stable identity from headers (e.g., decode JWT to get user_id)
168IdentityExtractor = Callable[[Dict[str, str]], Optional[str]]
170# Type alias for message handler factory
171# Factory that creates message handlers given URL and optional gateway_id
172# The handler receives ServerNotification, ServerRequest responders, or Exceptions
173MessageHandlerFactory = Callable[
174 [str, Optional[str]], # (url, gateway_id)
175 Callable[
176 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception],
177 Any, # Coroutine
178 ],
179]
182class MCPSessionPool: # pylint: disable=too-many-instance-attributes
183 """
184 Pool of MCP ClientSessions keyed by (user_identity, server URL, identity hash, transport type, gateway_id).
186 Thread-Safety:
187 This pool is designed for asyncio concurrency. It uses asyncio.Lock
188 for synchronization, which is safe for coroutine-based concurrency
189 but NOT for multi-threaded access.
191 Session Isolation:
192 Sessions are isolated per user/tenant to prevent session collision.
193 The identity hash is derived from authentication headers ensuring
194 that different users never share MCP sessions.
196 Transport Isolation:
197 Sessions are also isolated by transport type (SSE vs STREAMABLE_HTTP).
198 The same URL with different transports will use separate pools.
200 Gateway Isolation:
201 Sessions are isolated by gateway_id for correct notification attribution.
202 When notifications are enabled, each gateway gets its own pooled sessions
203 even if they share the same URL and authentication.
205 Features:
206 - Session reuse across requests (10-20x latency improvement)
207 - Per-user/tenant session isolation (prevents session collision)
208 - Per-transport session isolation (prevents transport mismatch)
209 - TTL-based expiration with configurable lifetime
210 - Health checks on acquire for stale sessions
211 - Configurable pool size per URL+identity+transport
212 - Circuit breaker for failing endpoints
213 - Idle pool key eviction to prevent unbounded growth
214 - Custom identity extractor for rotating tokens (e.g., JWT decode)
215 - Metrics for monitoring (hits, misses, evictions)
216 - Graceful shutdown with close_all()
218 Usage:
219 pool = MCPSessionPool()
221 # Use as context manager for lifecycle management
222 async with pool:
223 pooled = await pool.acquire(url, headers)
224 try:
225 result = await pooled.session.call_tool("my_tool", {})
226 finally:
227 await pool.release(pooled)
229 # With custom identity extractor for JWT tokens:
230 def extract_user_id(headers: dict) -> str:
231 token = headers.get("Authorization", "").replace("Bearer ", "")
232 claims = jwt.decode(token, options={"verify_signature": False})
233 return claims.get("sub") or claims.get("user_id")
235 pool = MCPSessionPool(identity_extractor=extract_user_id)
236 """
238 # Headers that contribute to session identity (case-insensitive)
239 DEFAULT_IDENTITY_HEADERS: frozenset[str] = frozenset(
240 [
241 "authorization",
242 "x-tenant-id",
243 "x-user-id",
244 "x-api-key",
245 "cookie",
246 "x-mcp-session-id",
247 ]
248 )
250 def __init__(
251 self,
252 max_sessions_per_key: int = 10,
253 session_ttl_seconds: float = 300.0,
254 health_check_interval_seconds: float = 60.0,
255 acquire_timeout_seconds: float = 30.0,
256 session_create_timeout_seconds: float = 30.0,
257 circuit_breaker_threshold: int = 5,
258 circuit_breaker_reset_seconds: float = 60.0,
259 identity_headers: Optional[frozenset[str]] = None,
260 identity_extractor: Optional[IdentityExtractor] = None,
261 idle_pool_eviction_seconds: float = 600.0,
262 default_transport_timeout_seconds: float = 30.0,
263 health_check_methods: Optional[list[str]] = None,
264 health_check_timeout_seconds: float = 5.0,
265 message_handler_factory: Optional[MessageHandlerFactory] = None,
266 ):
267 """
268 Initialize the session pool.
270 Args:
271 max_sessions_per_key: Maximum pooled sessions per (URL, identity, transport).
272 session_ttl_seconds: Session TTL in seconds before forced expiration.
273 health_check_interval_seconds: Seconds of idle time before health check.
274 acquire_timeout_seconds: Timeout for waiting when pool is exhausted.
275 session_create_timeout_seconds: Timeout for creating new sessions.
276 circuit_breaker_threshold: Consecutive failures before circuit opens.
277 circuit_breaker_reset_seconds: Seconds before circuit breaker resets.
278 identity_headers: Headers that contribute to identity hash.
279 identity_extractor: Optional callback to extract stable identity from headers.
280 Use this when tokens rotate frequently (e.g., short-lived JWTs).
281 Should return a stable user/tenant ID string.
282 idle_pool_eviction_seconds: Evict empty pool keys after this many seconds of no use.
283 default_transport_timeout_seconds: Default timeout for transport connections.
284 health_check_methods: Ordered list of health check methods to try.
285 Options: ping, list_tools, list_prompts, list_resources, skip.
286 Default: ["ping", "skip"] (try ping, skip if unsupported).
287 health_check_timeout_seconds: Timeout for each health check attempt.
288 message_handler_factory: Optional factory for creating message handlers.
289 Called with (url, gateway_id) to create handlers for
290 each new session. Enables notification handling.
291 """
292 # Configuration
293 self._max_sessions = max_sessions_per_key
294 self._session_ttl = session_ttl_seconds
295 self._health_check_interval = health_check_interval_seconds
296 self._acquire_timeout = acquire_timeout_seconds
297 self._session_create_timeout = session_create_timeout_seconds
298 self._circuit_breaker_threshold = circuit_breaker_threshold
299 self._circuit_breaker_reset = circuit_breaker_reset_seconds
300 self._identity_headers = identity_headers or self.DEFAULT_IDENTITY_HEADERS
301 self._identity_extractor = identity_extractor
302 self._idle_pool_eviction = idle_pool_eviction_seconds
303 self._default_transport_timeout = default_transport_timeout_seconds
304 self._health_check_methods = health_check_methods or ["ping", "skip"]
305 self._health_check_timeout = health_check_timeout_seconds
306 self._message_handler_factory = message_handler_factory
308 # State - protected by _global_lock for creation, per-key locks for access
309 self._global_lock = asyncio.Lock()
310 self._pools: Dict[PoolKey, asyncio.Queue[PooledSession]] = {}
311 self._active: Dict[PoolKey, Set[PooledSession]] = {}
312 self._locks: Dict[PoolKey, asyncio.Lock] = {}
313 self._semaphores: Dict[PoolKey, asyncio.Semaphore] = {}
314 self._pool_last_used: Dict[PoolKey, float] = {} # Track last use time per pool key
316 # Circuit breaker state
317 self._failures: Dict[str, int] = {} # url -> consecutive failure count
318 self._circuit_open_until: Dict[str, float] = {} # url -> timestamp
320 # Eviction throttling - only run eviction once per interval
321 self._last_eviction_run: float = 0.0
322 self._eviction_run_interval: float = 60.0 # Run eviction at most every 60 seconds
324 # Metrics
325 self._hits = 0
326 self._misses = 0
327 self._evictions = 0
328 self._health_check_failures = 0
329 self._circuit_breaker_trips = 0
330 self._pool_keys_evicted = 0
331 self._sessions_reaped = 0 # Sessions closed during background eviction
332 self._anonymous_identity_count = 0 # Count of requests with no identity headers
334 # Lifecycle
335 self._closed = False
337 # Pre-registered session mappings for session affinity
338 # Mapping from (mcp_session_id, url, transport_type, gateway_id) -> pool_key
339 # Set by broadcast() before acquire() is called to enable session affinity lookup
340 self._mcp_session_mapping: Dict[SessionMappingKey, PoolKey] = {}
341 self._mcp_session_mapping_lock = asyncio.Lock()
343 # Multi-worker session affinity via Redis pub/sub
344 # Track pending responses for forwarded RPC requests
345 self._rpc_listener_task: Optional[asyncio.Task[None]] = None
346 self._pending_responses: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
348 # Session affinity metrics
349 self._session_affinity_local_hits = 0
350 self._session_affinity_redis_hits = 0
351 self._session_affinity_misses = 0
352 self._forwarded_requests = 0
353 self._forwarded_request_failures = 0
354 self._forwarded_request_timeouts = 0
356 async def __aenter__(self) -> "MCPSessionPool":
357 """Async context manager entry.
359 Returns:
360 MCPSessionPool: This pool instance.
361 """
362 return self
364 async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
365 """Async context manager exit - closes all sessions.
367 Args:
368 exc_type: Exception type if an exception was raised.
369 exc_val: Exception value if an exception was raised.
370 exc_tb: Exception traceback if an exception was raised.
371 """
372 await self.close_all()
374 def _compute_identity_hash(self, headers: Optional[Dict[str, str]]) -> str:
375 """
376 Compute a hash of identity-relevant headers.
378 This ensures sessions are isolated per user/tenant. Different users
379 with different Authorization headers will have different identity hashes
380 and thus separate session pools.
382 Identity resolution order:
383 1. Custom identity_extractor (if configured) - for rotating tokens like JWTs
384 2. x-mcp-session-id header (if present) - for session affinity, ensures
385 requests with the same downstream session ID get the same upstream
386 session even when JWT tokens rotate (different jti values)
387 3. Configured identity headers - fallback to hashing all identity headers
389 Args:
390 headers: Request headers dict.
392 Returns:
393 Identity hash string, or "anonymous" if no identity headers present.
394 """
395 if not headers:
396 self._anonymous_identity_count += 1
397 logger.debug("Session pool identity collapsed to 'anonymous' (no headers provided). " + "Sessions will be shared. Ensure this is intentional for stateless MCP servers.")
398 return "anonymous"
400 # Try custom identity extractor first (for rotating tokens like JWTs)
401 if self._identity_extractor:
402 try:
403 extracted = self._identity_extractor(headers)
404 if extracted:
405 return hashlib.sha256(extracted.encode()).hexdigest()
406 except Exception as e:
407 logger.debug(f"Identity extractor failed, falling back to header hash: {e}")
409 # Normalize headers for case-insensitive lookup
410 headers_lower = {k.lower(): v for k, v in headers.items()}
412 # Session affinity: prioritize x-mcp-session-id for stable identity
413 # When present, use ONLY the session ID for identity hash. This ensures
414 # requests with the same downstream session ID get the same upstream session,
415 # even when JWT tokens rotate (different jti values per request).
416 if settings.mcpgateway_session_affinity_enabled:
417 session_id = headers_lower.get("x-mcp-session-id")
418 if session_id:
419 logger.debug(f"Using x-mcp-session-id for session affinity: {session_id[:8]}...")
420 return hashlib.sha256(session_id.encode()).hexdigest()
422 # Fallback: extract identity from configured headers
423 identity_parts = []
425 for header in sorted(self._identity_headers):
426 if header in headers_lower:
427 identity_parts.append(f"{header}:{headers_lower[header]}")
429 if not identity_parts:
430 self._anonymous_identity_count += 1
431 logger.debug(
432 "Session pool identity collapsed to 'anonymous' (no identity headers found). " + "Expected headers: %s. Sessions will be shared.",
433 list(self._identity_headers),
434 )
435 return "anonymous"
437 # Create a stable, deterministic hash using JSON serialization
438 # Prevents delimiter-collision or injection issues present in string joining
439 serialized_identity = orjson.dumps(identity_parts)
440 return hashlib.sha256(serialized_identity).hexdigest()
442 def _make_pool_key(
443 self,
444 url: str,
445 headers: Optional[Dict[str, str]],
446 transport_type: TransportType,
447 user_identity: str,
448 gateway_id: Optional[str] = None,
449 ) -> PoolKey:
450 """Create composite pool key from URL, identity, transport type, user identity, and gateway_id.
452 Including gateway_id ensures correct notification attribution when multiple gateways
453 share the same URL/auth. Sessions are isolated per gateway for proper event routing.
454 """
455 identity_hash = self._compute_identity_hash(headers)
457 # Anonymize user identity by hashing it (unless it's commonly "anonymous")
458 # Use full hash for collision resistance - truncate only for display in logs/metrics
459 if user_identity == "anonymous":
460 user_hash = "anonymous"
461 else:
462 user_hash = hashlib.sha256(user_identity.encode()).hexdigest()
464 # Use empty string for None gateway_id to maintain consistent key type
465 gw_id = gateway_id or ""
467 return (user_hash, url, identity_hash, transport_type.value, gw_id)
469 async def _get_or_create_lock(self, pool_key: PoolKey) -> asyncio.Lock:
470 """Get or create a lock for the given pool key (thread-safe)."""
471 async with self._global_lock:
472 if pool_key not in self._locks:
473 self._locks[pool_key] = asyncio.Lock()
474 return self._locks[pool_key]
476 async def _get_or_create_pool(self, pool_key: PoolKey) -> asyncio.Queue[PooledSession]:
477 """Get or create a pool queue for the given key (thread-safe)."""
478 async with self._global_lock:
479 if pool_key not in self._pools:
480 self._pools[pool_key] = asyncio.Queue(maxsize=self._max_sessions)
481 self._active[pool_key] = set()
482 self._semaphores[pool_key] = asyncio.Semaphore(self._max_sessions)
483 return self._pools[pool_key]
485 def _is_circuit_open(self, url: str) -> bool:
486 """Check if circuit breaker is open for a URL."""
487 if url not in self._circuit_open_until:
488 return False
489 if time.time() >= self._circuit_open_until[url]:
490 # Circuit breaker reset
491 del self._circuit_open_until[url]
492 self._failures[url] = 0
493 logger.info(f"Circuit breaker reset for {sanitize_url_for_logging(url)}")
494 return False
495 return True
497 def _record_failure(self, url: str) -> None:
498 """Record a failure and potentially trip circuit breaker."""
499 self._failures[url] = self._failures.get(url, 0) + 1
500 if self._failures[url] >= self._circuit_breaker_threshold:
501 self._circuit_open_until[url] = time.time() + self._circuit_breaker_reset
502 self._circuit_breaker_trips += 1
503 logger.warning(f"Circuit breaker opened for {sanitize_url_for_logging(url)} after {self._failures[url]} failures. " f"Will reset in {self._circuit_breaker_reset}s")
505 def _record_success(self, url: str) -> None:
506 """Record a success, resetting failure count."""
507 self._failures[url] = 0
509 @staticmethod
510 def is_valid_mcp_session_id(session_id: str) -> bool:
511 """Validate downstream MCP session ID format for affinity.
513 Used for:
514 - Redis key construction (ownership + mapping)
515 - Pub/Sub channel naming
516 - Avoiding log spam / injection
517 """
518 if not session_id:
519 return False
520 return bool(_MCP_SESSION_ID_PATTERN.match(session_id))
522 def _sanitize_redis_key_component(self, value: str) -> str:
523 """Sanitize a value for use in Redis key construction.
525 Replaces any characters that could cause key collision or injection.
527 Args:
528 value: The value to sanitize.
530 Returns:
531 Sanitized value safe for Redis key construction.
532 """
533 if not value:
534 return ""
536 # Replace problematic characters with underscores
537 return re.sub(r"[^a-zA-Z0-9_-]", "_", value)
539 def _session_mapping_redis_key(self, mcp_session_id: str, url: str, transport_type: str, gateway_id: str) -> str:
540 """Compute a bounded Redis key for session mapping.
542 The URL is hashed to keep keys small and avoid special character issues.
543 """
544 sanitized_session_id = self._sanitize_redis_key_component(mcp_session_id)
545 url_hash = hashlib.sha256(url.encode()).hexdigest()[:16]
546 return f"mcpgw:session_mapping:{sanitized_session_id}:{url_hash}:{transport_type}:{gateway_id}"
548 @staticmethod
549 def _pool_owner_key(mcp_session_id: str) -> str:
550 """Return Redis key for session ownership tracking."""
551 return f"mcpgw:pool_owner:{mcp_session_id}"
553 async def register_session_mapping(
554 self,
555 mcp_session_id: str,
556 url: str,
557 gateway_id: str,
558 transport_type: str,
559 user_email: Optional[str] = None,
560 ) -> None:
561 """Pre-register session mapping for session affinity.
563 Called from respond() to set up mapping BEFORE acquire() is called.
564 This ensures acquire() can find the correct pool key for session affinity.
566 The mapping stores the relationship between an incoming MCP session ID
567 and the pool key that should be used for upstream connections. This
568 enables session affinity even when JWT tokens rotate (different jti values
569 per request).
571 For multi-worker deployments, the mapping is also stored in Redis with TTL
572 so that any worker can look it up during acquire().
574 Args:
575 mcp_session_id: The downstream MCP session ID from x-mcp-session-id header.
576 url: The upstream MCP server URL.
577 gateway_id: The gateway ID.
578 transport_type: The transport type (sse, streamablehttp).
579 user_email: The email of the authenticated user (or "system" for unauthenticated).
580 """
581 if not settings.mcpgateway_session_affinity_enabled:
582 return
584 # Validate mcp_session_id to prevent Redis key injection
585 if not self.is_valid_mcp_session_id(mcp_session_id):
586 logger.warning(f"Invalid mcp_session_id format, skipping session mapping: {mcp_session_id[:20]}...")
587 return
589 # Use user email for user_identity, or "anonymous" if not provided
590 user_identity = user_email or "anonymous"
592 # Normalize gateway_id to empty string if None for consistent key matching
593 normalized_gateway_id = gateway_id or ""
595 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type, normalized_gateway_id)
597 # Compute what the pool_key will be for this session
598 # Use mcp_session_id as the identity basis for affinity
599 identity_hash = hashlib.sha256(mcp_session_id.encode()).hexdigest()
601 # Hash user identity for privacy (unless it's "anonymous")
602 if user_identity == "anonymous":
603 user_hash = "anonymous"
604 else:
605 user_hash = hashlib.sha256(user_identity.encode()).hexdigest()
607 pool_key: PoolKey = (user_hash, url, identity_hash, transport_type, normalized_gateway_id)
609 # Store in local memory
610 async with self._mcp_session_mapping_lock:
611 self._mcp_session_mapping[mapping_key] = pool_key
612 logger.debug(f"Session affinity pre-registered (local): {mcp_session_id[:8]}... → {url}, user={user_identity}")
614 # Store in Redis for multi-worker support AND register ownership atomically
615 # Registering ownership HERE (during mapping) instead of in acquire() prevents
616 # a race condition where two workers could both start creating sessions before
617 # either registers ownership
618 try:
619 # First-Party
620 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
622 redis = await get_redis_client()
623 if redis:
624 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type, normalized_gateway_id)
626 # Store pool_key as JSON for easy deserialization
627 pool_key_data = {
628 "user_hash": user_hash,
629 "url": url,
630 "identity_hash": identity_hash,
631 "transport_type": transport_type,
632 "gateway_id": normalized_gateway_id,
633 }
634 await redis.setex(redis_key, settings.mcpgateway_session_affinity_ttl, orjson.dumps(pool_key_data)) # TTL from config
636 # CRITICAL: Register ownership atomically with mapping.
637 # This claims ownership BEFORE any session creation attempt, preventing
638 # the race condition where two workers both start creating sessions
639 owner_key = self._pool_owner_key(mcp_session_id)
640 # Atomic claim with TTL (avoids the SETNX/EXPIRE crash window).
641 was_set = await redis.set(owner_key, WORKER_ID, nx=True, ex=settings.mcpgateway_session_affinity_ttl)
642 if was_set:
643 logger.debug(f"Session ownership claimed (SET NX): {mcp_session_id[:8]}... → worker {WORKER_ID}")
644 else:
645 # Another worker already claimed ownership
646 existing_owner = await redis.get(owner_key)
647 owner_id = existing_owner.decode() if isinstance(existing_owner, bytes) else existing_owner
648 logger.debug(f"Session ownership already claimed by {owner_id}: {mcp_session_id[:8]}...")
650 logger.debug(f"Session affinity pre-registered (Redis): {mcp_session_id[:8]}... TTL={settings.mcpgateway_session_affinity_ttl}s")
651 except Exception as e:
652 # Redis failure is non-fatal - local mapping still works for same-worker requests
653 logger.debug(f"Failed to store session mapping in Redis: {e}")
655 async def acquire(
656 self,
657 url: str,
658 headers: Optional[Dict[str, str]] = None,
659 transport_type: TransportType = TransportType.STREAMABLE_HTTP,
660 httpx_client_factory: Optional[HttpxClientFactory] = None,
661 timeout: Optional[float] = None,
662 user_identity: Optional[str] = None,
663 gateway_id: Optional[str] = None,
664 ) -> PooledSession:
665 """
666 Acquire a session for the given URL, identity, and transport type.
668 Sessions are isolated by identity (derived from auth headers) AND
669 transport type. Returns an initialized, healthy session ready for tool calls.
671 Args:
672 url: The MCP server URL.
673 headers: Request headers (used for identity hashing and passed to server).
674 transport_type: The transport type (SSE or STREAMABLE_HTTP).
675 httpx_client_factory: Optional factory for creating httpx clients
676 (for custom SSL/timeout configuration).
677 timeout: Optional timeout in seconds for transport connection.
678 gateway_id: Optional gateway ID for notification handler context.
680 Returns:
681 PooledSession ready for use.
683 Raises:
684 asyncio.TimeoutError: If acquire times out waiting for available session.
685 RuntimeError: If pool is closed or circuit breaker is open.
686 Exception: If session creation fails.
687 """
688 if self._closed:
689 raise RuntimeError("Session pool is closed")
691 if self._is_circuit_open(url):
692 raise RuntimeError(f"Circuit breaker open for {url}")
694 # Use default timeout if not provided
695 effective_timeout = timeout if timeout is not None else self._default_transport_timeout
697 user_id = user_identity or "anonymous"
698 pool_key: Optional[PoolKey] = None
700 # Check pre-registered mapping first (set by respond() for session affinity)
701 if settings.mcpgateway_session_affinity_enabled and headers:
702 headers_lower = {k.lower(): v for k, v in headers.items()}
703 mcp_session_id = headers_lower.get("x-mcp-session-id")
704 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
705 normalized_gateway_id = gateway_id or ""
706 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type.value, normalized_gateway_id)
708 # Check local memory first (fast path - same worker)
709 async with self._mcp_session_mapping_lock:
710 pool_key = self._mcp_session_mapping.get(mapping_key)
711 if pool_key:
712 self._session_affinity_local_hits += 1
713 logger.debug(f"Session affinity hit (local): {mcp_session_id[:8]}...")
715 # If not in local memory, check Redis (multi-worker support)
716 if pool_key is None:
717 try:
718 # First-Party
719 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
721 redis = await get_redis_client()
722 if redis:
723 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type.value, normalized_gateway_id)
724 pool_key_data = await redis.get(redis_key)
725 if pool_key_data:
726 # Deserialize pool_key from JSON
727 data = orjson.loads(pool_key_data)
728 pool_key = (
729 data["user_hash"],
730 data["url"],
731 data["identity_hash"],
732 data["transport_type"],
733 data["gateway_id"],
734 )
735 # Cache in local memory for future requests
736 async with self._mcp_session_mapping_lock:
737 self._mcp_session_mapping[mapping_key] = pool_key
738 self._session_affinity_redis_hits += 1
739 logger.debug(f"Session affinity hit (Redis): {mcp_session_id[:8]}...")
740 except Exception as e:
741 logger.debug(f"Failed to check Redis for session mapping: {e}")
743 # Fallback to normal pool key computation
744 if pool_key is None:
745 self._session_affinity_misses += 1
746 pool_key = self._make_pool_key(url, headers, transport_type, user_id, gateway_id)
748 pool = await self._get_or_create_pool(pool_key)
750 # Update pool key last used time IMMEDIATELY after getting pool
751 # This prevents race with eviction removing keys between awaits
752 self._pool_last_used[pool_key] = time.time()
754 lock = await self._get_or_create_lock(pool_key)
756 # Guard semaphore access - eviction may have removed it between awaits
757 # If so, re-create the pool structures
758 if pool_key not in self._semaphores:
759 pool = await self._get_or_create_pool(pool_key)
760 self._pool_last_used[pool_key] = time.time()
762 semaphore = self._semaphores[pool_key]
764 # Throttled eviction - only run if enough time has passed (inline, not spawned)
765 await self._maybe_evict_idle_pool_keys()
767 # Try to get from pool first (quick path, no lock needed for queue get)
768 while True:
769 try:
770 pooled = pool.get_nowait()
771 except asyncio.QueueEmpty:
772 break
774 # Validate the session outside the lock
775 if await self._validate_session(pooled):
776 pooled.last_used = time.time()
777 pooled.use_count += 1
778 self._hits += 1
779 async with lock:
780 self._active[pool_key].add(pooled)
781 logger.debug(f"Pool hit for {sanitize_url_for_logging(url)} (identity={pool_key[2][:8]}, transport={transport_type.value})")
782 return pooled
784 # Session invalid, close it
785 await self._close_session(pooled)
786 self._evictions += 1
787 semaphore.release() # Free up a slot
789 # No valid session in pool - try to create one or wait
790 try:
791 # Use semaphore with timeout to limit concurrent sessions
792 acquired = await asyncio.wait_for(semaphore.acquire(), timeout=self._acquire_timeout)
793 if not acquired:
794 raise asyncio.TimeoutError("Failed to acquire session slot")
795 except asyncio.TimeoutError:
796 raise asyncio.TimeoutError(f"Timeout waiting for available session for {sanitize_url_for_logging(url)}") from None
798 # Create new session (semaphore acquired)
799 try:
800 # Verify we own this session before creating (prevents race condition)
801 # If another worker already claimed ownership, we should not create a new session
802 # Note: Ownership is registered atomically in register_session_mapping() using SETNX
803 if settings.mcpgateway_session_affinity_enabled and headers:
804 headers_lower = {k.lower(): v for k, v in headers.items()}
805 mcp_session_id = headers_lower.get("x-mcp-session-id")
806 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
807 owner = await self._get_pool_session_owner(mcp_session_id)
808 if owner and owner != WORKER_ID:
809 # Another worker claimed ownership - should have been forwarded
810 # Release semaphore and raise to trigger forwarding
811 semaphore.release()
812 logger.warning(f"Session {mcp_session_id[:8]}... owned by worker {owner}, not us ({WORKER_ID})")
813 raise RuntimeError(f"Session owned by another worker: {owner}")
815 pooled = await asyncio.wait_for(
816 self._create_session(url, headers, transport_type, httpx_client_factory, effective_timeout, gateway_id),
817 timeout=self._session_create_timeout,
818 )
819 # Store identity components for key reconstruction
820 pooled.identity_key = pool_key[2]
821 pooled.user_identity = user_id
823 # Note: Ownership is now registered atomically in register_session_mapping()
824 # before acquire() is called, so we don't need to register it here
826 self._misses += 1
827 self._record_success(url)
828 async with lock:
829 self._active[pool_key].add(pooled)
830 logger.debug(f"Pool miss for {sanitize_url_for_logging(url)} - created new session (transport={transport_type.value})")
831 return pooled
832 except BaseException as e:
833 # Release semaphore on ANY failure (including CancelledError)
834 semaphore.release()
835 if not isinstance(e, asyncio.CancelledError):
836 self._record_failure(url)
837 logger.warning(f"Failed to create session for {sanitize_url_for_logging(url)}: {e}")
838 raise
840 async def release(self, pooled: PooledSession) -> None:
841 """
842 Return a session to the pool for reuse.
844 Args:
845 pooled: The session to release.
846 """
847 if pooled.is_closed:
848 logger.warning("Attempted to release already-closed session")
849 return
851 # Pool key includes transport type, user identity, and gateway_id
852 # Re-compute user hash from stored raw identity (full hash for collision resistance)
853 user_hash = "anonymous"
854 if pooled.user_identity != "anonymous":
855 user_hash = hashlib.sha256(pooled.user_identity.encode()).hexdigest()
857 pool_key = (user_hash, pooled.url, pooled.identity_key, pooled.transport_type.value, pooled.gateway_id)
858 lock = await self._get_or_create_lock(pool_key)
859 pool = await self._get_or_create_pool(pool_key)
861 async with lock:
862 # Update last-used FIRST to prevent eviction race:
863 # If eviction runs between removing from _active and putting back in pool,
864 # it would see key as idle + inactive and evict it. By updating last-used
865 # while still holding the lock and before removing from _active, we ensure
866 # eviction sees recent activity.
867 self._pool_last_used[pool_key] = time.time()
868 self._active.get(pool_key, set()).discard(pooled)
870 # Check if session should be returned to pool
871 if self._closed or pooled.age_seconds > self._session_ttl:
872 await self._close_session(pooled)
873 if pool_key in self._semaphores:
874 self._semaphores[pool_key].release()
875 if pooled.age_seconds > self._session_ttl:
876 self._evictions += 1
877 return
879 # Return to pool (pool may have been evicted in edge case, recreate if needed)
880 if pool_key not in self._pools:
881 pool = await self._get_or_create_pool(pool_key)
882 self._pool_last_used[pool_key] = time.time()
884 try:
885 pool.put_nowait(pooled)
886 logger.debug(f"Session returned to pool for {sanitize_url_for_logging(pooled.url)}")
887 except asyncio.QueueFull:
888 # Pool full (shouldn't happen with semaphore), close session
889 await self._close_session(pooled)
890 if pool_key in self._semaphores:
891 self._semaphores[pool_key].release()
893 async def _maybe_evict_idle_pool_keys(self) -> None:
894 """
895 Reap stale sessions and evict idle pool keys.
897 This method is throttled - it only runs eviction if enough time has
898 passed since the last run (default: 60 seconds). This prevents:
899 - Unbounded task spawning on every acquire
900 - Lock contention under high load
902 Two-phase cleanup:
903 1. Close expired/stale sessions parked in idle pools (frees connections)
904 2. Evict pool keys that are now empty and have no active sessions
906 This prevents unbounded connection and pool key growth when using
907 rotating tokens (e.g., short-lived JWTs with unique identifiers).
908 """
909 if self._closed:
910 return
912 now = time.time()
914 # Throttle: only run eviction once per interval
915 if now - self._last_eviction_run < self._eviction_run_interval:
916 return
918 self._last_eviction_run = now
920 # Collect sessions to close and keys to evict (minimize time holding lock)
921 sessions_to_close: list[PooledSession] = []
922 keys_to_evict: list[PoolKey] = []
924 async with self._global_lock:
925 for pool_key, last_used in list(self._pool_last_used.items()):
926 # Skip recently-used pools
927 if now - last_used < self._idle_pool_eviction:
928 continue
930 pool = self._pools.get(pool_key)
931 active = self._active.get(pool_key, set())
933 # Skip if there are active sessions (in use)
934 if active:
935 continue
937 if pool:
938 # Phase 1: Drain and collect expired/stale sessions from idle pools
939 while not pool.empty():
940 try:
941 session = pool.get_nowait()
942 # Close if expired OR idle too long (defense in depth)
943 if session.age_seconds > self._session_ttl or session.idle_seconds > self._idle_pool_eviction:
944 sessions_to_close.append(session)
945 # Release semaphore slot for this session
946 if pool_key in self._semaphores:
947 self._semaphores[pool_key].release()
948 else:
949 # Session still valid, put it back
950 pool.put_nowait(session)
951 break # Stop draining if we find a valid session
952 except asyncio.QueueEmpty:
953 break
955 # Phase 2: Evict pool key if now empty
956 if pool.empty():
957 keys_to_evict.append(pool_key)
959 # Remove evicted keys from all tracking dicts
960 for pool_key in keys_to_evict:
961 self._pools.pop(pool_key, None)
962 self._active.pop(pool_key, None)
963 self._locks.pop(pool_key, None)
964 self._semaphores.pop(pool_key, None)
965 self._pool_last_used.pop(pool_key, None)
966 self._pool_keys_evicted += 1
967 logger.debug(f"Evicted idle pool key: {pool_key[0][:8]}|{pool_key[1]}|{pool_key[2][:8]}")
969 # Close sessions outside the lock (I/O operations)
970 for session in sessions_to_close:
971 await self._close_session(session)
972 self._sessions_reaped += 1
973 logger.debug(f"Reaped stale session for {sanitize_url_for_logging(session.url)} (age={session.age_seconds:.1f}s)")
975 async def _validate_session(self, pooled: PooledSession) -> bool:
976 """
977 Validate a session is still usable.
979 Checks TTL and performs health check if session is stale.
981 Args:
982 pooled: The session to validate.
984 Returns:
985 True if session is valid, False otherwise.
986 """
987 if pooled.is_closed:
988 return False
990 # Check TTL
991 if pooled.age_seconds > self._session_ttl:
992 logger.debug(f"Session expired (age={pooled.age_seconds:.1f}s)")
993 return False
995 # Health check if stale
996 if pooled.idle_seconds > self._health_check_interval:
997 return await self._run_health_check_chain(pooled)
999 return True
1001 async def _run_health_check_chain(self, pooled: PooledSession) -> bool:
1002 """
1003 Run health check methods in configured order until one succeeds.
1005 The health check chain allows configuring which methods to try and in what order.
1006 This supports both modern servers (with ping support) and legacy servers
1007 (that may only support list_tools or no health check at all).
1009 Args:
1010 pooled: The session to health check.
1012 Returns:
1013 True if any health check method succeeds, False if all fail.
1014 """
1015 for method in self._health_check_methods:
1016 try:
1017 if method == "ping":
1018 await asyncio.wait_for(pooled.session.send_ping(), timeout=self._health_check_timeout)
1019 logger.debug(f"Health check passed: ping (url={sanitize_url_for_logging(pooled.url)})")
1020 return True
1021 if method == "list_tools":
1022 await asyncio.wait_for(pooled.session.list_tools(), timeout=self._health_check_timeout)
1023 logger.debug(f"Health check passed: list_tools (url={sanitize_url_for_logging(pooled.url)})")
1024 return True
1025 if method == "list_prompts":
1026 await asyncio.wait_for(pooled.session.list_prompts(), timeout=self._health_check_timeout)
1027 logger.debug(f"Health check passed: list_prompts (url={sanitize_url_for_logging(pooled.url)})")
1028 return True
1029 if method == "list_resources":
1030 await asyncio.wait_for(pooled.session.list_resources(), timeout=self._health_check_timeout)
1031 logger.debug(f"Health check passed: list_resources (url={sanitize_url_for_logging(pooled.url)})")
1032 return True
1033 if method == "skip":
1034 logger.debug(f"Health check skipped per configuration (url={sanitize_url_for_logging(pooled.url)})")
1035 return True
1036 logger.warning(f"Unknown health check method '{method}', skipping")
1037 continue
1039 except McpError as e:
1040 # METHOD_NOT_FOUND (-32601) means the method isn't supported - try next
1041 if e.error.code == METHOD_NOT_FOUND:
1042 logger.debug(f"Health check method '{method}' not supported by server, trying next")
1043 continue
1044 # Other MCP errors are real failures
1045 logger.debug(f"Health check '{method}' failed with MCP error: {e}")
1046 self._health_check_failures += 1
1047 return False
1049 except asyncio.TimeoutError:
1050 logger.debug(f"Health check '{method}' timed out after {self._health_check_timeout}s, trying next")
1051 continue
1053 except Exception as e:
1054 logger.debug(f"Health check '{method}' failed: {e}")
1055 self._health_check_failures += 1
1056 return False
1058 # All methods failed or were unsupported
1059 logger.warning(f"All health check methods failed or unsupported (methods={self._health_check_methods})")
1060 self._health_check_failures += 1
1061 return False
1063 async def _create_session(
1064 self,
1065 url: str,
1066 headers: Optional[Dict[str, str]],
1067 transport_type: TransportType,
1068 httpx_client_factory: Optional[HttpxClientFactory],
1069 timeout: Optional[float] = None,
1070 gateway_id: Optional[str] = None,
1071 ) -> PooledSession:
1072 """
1073 Create a new initialized MCP session.
1075 Args:
1076 url: Server URL.
1077 headers: Request headers.
1078 transport_type: Transport type to use.
1079 httpx_client_factory: Optional factory for httpx clients.
1080 timeout: Optional timeout in seconds for transport connection.
1081 gateway_id: Optional gateway ID for notification handler context.
1083 Returns:
1084 Initialized PooledSession.
1086 Raises:
1087 RuntimeError: If session creation or initialization fails.
1088 asyncio.CancelledError: If cancelled during creation.
1089 """
1090 # Merge headers with defaults
1091 merged_headers = {"Accept": "application/json, text/event-stream"}
1092 if headers:
1093 merged_headers.update(headers)
1095 # Strip gateway-internal session affinity headers before sending to upstream
1096 # x-mcp-session-id is our internal representation, mcp-session-id is the MCP protocol header
1097 # Neither should be forwarded to upstream servers
1098 keys_to_remove = [k for k in merged_headers if k.lower() in ("x-mcp-session-id", "mcp-session-id")]
1099 for k in keys_to_remove:
1100 del merged_headers[k]
1102 identity_key = self._compute_identity_hash(headers)
1103 transport_ctx = None
1104 session = None
1105 success = False
1107 try:
1108 # Create transport context
1109 if transport_type == TransportType.SSE:
1110 if httpx_client_factory:
1111 transport_ctx = sse_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout)
1112 else:
1113 transport_ctx = sse_client(url=url, headers=merged_headers, timeout=timeout)
1114 # pylint: disable=unnecessary-dunder-call,no-member
1115 streams = await transport_ctx.__aenter__() # Must call directly for manual lifecycle management
1116 read_stream, write_stream = streams[0], streams[1]
1117 else: # STREAMABLE_HTTP
1118 if httpx_client_factory:
1119 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout)
1120 else:
1121 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, timeout=timeout)
1122 # pylint: disable=unnecessary-dunder-call,no-member
1123 read_stream, write_stream, _ = await transport_ctx.__aenter__() # Must call directly for manual lifecycle management
1125 # Create message handler if factory is configured
1126 message_handler = None
1127 if self._message_handler_factory:
1128 try:
1129 message_handler = self._message_handler_factory(url, gateway_id)
1130 logger.debug(f"Created message handler for session {sanitize_url_for_logging(url)} (gateway={gateway_id})")
1131 except Exception as e:
1132 logger.warning(f"Failed to create message handler for {sanitize_url_for_logging(url)}: {e}")
1134 # Create and initialize session
1135 session = ClientSession(read_stream, write_stream, message_handler=message_handler)
1136 # pylint: disable=unnecessary-dunder-call
1137 await session.__aenter__() # Must call directly for manual lifecycle management
1138 await session.initialize()
1140 logger.info(f"Created new MCP session for {sanitize_url_for_logging(url)} (transport={transport_type.value})")
1141 success = True
1143 return PooledSession(
1144 session=session,
1145 transport_context=transport_ctx,
1146 url=url,
1147 transport_type=transport_type,
1148 headers=merged_headers,
1149 identity_key=identity_key,
1150 gateway_id=gateway_id or "",
1151 )
1153 except asyncio.CancelledError: # pylint: disable=try-except-raise
1154 # Re-raise CancelledError after cleanup (handled in finally)
1155 raise
1157 except Exception as e:
1158 raise RuntimeError(f"Failed to create MCP session for {url}: {e}") from e
1160 finally:
1161 # Clean up on ANY failure (Exception, CancelledError, etc.)
1162 # Only clean up if we didn't succeed
1163 # Use anyio.move_on_after instead of asyncio.wait_for to properly propagate
1164 # cancellation through anyio's cancel scope system (prevents orphaned spinning tasks)
1165 if not success:
1166 cleanup_timeout = _get_cleanup_timeout()
1167 if session is not None:
1168 with anyio.move_on_after(cleanup_timeout):
1169 try:
1170 await session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
1171 except Exception: # nosec B110 - Best effort cleanup on connection failure
1172 pass
1173 if transport_ctx is not None:
1174 with anyio.move_on_after(cleanup_timeout):
1175 try:
1176 await transport_ctx.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
1177 except Exception: # nosec B110 - Best effort cleanup on connection failure
1178 pass
1180 async def _close_session(self, pooled: PooledSession) -> None:
1181 """
1182 Close a session and its transport.
1184 Uses timeouts to prevent indefinite blocking if session/transport tasks
1185 don't respond to cancellation. This prevents CPU spin loops in anyio's
1186 _deliver_cancellation which can occur when async iterators or blocking
1187 operations don't properly handle CancelledError.
1189 Args:
1190 pooled: The session to close.
1191 """
1192 if pooled.is_closed:
1193 return
1195 pooled.mark_closed()
1197 # Use anyio's move_on_after instead of asyncio.wait_for to properly propagate
1198 # cancellation through anyio's cancel scope system. asyncio.wait_for() creates
1199 # orphaned anyio tasks that keep spinning in _deliver_cancellation.
1200 cleanup_timeout = _get_cleanup_timeout()
1202 # Close session with anyio timeout
1203 with anyio.move_on_after(cleanup_timeout) as session_scope:
1204 try:
1205 await pooled.session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
1206 except Exception as e:
1207 logger.debug(f"Error closing session: {e}")
1208 if session_scope.cancelled_caught:
1209 logger.warning(f"Session cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway")
1211 # Close transport with anyio timeout
1212 with anyio.move_on_after(cleanup_timeout) as transport_scope:
1213 try:
1214 await pooled.transport_context.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
1215 except Exception as e:
1216 logger.debug(f"Error closing transport: {e}")
1217 if transport_scope.cancelled_caught:
1218 logger.warning(f"Transport cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway")
1220 logger.debug(f"Closed session for {sanitize_url_for_logging(pooled.url)} (uses={pooled.use_count})")
1222 # Clean up pool_owner key in Redis for session affinity
1223 if settings.mcpgateway_session_affinity_enabled and pooled.headers:
1224 headers_lower = {k.lower(): v for k, v in pooled.headers.items()}
1225 mcp_session_id = headers_lower.get("x-mcp-session-id")
1226 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
1227 await self._cleanup_pool_session_owner(mcp_session_id)
1229 async def _cleanup_pool_session_owner(self, mcp_session_id: str) -> None:
1230 """Clean up pool_owner key in Redis when session is closed.
1232 Only deletes the key if this worker owns it (to prevent removing other workers' ownership).
1234 Args:
1235 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1236 """
1237 try:
1238 # First-Party
1239 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1241 redis = await get_redis_client()
1242 if redis:
1243 key = self._pool_owner_key(mcp_session_id)
1244 # Only delete if we own it
1245 owner = await redis.get(key)
1246 if owner:
1247 owner_id = owner.decode() if isinstance(owner, bytes) else owner
1248 if owner_id == WORKER_ID:
1249 await redis.delete(key)
1250 logger.debug(f"Cleaned up pool session owner: {mcp_session_id[:8]}...")
1251 except Exception as e:
1252 # Cleanup failure is non-fatal
1253 logger.debug(f"Failed to cleanup pool session owner in Redis: {e}")
1255 async def close_all(self) -> None:
1256 """
1257 Gracefully close all pooled and active sessions.
1259 Should be called during application shutdown.
1260 """
1261 self._closed = True
1262 logger.info("Closing all pooled sessions...")
1264 async with self._global_lock:
1265 # Close all pooled sessions
1266 for _pool_key, pool in list(self._pools.items()):
1267 while not pool.empty():
1268 try:
1269 pooled = pool.get_nowait()
1270 await self._close_session(pooled)
1271 except asyncio.QueueEmpty:
1272 break
1274 # Close all active sessions
1275 for _pool_key, active_set in list(self._active.items()):
1276 for pooled in list(active_set):
1277 await self._close_session(pooled)
1279 self._pools.clear()
1280 self._active.clear()
1281 self._locks.clear()
1282 self._semaphores.clear()
1284 # Stop RPC listener if running
1285 if self._rpc_listener_task and not self._rpc_listener_task.done():
1286 self._rpc_listener_task.cancel()
1287 try:
1288 await self._rpc_listener_task
1289 except asyncio.CancelledError:
1290 pass
1291 self._rpc_listener_task = None
1293 logger.info("All sessions closed")
1295 async def register_pool_session_owner(self, mcp_session_id: str) -> None:
1296 """Register this worker as owner of a pool session in Redis.
1298 This enables multi-worker session affinity by tracking which worker owns
1299 which pool session. When a request with x-mcp-session-id arrives at a
1300 different worker, it can forward the request to the owner worker.
1302 Note: This method is now primarily used for refreshing TTL on existing ownership.
1303 Initial ownership is claimed atomically in register_session_mapping() using SETNX.
1305 Args:
1306 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1307 """
1308 if not settings.mcpgateway_session_affinity_enabled:
1309 return
1311 if not self.is_valid_mcp_session_id(mcp_session_id):
1312 logger.debug("Invalid mcp_session_id for owner registration, skipping")
1313 return
1315 try:
1316 # First-Party
1317 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1319 redis = await get_redis_client()
1320 if redis:
1321 key = self._pool_owner_key(mcp_session_id)
1323 # Do not steal ownership: only claim if missing, or refresh TTL if we already own.
1324 # Lua keeps this atomic.
1325 script = """
1326 local cur = redis.call('GET', KEYS[1])
1327 if not cur then
1328 redis.call('SET', KEYS[1], ARGV[1], 'EX', ARGV[2])
1329 return 1
1330 end
1331 if cur == ARGV[1] then
1332 redis.call('EXPIRE', KEYS[1], ARGV[2])
1333 return 2
1334 end
1335 return 0
1336 """
1337 ttl = int(settings.mcpgateway_session_affinity_ttl)
1338 outcome = await redis.eval(script, 1, key, WORKER_ID, ttl)
1339 logger.debug(f"Owner registration outcome={outcome} for session {mcp_session_id[:8]}...")
1340 except Exception as e:
1341 # Redis failure is non-fatal - single worker mode still works
1342 logger.debug(f"Failed to register pool session owner in Redis: {e}")
1344 async def _get_pool_session_owner(self, mcp_session_id: str) -> Optional[str]:
1345 """Get the worker ID that owns a pool session.
1347 Args:
1348 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1350 Returns:
1351 The worker ID that owns this session, or None if not found or Redis unavailable.
1352 """
1353 if not settings.mcpgateway_session_affinity_enabled:
1354 return None
1356 if not self.is_valid_mcp_session_id(mcp_session_id):
1357 return None
1359 try:
1360 # First-Party
1361 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1363 redis = await get_redis_client()
1364 if redis:
1365 key = self._pool_owner_key(mcp_session_id)
1366 owner = await redis.get(key)
1367 if owner:
1368 decoded = owner.decode() if isinstance(owner, bytes) else owner
1369 return decoded
1370 except Exception as e:
1371 logger.debug(f"Failed to get pool session owner from Redis: {e}")
1372 return None
1374 async def forward_request_to_owner(
1375 self,
1376 mcp_session_id: str,
1377 request_data: Dict[str, Any],
1378 timeout: Optional[float] = None,
1379 ) -> Optional[Dict[str, Any]]:
1380 """Forward RPC request to the worker that owns the pool session.
1382 This method checks Redis to find which worker owns the pool session for
1383 the given mcp_session_id. If owned by another worker, it forwards the
1384 request via Redis pub/sub and waits for the response.
1386 Args:
1387 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1388 request_data: The RPC request data to forward.
1389 timeout: Optional timeout in seconds (default from config).
1391 Returns:
1392 The response from the owner worker, or None if we own the session
1393 (caller should execute locally) or if Redis is unavailable.
1395 Raises:
1396 asyncio.TimeoutError: If the forwarded request times out.
1397 """
1398 if not settings.mcpgateway_session_affinity_enabled:
1399 return None
1401 if not self.is_valid_mcp_session_id(mcp_session_id):
1402 return None
1404 effective_timeout = timeout if timeout is not None else settings.mcpgateway_pool_rpc_forward_timeout
1406 try:
1407 # First-Party
1408 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1410 redis = await get_redis_client()
1411 if not redis:
1412 return None # Execute locally - no Redis
1414 # Check who owns this session
1415 owner = await redis.get(self._pool_owner_key(mcp_session_id))
1416 method = request_data.get("method", "unknown")
1417 if not owner:
1418 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | No owner → execute locally (new session)")
1419 return None # No owner registered - execute locally (new session)
1421 owner_id = owner.decode() if isinstance(owner, bytes) else owner
1422 if owner_id == WORKER_ID:
1423 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | We own it → execute locally")
1424 return None # We own it - execute locally
1426 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Owner: {owner_id} → forwarding")
1428 # Forward to owner worker via pub/sub
1429 response_id = str(uuid.uuid4())
1430 response_channel = f"mcpgw:pool_rpc_response:{response_id}"
1432 # Subscribe to response channel
1433 pubsub = redis.pubsub()
1434 await pubsub.subscribe(response_channel)
1436 try:
1437 # Prepare request with response channel
1438 forward_data = {
1439 "type": "rpc_forward",
1440 **request_data,
1441 "response_channel": response_channel,
1442 "mcp_session_id": mcp_session_id,
1443 }
1445 # Publish request to owner's channel
1446 await redis.publish(f"mcpgw:pool_rpc:{owner_id}", orjson.dumps(forward_data))
1447 self._forwarded_requests += 1
1448 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Published to worker {owner_id}")
1450 # Wait for response
1451 async with asyncio.timeout(effective_timeout):
1452 while True:
1453 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
1454 if msg and msg["type"] == "message":
1455 return orjson.loads(msg["data"])
1456 finally:
1457 await pubsub.unsubscribe(response_channel)
1459 except asyncio.TimeoutError:
1460 self._forwarded_request_timeouts += 1
1461 logger.warning(f"Timeout forwarding request to owner for session {mcp_session_id[:8]}...")
1462 raise
1463 except Exception as e:
1464 self._forwarded_request_failures += 1
1465 logger.debug(f"Error forwarding request to owner: {e}")
1466 return None # Execute locally on error
1468 async def start_rpc_listener(self) -> None:
1469 """Start listening for forwarded RPC and HTTP requests on this worker's channels.
1471 This method subscribes to Redis pub/sub channels specific to this worker
1472 and processes incoming forwarded requests from other workers:
1473 - mcpgw:pool_rpc:{WORKER_ID} - for SSE transport JSON-RPC forwards
1474 - mcpgw:pool_http:{WORKER_ID} - for Streamable HTTP request forwards
1475 """
1476 if not settings.mcpgateway_session_affinity_enabled:
1477 return
1479 try:
1480 # First-Party
1481 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1483 redis = await get_redis_client()
1484 if not redis:
1485 logger.debug("Redis not available, RPC listener not started")
1486 return
1488 rpc_channel = f"mcpgw:pool_rpc:{WORKER_ID}"
1489 http_channel = f"mcpgw:pool_http:{WORKER_ID}"
1490 pubsub = redis.pubsub()
1491 await pubsub.subscribe(rpc_channel, http_channel)
1492 logger.info(f"RPC/HTTP listener started for worker {WORKER_ID} on channels: {rpc_channel}, {http_channel}")
1494 while not self._closed:
1495 try:
1496 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
1497 if msg and msg["type"] == "message":
1498 request = orjson.loads(msg["data"])
1499 forward_type = request.get("type")
1500 response_channel = request.get("response_channel")
1502 if response_channel:
1503 if forward_type == "rpc_forward":
1504 # Execute forwarded RPC request for SSE transport
1505 response = await self._execute_forwarded_request(request)
1506 await redis.publish(response_channel, orjson.dumps(response))
1507 logger.debug(f"Processed forwarded RPC request, response sent to {response_channel}")
1508 elif forward_type == "http_forward":
1509 # Execute forwarded HTTP request for Streamable HTTP transport
1510 await self._execute_forwarded_http_request(request, redis)
1511 else:
1512 logger.warning(f"Unknown forward type: {forward_type}")
1513 except asyncio.CancelledError:
1514 break
1515 except Exception as e:
1516 logger.warning(f"Error processing forwarded request: {e}")
1518 await pubsub.unsubscribe(rpc_channel, http_channel)
1519 logger.info(f"RPC/HTTP listener stopped for worker {WORKER_ID}")
1521 except Exception as e:
1522 logger.warning(f"RPC/HTTP listener failed: {e}")
1524 async def _execute_forwarded_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
1525 """Execute a forwarded RPC request locally via internal HTTP call.
1527 This method handles RPC requests that were forwarded from another worker.
1528 Instead of handling specific methods here, we make an internal HTTP call
1529 to the local /rpc endpoint which reuses ALL existing method handling logic.
1531 The x-forwarded-internally header prevents infinite forwarding loops.
1533 Args:
1534 request: The forwarded RPC request containing method, params, headers, req_id, etc.
1536 Returns:
1537 The JSON-RPC response from the local endpoint.
1538 """
1539 try:
1540 method = request.get("method")
1541 params = request.get("params", {})
1542 headers = request.get("headers", {})
1543 req_id = request.get("req_id", 1)
1544 mcp_session_id = request.get("mcp_session_id", "unknown")
1545 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
1547 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Received forwarded request, executing locally")
1549 # Make internal HTTP call to local /rpc endpoint
1550 # This reuses ALL existing method handling logic without duplication
1551 async with httpx.AsyncClient() as client:
1552 # Build headers for internal request - forward original headers
1553 # but add x-forwarded-internally to prevent infinite loops
1554 internal_headers = dict(headers)
1555 internal_headers["x-forwarded-internally"] = "true"
1556 # Ensure content-type is set
1557 internal_headers["content-type"] = "application/json"
1559 response = await client.post(
1560 f"http://127.0.0.1:{settings.port}/rpc",
1561 json={"jsonrpc": "2.0", "method": method, "params": params, "id": req_id},
1562 headers=internal_headers,
1563 timeout=settings.mcpgateway_pool_rpc_forward_timeout,
1564 )
1566 # Parse response
1567 response_data = response.json()
1569 # Extract result or error from JSON-RPC response
1570 if "error" in response_data:
1571 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error")
1572 return {"error": response_data["error"]}
1573 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed successfully")
1574 return {"result": response_data.get("result", {})}
1576 except httpx.TimeoutException:
1577 logger.warning(f"Timeout executing forwarded request: {request.get('method')}")
1578 return {"error": {"code": -32603, "message": "Internal request timeout"}}
1579 except Exception as e:
1580 logger.warning(f"Error executing forwarded request: {e}")
1581 return {"error": {"code": -32603, "message": str(e)}}
1583 async def _execute_forwarded_http_request(self, request: Dict[str, Any], redis: Any) -> None:
1584 """Execute a forwarded HTTP request locally and return response via Redis.
1586 This method handles full HTTP requests forwarded from other workers for
1587 Streamable HTTP transport session affinity. It reconstructs the HTTP request,
1588 makes an internal call to the appropriate endpoint, and publishes the response
1589 back through Redis.
1591 Args:
1592 request: Serialized HTTP request data from Redis Pub/Sub containing:
1593 - type: "http_forward"
1594 - response_channel: Redis channel to publish response to
1595 - mcp_session_id: Session identifier
1596 - method: HTTP method (GET, POST, DELETE)
1597 - path: Request path (e.g., /mcp)
1598 - query_string: Query parameters
1599 - headers: Request headers dict
1600 - body: Hex-encoded request body
1601 redis: Redis client for publishing response
1602 """
1603 response_channel = None
1604 try:
1605 response_channel = request.get("response_channel")
1606 method = request.get("method")
1607 path = request.get("path")
1608 query_string = request.get("query_string", "")
1609 headers = request.get("headers", {})
1610 body_hex = request.get("body", "")
1611 mcp_session_id = request.get("mcp_session_id")
1613 # Decode hex body back to bytes
1614 body = bytes.fromhex(body_hex) if body_hex else b""
1616 session_short = mcp_session_id[:8] if mcp_session_id and len(mcp_session_id) >= 8 else "unknown"
1617 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Received forwarded HTTP request: {method} {path}")
1619 # Add internal forwarding headers to prevent loops
1620 internal_headers = dict(headers)
1621 internal_headers["x-forwarded-internally"] = "true"
1622 internal_headers["x-original-worker"] = request.get("original_worker", "unknown")
1624 # Make internal HTTP request to local endpoint
1625 url = f"http://127.0.0.1:{settings.port}{path}"
1626 if query_string:
1627 url = f"{url}?{query_string}"
1629 async with httpx.AsyncClient() as client:
1630 response = await client.request(
1631 method=method,
1632 url=url,
1633 headers=internal_headers,
1634 content=body,
1635 timeout=settings.mcpgateway_pool_rpc_forward_timeout,
1636 )
1638 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Executed locally: {response.status_code}")
1640 # Serialize response for Redis transport
1641 response_data = {
1642 "status": response.status_code,
1643 "headers": dict(response.headers),
1644 "body": response.content.hex(), # Hex encode binary response
1645 }
1647 # Publish response back to requesting worker
1648 if redis and response_channel:
1649 await redis.publish(response_channel, orjson.dumps(response_data))
1650 logger.debug(f"[HTTP_AFFINITY] Published HTTP response to Redis channel: {response_channel}")
1652 except Exception as e:
1653 logger.error(f"Error executing forwarded HTTP request: {e}")
1654 # Try to send error response if possible
1655 if redis and response_channel:
1656 error_response = {
1657 "status": 500,
1658 "headers": {"content-type": "application/json"},
1659 "body": orjson.dumps({"error": "Internal forwarding error"}).hex(),
1660 }
1661 try:
1662 await redis.publish(response_channel, orjson.dumps(error_response))
1663 except Exception as publish_error:
1664 logger.debug(f"Failed to publish error response via Redis: {publish_error}")
1666 async def get_streamable_http_session_owner(self, mcp_session_id: str) -> Optional[str]:
1667 """Get the worker ID that owns a Streamable HTTP session.
1669 This is a public wrapper around _get_pool_session_owner for use by
1670 streamablehttp_transport to check session ownership before handling requests.
1672 Args:
1673 mcp_session_id: The MCP session ID from mcp-session-id header.
1675 Returns:
1676 Worker ID if found, None otherwise.
1677 """
1678 return await self._get_pool_session_owner(mcp_session_id)
1680 async def forward_streamable_http_to_owner(
1681 self,
1682 owner_worker_id: str,
1683 mcp_session_id: str,
1684 method: str,
1685 path: str,
1686 headers: Dict[str, str],
1687 body: bytes,
1688 query_string: str = "",
1689 ) -> Optional[Dict[str, Any]]:
1690 """Forward a Streamable HTTP request to the worker that owns the session via Redis Pub/Sub.
1692 This method forwards the entire HTTP request to another worker using Redis
1693 Pub/Sub channels, similar to forward_request_to_owner() for SSE transport.
1694 This ensures session affinity works correctly in single-host multi-worker
1695 deployments where hostname-based routing fails.
1697 Args:
1698 owner_worker_id: The worker ID that owns the session.
1699 mcp_session_id: The MCP session ID.
1700 method: HTTP method (GET, POST, DELETE).
1701 path: Request path (e.g., /mcp).
1702 headers: Request headers.
1703 body: Request body bytes.
1704 query_string: Query string if any.
1706 Returns:
1707 Dict with 'status', 'headers', and 'body' from the owner worker's response,
1708 or None if forwarding fails.
1709 """
1710 if not settings.mcpgateway_session_affinity_enabled:
1711 return None
1713 if not self.is_valid_mcp_session_id(mcp_session_id):
1714 return None
1716 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
1717 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | {method} {path} | Forwarding to worker {owner_worker_id}")
1719 try:
1720 # First-Party
1721 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1723 redis = await get_redis_client()
1724 if not redis:
1725 logger.warning("Redis unavailable for HTTP forwarding, executing locally")
1726 return None # Fall back to local execution
1728 # Generate unique response channel for this request
1729 response_uuid = uuid.uuid4().hex
1730 response_channel = f"mcpgw:pool_http_response:{response_uuid}"
1732 # Serialize HTTP request for Redis transport
1733 forward_data = {
1734 "type": "http_forward",
1735 "response_channel": response_channel,
1736 "mcp_session_id": mcp_session_id,
1737 "method": method,
1738 "path": path,
1739 "query_string": query_string,
1740 "headers": headers,
1741 "body": body.hex() if body else "", # Hex encode binary body
1742 "original_worker": WORKER_ID,
1743 "timestamp": time.time(),
1744 }
1746 # Subscribe to response channel BEFORE publishing request (prevent race)
1747 pubsub = redis.pubsub()
1748 await pubsub.subscribe(response_channel)
1750 try:
1751 # Publish forwarded request to owner worker's HTTP channel
1752 owner_channel = f"mcpgw:pool_http:{owner_worker_id}"
1753 await redis.publish(owner_channel, orjson.dumps(forward_data))
1754 logger.debug(f"[HTTP_AFFINITY] Published HTTP request to Redis channel: {owner_channel}")
1756 # Wait for response with timeout
1757 timeout = settings.mcpgateway_pool_rpc_forward_timeout
1758 async with asyncio.timeout(timeout):
1759 while True:
1760 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
1761 if msg and msg["type"] == "message":
1762 response_data = orjson.loads(msg["data"])
1763 logger.debug(f"[HTTP_AFFINITY] Received HTTP response via Redis: status={response_data.get('status')}")
1765 # Decode hex body back to bytes
1766 body_hex = response_data.get("body", "")
1767 response_data["body"] = bytes.fromhex(body_hex) if body_hex else b""
1769 self._forwarded_requests += 1
1770 return response_data
1772 finally:
1773 await pubsub.unsubscribe(response_channel)
1775 except asyncio.TimeoutError:
1776 self._forwarded_request_timeouts += 1
1777 logger.warning(f"Timeout forwarding HTTP request to owner {owner_worker_id}")
1778 return None
1779 except Exception as e:
1780 self._forwarded_request_failures += 1
1781 logger.warning(f"Error forwarding HTTP request via Redis: {e}")
1782 return None
1784 def get_metrics(self) -> Dict[str, Any]:
1785 """
1786 Return pool metrics for monitoring.
1788 Returns:
1789 Dict with hits, misses, evictions, hit_rate, and per-pool stats.
1790 """
1791 total_requests = self._hits + self._misses
1792 total_affinity_requests = self._session_affinity_local_hits + self._session_affinity_redis_hits + self._session_affinity_misses
1793 return {
1794 "hits": self._hits,
1795 "misses": self._misses,
1796 "evictions": self._evictions,
1797 "health_check_failures": self._health_check_failures,
1798 "circuit_breaker_trips": self._circuit_breaker_trips,
1799 "pool_keys_evicted": self._pool_keys_evicted,
1800 "sessions_reaped": self._sessions_reaped,
1801 "anonymous_identity_count": self._anonymous_identity_count,
1802 "hit_rate": self._hits / total_requests if total_requests > 0 else 0.0,
1803 "pool_key_count": len(self._pools),
1804 # Session affinity metrics
1805 "session_affinity": {
1806 "local_hits": self._session_affinity_local_hits,
1807 "redis_hits": self._session_affinity_redis_hits,
1808 "misses": self._session_affinity_misses,
1809 "hit_rate": (self._session_affinity_local_hits + self._session_affinity_redis_hits) / total_affinity_requests if total_affinity_requests > 0 else 0.0,
1810 "forwarded_requests": self._forwarded_requests,
1811 "forwarded_failures": self._forwarded_request_failures,
1812 "forwarded_timeouts": self._forwarded_request_timeouts,
1813 },
1814 "pools": {
1815 f"{url}|{identity[:8]}|{transport}|{user}|{gw_id[:8] if gw_id else 'none'}": {
1816 "available": pool.qsize(),
1817 "active": len(self._active.get((user, url, identity, transport, gw_id), set())),
1818 "max": self._max_sessions,
1819 }
1820 for (user, url, identity, transport, gw_id), pool in self._pools.items()
1821 },
1822 "circuit_breakers": {
1823 url: {
1824 "failures": self._failures.get(url, 0),
1825 "open_until": self._circuit_open_until.get(url),
1826 }
1827 for url in set(self._failures.keys()) | set(self._circuit_open_until.keys())
1828 },
1829 }
1831 @asynccontextmanager
1832 async def session(
1833 self,
1834 url: str,
1835 headers: Optional[Dict[str, str]] = None,
1836 transport_type: TransportType = TransportType.STREAMABLE_HTTP,
1837 httpx_client_factory: Optional[HttpxClientFactory] = None,
1838 timeout: Optional[float] = None,
1839 user_identity: Optional[str] = None,
1840 gateway_id: Optional[str] = None,
1841 ) -> "AsyncIterator[PooledSession]":
1842 """
1843 Context manager for acquiring and releasing a session.
1845 Usage:
1846 async with pool.session(url, headers) as pooled:
1847 result = await pooled.session.call_tool("my_tool", {})
1849 Args:
1850 url: The MCP server URL.
1851 headers: Request headers.
1852 transport_type: Transport type to use.
1853 httpx_client_factory: Optional factory for httpx clients.
1854 timeout: Optional timeout in seconds for transport connection.
1855 user_identity: Optional user identity for strict isolation.
1856 gateway_id: Optional gateway ID for notification handler context.
1858 Yields:
1859 PooledSession ready for use.
1860 """
1861 pooled = await self.acquire(url, headers, transport_type, httpx_client_factory, timeout, user_identity, gateway_id)
1862 try:
1863 yield pooled
1864 finally:
1865 await self.release(pooled)
1868# Global pool instance - initialized by FastAPI lifespan
1869_mcp_session_pool: Optional[MCPSessionPool] = None
1872def get_mcp_session_pool() -> MCPSessionPool:
1873 """Get the global MCP session pool instance.
1875 Returns:
1876 The global MCPSessionPool instance.
1878 Raises:
1879 RuntimeError: If pool has not been initialized.
1880 """
1881 if _mcp_session_pool is None:
1882 raise RuntimeError("MCP session pool not initialized. Call init_mcp_session_pool() first.")
1883 return _mcp_session_pool
1886def init_mcp_session_pool(
1887 max_sessions_per_key: int = 10,
1888 session_ttl_seconds: float = 300.0,
1889 health_check_interval_seconds: float = 60.0,
1890 acquire_timeout_seconds: float = 30.0,
1891 session_create_timeout_seconds: float = 30.0,
1892 circuit_breaker_threshold: int = 5,
1893 circuit_breaker_reset_seconds: float = 60.0,
1894 identity_headers: Optional[frozenset[str]] = None,
1895 identity_extractor: Optional[IdentityExtractor] = None,
1896 idle_pool_eviction_seconds: float = 600.0,
1897 default_transport_timeout_seconds: float = 30.0,
1898 health_check_methods: Optional[list[str]] = None,
1899 health_check_timeout_seconds: float = 5.0,
1900 message_handler_factory: Optional[MessageHandlerFactory] = None,
1901 enable_notifications: bool = True,
1902 notification_debounce_seconds: float = 5.0,
1903) -> MCPSessionPool:
1904 """Initialize the global MCP session pool.
1906 Args:
1907 See MCPSessionPool.__init__ for argument descriptions.
1908 enable_notifications: Enable automatic notification service for list_changed events.
1909 notification_debounce_seconds: Debounce interval for notification-triggered refreshes.
1911 Returns:
1912 The initialized MCPSessionPool instance.
1913 """
1914 global _mcp_session_pool # pylint: disable=global-statement
1916 # Auto-create notification service if enabled and no custom handler provided
1917 effective_handler_factory = message_handler_factory
1918 if enable_notifications and message_handler_factory is None:
1919 # First-Party
1920 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
1921 init_notification_service,
1922 )
1924 # Initialize notification service (will be started during acquire with gateway context)
1925 notification_svc = init_notification_service(debounce_seconds=notification_debounce_seconds)
1927 # Create default handler factory that uses notification service
1928 def default_handler_factory(url: str, gateway_id: Optional[str]):
1929 """Create a message handler for MCP session notifications.
1931 Args:
1932 url: The MCP server URL for the session.
1933 gateway_id: Optional gateway ID for attribution, falls back to URL if not provided.
1935 Returns:
1936 A message handler that forwards notifications to the notification service.
1937 """
1938 return notification_svc.create_message_handler(gateway_id or url, url)
1940 effective_handler_factory = default_handler_factory
1941 logger.info("MCP notification service created (debounce=%ss)", notification_debounce_seconds)
1943 _mcp_session_pool = MCPSessionPool(
1944 max_sessions_per_key=max_sessions_per_key,
1945 session_ttl_seconds=session_ttl_seconds,
1946 health_check_interval_seconds=health_check_interval_seconds,
1947 acquire_timeout_seconds=acquire_timeout_seconds,
1948 session_create_timeout_seconds=session_create_timeout_seconds,
1949 circuit_breaker_threshold=circuit_breaker_threshold,
1950 circuit_breaker_reset_seconds=circuit_breaker_reset_seconds,
1951 identity_headers=identity_headers,
1952 identity_extractor=identity_extractor,
1953 idle_pool_eviction_seconds=idle_pool_eviction_seconds,
1954 default_transport_timeout_seconds=default_transport_timeout_seconds,
1955 health_check_methods=health_check_methods,
1956 health_check_timeout_seconds=health_check_timeout_seconds,
1957 message_handler_factory=effective_handler_factory,
1958 )
1959 logger.info("MCP session pool initialized")
1960 return _mcp_session_pool
1963async def close_mcp_session_pool() -> None:
1964 """Close the global MCP session pool and notification service."""
1965 global _mcp_session_pool # pylint: disable=global-statement
1966 if _mcp_session_pool is not None:
1967 await _mcp_session_pool.close_all()
1968 _mcp_session_pool = None
1969 logger.info("MCP session pool closed")
1971 # Close notification service if it was initialized
1972 try:
1973 # First-Party
1974 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
1975 close_notification_service,
1976 )
1978 await close_notification_service()
1979 except (ImportError, RuntimeError):
1980 pass # Notification service not initialized
1983async def start_pool_notification_service(gateway_service: Any = None) -> None:
1984 """Start the notification service background worker.
1986 Call this after gateway_service is initialized to enable event-driven refresh.
1988 Args:
1989 gateway_service: Optional GatewayService instance for triggering refreshes.
1990 """
1991 try:
1992 # First-Party
1993 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
1994 get_notification_service,
1995 )
1997 notification_svc = get_notification_service()
1998 await notification_svc.initialize(gateway_service)
1999 logger.info("MCP notification service started")
2000 except RuntimeError:
2001 logger.debug("Notification service not configured, skipping start")
2004def register_gateway_capabilities_for_notifications(gateway_id: str, capabilities: Dict[str, Any]) -> None:
2005 """Register gateway capabilities for notification handling.
2007 Call this after gateway initialization to enable list_changed notifications.
2009 Args:
2010 gateway_id: The gateway ID.
2011 capabilities: Server capabilities from initialization response.
2012 """
2013 try:
2014 # First-Party
2015 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2016 get_notification_service,
2017 )
2019 notification_svc = get_notification_service()
2020 notification_svc.register_gateway_capabilities(gateway_id, capabilities)
2021 except RuntimeError:
2022 pass # Notification service not initialized
2025def unregister_gateway_from_notifications(gateway_id: str) -> None:
2026 """Unregister a gateway from notification handling.
2028 Call this when a gateway is deleted.
2030 Args:
2031 gateway_id: The gateway ID to unregister.
2032 """
2033 try:
2034 # First-Party
2035 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2036 get_notification_service,
2037 )
2039 notification_svc = get_notification_service()
2040 notification_svc.unregister_gateway(gateway_id)
2041 except RuntimeError:
2042 pass # Notification service not initialized