Coverage for mcpgateway / services / mcp_session_pool.py: 99%
1001 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""
3MCP Session Pool Implementation.
5Provides session pooling for MCP ClientSessions to reduce per-request overhead.
6Sessions are isolated per user/tenant via identity hashing to prevent session collision.
8Performance Impact:
9 - Without pooling: 20-23ms per tool call (new session each time)
10 - With pooling: 1-2ms per tool call (10-20x improvement)
12Security:
13 Sessions are isolated by (url, identity_hash, transport_type) to prevent:
14 - Cross-user session sharing
15 - Cross-tenant data leakage
16 - Authentication bypass
18Copyright 2026
19SPDX-License-Identifier: Apache-2.0
20Authors: Mihai Criveti
21"""
23# flake8: noqa: DAR101, DAR201, DAR401
25# Future
26from __future__ import annotations
28# Standard
29import asyncio
30from contextlib import asynccontextmanager
31from dataclasses import dataclass, field
32from enum import Enum
33import hashlib
34import logging
35import os
36import re
37import socket
38import time
39from typing import Any, Callable, Dict, Optional, Set, Tuple, TYPE_CHECKING
40import uuid
42# Third-Party
43import anyio
44import httpx
45from mcp import ClientSession, McpError
46from mcp.client.sse import sse_client
47from mcp.client.streamable_http import streamablehttp_client
48from mcp.shared.session import RequestResponder
49import mcp.types as mcp_types
50import orjson
52# First-Party
53from mcpgateway.common.validators import SecurityValidator
54from mcpgateway.config import settings
55from mcpgateway.utils.internal_http import internal_loopback_base_url, internal_loopback_verify
56from mcpgateway.utils.url_auth import sanitize_url_for_logging
58# JSON-RPC standard error code for method not found
59METHOD_NOT_FOUND = -32601
61# Shared session-id validation (downstream MCP session IDs used for affinity).
62# Intentionally strict: protects Redis key/channel construction and log lines.
63_MCP_SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
65# Worker ID for multi-worker session affinity
66# Uses hostname + PID to be unique across Docker containers (each container has PID 1)
67# and across gunicorn workers within the same container
68WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
71def _get_cleanup_timeout() -> float:
72 """Get session cleanup timeout from config (lazy import to avoid circular deps).
74 This timeout controls how long to wait for session/transport __aexit__ calls
75 when closing sessions. It prevents CPU spin loops when internal tasks don't
76 respond to cancellation (anyio's _deliver_cancellation issue).
78 Returns:
79 Cleanup timeout in seconds (default: 5.0)
80 """
81 try:
82 # Lazy import to avoid circular dependency during startup
83 return settings.mcp_session_pool_cleanup_timeout
84 except Exception:
85 return 5.0 # Fallback default
88if TYPE_CHECKING:
89 # Standard
90 from collections.abc import AsyncIterator # pragma: no cover
92logger = logging.getLogger(__name__)
95class TransportType(Enum):
96 """Supported MCP transport types."""
98 SSE = "sse"
99 STREAMABLE_HTTP = "streamablehttp"
102@dataclass(eq=False) # eq=False makes instances hashable by object identity
103class PooledSession:
104 """A pooled MCP session with metadata for lifecycle management.
106 Note: eq=False is required because we store these in Sets for active session
107 tracking. This makes instances hashable by their object id (identity).
108 """
110 session: ClientSession
111 transport_context: Any # The transport context manager (kept open)
112 url: str
113 transport_type: TransportType
114 headers: Dict[str, str] # Original headers (for reconnection)
115 identity_key: str # Identity hash component for headers
116 user_identity: str = "anonymous" # for user isolation
117 gateway_id: str = "" # Gateway ID for notification attribution
118 created_at: float = field(default_factory=time.time)
119 last_used: float = field(default_factory=time.time)
120 use_count: int = 0
121 _closed: bool = field(default=False, repr=False)
122 _owner_task: Optional[asyncio.Task] = field(default=None, repr=False)
123 _shutdown_event: Optional[asyncio.Event] = field(default=None, repr=False)
125 @property
126 def age_seconds(self) -> float:
127 """Return session age in seconds.
129 Returns:
130 float: Session age in seconds since creation.
131 """
132 return time.time() - self.created_at
134 @property
135 def idle_seconds(self) -> float:
136 """Return seconds since last use.
138 Returns:
139 float: Seconds since last use of this session.
140 """
141 return time.time() - self.last_used
143 @property
144 def is_closed(self) -> bool:
145 """Return whether this session has been closed or its transport is broken.
147 Checks both the internal closed flag and the underlying transport stream
148 state to detect sessions broken by server restarts or network drops before
149 they raise ClosedResourceError at the call site.
151 Returns:
152 bool: True if session is closed or transport is broken, False otherwise.
153 """
154 if self._closed:
155 return True
156 # Check if the owner background task has died
157 if self._owner_task is not None and self._owner_task.done():
158 return True
159 # Detect externally-broken transport (e.g. server restart, network drop).
160 # MCP's BaseSession stores the write stream as _write_stream. Check it with
161 # getattr fallbacks so this degrades gracefully if MCP internals change.
162 try:
163 write_stream = getattr(self.session, "_write_stream", None)
164 if write_stream is not None:
165 if getattr(write_stream, "_closed", False) is True:
166 return True
167 state = getattr(write_stream, "_state", None)
168 if state is not None:
169 open_rx = getattr(state, "open_receive_channels", 1)
170 if isinstance(open_rx, int) and open_rx == 0:
171 return True
172 except Exception: # nosec B110 - Graceful degradation if MCP internals change
173 pass
174 return False
176 def mark_closed(self) -> None:
177 """Mark this session as closed."""
178 self._closed = True
180 @property
181 def owner_task(self) -> "Optional[asyncio.Task]":
182 """Return the background owner task, if any."""
183 return self._owner_task
185 @property
186 def shutdown_event(self) -> Optional[asyncio.Event]:
187 """Return the shutdown event for the owner task, if any."""
188 return self._shutdown_event
191# Type aliases
192# Pool key includes transport type and gateway_id to prevent returning wrong transport for same URL
193# and to ensure correct notification attribution when notifications are enabled
194PoolKey = Tuple[str, str, str, str, str] # (user_identity_hash, url, identity_hash, transport_type, gateway_id)
196# Session affinity mapping key: (mcp_session_id, url, transport_type, gateway_id)
197SessionMappingKey = Tuple[str, str, str, str]
198HttpxClientFactory = Callable[
199 [Optional[Dict[str, str]], Optional[httpx.Timeout], Optional[httpx.Auth]],
200 httpx.AsyncClient,
201]
204# Type alias for identity extractor callback
205# Extracts stable identity from headers (e.g., decode JWT to get user_id)
206IdentityExtractor = Callable[[Dict[str, str]], Optional[str]]
208# Type alias for message handler factory
209# Factory that creates message handlers given URL and optional gateway_id
210# The handler receives ServerNotification, ServerRequest responders, or Exceptions
211MessageHandlerFactory = Callable[
212 [str, Optional[str]], # (url, gateway_id)
213 Callable[
214 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception],
215 Any, # Coroutine
216 ],
217]
220class MCPSessionPool: # pylint: disable=too-many-instance-attributes
221 """
222 Pool of MCP ClientSessions keyed by (user_identity, server URL, identity hash, transport type, gateway_id).
224 Thread-Safety:
225 This pool is designed for asyncio concurrency. It uses asyncio.Lock
226 for synchronization, which is safe for coroutine-based concurrency
227 but NOT for multi-threaded access.
229 Session Isolation:
230 Sessions are isolated per user/tenant to prevent session collision.
231 The identity hash is derived from authentication headers ensuring
232 that different users never share MCP sessions.
234 Transport Isolation:
235 Sessions are also isolated by transport type (SSE vs STREAMABLE_HTTP).
236 The same URL with different transports will use separate pools.
238 Gateway Isolation:
239 Sessions are isolated by gateway_id for correct notification attribution.
240 When notifications are enabled, each gateway gets its own pooled sessions
241 even if they share the same URL and authentication.
243 Features:
244 - Session reuse across requests (10-20x latency improvement)
245 - Per-user/tenant session isolation (prevents session collision)
246 - Per-transport session isolation (prevents transport mismatch)
247 - TTL-based expiration with configurable lifetime
248 - Health checks on acquire for stale sessions
249 - Configurable pool size per URL+identity+transport
250 - Circuit breaker for failing endpoints
251 - Idle pool key eviction to prevent unbounded growth
252 - Custom identity extractor for rotating tokens (e.g., JWT decode)
253 - Metrics for monitoring (hits, misses, evictions)
254 - Graceful shutdown with close_all()
256 Usage:
257 pool = MCPSessionPool()
259 # Use as context manager for lifecycle management
260 async with pool:
261 pooled = await pool.acquire(url, headers)
262 try:
263 result = await pooled.session.call_tool("my_tool", {})
264 finally:
265 await pool.release(pooled)
267 # With custom identity extractor for JWT tokens:
268 def extract_user_id(headers: dict) -> str:
269 token = headers.get("Authorization", "").replace("Bearer ", "")
270 claims = jwt.decode(token, options={"verify_signature": False})
271 return claims.get("sub") or claims.get("user_id")
273 pool = MCPSessionPool(identity_extractor=extract_user_id)
274 """
276 # Headers that contribute to session identity (case-insensitive)
277 DEFAULT_IDENTITY_HEADERS: frozenset[str] = frozenset(
278 [
279 "authorization",
280 "x-tenant-id",
281 "x-user-id",
282 "x-api-key",
283 "cookie",
284 "x-mcp-session-id",
285 ]
286 )
288 def __init__(
289 self,
290 max_sessions_per_key: int = 10,
291 session_ttl_seconds: float = 300.0,
292 health_check_interval_seconds: float = 60.0,
293 acquire_timeout_seconds: float = 30.0,
294 session_create_timeout_seconds: float = 30.0,
295 circuit_breaker_threshold: int = 5,
296 circuit_breaker_reset_seconds: float = 60.0,
297 identity_headers: Optional[frozenset[str]] = None,
298 identity_extractor: Optional[IdentityExtractor] = None,
299 idle_pool_eviction_seconds: float = 600.0,
300 default_transport_timeout_seconds: float = 30.0,
301 health_check_methods: Optional[list[str]] = None,
302 health_check_timeout_seconds: float = 5.0,
303 message_handler_factory: Optional[MessageHandlerFactory] = None,
304 ):
305 """
306 Initialize the session pool.
308 Args:
309 max_sessions_per_key: Maximum pooled sessions per (URL, identity, transport).
310 session_ttl_seconds: Session TTL in seconds before forced expiration.
311 health_check_interval_seconds: Seconds of idle time before health check.
312 acquire_timeout_seconds: Timeout for waiting when pool is exhausted.
313 session_create_timeout_seconds: Timeout for creating new sessions.
314 circuit_breaker_threshold: Consecutive failures before circuit opens.
315 circuit_breaker_reset_seconds: Seconds before circuit breaker resets.
316 identity_headers: Headers that contribute to identity hash.
317 identity_extractor: Optional callback to extract stable identity from headers.
318 Use this when tokens rotate frequently (e.g., short-lived JWTs).
319 Should return a stable user/tenant ID string.
320 idle_pool_eviction_seconds: Evict empty pool keys after this many seconds of no use.
321 default_transport_timeout_seconds: Default timeout for transport connections.
322 health_check_methods: Ordered list of health check methods to try.
323 Options: ping, list_tools, list_prompts, list_resources, skip.
324 Default: ["ping", "skip"] (try ping, skip if unsupported).
325 health_check_timeout_seconds: Timeout for each health check attempt.
326 message_handler_factory: Optional factory for creating message handlers.
327 Called with (url, gateway_id) to create handlers for
328 each new session. Enables notification handling.
329 """
330 # Configuration
331 self._max_sessions = max_sessions_per_key
332 self._session_ttl = session_ttl_seconds
333 self._health_check_interval = health_check_interval_seconds
334 self._acquire_timeout = acquire_timeout_seconds
335 self._session_create_timeout = session_create_timeout_seconds
336 self._circuit_breaker_threshold = circuit_breaker_threshold
337 self._circuit_breaker_reset = circuit_breaker_reset_seconds
338 self._identity_headers = identity_headers or self.DEFAULT_IDENTITY_HEADERS
339 self._identity_extractor = identity_extractor
340 self._idle_pool_eviction = idle_pool_eviction_seconds
341 self._default_transport_timeout = default_transport_timeout_seconds
342 self._health_check_methods = health_check_methods or ["ping", "skip"]
343 self._health_check_timeout = health_check_timeout_seconds
344 self._message_handler_factory = message_handler_factory
346 # State - protected by _global_lock for creation, per-key locks for access
347 self._global_lock = asyncio.Lock()
348 self._pools: Dict[PoolKey, asyncio.Queue[PooledSession]] = {}
349 self._active: Dict[PoolKey, Set[PooledSession]] = {}
350 self._locks: Dict[PoolKey, asyncio.Lock] = {}
351 self._semaphores: Dict[PoolKey, asyncio.Semaphore] = {}
352 self._pool_last_used: Dict[PoolKey, float] = {} # Track last use time per pool key
354 # Circuit breaker state
355 self._failures: Dict[str, int] = {} # url -> consecutive failure count
356 self._circuit_open_until: Dict[str, float] = {} # url -> timestamp
358 # Eviction throttling - only run eviction once per interval
359 self._last_eviction_run: float = 0.0
360 self._eviction_run_interval: float = 60.0 # Run eviction at most every 60 seconds
362 # Metrics
363 self._hits = 0
364 self._misses = 0
365 self._evictions = 0
366 self._health_check_failures = 0
367 self._circuit_breaker_trips = 0
368 self._pool_keys_evicted = 0
369 self._sessions_reaped = 0 # Sessions closed during background eviction
370 self._anonymous_identity_count = 0 # Count of requests with no identity headers
372 # Lifecycle
373 self._closed = False
375 # Pre-registered session mappings for session affinity
376 # Mapping from (mcp_session_id, url, transport_type, gateway_id) -> pool_key
377 # Set by broadcast() before acquire() is called to enable session affinity lookup
378 self._mcp_session_mapping: Dict[SessionMappingKey, PoolKey] = {}
379 self._mcp_session_mapping_lock = asyncio.Lock()
381 # Multi-worker session affinity via Redis pub/sub
382 # Track pending responses for forwarded RPC requests
383 self._rpc_listener_task: Optional[asyncio.Task[None]] = None
384 self._heartbeat_task: Optional[asyncio.Task[None]] = None
386 # Session affinity metrics
387 self._session_affinity_local_hits = 0
388 self._session_affinity_redis_hits = 0
389 self._session_affinity_misses = 0
390 self._forwarded_requests = 0
391 self._forwarded_request_failures = 0
392 self._forwarded_request_timeouts = 0
394 async def __aenter__(self) -> "MCPSessionPool":
395 """Async context manager entry.
397 Returns:
398 MCPSessionPool: This pool instance.
399 """
400 self.start_heartbeat()
401 return self
403 async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
404 """Async context manager exit - closes all sessions.
406 Args:
407 exc_type: Exception type if an exception was raised.
408 exc_val: Exception value if an exception was raised.
409 exc_tb: Exception traceback if an exception was raised.
410 """
411 await self.close_all()
413 def _compute_identity_hash(self, headers: Optional[Dict[str, str]]) -> str:
414 """
415 Compute a hash of identity-relevant headers.
417 This ensures sessions are isolated per user/tenant. Different users
418 with different Authorization headers will have different identity hashes
419 and thus separate session pools.
421 Identity resolution order:
422 1. Custom identity_extractor (if configured) - for rotating tokens like JWTs
423 2. x-mcp-session-id header (if present) - for session affinity, ensures
424 requests with the same downstream session ID get the same upstream
425 session even when JWT tokens rotate (different jti values)
426 3. Configured identity headers - fallback to hashing all identity headers
428 Args:
429 headers: Request headers dict.
431 Returns:
432 Identity hash string, or "anonymous" if no identity headers present.
433 """
434 if not headers:
435 self._anonymous_identity_count += 1
436 logger.debug("Session pool identity collapsed to 'anonymous' (no headers provided). " + "Sessions will be shared. Ensure this is intentional for stateless MCP servers.")
437 return "anonymous"
439 # Try custom identity extractor first (for rotating tokens like JWTs)
440 if self._identity_extractor:
441 try:
442 extracted = self._identity_extractor(headers)
443 if extracted:
444 return hashlib.sha256(extracted.encode()).hexdigest()
445 except Exception as e:
446 logger.debug(f"Identity extractor failed, falling back to header hash: {e}")
448 # Normalize headers for case-insensitive lookup
449 headers_lower = {k.lower(): v for k, v in headers.items()}
451 # Session affinity: prioritize x-mcp-session-id for stable identity
452 # When present, use ONLY the session ID for identity hash. This ensures
453 # requests with the same downstream session ID get the same upstream session,
454 # even when JWT tokens rotate (different jti values per request).
455 if settings.mcpgateway_session_affinity_enabled:
456 session_id = headers_lower.get("x-mcp-session-id")
457 if session_id:
458 logger.debug(f"Using x-mcp-session-id for session affinity: {session_id[:8]}...")
459 return hashlib.sha256(session_id.encode()).hexdigest()
461 # Fallback: extract identity from configured headers
462 identity_parts = []
464 for header in sorted(self._identity_headers):
465 if header in headers_lower:
466 identity_parts.append(f"{header}:{headers_lower[header]}")
468 if not identity_parts:
469 self._anonymous_identity_count += 1
470 logger.debug(
471 "Session pool identity collapsed to 'anonymous' (no identity headers found). " + "Expected headers: %s. Sessions will be shared.",
472 list(self._identity_headers),
473 )
474 return "anonymous"
476 # Create a stable, deterministic hash using JSON serialization
477 # Prevents delimiter-collision or injection issues present in string joining
478 serialized_identity = orjson.dumps(identity_parts)
479 return hashlib.sha256(serialized_identity).hexdigest()
481 def _make_pool_key(
482 self,
483 url: str,
484 headers: Optional[Dict[str, str]],
485 transport_type: TransportType,
486 user_identity: str,
487 gateway_id: Optional[str] = None,
488 ) -> PoolKey:
489 """Create composite pool key from URL, identity, transport type, user identity, and gateway_id.
491 Including gateway_id ensures correct notification attribution when multiple gateways
492 share the same URL/auth. Sessions are isolated per gateway for proper event routing.
493 """
494 identity_hash = self._compute_identity_hash(headers)
496 # Anonymize user identity by hashing it (unless it's commonly "anonymous")
497 # Use full hash for collision resistance - truncate only for display in logs/metrics
498 if user_identity == "anonymous":
499 user_hash = "anonymous"
500 else:
501 user_hash = hashlib.sha256(user_identity.encode()).hexdigest()
503 # Use empty string for None gateway_id to maintain consistent key type
504 gw_id = gateway_id or ""
506 return (user_hash, url, identity_hash, transport_type.value, gw_id)
508 async def _get_or_create_lock(self, pool_key: PoolKey) -> asyncio.Lock:
509 """Get or create a lock for the given pool key (thread-safe)."""
510 async with self._global_lock:
511 if pool_key not in self._locks:
512 self._locks[pool_key] = asyncio.Lock()
513 return self._locks[pool_key]
515 async def _get_or_create_pool(self, pool_key: PoolKey) -> asyncio.Queue[PooledSession]:
516 """Get or create a pool queue for the given key (thread-safe)."""
517 async with self._global_lock:
518 if pool_key not in self._pools:
519 self._pools[pool_key] = asyncio.Queue(maxsize=self._max_sessions)
520 self._active[pool_key] = set()
521 self._semaphores[pool_key] = asyncio.Semaphore(self._max_sessions)
522 return self._pools[pool_key]
524 def _is_circuit_open(self, url: str) -> bool:
525 """Check if circuit breaker is open for a URL."""
526 if url not in self._circuit_open_until:
527 return False
528 if time.time() >= self._circuit_open_until[url]:
529 # Circuit breaker reset
530 del self._circuit_open_until[url]
531 self._failures[url] = 0
532 logger.info(f"Circuit breaker reset for {sanitize_url_for_logging(url)}")
533 return False
534 return True
536 def _record_failure(self, url: str) -> None:
537 """Record a failure and potentially trip circuit breaker."""
538 self._failures[url] = self._failures.get(url, 0) + 1
539 if self._failures[url] >= self._circuit_breaker_threshold:
540 self._circuit_open_until[url] = time.time() + self._circuit_breaker_reset
541 self._circuit_breaker_trips += 1
542 logger.warning(f"Circuit breaker opened for {sanitize_url_for_logging(url)} after {self._failures[url]} failures. " f"Will reset in {self._circuit_breaker_reset}s")
544 def _record_success(self, url: str) -> None:
545 """Record a success, resetting failure count."""
546 self._failures[url] = 0
548 @staticmethod
549 def is_valid_mcp_session_id(session_id: str) -> bool:
550 """Validate downstream MCP session ID format for affinity.
552 Used for:
553 - Redis key construction (ownership + mapping)
554 - Pub/Sub channel naming
555 - Avoiding log spam / injection
556 """
557 if not session_id:
558 return False
559 return bool(_MCP_SESSION_ID_PATTERN.match(session_id))
561 def _sanitize_redis_key_component(self, value: str) -> str:
562 """Sanitize a value for use in Redis key construction.
564 Replaces any characters that could cause key collision or injection.
566 Args:
567 value: The value to sanitize.
569 Returns:
570 Sanitized value safe for Redis key construction.
571 """
572 if not value:
573 return ""
575 # Replace problematic characters with underscores
576 return re.sub(r"[^a-zA-Z0-9_-]", "_", value)
578 def _session_mapping_redis_key(self, mcp_session_id: str, url: str, transport_type: str, gateway_id: str) -> str:
579 """Compute a bounded Redis key for session mapping.
581 The URL is hashed to keep keys small and avoid special character issues.
582 """
583 sanitized_session_id = self._sanitize_redis_key_component(mcp_session_id)
584 url_hash = hashlib.sha256(url.encode()).hexdigest()[:16]
585 return f"mcpgw:session_mapping:{sanitized_session_id}:{url_hash}:{transport_type}:{gateway_id}"
587 @staticmethod
588 def _pool_owner_key(mcp_session_id: str) -> str:
589 """Return Redis key for session ownership tracking."""
590 return f"mcpgw:pool_owner:{mcp_session_id}"
592 def _worker_heartbeat_key(self) -> str:
593 """Redis key for this worker's heartbeat."""
594 return f"mcpgw:worker_heartbeat:{WORKER_ID}"
596 def start_heartbeat(self) -> None:
597 """Start the worker heartbeat background task.
599 Must be called from an async context. Safe to call multiple times;
600 subsequent calls are no-ops if the heartbeat is already running.
601 """
602 if not settings.mcpgateway_session_affinity_enabled:
603 return
604 if self._heartbeat_task is None or self._heartbeat_task.done():
605 self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
607 async def _run_heartbeat_loop(self) -> None:
608 """Maintain worker heartbeat in Redis."""
609 # First-Party
610 from mcpgateway.utils.redis_client import get_redis_client
612 while not self._closed:
613 try:
614 redis = await get_redis_client()
615 if redis:
616 # Refresh heartbeat with 30s TTL (much shorter than session TTL)
617 await redis.setex(self._worker_heartbeat_key(), 30, "alive")
618 except Exception as e:
619 logger.debug(f"Heartbeat update failed: {e}")
621 await asyncio.sleep(10) # Refresh every 10s
623 async def _is_worker_alive(self, worker_id: str) -> bool:
624 """Check if a worker is alive via heartbeat."""
625 try:
626 # First-Party
627 from mcpgateway.utils.redis_client import get_redis_client
629 redis = await get_redis_client()
630 if not redis:
631 return True # Assume alive if Redis unavailable
633 heartbeat_key = f"mcpgw:worker_heartbeat:{worker_id}"
634 return await redis.exists(heartbeat_key) > 0
635 except Exception:
636 return True # Fail open
638 async def register_session_mapping(
639 self,
640 mcp_session_id: str,
641 url: str,
642 gateway_id: str,
643 transport_type: str,
644 user_email: Optional[str] = None,
645 ) -> None:
646 """Pre-register session mapping for session affinity.
648 Called from respond() to set up mapping BEFORE acquire() is called.
649 This ensures acquire() can find the correct pool key for session affinity.
651 The mapping stores the relationship between an incoming MCP session ID
652 and the pool key that should be used for upstream connections. This
653 enables session affinity even when JWT tokens rotate (different jti values
654 per request).
656 For multi-worker deployments, the mapping is also stored in Redis with TTL
657 so that any worker can look it up during acquire().
659 Args:
660 mcp_session_id: The downstream MCP session ID from x-mcp-session-id header.
661 url: The upstream MCP server URL.
662 gateway_id: The gateway ID.
663 transport_type: The transport type (sse, streamablehttp).
664 user_email: The email of the authenticated user (or "system" for unauthenticated).
665 """
666 if not settings.mcpgateway_session_affinity_enabled:
667 return
669 # Validate mcp_session_id to prevent Redis key injection
670 if not self.is_valid_mcp_session_id(mcp_session_id):
671 logger.warning(f"Invalid mcp_session_id format, skipping session mapping: {mcp_session_id[:20]}...")
672 return
674 # Use user email for user_identity, or "anonymous" if not provided
675 user_identity = user_email or "anonymous"
677 # Normalize gateway_id to empty string if None for consistent key matching
678 normalized_gateway_id = gateway_id or ""
680 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type, normalized_gateway_id)
682 # Compute what the pool_key will be for this session
683 # Use mcp_session_id as the identity basis for affinity
684 identity_hash = hashlib.sha256(mcp_session_id.encode()).hexdigest()
686 # Hash user identity for privacy (unless it's "anonymous")
687 if user_identity == "anonymous":
688 user_hash = "anonymous"
689 else:
690 user_hash = hashlib.sha256(user_identity.encode()).hexdigest()
692 pool_key: PoolKey = (user_hash, url, identity_hash, transport_type, normalized_gateway_id)
694 # Store in local memory
695 async with self._mcp_session_mapping_lock:
696 self._mcp_session_mapping[mapping_key] = pool_key
697 logger.debug(f"Session affinity pre-registered (local): {mcp_session_id[:8]}... → {url}, user={SecurityValidator.sanitize_log_message(user_identity)}")
699 # Store in Redis for multi-worker support AND register ownership atomically
700 # Registering ownership HERE (during mapping) instead of in acquire() prevents
701 # a race condition where two workers could both start creating sessions before
702 # either registers ownership
703 try:
704 # First-Party
705 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
707 redis = await get_redis_client()
708 if redis:
709 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type, normalized_gateway_id)
711 # Store pool_key as JSON for easy deserialization
712 pool_key_data = {
713 "user_hash": user_hash,
714 "url": url,
715 "identity_hash": identity_hash,
716 "transport_type": transport_type,
717 "gateway_id": normalized_gateway_id,
718 }
719 await redis.setex(redis_key, settings.mcpgateway_session_affinity_ttl, orjson.dumps(pool_key_data)) # TTL from config
721 # CRITICAL: Register ownership atomically with mapping.
722 # This claims ownership BEFORE any session creation attempt, preventing
723 # the race condition where two workers both start creating sessions
724 owner_key = self._pool_owner_key(mcp_session_id)
725 # Atomic claim with TTL (avoids the SETNX/EXPIRE crash window).
726 was_set = await redis.set(owner_key, WORKER_ID, nx=True, ex=settings.mcpgateway_session_affinity_ttl)
727 if was_set:
728 logger.debug(f"Session ownership claimed (SET NX): {mcp_session_id[:8]}... → worker {WORKER_ID}")
729 else:
730 # Another worker already claimed ownership
731 existing_owner = await redis.get(owner_key)
732 owner_id = existing_owner.decode() if isinstance(existing_owner, bytes) else existing_owner
733 logger.debug(f"Session ownership already claimed by {owner_id}: {mcp_session_id[:8]}...")
735 logger.debug(f"Session affinity pre-registered (Redis): {mcp_session_id[:8]}... TTL={settings.mcpgateway_session_affinity_ttl}s")
736 except Exception as e:
737 # Redis failure is non-fatal - local mapping still works for same-worker requests
738 logger.debug(f"Failed to store session mapping in Redis: {e}")
740 async def acquire(
741 self,
742 url: str,
743 headers: Optional[Dict[str, str]] = None,
744 transport_type: TransportType = TransportType.STREAMABLE_HTTP,
745 httpx_client_factory: Optional[HttpxClientFactory] = None,
746 timeout: Optional[float] = None,
747 user_identity: Optional[str] = None,
748 gateway_id: Optional[str] = None,
749 ) -> PooledSession:
750 """
751 Acquire a session for the given URL, identity, and transport type.
753 Sessions are isolated by identity (derived from auth headers) AND
754 transport type. Returns an initialized, healthy session ready for tool calls.
756 Args:
757 url: The MCP server URL.
758 headers: Request headers (used for identity hashing and passed to server).
759 transport_type: The transport type (SSE or STREAMABLE_HTTP).
760 httpx_client_factory: Optional factory for creating httpx clients
761 (for custom SSL/timeout configuration).
762 timeout: Optional timeout in seconds for transport connection.
763 gateway_id: Optional gateway ID for notification handler context.
765 Returns:
766 PooledSession ready for use.
768 Raises:
769 asyncio.TimeoutError: If acquire times out waiting for available session.
770 RuntimeError: If pool is closed or circuit breaker is open.
771 Exception: If session creation fails.
772 """
773 if self._closed:
774 raise RuntimeError("Session pool is closed")
776 if self._is_circuit_open(url):
777 raise RuntimeError(f"Circuit breaker open for {url}")
779 # Use default timeout if not provided
780 effective_timeout = timeout if timeout is not None else self._default_transport_timeout
782 user_id = user_identity or "anonymous"
783 pool_key: Optional[PoolKey] = None
785 # Check pre-registered mapping first (set by respond() for session affinity)
786 if settings.mcpgateway_session_affinity_enabled and headers:
787 headers_lower = {k.lower(): v for k, v in headers.items()}
788 mcp_session_id = headers_lower.get("x-mcp-session-id")
789 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
790 normalized_gateway_id = gateway_id or ""
791 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type.value, normalized_gateway_id)
793 # Check local memory first (fast path - same worker)
794 async with self._mcp_session_mapping_lock:
795 pool_key = self._mcp_session_mapping.get(mapping_key)
796 if pool_key:
797 self._session_affinity_local_hits += 1
798 logger.debug(f"Session affinity hit (local): {mcp_session_id[:8]}...")
800 # If not in local memory, check Redis (multi-worker support)
801 if pool_key is None:
802 try:
803 # First-Party
804 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
806 redis = await get_redis_client()
807 if redis:
808 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type.value, normalized_gateway_id)
809 pool_key_data = await redis.get(redis_key)
810 if pool_key_data:
811 # Deserialize pool_key from JSON
812 data = orjson.loads(pool_key_data)
813 pool_key = (
814 data["user_hash"],
815 data["url"],
816 data["identity_hash"],
817 data["transport_type"],
818 data["gateway_id"],
819 )
820 # Cache in local memory for future requests
821 async with self._mcp_session_mapping_lock:
822 self._mcp_session_mapping[mapping_key] = pool_key
823 self._session_affinity_redis_hits += 1
824 logger.debug(f"Session affinity hit (Redis): {mcp_session_id[:8]}...")
825 except Exception as e:
826 logger.debug(f"Failed to check Redis for session mapping: {e}")
828 # Fallback to normal pool key computation
829 if pool_key is None:
830 self._session_affinity_misses += 1
831 pool_key = self._make_pool_key(url, headers, transport_type, user_id, gateway_id)
833 pool = await self._get_or_create_pool(pool_key)
835 # Update pool key last used time IMMEDIATELY after getting pool
836 # This prevents race with eviction removing keys between awaits
837 self._pool_last_used[pool_key] = time.time()
839 lock = await self._get_or_create_lock(pool_key)
841 # Guard semaphore access - eviction may have removed it between awaits
842 # If so, re-create the pool structures
843 if pool_key not in self._semaphores:
844 pool = await self._get_or_create_pool(pool_key)
845 self._pool_last_used[pool_key] = time.time()
847 semaphore = self._semaphores[pool_key]
849 # Throttled eviction - only run if enough time has passed (inline, not spawned)
850 await self._maybe_evict_idle_pool_keys()
852 # Try to get from pool first (quick path, no lock needed for queue get)
853 while True:
854 try:
855 pooled = pool.get_nowait()
856 except asyncio.QueueEmpty:
857 break
859 # Validate the session outside the lock
860 if await self._validate_session(pooled):
861 pooled.last_used = time.time()
862 pooled.use_count += 1
863 self._hits += 1
864 async with lock:
865 self._active[pool_key].add(pooled)
866 logger.debug(f"Pool hit for {sanitize_url_for_logging(url)} (identity={pool_key[2][:8]}, transport={transport_type.value})")
867 return pooled
869 # Session invalid, close it
870 await self._close_session(pooled)
871 self._evictions += 1
872 semaphore.release() # Free up a slot
874 # No valid session in pool - try to create one or wait
875 try:
876 # Use semaphore with timeout to limit concurrent sessions
877 acquired = await asyncio.wait_for(semaphore.acquire(), timeout=self._acquire_timeout)
878 if not acquired:
879 raise asyncio.TimeoutError("Failed to acquire session slot")
880 except asyncio.TimeoutError:
881 raise asyncio.TimeoutError(f"Timeout waiting for available session for {sanitize_url_for_logging(url)}") from None
883 # Create new session (semaphore acquired)
884 try:
885 # Verify we own this session before creating (prevents race condition)
886 # If another worker already claimed ownership, we should not create a new session
887 # Note: Ownership is registered atomically in register_session_mapping() using SETNX
888 if settings.mcpgateway_session_affinity_enabled and headers:
889 headers_lower = {k.lower(): v for k, v in headers.items()}
890 mcp_session_id = headers_lower.get("x-mcp-session-id")
891 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
892 owner = await self._get_pool_session_owner(mcp_session_id)
893 if owner and owner != WORKER_ID:
894 # Check if owner is still alive
895 if not await self._is_worker_alive(owner):
896 # Owner is dead - reclaim ownership with compare-and-swap
897 # First-Party
898 from mcpgateway.utils.redis_client import get_redis_client
900 redis = await get_redis_client()
901 if redis:
902 owner_key = self._pool_owner_key(mcp_session_id)
903 # Lua CAS: only reclaim if still owned by the dead worker
904 cas_script = """
905 local cur = redis.call('GET', KEYS[1])
906 if cur == ARGV[1] then
907 redis.call('SET', KEYS[1], ARGV[2], 'EX', ARGV[3])
908 return 1
909 end
910 return 0
911 """
912 ttl = int(settings.mcpgateway_session_affinity_ttl)
913 reclaimed = await redis.eval(cas_script, 1, owner_key, owner, WORKER_ID, ttl)
914 if reclaimed == 1:
915 logger.info(f"Reclaimed ownership from dead worker {owner}: {mcp_session_id[:8]}...")
916 else:
917 # Another worker already reclaimed - let it handle
918 # (outer except BaseException releases the semaphore)
919 raise RuntimeError(f"Session reclaimed by another worker: {mcp_session_id[:8]}...")
920 else:
921 # Owner is alive - should have been forwarded
922 # (outer except BaseException releases the semaphore)
923 logger.warning(f"Session {mcp_session_id[:8]}... owned by worker {owner}, not us ({WORKER_ID})")
924 raise RuntimeError(f"Session owned by another worker: {owner}")
926 pooled = await asyncio.wait_for(
927 self._create_session(url, headers, transport_type, httpx_client_factory, effective_timeout, gateway_id),
928 timeout=self._session_create_timeout,
929 )
930 # Store identity components for key reconstruction
931 pooled.identity_key = pool_key[2]
932 pooled.user_identity = user_id
934 # Note: Ownership is now registered atomically in register_session_mapping()
935 # before acquire() is called, so we don't need to register it here
937 self._misses += 1
938 self._record_success(url)
939 async with lock:
940 self._active[pool_key].add(pooled)
941 logger.debug(f"Pool miss for {sanitize_url_for_logging(url)} - created new session (transport={transport_type.value})")
942 return pooled
943 except BaseException as e:
944 # Release semaphore on ANY failure (including CancelledError)
945 semaphore.release()
946 if not isinstance(e, asyncio.CancelledError):
947 self._record_failure(url)
948 logger.warning(f"Failed to create session for {sanitize_url_for_logging(url)}: {e}")
949 raise
951 async def release(self, pooled: PooledSession, *, discard: bool = False) -> None:
952 """
953 Return a session to the pool for reuse, or discard it.
955 Args:
956 pooled: The session to release.
957 discard: If True, close the session instead of returning it to the
958 pool. Used when the caller detected a transport error
959 (e.g. ``ClosedResourceError``) to prevent recycling a
960 broken session.
961 """
962 # Treat already-closed sessions (e.g. dead owner task) as discards —
963 # still need to remove from _active and release the semaphore slot.
964 if pooled.is_closed:
965 discard = True
967 # Pool key includes transport type, user identity, and gateway_id
968 # Re-compute user hash from stored raw identity (full hash for collision resistance)
969 user_hash = "anonymous"
970 if pooled.user_identity != "anonymous":
971 user_hash = hashlib.sha256(pooled.user_identity.encode()).hexdigest()
973 pool_key = (user_hash, pooled.url, pooled.identity_key, pooled.transport_type.value, pooled.gateway_id)
974 lock = await self._get_or_create_lock(pool_key)
975 pool = await self._get_or_create_pool(pool_key)
977 async with lock:
978 # Update last-used FIRST to prevent eviction race:
979 # If eviction runs between removing from _active and putting back in pool,
980 # it would see key as idle + inactive and evict it. By updating last-used
981 # while still holding the lock and before removing from _active, we ensure
982 # eviction sees recent activity.
983 self._pool_last_used[pool_key] = time.time()
984 self._active.get(pool_key, set()).discard(pooled)
986 # Discard broken sessions instead of recycling them
987 if discard:
988 logger.debug(f"Discarding broken session for {sanitize_url_for_logging(pooled.url)}")
989 await self._close_session(pooled)
990 if pool_key in self._semaphores:
991 self._semaphores[pool_key].release()
992 self._evictions += 1
993 return
995 # Check if session should be returned to pool
996 if self._closed or pooled.age_seconds > self._session_ttl:
997 await self._close_session(pooled)
998 if pool_key in self._semaphores:
999 self._semaphores[pool_key].release()
1000 if pooled.age_seconds > self._session_ttl:
1001 self._evictions += 1
1002 return
1004 # Return to pool (pool may have been evicted in edge case, recreate if needed)
1005 if pool_key not in self._pools:
1006 pool = await self._get_or_create_pool(pool_key)
1007 self._pool_last_used[pool_key] = time.time()
1009 try:
1010 pool.put_nowait(pooled)
1011 logger.debug(f"Session returned to pool for {sanitize_url_for_logging(pooled.url)}")
1012 except asyncio.QueueFull:
1013 # Pool full (shouldn't happen with semaphore), close session
1014 await self._close_session(pooled)
1015 if pool_key in self._semaphores:
1016 self._semaphores[pool_key].release()
1018 async def _maybe_evict_idle_pool_keys(self) -> None:
1019 """
1020 Reap stale sessions and evict idle pool keys.
1022 This method is throttled - it only runs eviction if enough time has
1023 passed since the last run (default: 60 seconds). This prevents:
1024 - Unbounded task spawning on every acquire
1025 - Lock contention under high load
1027 Two-phase cleanup:
1028 1. Close expired/stale sessions parked in idle pools (frees connections)
1029 2. Evict pool keys that are now empty and have no active sessions
1031 This prevents unbounded connection and pool key growth when using
1032 rotating tokens (e.g., short-lived JWTs with unique identifiers).
1033 """
1034 if self._closed:
1035 return
1037 now = time.time()
1039 # Throttle: only run eviction once per interval
1040 if now - self._last_eviction_run < self._eviction_run_interval:
1041 return
1043 self._last_eviction_run = now
1045 # Collect sessions to close and keys to evict (minimize time holding lock)
1046 sessions_to_close: list[PooledSession] = []
1047 keys_to_evict: list[PoolKey] = []
1049 async with self._global_lock:
1050 for pool_key, last_used in list(self._pool_last_used.items()):
1051 # Skip recently-used pools
1052 if now - last_used < self._idle_pool_eviction:
1053 continue
1055 pool = self._pools.get(pool_key)
1056 active = self._active.get(pool_key, set())
1058 # Skip if there are active sessions (in use)
1059 if active:
1060 continue
1062 if pool:
1063 # Phase 1: Drain and collect expired/stale sessions from idle pools
1064 while not pool.empty():
1065 try:
1066 session = pool.get_nowait()
1067 # Close if expired OR idle too long (defense in depth)
1068 if session.age_seconds > self._session_ttl or session.idle_seconds > self._idle_pool_eviction:
1069 sessions_to_close.append(session)
1070 # Release semaphore slot for this session
1071 if pool_key in self._semaphores:
1072 self._semaphores[pool_key].release()
1073 else:
1074 # Session still valid, put it back
1075 pool.put_nowait(session)
1076 break # Stop draining if we find a valid session
1077 except asyncio.QueueEmpty:
1078 break
1080 # Phase 2: Evict pool key if now empty
1081 if pool.empty():
1082 keys_to_evict.append(pool_key)
1084 # Remove evicted keys from all tracking dicts
1085 for pool_key in keys_to_evict:
1086 self._pools.pop(pool_key, None)
1087 self._active.pop(pool_key, None)
1088 self._locks.pop(pool_key, None)
1089 self._semaphores.pop(pool_key, None)
1090 self._pool_last_used.pop(pool_key, None)
1092 # Clean up session mappings pointing to this evicted pool key
1093 async with self._mcp_session_mapping_lock:
1094 stale_mappings = [k for k, v in self._mcp_session_mapping.items() if v == pool_key]
1095 for mapping_key in stale_mappings:
1096 self._mcp_session_mapping.pop(mapping_key, None)
1098 self._pool_keys_evicted += 1
1099 logger.debug(f"Evicted idle pool key: {pool_key[0][:8]}|{pool_key[1]}|{pool_key[2][:8]}")
1101 # Close sessions outside the lock (I/O operations)
1102 for session in sessions_to_close:
1103 await self._close_session(session)
1104 self._sessions_reaped += 1
1105 logger.debug(f"Reaped stale session for {sanitize_url_for_logging(session.url)} (age={session.age_seconds:.1f}s)")
1107 async def _validate_session(self, pooled: PooledSession) -> bool:
1108 """
1109 Validate a session is still usable.
1111 Checks TTL and performs health check if session is stale.
1113 Args:
1114 pooled: The session to validate.
1116 Returns:
1117 True if session is valid, False otherwise.
1118 """
1119 if pooled.is_closed:
1120 return False
1122 # Check TTL
1123 if pooled.age_seconds > self._session_ttl:
1124 logger.debug(f"Session expired (age={pooled.age_seconds:.1f}s)")
1125 return False
1127 # Health check if stale
1128 if pooled.idle_seconds > self._health_check_interval:
1129 return await self._run_health_check_chain(pooled)
1131 return True
1133 async def _run_health_check_chain(self, pooled: PooledSession) -> bool:
1134 """
1135 Run health check methods in configured order until one succeeds.
1137 The health check chain allows configuring which methods to try and in what order.
1138 This supports both modern servers (with ping support) and legacy servers
1139 (that may only support list_tools or no health check at all).
1141 Args:
1142 pooled: The session to health check.
1144 Returns:
1145 True if any health check method succeeds, False if all fail.
1146 """
1147 for method in self._health_check_methods:
1148 try:
1149 if method == "ping":
1150 with anyio.fail_after(self._health_check_timeout):
1151 await pooled.session.send_ping()
1152 logger.debug(f"Health check passed: ping (url={sanitize_url_for_logging(pooled.url)})")
1153 return True
1154 if method == "list_tools":
1155 with anyio.fail_after(self._health_check_timeout):
1156 await pooled.session.list_tools()
1157 logger.debug(f"Health check passed: list_tools (url={sanitize_url_for_logging(pooled.url)})")
1158 return True
1159 if method == "list_prompts":
1160 with anyio.fail_after(self._health_check_timeout):
1161 await pooled.session.list_prompts()
1162 logger.debug(f"Health check passed: list_prompts (url={sanitize_url_for_logging(pooled.url)})")
1163 return True
1164 if method == "list_resources":
1165 with anyio.fail_after(self._health_check_timeout):
1166 await pooled.session.list_resources()
1167 logger.debug(f"Health check passed: list_resources (url={sanitize_url_for_logging(pooled.url)})")
1168 return True
1169 if method == "skip":
1170 logger.debug(f"Health check skipped per configuration (url={sanitize_url_for_logging(pooled.url)})")
1171 return True
1172 logger.warning(f"Unknown health check method '{method}', skipping")
1173 continue
1175 except McpError as e:
1176 # METHOD_NOT_FOUND (-32601) means the method isn't supported - try next
1177 if e.error.code == METHOD_NOT_FOUND:
1178 logger.debug(f"Health check method '{method}' not supported by server, trying next")
1179 continue
1180 # Other MCP errors are real failures
1181 logger.debug(f"Health check '{method}' failed with MCP error: {e}")
1182 self._health_check_failures += 1
1183 return False
1185 except TimeoutError:
1186 logger.debug(f"Health check '{method}' timed out after {self._health_check_timeout}s, trying next")
1187 continue
1189 except Exception as e:
1190 logger.debug(f"Health check '{method}' failed: {e}")
1191 self._health_check_failures += 1
1192 return False
1194 # All methods failed or were unsupported
1195 logger.warning(f"All health check methods failed or unsupported (methods={self._health_check_methods})")
1196 self._health_check_failures += 1
1197 return False
1199 async def _session_owner_coro(
1200 self,
1201 url: str,
1202 transport_type: TransportType,
1203 merged_headers: Dict[str, str],
1204 httpx_client_factory: Optional[HttpxClientFactory],
1205 timeout: Optional[float],
1206 gateway_id: Optional[str],
1207 ready_future: "asyncio.Future[Tuple[ClientSession, Any]]",
1208 shutdown_event: asyncio.Event,
1209 ) -> None:
1210 """Background task that owns the transport and session lifecycle.
1212 Runs transport and session inside proper ``async with`` blocks so that
1213 anyio cancel scopes are bound to THIS task, not the request handler.
1214 Signals readiness via *ready_future*, then blocks on *shutdown_event*
1215 until the pool requests cleanup.
1216 """
1217 try:
1218 # Build transport context
1219 if transport_type == TransportType.SSE:
1220 if httpx_client_factory:
1221 transport_ctx = sse_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout)
1222 else:
1223 transport_ctx = sse_client(url=url, headers=merged_headers, timeout=timeout)
1224 else: # STREAMABLE_HTTP
1225 if httpx_client_factory:
1226 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout)
1227 else:
1228 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, timeout=timeout)
1230 async with transport_ctx as streams:
1231 if transport_type == TransportType.SSE:
1232 read_stream, write_stream = streams[0], streams[1]
1233 else:
1234 read_stream, write_stream = streams[0], streams[1]
1236 # Create message handler if factory is configured
1237 message_handler = None
1238 if self._message_handler_factory:
1239 try:
1240 message_handler = self._message_handler_factory(url, gateway_id)
1241 logger.debug(f"Created message handler for session {sanitize_url_for_logging(url)} (gateway={SecurityValidator.sanitize_log_message(gateway_id)})")
1242 except Exception as e:
1243 logger.warning(f"Failed to create message handler for {sanitize_url_for_logging(url)}: {e}")
1245 async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1246 await session.initialize()
1247 # Signal the session is ready for use
1248 if not ready_future.done():
1249 ready_future.set_result((session, transport_ctx))
1250 logger.info(f"Created new MCP session for {sanitize_url_for_logging(url)} (transport={transport_type.value})")
1251 # Block here until the pool asks us to shut down
1252 await shutdown_event.wait()
1253 # async with ClientSession exits: session.__aexit__ unwinds properly
1254 # async with transport_ctx exits: transport.__aexit__ unwinds properly
1256 except BaseException as exc:
1257 if not ready_future.done():
1258 ready_future.set_exception(RuntimeError(f"Failed to create MCP session for {url}: {exc}"))
1259 # Let the task finish; the pool will detect _owner_task.done()
1261 async def _create_session(
1262 self,
1263 url: str,
1264 headers: Optional[Dict[str, str]],
1265 transport_type: TransportType,
1266 httpx_client_factory: Optional[HttpxClientFactory],
1267 timeout: Optional[float] = None,
1268 gateway_id: Optional[str] = None,
1269 ) -> PooledSession:
1270 """
1271 Create a new initialized MCP session via a dedicated background task.
1273 The transport and session contexts are entered inside a background
1274 ``asyncio.Task`` so that their anyio cancel scopes are bound to that
1275 task, not to the HTTP request handler. This prevents child-task
1276 failures in the transport's TaskGroup from cancelling the request.
1278 Args:
1279 url: Server URL.
1280 headers: Request headers.
1281 transport_type: Transport type to use.
1282 httpx_client_factory: Optional factory for httpx clients.
1283 timeout: Optional timeout in seconds for transport connection.
1284 gateway_id: Optional gateway ID for notification handler context.
1286 Returns:
1287 Initialized PooledSession.
1289 Raises:
1290 RuntimeError: If session creation or initialization fails.
1291 asyncio.CancelledError: If cancelled during creation.
1292 """
1293 # Merge headers with defaults
1294 merged_headers = {"Accept": "application/json, text/event-stream"}
1295 if headers:
1296 merged_headers.update(headers)
1298 # Strip gateway-internal session affinity headers before sending to upstream
1299 keys_to_remove = [k for k in merged_headers if k.lower() in ("x-mcp-session-id", "mcp-session-id")]
1300 for k in keys_to_remove:
1301 del merged_headers[k]
1303 identity_key = self._compute_identity_hash(headers)
1304 shutdown_event = asyncio.Event()
1305 loop = asyncio.get_running_loop()
1306 ready_future: asyncio.Future[Tuple[ClientSession, Any]] = loop.create_future()
1308 owner_task = asyncio.create_task(
1309 self._session_owner_coro(url, transport_type, merged_headers, httpx_client_factory, timeout, gateway_id, ready_future, shutdown_event),
1310 name=f"mcp-session-owner-{sanitize_url_for_logging(url)}",
1311 )
1313 success = False
1314 try:
1315 session, transport_ctx = await asyncio.wait_for(ready_future, timeout=self._session_create_timeout)
1316 success = True
1317 finally:
1318 # Clean up owner task on ANY failure (TimeoutError, Exception, CancelledError)
1319 if not success:
1320 shutdown_event.set()
1321 owner_task.cancel()
1322 cleanup_timeout = _get_cleanup_timeout()
1323 with anyio.move_on_after(cleanup_timeout):
1324 try:
1325 await owner_task
1326 except BaseException: # nosec B110 - Best effort cleanup
1327 pass
1329 return PooledSession(
1330 session=session,
1331 transport_context=transport_ctx,
1332 url=url,
1333 transport_type=transport_type,
1334 headers=merged_headers,
1335 identity_key=identity_key,
1336 gateway_id=gateway_id or "",
1337 _owner_task=owner_task,
1338 _shutdown_event=shutdown_event,
1339 )
1341 async def _close_session(self, pooled: PooledSession) -> None:
1342 """
1343 Close a session and its transport.
1345 For sessions with a background owner task, signals the task to shut down
1346 and waits for it to complete (which unwinds the ``async with`` contexts
1347 naturally). Falls back to manual ``__aexit__`` for legacy sessions
1348 without an owner task.
1350 Args:
1351 pooled: The session to close.
1352 """
1353 if pooled.is_closed and pooled.shutdown_event is None:
1354 # Truly closed (legacy path, no owner task) — nothing to clean up
1355 return
1357 pooled.mark_closed()
1358 cleanup_timeout = _get_cleanup_timeout()
1360 if pooled.shutdown_event is not None and pooled.owner_task is not None:
1361 # Signal the owner task to shut down gracefully
1362 pooled.shutdown_event.set()
1364 if not pooled.owner_task.done():
1365 # Wait for graceful exit
1366 with anyio.move_on_after(cleanup_timeout) as scope:
1367 try:
1368 await pooled.owner_task
1369 except (asyncio.CancelledError, Exception): # nosec B110
1370 pass
1371 if scope.cancelled_caught:
1372 logger.warning(f"Session owner cleanup timed out for {sanitize_url_for_logging(pooled.url)} - force cancelling")
1373 pooled.owner_task.cancel()
1374 try:
1375 await pooled.owner_task
1376 except (asyncio.CancelledError, Exception): # nosec B110
1377 pass
1378 else:
1379 # Legacy path: manual __aexit__ for sessions without owner task
1380 with anyio.move_on_after(cleanup_timeout) as session_scope:
1381 try:
1382 await pooled.session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
1383 except Exception as e:
1384 logger.debug(f"Error closing session: {e}")
1385 if session_scope.cancelled_caught:
1386 logger.warning(f"Session cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway")
1388 if pooled.transport_context is not None:
1389 with anyio.move_on_after(cleanup_timeout) as transport_scope:
1390 try:
1391 await pooled.transport_context.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call
1392 except Exception as e:
1393 logger.debug(f"Error closing transport: {e}")
1394 if transport_scope.cancelled_caught:
1395 logger.warning(f"Transport cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway")
1397 logger.debug(f"Closed session for {sanitize_url_for_logging(pooled.url)} (uses={pooled.use_count})")
1399 # Clean up pool_owner key in Redis for session affinity
1400 if settings.mcpgateway_session_affinity_enabled and pooled.headers:
1401 headers_lower = {k.lower(): v for k, v in pooled.headers.items()}
1402 mcp_session_id = headers_lower.get("x-mcp-session-id")
1403 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id):
1404 await self._cleanup_pool_session_owner(mcp_session_id)
1406 async def _cleanup_pool_session_owner(self, mcp_session_id: str) -> None:
1407 """Clean up pool_owner key in Redis when session is closed.
1409 Only deletes the key if this worker owns it (to prevent removing other workers' ownership).
1411 Args:
1412 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1413 """
1414 try:
1415 # First-Party
1416 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1418 redis = await get_redis_client()
1419 if redis:
1420 key = self._pool_owner_key(mcp_session_id)
1421 # Only delete if we own it
1422 owner = await redis.get(key)
1423 if owner:
1424 owner_id = owner.decode() if isinstance(owner, bytes) else owner
1425 if owner_id == WORKER_ID:
1426 await redis.delete(key)
1427 logger.debug(f"Cleaned up pool session owner: {mcp_session_id[:8]}...")
1428 except Exception as e:
1429 # Cleanup failure is non-fatal
1430 logger.debug(f"Failed to cleanup pool session owner in Redis: {e}")
1432 async def cleanup_streamable_http_session_owner(self, mcp_session_id: str) -> None:
1433 """Public wrapper for cleaning up Streamable HTTP session ownership.
1435 This is used by trusted internal MCP session teardown paths that need to
1436 remove affinity ownership without reaching into private helpers.
1437 """
1438 if not self.is_valid_mcp_session_id(mcp_session_id):
1439 logger.debug("Invalid mcp_session_id for owner cleanup, skipping")
1440 return
1441 await self._cleanup_pool_session_owner(mcp_session_id)
1443 async def close_all(self) -> None:
1444 """
1445 Gracefully close all pooled and active sessions.
1447 Should be called during application shutdown.
1448 """
1449 self._closed = True
1450 logger.info("Closing all pooled sessions...")
1452 async with self._global_lock:
1453 # Close all pooled sessions
1454 for _pool_key, pool in list(self._pools.items()):
1455 while not pool.empty():
1456 try:
1457 pooled = pool.get_nowait()
1458 await self._close_session(pooled)
1459 except asyncio.QueueEmpty:
1460 break
1462 # Close all active sessions
1463 for _pool_key, active_set in list(self._active.items()):
1464 for pooled in list(active_set):
1465 await self._close_session(pooled)
1467 self._pools.clear()
1468 self._active.clear()
1469 self._locks.clear()
1470 self._semaphores.clear()
1471 self._mcp_session_mapping.clear()
1473 # Stop RPC listener if running
1474 if self._rpc_listener_task and not self._rpc_listener_task.done():
1475 self._rpc_listener_task.cancel()
1476 try:
1477 await self._rpc_listener_task
1478 except asyncio.CancelledError:
1479 pass
1480 self._rpc_listener_task = None
1482 # Stop heartbeat if running
1483 if self._heartbeat_task and not self._heartbeat_task.done():
1484 self._heartbeat_task.cancel()
1485 try:
1486 await self._heartbeat_task
1487 except asyncio.CancelledError:
1488 pass
1489 self._heartbeat_task = None
1491 logger.info("All sessions closed")
1493 async def drain_all(self) -> None:
1494 """Close all pooled and active sessions without marking the pool as closed.
1496 Unlike ``close_all()``, the pool remains operational after draining.
1497 New sessions will be created on demand with fresh TLS state.
1498 Use this for certificate rotation (SIGHUP).
1499 """
1500 logger.info("Draining all pooled sessions for TLS rotation...")
1502 async with self._global_lock:
1503 for _pool_key, pool in list(self._pools.items()):
1504 while not pool.empty():
1505 try:
1506 pooled = pool.get_nowait()
1507 await self._close_session(pooled)
1508 except asyncio.QueueEmpty:
1509 break
1511 for _pool_key, active_set in list(self._active.items()):
1512 for pooled in list(active_set):
1513 await self._close_session(pooled)
1515 self._pools.clear()
1516 self._active.clear()
1517 self._mcp_session_mapping.clear()
1519 logger.info("All pooled sessions drained; pool remains operational")
1521 async def register_pool_session_owner(self, mcp_session_id: str) -> None:
1522 """Register this worker as owner of a pool session in Redis.
1524 This enables multi-worker session affinity by tracking which worker owns
1525 which pool session. When a request with x-mcp-session-id arrives at a
1526 different worker, it can forward the request to the owner worker.
1528 Note: This method is now primarily used for refreshing TTL on existing ownership.
1529 Initial ownership is claimed atomically in register_session_mapping() using SETNX.
1531 Args:
1532 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1533 """
1534 if not settings.mcpgateway_session_affinity_enabled:
1535 return
1537 if not self.is_valid_mcp_session_id(mcp_session_id):
1538 logger.debug("Invalid mcp_session_id for owner registration, skipping")
1539 return
1541 try:
1542 # First-Party
1543 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1545 redis = await get_redis_client()
1546 if redis:
1547 key = self._pool_owner_key(mcp_session_id)
1549 # Do not steal ownership: only claim if missing, or refresh TTL if we already own.
1550 # Lua keeps this atomic.
1551 script = """
1552 local cur = redis.call('GET', KEYS[1])
1553 if not cur then
1554 redis.call('SET', KEYS[1], ARGV[1], 'EX', ARGV[2])
1555 return 1
1556 end
1557 if cur == ARGV[1] then
1558 redis.call('EXPIRE', KEYS[1], ARGV[2])
1559 return 2
1560 end
1561 return 0
1562 """
1563 ttl = int(settings.mcpgateway_session_affinity_ttl)
1564 outcome = await redis.eval(script, 1, key, WORKER_ID, ttl)
1565 logger.debug(f"Owner registration outcome={outcome} for session {mcp_session_id[:8]}...")
1566 except Exception as e:
1567 # Redis failure is non-fatal - single worker mode still works
1568 logger.debug(f"Failed to register pool session owner in Redis: {e}")
1570 async def _get_pool_session_owner(self, mcp_session_id: str) -> Optional[str]:
1571 """Get the worker ID that owns a pool session.
1573 Args:
1574 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1576 Returns:
1577 The worker ID that owns this session, or None if not found or Redis unavailable.
1578 """
1579 if not settings.mcpgateway_session_affinity_enabled:
1580 return None
1582 if not self.is_valid_mcp_session_id(mcp_session_id):
1583 return None
1585 try:
1586 # First-Party
1587 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1589 redis = await get_redis_client()
1590 if redis:
1591 key = self._pool_owner_key(mcp_session_id)
1592 owner = await redis.get(key)
1593 if owner:
1594 decoded = owner.decode() if isinstance(owner, bytes) else owner
1595 return decoded
1596 except Exception as e:
1597 logger.debug(f"Failed to get pool session owner from Redis: {e}")
1598 return None
1600 async def forward_request_to_owner(
1601 self,
1602 mcp_session_id: str,
1603 request_data: Dict[str, Any],
1604 timeout: Optional[float] = None,
1605 ) -> Optional[Dict[str, Any]]:
1606 """Forward RPC request to the worker that owns the pool session.
1608 This method checks Redis to find which worker owns the pool session for
1609 the given mcp_session_id. If owned by another worker, it forwards the
1610 request via Redis pub/sub and waits for the response.
1612 Args:
1613 mcp_session_id: The MCP session ID from x-mcp-session-id header.
1614 request_data: The RPC request data to forward.
1615 timeout: Optional timeout in seconds (default from config).
1617 Returns:
1618 The response from the owner worker, or None if we own the session
1619 (caller should execute locally) or if Redis is unavailable.
1621 Raises:
1622 asyncio.TimeoutError: If the forwarded request times out.
1623 """
1624 if not settings.mcpgateway_session_affinity_enabled:
1625 return None
1627 if not self.is_valid_mcp_session_id(mcp_session_id):
1628 return None
1630 effective_timeout = timeout if timeout is not None else settings.mcpgateway_pool_rpc_forward_timeout
1632 try:
1633 # First-Party
1634 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1636 redis = await get_redis_client()
1637 if not redis:
1638 return None # Execute locally - no Redis
1640 # Check who owns this session
1641 owner = await redis.get(self._pool_owner_key(mcp_session_id))
1642 method = request_data.get("method", "unknown")
1643 if not owner:
1644 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | No owner → execute locally (new session)")
1645 return None # No owner registered - execute locally (new session)
1647 owner_id = owner.decode() if isinstance(owner, bytes) else owner
1648 if owner_id == WORKER_ID:
1649 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | We own it → execute locally")
1650 return None # We own it - execute locally
1652 if not await self._is_worker_alive(owner_id):
1653 logger.warning(f"[AFFINITY] Owner {owner_id} is dead for session {mcp_session_id[:8]}...")
1654 # CAS: reclaim only if still owned by the dead worker
1655 cas_script = """
1656 local cur = redis.call('GET', KEYS[1])
1657 if cur == ARGV[1] then
1658 redis.call('SET', KEYS[1], ARGV[2], 'EX', ARGV[3])
1659 return 1
1660 end
1661 return 0
1662 """
1663 ttl = int(settings.mcpgateway_session_affinity_ttl)
1664 reclaimed = await redis.eval(cas_script, 1, self._pool_owner_key(mcp_session_id), owner_id, WORKER_ID, ttl)
1665 if reclaimed == 1:
1666 logger.info(f"[AFFINITY] Reclaimed session {mcp_session_id[:8]}... from dead worker {owner_id} → execute locally")
1667 return None # We won the reclaim - execute locally
1668 # Another worker already reclaimed; re-read the new owner and forward
1669 new_owner = await redis.get(self._pool_owner_key(mcp_session_id))
1670 if not new_owner:
1671 return None # Key vanished - execute locally
1672 owner_id = new_owner.decode() if isinstance(new_owner, bytes) else new_owner
1673 if owner_id == WORKER_ID:
1674 return None # We ended up as owner
1675 logger.info(f"[AFFINITY] Session {mcp_session_id[:8]}... reclaimed by {owner_id} → forwarding to new owner")
1677 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Owner: {owner_id} → forwarding")
1679 # Forward to owner worker via pub/sub
1680 response_id = str(uuid.uuid4())
1681 response_channel = f"mcpgw:pool_rpc_response:{response_id}"
1683 # Subscribe to response channel
1684 pubsub = redis.pubsub()
1685 await pubsub.subscribe(response_channel)
1687 try:
1688 # Prepare request with response channel
1689 forward_data = {
1690 "type": "rpc_forward",
1691 **request_data,
1692 "response_channel": response_channel,
1693 "mcp_session_id": mcp_session_id,
1694 }
1696 # Publish request to owner's channel
1697 await redis.publish(f"mcpgw:pool_rpc:{owner_id}", orjson.dumps(forward_data))
1698 self._forwarded_requests += 1
1699 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Published to worker {owner_id}")
1701 # Wait for response
1702 async with asyncio.timeout(effective_timeout):
1703 async for msg in pubsub.listen():
1704 if msg["type"] == "message":
1705 return orjson.loads(msg["data"])
1706 finally:
1707 await pubsub.unsubscribe(response_channel)
1709 except asyncio.TimeoutError:
1710 self._forwarded_request_timeouts += 1
1711 logger.warning(f"Timeout forwarding request to owner for session {mcp_session_id[:8]}...")
1712 raise
1713 except Exception as e:
1714 self._forwarded_request_failures += 1
1715 logger.debug(f"Error forwarding request to owner: {e}")
1716 return None # Execute locally on error
1718 async def start_rpc_listener(self) -> None:
1719 """Start listening for forwarded RPC and HTTP requests on this worker's channels.
1721 This method subscribes to Redis pub/sub channels specific to this worker
1722 and processes incoming forwarded requests from other workers:
1723 - mcpgw:pool_rpc:{WORKER_ID} - for SSE transport JSON-RPC forwards
1724 - mcpgw:pool_http:{WORKER_ID} - for Streamable HTTP request forwards
1725 """
1726 if not settings.mcpgateway_session_affinity_enabled:
1727 return
1729 try:
1730 # First-Party
1731 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1733 redis = await get_redis_client()
1734 if not redis:
1735 logger.debug("Redis not available, RPC listener not started")
1736 return
1738 rpc_channel = f"mcpgw:pool_rpc:{WORKER_ID}"
1739 http_channel = f"mcpgw:pool_http:{WORKER_ID}"
1740 pubsub = redis.pubsub()
1741 await pubsub.subscribe(rpc_channel, http_channel)
1742 logger.info(f"RPC/HTTP listener started for worker {WORKER_ID} on channels: {rpc_channel}, {http_channel}")
1744 try:
1745 while not self._closed:
1746 try:
1747 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
1748 if msg and msg["type"] == "message":
1749 request = orjson.loads(msg["data"])
1750 forward_type = request.get("type")
1751 response_channel = request.get("response_channel")
1753 if response_channel:
1754 if forward_type == "rpc_forward":
1755 # Execute forwarded RPC request for SSE transport
1756 response = await self._execute_forwarded_request(request)
1757 await redis.publish(response_channel, orjson.dumps(response))
1758 logger.debug(f"Processed forwarded RPC request, response sent to {response_channel}")
1759 elif forward_type == "http_forward":
1760 # Execute forwarded HTTP request for Streamable HTTP transport
1761 await self._execute_forwarded_http_request(request, redis)
1762 else:
1763 logger.warning(f"Unknown forward type: {forward_type}")
1764 except Exception as e:
1765 logger.warning(f"Error processing forwarded request: {e}")
1766 finally:
1767 await pubsub.unsubscribe(rpc_channel, http_channel)
1768 logger.info(f"RPC/HTTP listener stopped for worker {WORKER_ID}")
1770 except Exception as e:
1771 logger.warning(f"RPC/HTTP listener failed: {e}")
1773 async def _execute_forwarded_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
1774 """Execute a forwarded RPC request locally via internal HTTP call.
1776 This method handles RPC requests that were forwarded from another worker.
1777 Instead of handling specific methods here, we make an internal HTTP call
1778 to the local /rpc endpoint which reuses ALL existing method handling logic.
1780 The x-forwarded-internally header prevents infinite forwarding loops.
1782 Args:
1783 request: The forwarded RPC request containing method, params, headers, req_id, etc.
1785 Returns:
1786 The JSON-RPC response from the local endpoint.
1787 """
1788 try:
1789 method = request.get("method")
1790 params = request.get("params", {})
1791 headers = request.get("headers", {})
1792 req_id = request.get("req_id", 1)
1793 mcp_session_id = request.get("mcp_session_id", "unknown")
1794 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
1796 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Received forwarded request, executing locally")
1798 # Make internal HTTP/HTTPS call to local /rpc endpoint.
1799 # This reuses ALL existing method handling logic without duplication.
1800 internal_base_url = internal_loopback_base_url()
1801 async with httpx.AsyncClient(verify=internal_loopback_verify()) as client:
1802 # Build headers for internal request - forward original headers
1803 # but add x-forwarded-internally to prevent infinite loops.
1804 # Relies on the originating transport having already filtered
1805 # passthrough headers via extract_headers_for_loopback (#3640).
1806 internal_headers = dict(headers)
1807 internal_headers["x-forwarded-internally"] = "true"
1808 # Ensure content-type is set
1809 internal_headers["content-type"] = "application/json"
1811 response = await client.post(
1812 f"{internal_base_url}/rpc",
1813 json={"jsonrpc": "2.0", "method": method, "params": params, "id": req_id},
1814 headers=internal_headers,
1815 timeout=settings.mcpgateway_pool_rpc_forward_timeout,
1816 )
1818 # Gate on HTTP status first: non-2xx responses are errors
1819 # even if the body parses as JSON.
1820 if not response.is_success:
1821 try:
1822 response_data = response.json()
1823 except ValueError:
1824 response_data = {}
1825 if not isinstance(response_data, dict):
1826 response_data = {}
1828 # If body is a JSON-RPC error ({"error": {...}}), propagate it
1829 if "error" in response_data and isinstance(response_data["error"], dict):
1830 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error (HTTP {response.status_code})")
1831 return {"error": response_data["error"]}
1833 # Non-JSON-RPC error body (e.g. {"detail": "..."}): map to JSON-RPC error
1834 detail = response_data.get("detail", response.text[:200] or "Unknown error")
1835 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution failed with HTTP {response.status_code}")
1836 return {"error": {"code": -32603, "message": f"Forwarded request failed (HTTP {response.status_code}): {detail}"}}
1838 # Parse successful response
1839 response_data = response.json()
1841 # Extract result or error from JSON-RPC response
1842 if "error" in response_data:
1843 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error")
1844 return {"error": response_data["error"]}
1845 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed successfully")
1846 return {"result": response_data.get("result", {})}
1848 except httpx.TimeoutException:
1849 logger.warning(f"Timeout executing forwarded request: {request.get('method')}")
1850 return {"error": {"code": -32603, "message": "Internal request timeout"}}
1851 except Exception as e:
1852 logger.warning(f"Error executing forwarded request: {e}")
1853 return {"error": {"code": -32603, "message": str(e)}}
1855 async def _execute_forwarded_http_request(self, request: Dict[str, Any], redis: Any) -> None:
1856 """Execute a forwarded HTTP request locally and return response via Redis.
1858 This method handles full HTTP requests forwarded from other workers for
1859 Streamable HTTP transport session affinity. It reconstructs the HTTP request,
1860 makes an internal call to the appropriate endpoint, and publishes the response
1861 back through Redis.
1863 Args:
1864 request: Serialized HTTP request data from Redis Pub/Sub containing:
1865 - type: "http_forward"
1866 - response_channel: Redis channel to publish response to
1867 - mcp_session_id: Session identifier
1868 - method: HTTP method (GET, POST, DELETE)
1869 - path: Request path (e.g., /mcp)
1870 - query_string: Query parameters
1871 - headers: Request headers dict
1872 - body: Hex-encoded request body
1873 redis: Redis client for publishing response
1874 """
1875 response_channel = None
1876 try:
1877 response_channel = request.get("response_channel")
1878 method = request.get("method")
1879 path = request.get("path")
1880 query_string = request.get("query_string", "")
1881 headers = request.get("headers", {})
1882 body_hex = request.get("body", "")
1883 mcp_session_id = request.get("mcp_session_id")
1885 # Decode hex body back to bytes
1886 body = bytes.fromhex(body_hex) if body_hex else b""
1888 session_short = mcp_session_id[:8] if mcp_session_id and len(mcp_session_id) >= 8 else "unknown"
1889 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Received forwarded HTTP request: {method} {path}")
1891 # Add internal forwarding headers to prevent loops.
1892 # Relies on the originating transport having already filtered
1893 # passthrough headers via extract_headers_for_loopback (#3640).
1894 internal_headers = dict(headers)
1895 internal_headers["x-forwarded-internally"] = "true"
1896 internal_headers["x-original-worker"] = request.get("original_worker", "unknown")
1898 # Make internal HTTP/HTTPS request to local endpoint
1899 url = f"{internal_loopback_base_url()}{path}"
1900 if query_string:
1901 url = f"{url}?{query_string}"
1903 async with httpx.AsyncClient(verify=internal_loopback_verify()) as client:
1904 response = await client.request(
1905 method=method,
1906 url=url,
1907 headers=internal_headers,
1908 content=body,
1909 timeout=settings.mcpgateway_pool_rpc_forward_timeout,
1910 )
1912 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Executed locally: {response.status_code}")
1914 # Serialize response for Redis transport
1915 response_data = {
1916 "status": response.status_code,
1917 "headers": dict(response.headers),
1918 "body": response.content.hex(), # Hex encode binary response
1919 }
1921 # Publish response back to requesting worker
1922 if redis and response_channel:
1923 await redis.publish(response_channel, orjson.dumps(response_data))
1924 logger.debug(f"[HTTP_AFFINITY] Published HTTP response to Redis channel: {response_channel}")
1926 except Exception as e:
1927 logger.error(f"Error executing forwarded HTTP request: {e}")
1928 # Try to send error response if possible
1929 if redis and response_channel:
1930 error_response = {
1931 "status": 500,
1932 "headers": {"content-type": "application/json"},
1933 "body": orjson.dumps({"error": "Internal forwarding error"}).hex(),
1934 }
1935 try:
1936 await redis.publish(response_channel, orjson.dumps(error_response))
1937 except Exception as publish_error:
1938 logger.debug(f"Failed to publish error response via Redis: {publish_error}")
1940 async def get_streamable_http_session_owner(self, mcp_session_id: str) -> Optional[str]:
1941 """Get the worker ID that owns a Streamable HTTP session.
1943 This is a public wrapper around _get_pool_session_owner for use by
1944 streamablehttp_transport to check session ownership before handling requests.
1946 Args:
1947 mcp_session_id: The MCP session ID from mcp-session-id header.
1949 Returns:
1950 Worker ID if found, None otherwise.
1951 """
1952 return await self._get_pool_session_owner(mcp_session_id)
1954 async def forward_streamable_http_to_owner(
1955 self,
1956 owner_worker_id: str,
1957 mcp_session_id: str,
1958 method: str,
1959 path: str,
1960 headers: Dict[str, str],
1961 body: bytes,
1962 query_string: str = "",
1963 ) -> Optional[Dict[str, Any]]:
1964 """Forward a Streamable HTTP request to the worker that owns the session via Redis Pub/Sub.
1966 This method forwards the entire HTTP request to another worker using Redis
1967 Pub/Sub channels, similar to forward_request_to_owner() for SSE transport.
1968 This ensures session affinity works correctly in single-host multi-worker
1969 deployments where hostname-based routing fails.
1971 Args:
1972 owner_worker_id: The worker ID that owns the session.
1973 mcp_session_id: The MCP session ID.
1974 method: HTTP method (GET, POST, DELETE).
1975 path: Request path (e.g., /mcp).
1976 headers: Request headers.
1977 body: Request body bytes.
1978 query_string: Query string if any.
1980 Returns:
1981 Dict with 'status', 'headers', and 'body' from the owner worker's response,
1982 or None if forwarding fails.
1983 """
1984 if not settings.mcpgateway_session_affinity_enabled:
1985 return None
1987 if not self.is_valid_mcp_session_id(mcp_session_id):
1988 return None
1990 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
1991 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | {method} {path} | Forwarding to worker {owner_worker_id}")
1993 try:
1994 # First-Party
1995 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel
1997 redis = await get_redis_client()
1998 if not redis:
1999 logger.warning("Redis unavailable for HTTP forwarding, executing locally")
2000 return None # Fall back to local execution
2002 # Generate unique response channel for this request
2003 response_uuid = uuid.uuid4().hex
2004 response_channel = f"mcpgw:pool_http_response:{response_uuid}"
2006 # Serialize HTTP request for Redis transport
2007 forward_data = {
2008 "type": "http_forward",
2009 "response_channel": response_channel,
2010 "mcp_session_id": mcp_session_id,
2011 "method": method,
2012 "path": path,
2013 "query_string": query_string,
2014 "headers": headers,
2015 "body": body.hex() if body else "", # Hex encode binary body
2016 "original_worker": WORKER_ID,
2017 "timestamp": time.time(),
2018 }
2020 # Subscribe to response channel BEFORE publishing request (prevent race)
2021 pubsub = redis.pubsub()
2022 await pubsub.subscribe(response_channel)
2024 try:
2025 # Publish forwarded request to owner worker's HTTP channel
2026 owner_channel = f"mcpgw:pool_http:{owner_worker_id}"
2027 await redis.publish(owner_channel, orjson.dumps(forward_data))
2028 logger.debug(f"[HTTP_AFFINITY] Published HTTP request to Redis channel: {owner_channel}")
2030 # Wait for response with timeout
2031 timeout = settings.mcpgateway_pool_rpc_forward_timeout
2032 async with asyncio.timeout(timeout):
2033 while True:
2034 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
2035 if msg and msg["type"] == "message":
2036 response_data = orjson.loads(msg["data"])
2037 logger.debug(f"[HTTP_AFFINITY] Received HTTP response via Redis: status={response_data.get('status')}")
2039 # Decode hex body back to bytes
2040 body_hex = response_data.get("body", "")
2041 response_data["body"] = bytes.fromhex(body_hex) if body_hex else b""
2043 self._forwarded_requests += 1
2044 return response_data
2046 finally:
2047 await pubsub.unsubscribe(response_channel)
2049 except asyncio.TimeoutError:
2050 self._forwarded_request_timeouts += 1
2051 logger.warning(f"Timeout forwarding HTTP request to owner {owner_worker_id}")
2052 return None
2053 except Exception as e:
2054 self._forwarded_request_failures += 1
2055 logger.warning(f"Error forwarding HTTP request via Redis: {e}")
2056 return None
2058 def get_metrics(self) -> Dict[str, Any]:
2059 """
2060 Return pool metrics for monitoring.
2062 Returns:
2063 Dict with hits, misses, evictions, hit_rate, and per-pool stats.
2064 """
2065 total_requests = self._hits + self._misses
2066 total_affinity_requests = self._session_affinity_local_hits + self._session_affinity_redis_hits + self._session_affinity_misses
2067 return {
2068 "hits": self._hits,
2069 "misses": self._misses,
2070 "evictions": self._evictions,
2071 "health_check_failures": self._health_check_failures,
2072 "circuit_breaker_trips": self._circuit_breaker_trips,
2073 "pool_keys_evicted": self._pool_keys_evicted,
2074 "sessions_reaped": self._sessions_reaped,
2075 "anonymous_identity_count": self._anonymous_identity_count,
2076 "hit_rate": self._hits / total_requests if total_requests > 0 else 0.0,
2077 "pool_key_count": len(self._pools),
2078 # Session affinity metrics
2079 "session_affinity": {
2080 "local_hits": self._session_affinity_local_hits,
2081 "redis_hits": self._session_affinity_redis_hits,
2082 "misses": self._session_affinity_misses,
2083 "hit_rate": (self._session_affinity_local_hits + self._session_affinity_redis_hits) / total_affinity_requests if total_affinity_requests > 0 else 0.0,
2084 "forwarded_requests": self._forwarded_requests,
2085 "forwarded_failures": self._forwarded_request_failures,
2086 "forwarded_timeouts": self._forwarded_request_timeouts,
2087 },
2088 "pools": {
2089 f"{url}|{identity[:8]}|{transport}|{user}|{gw_id[:8] if gw_id else 'none'}": {
2090 "available": pool.qsize(),
2091 "active": len(self._active.get((user, url, identity, transport, gw_id), set())),
2092 "max": self._max_sessions,
2093 }
2094 for (user, url, identity, transport, gw_id), pool in self._pools.items()
2095 },
2096 "circuit_breakers": {
2097 url: {
2098 "failures": self._failures.get(url, 0),
2099 "open_until": self._circuit_open_until.get(url),
2100 }
2101 for url in set(self._failures.keys()) | set(self._circuit_open_until.keys())
2102 },
2103 }
2105 @asynccontextmanager
2106 async def session(
2107 self,
2108 url: str,
2109 headers: Optional[Dict[str, str]] = None,
2110 transport_type: TransportType = TransportType.STREAMABLE_HTTP,
2111 httpx_client_factory: Optional[HttpxClientFactory] = None,
2112 timeout: Optional[float] = None,
2113 user_identity: Optional[str] = None,
2114 gateway_id: Optional[str] = None,
2115 ) -> "AsyncIterator[PooledSession]":
2116 """
2117 Context manager for acquiring and releasing a session.
2119 Usage:
2120 async with pool.session(url, headers) as pooled:
2121 result = await pooled.session.call_tool("my_tool", {})
2123 Args:
2124 url: The MCP server URL.
2125 headers: Request headers.
2126 transport_type: Transport type to use.
2127 httpx_client_factory: Optional factory for httpx clients.
2128 timeout: Optional timeout in seconds for transport connection.
2129 user_identity: Optional user identity for strict isolation.
2130 gateway_id: Optional gateway ID for notification handler context.
2132 Yields:
2133 PooledSession ready for use.
2134 """
2135 pooled = await self.acquire(url, headers, transport_type, httpx_client_factory, timeout, user_identity, gateway_id)
2136 failed = False
2137 try:
2138 yield pooled
2139 except BaseException:
2140 # Session encountered an error (e.g. ClosedResourceError) — evict it
2141 # instead of returning a broken session to the pool.
2142 failed = True
2143 raise
2144 finally:
2145 await self.release(pooled, discard=failed)
2148# Global pool instance - initialized by FastAPI lifespan
2149_mcp_session_pool: Optional[MCPSessionPool] = None
2152def get_mcp_session_pool() -> MCPSessionPool:
2153 """Get the global MCP session pool instance.
2155 Returns:
2156 The global MCPSessionPool instance.
2158 Raises:
2159 RuntimeError: If pool has not been initialized.
2160 """
2161 if _mcp_session_pool is None:
2162 raise RuntimeError("MCP session pool not initialized. Call init_mcp_session_pool() first.")
2163 return _mcp_session_pool
2166def init_mcp_session_pool(
2167 max_sessions_per_key: int = 10,
2168 session_ttl_seconds: float = 300.0,
2169 health_check_interval_seconds: float = 60.0,
2170 acquire_timeout_seconds: float = 30.0,
2171 session_create_timeout_seconds: float = 30.0,
2172 circuit_breaker_threshold: int = 5,
2173 circuit_breaker_reset_seconds: float = 60.0,
2174 identity_headers: Optional[frozenset[str]] = None,
2175 identity_extractor: Optional[IdentityExtractor] = None,
2176 idle_pool_eviction_seconds: float = 600.0,
2177 default_transport_timeout_seconds: float = 30.0,
2178 health_check_methods: Optional[list[str]] = None,
2179 health_check_timeout_seconds: float = 5.0,
2180 message_handler_factory: Optional[MessageHandlerFactory] = None,
2181 enable_notifications: bool = True,
2182 notification_debounce_seconds: float = 5.0,
2183) -> MCPSessionPool:
2184 """Initialize the global MCP session pool.
2186 Args:
2187 See MCPSessionPool.__init__ for argument descriptions.
2188 enable_notifications: Enable automatic notification service for list_changed events.
2189 notification_debounce_seconds: Debounce interval for notification-triggered refreshes.
2191 Returns:
2192 The initialized MCPSessionPool instance.
2193 """
2194 global _mcp_session_pool # pylint: disable=global-statement
2196 # Auto-create notification service if enabled and no custom handler provided
2197 effective_handler_factory = message_handler_factory
2198 if enable_notifications and message_handler_factory is None:
2199 # First-Party
2200 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2201 init_notification_service,
2202 )
2204 # Initialize notification service (will be started during acquire with gateway context)
2205 notification_svc = init_notification_service(debounce_seconds=notification_debounce_seconds)
2207 # Create default handler factory that uses notification service
2208 def default_handler_factory(url: str, gateway_id: Optional[str]):
2209 """Create a message handler for MCP session notifications.
2211 Args:
2212 url: The MCP server URL for the session.
2213 gateway_id: Optional gateway ID for attribution, falls back to URL if not provided.
2215 Returns:
2216 A message handler that forwards notifications to the notification service.
2217 """
2218 return notification_svc.create_message_handler(gateway_id or url, url)
2220 effective_handler_factory = default_handler_factory
2221 logger.info("MCP notification service created (debounce=%ss)", notification_debounce_seconds)
2223 _mcp_session_pool = MCPSessionPool(
2224 max_sessions_per_key=max_sessions_per_key,
2225 session_ttl_seconds=session_ttl_seconds,
2226 health_check_interval_seconds=health_check_interval_seconds,
2227 acquire_timeout_seconds=acquire_timeout_seconds,
2228 session_create_timeout_seconds=session_create_timeout_seconds,
2229 circuit_breaker_threshold=circuit_breaker_threshold,
2230 circuit_breaker_reset_seconds=circuit_breaker_reset_seconds,
2231 identity_headers=identity_headers,
2232 identity_extractor=identity_extractor,
2233 idle_pool_eviction_seconds=idle_pool_eviction_seconds,
2234 default_transport_timeout_seconds=default_transport_timeout_seconds,
2235 health_check_methods=health_check_methods,
2236 health_check_timeout_seconds=health_check_timeout_seconds,
2237 message_handler_factory=effective_handler_factory,
2238 )
2239 logger.info("MCP session pool initialized")
2240 return _mcp_session_pool
2243async def close_mcp_session_pool() -> None:
2244 """Close the global MCP session pool and notification service."""
2245 global _mcp_session_pool # pylint: disable=global-statement
2246 if _mcp_session_pool is not None:
2247 await _mcp_session_pool.close_all()
2248 _mcp_session_pool = None
2249 logger.info("MCP session pool closed")
2251 # Close notification service if it was initialized
2252 try:
2253 # First-Party
2254 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2255 close_notification_service,
2256 )
2258 await close_notification_service()
2259 except (ImportError, RuntimeError):
2260 pass # Notification service not initialized
2263async def drain_mcp_session_pool() -> None:
2264 """Drain all sessions from the global pool without destroying the pool.
2266 Sessions are closed so new ones reconnect with fresh TLS state.
2267 The pool remains operational — unlike ``close_mcp_session_pool()``,
2268 which shuts it down permanently.
2269 """
2270 if _mcp_session_pool is not None:
2271 await _mcp_session_pool.drain_all()
2274async def start_pool_notification_service(gateway_service: Any = None) -> None:
2275 """Start the notification service background worker.
2277 Call this after gateway_service is initialized to enable event-driven refresh.
2279 Args:
2280 gateway_service: Optional GatewayService instance for triggering refreshes.
2281 """
2282 try:
2283 # First-Party
2284 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2285 get_notification_service,
2286 )
2288 notification_svc = get_notification_service()
2289 await notification_svc.initialize(gateway_service)
2290 logger.info("MCP notification service started")
2291 except RuntimeError:
2292 logger.debug("Notification service not configured, skipping start")
2295def register_gateway_capabilities_for_notifications(gateway_id: str, capabilities: Dict[str, Any]) -> None:
2296 """Register gateway capabilities for notification handling.
2298 Call this after gateway initialization to enable list_changed notifications.
2300 Args:
2301 gateway_id: The gateway ID.
2302 capabilities: Server capabilities from initialization response.
2303 """
2304 try:
2305 # First-Party
2306 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2307 get_notification_service,
2308 )
2310 notification_svc = get_notification_service()
2311 notification_svc.register_gateway_capabilities(gateway_id, capabilities)
2312 except RuntimeError:
2313 pass # Notification service not initialized
2316def unregister_gateway_from_notifications(gateway_id: str) -> None:
2317 """Unregister a gateway from notification handling.
2319 Call this when a gateway is deleted.
2321 Args:
2322 gateway_id: The gateway ID to unregister.
2323 """
2324 try:
2325 # First-Party
2326 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel
2327 get_notification_service,
2328 )
2330 notification_svc = get_notification_service()
2331 notification_svc.unregister_gateway(gateway_id)
2332 except RuntimeError:
2333 pass # Notification service not initialized