Coverage for mcpgateway / services / mcp_session_pool.py: 99%

1001 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2""" 

3MCP Session Pool Implementation. 

4 

5Provides session pooling for MCP ClientSessions to reduce per-request overhead. 

6Sessions are isolated per user/tenant via identity hashing to prevent session collision. 

7 

8Performance Impact: 

9 - Without pooling: 20-23ms per tool call (new session each time) 

10 - With pooling: 1-2ms per tool call (10-20x improvement) 

11 

12Security: 

13 Sessions are isolated by (url, identity_hash, transport_type) to prevent: 

14 - Cross-user session sharing 

15 - Cross-tenant data leakage 

16 - Authentication bypass 

17 

18Copyright 2026 

19SPDX-License-Identifier: Apache-2.0 

20Authors: Mihai Criveti 

21""" 

22 

23# flake8: noqa: DAR101, DAR201, DAR401 

24 

25# Future 

26from __future__ import annotations 

27 

28# Standard 

29import asyncio 

30from contextlib import asynccontextmanager 

31from dataclasses import dataclass, field 

32from enum import Enum 

33import hashlib 

34import logging 

35import os 

36import re 

37import socket 

38import time 

39from typing import Any, Callable, Dict, Optional, Set, Tuple, TYPE_CHECKING 

40import uuid 

41 

42# Third-Party 

43import anyio 

44import httpx 

45from mcp import ClientSession, McpError 

46from mcp.client.sse import sse_client 

47from mcp.client.streamable_http import streamablehttp_client 

48from mcp.shared.session import RequestResponder 

49import mcp.types as mcp_types 

50import orjson 

51 

52# First-Party 

53from mcpgateway.common.validators import SecurityValidator 

54from mcpgateway.config import settings 

55from mcpgateway.utils.internal_http import internal_loopback_base_url, internal_loopback_verify 

56from mcpgateway.utils.url_auth import sanitize_url_for_logging 

57 

58# JSON-RPC standard error code for method not found 

59METHOD_NOT_FOUND = -32601 

60 

61# Shared session-id validation (downstream MCP session IDs used for affinity). 

62# Intentionally strict: protects Redis key/channel construction and log lines. 

63_MCP_SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") 

64 

65# Worker ID for multi-worker session affinity 

66# Uses hostname + PID to be unique across Docker containers (each container has PID 1) 

67# and across gunicorn workers within the same container 

68WORKER_ID = f"{socket.gethostname()}:{os.getpid()}" 

69 

70 

71def _get_cleanup_timeout() -> float: 

72 """Get session cleanup timeout from config (lazy import to avoid circular deps). 

73 

74 This timeout controls how long to wait for session/transport __aexit__ calls 

75 when closing sessions. It prevents CPU spin loops when internal tasks don't 

76 respond to cancellation (anyio's _deliver_cancellation issue). 

77 

78 Returns: 

79 Cleanup timeout in seconds (default: 5.0) 

80 """ 

81 try: 

82 # Lazy import to avoid circular dependency during startup 

83 return settings.mcp_session_pool_cleanup_timeout 

84 except Exception: 

85 return 5.0 # Fallback default 

86 

87 

88if TYPE_CHECKING: 

89 # Standard 

90 from collections.abc import AsyncIterator # pragma: no cover 

91 

92logger = logging.getLogger(__name__) 

93 

94 

95class TransportType(Enum): 

96 """Supported MCP transport types.""" 

97 

98 SSE = "sse" 

99 STREAMABLE_HTTP = "streamablehttp" 

100 

101 

102@dataclass(eq=False) # eq=False makes instances hashable by object identity 

103class PooledSession: 

104 """A pooled MCP session with metadata for lifecycle management. 

105 

106 Note: eq=False is required because we store these in Sets for active session 

107 tracking. This makes instances hashable by their object id (identity). 

108 """ 

109 

110 session: ClientSession 

111 transport_context: Any # The transport context manager (kept open) 

112 url: str 

113 transport_type: TransportType 

114 headers: Dict[str, str] # Original headers (for reconnection) 

115 identity_key: str # Identity hash component for headers 

116 user_identity: str = "anonymous" # for user isolation 

117 gateway_id: str = "" # Gateway ID for notification attribution 

118 created_at: float = field(default_factory=time.time) 

119 last_used: float = field(default_factory=time.time) 

120 use_count: int = 0 

121 _closed: bool = field(default=False, repr=False) 

122 _owner_task: Optional[asyncio.Task] = field(default=None, repr=False) 

123 _shutdown_event: Optional[asyncio.Event] = field(default=None, repr=False) 

124 

125 @property 

126 def age_seconds(self) -> float: 

127 """Return session age in seconds. 

128 

129 Returns: 

130 float: Session age in seconds since creation. 

131 """ 

132 return time.time() - self.created_at 

133 

134 @property 

135 def idle_seconds(self) -> float: 

136 """Return seconds since last use. 

137 

138 Returns: 

139 float: Seconds since last use of this session. 

140 """ 

141 return time.time() - self.last_used 

142 

143 @property 

144 def is_closed(self) -> bool: 

145 """Return whether this session has been closed or its transport is broken. 

146 

147 Checks both the internal closed flag and the underlying transport stream 

148 state to detect sessions broken by server restarts or network drops before 

149 they raise ClosedResourceError at the call site. 

150 

151 Returns: 

152 bool: True if session is closed or transport is broken, False otherwise. 

153 """ 

154 if self._closed: 

155 return True 

156 # Check if the owner background task has died 

157 if self._owner_task is not None and self._owner_task.done(): 

158 return True 

159 # Detect externally-broken transport (e.g. server restart, network drop). 

160 # MCP's BaseSession stores the write stream as _write_stream. Check it with 

161 # getattr fallbacks so this degrades gracefully if MCP internals change. 

162 try: 

163 write_stream = getattr(self.session, "_write_stream", None) 

164 if write_stream is not None: 

165 if getattr(write_stream, "_closed", False) is True: 

166 return True 

167 state = getattr(write_stream, "_state", None) 

168 if state is not None: 

169 open_rx = getattr(state, "open_receive_channels", 1) 

170 if isinstance(open_rx, int) and open_rx == 0: 

171 return True 

172 except Exception: # nosec B110 - Graceful degradation if MCP internals change 

173 pass 

174 return False 

175 

176 def mark_closed(self) -> None: 

177 """Mark this session as closed.""" 

178 self._closed = True 

179 

180 @property 

181 def owner_task(self) -> "Optional[asyncio.Task]": 

182 """Return the background owner task, if any.""" 

183 return self._owner_task 

184 

185 @property 

186 def shutdown_event(self) -> Optional[asyncio.Event]: 

187 """Return the shutdown event for the owner task, if any.""" 

188 return self._shutdown_event 

189 

190 

191# Type aliases 

192# Pool key includes transport type and gateway_id to prevent returning wrong transport for same URL 

193# and to ensure correct notification attribution when notifications are enabled 

194PoolKey = Tuple[str, str, str, str, str] # (user_identity_hash, url, identity_hash, transport_type, gateway_id) 

195 

196# Session affinity mapping key: (mcp_session_id, url, transport_type, gateway_id) 

197SessionMappingKey = Tuple[str, str, str, str] 

198HttpxClientFactory = Callable[ 

199 [Optional[Dict[str, str]], Optional[httpx.Timeout], Optional[httpx.Auth]], 

200 httpx.AsyncClient, 

201] 

202 

203 

204# Type alias for identity extractor callback 

205# Extracts stable identity from headers (e.g., decode JWT to get user_id) 

206IdentityExtractor = Callable[[Dict[str, str]], Optional[str]] 

207 

208# Type alias for message handler factory 

209# Factory that creates message handlers given URL and optional gateway_id 

210# The handler receives ServerNotification, ServerRequest responders, or Exceptions 

211MessageHandlerFactory = Callable[ 

212 [str, Optional[str]], # (url, gateway_id) 

213 Callable[ 

214 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception], 

215 Any, # Coroutine 

216 ], 

217] 

218 

219 

220class MCPSessionPool: # pylint: disable=too-many-instance-attributes 

221 """ 

222 Pool of MCP ClientSessions keyed by (user_identity, server URL, identity hash, transport type, gateway_id). 

223 

224 Thread-Safety: 

225 This pool is designed for asyncio concurrency. It uses asyncio.Lock 

226 for synchronization, which is safe for coroutine-based concurrency 

227 but NOT for multi-threaded access. 

228 

229 Session Isolation: 

230 Sessions are isolated per user/tenant to prevent session collision. 

231 The identity hash is derived from authentication headers ensuring 

232 that different users never share MCP sessions. 

233 

234 Transport Isolation: 

235 Sessions are also isolated by transport type (SSE vs STREAMABLE_HTTP). 

236 The same URL with different transports will use separate pools. 

237 

238 Gateway Isolation: 

239 Sessions are isolated by gateway_id for correct notification attribution. 

240 When notifications are enabled, each gateway gets its own pooled sessions 

241 even if they share the same URL and authentication. 

242 

243 Features: 

244 - Session reuse across requests (10-20x latency improvement) 

245 - Per-user/tenant session isolation (prevents session collision) 

246 - Per-transport session isolation (prevents transport mismatch) 

247 - TTL-based expiration with configurable lifetime 

248 - Health checks on acquire for stale sessions 

249 - Configurable pool size per URL+identity+transport 

250 - Circuit breaker for failing endpoints 

251 - Idle pool key eviction to prevent unbounded growth 

252 - Custom identity extractor for rotating tokens (e.g., JWT decode) 

253 - Metrics for monitoring (hits, misses, evictions) 

254 - Graceful shutdown with close_all() 

255 

256 Usage: 

257 pool = MCPSessionPool() 

258 

259 # Use as context manager for lifecycle management 

260 async with pool: 

261 pooled = await pool.acquire(url, headers) 

262 try: 

263 result = await pooled.session.call_tool("my_tool", {}) 

264 finally: 

265 await pool.release(pooled) 

266 

267 # With custom identity extractor for JWT tokens: 

268 def extract_user_id(headers: dict) -> str: 

269 token = headers.get("Authorization", "").replace("Bearer ", "") 

270 claims = jwt.decode(token, options={"verify_signature": False}) 

271 return claims.get("sub") or claims.get("user_id") 

272 

273 pool = MCPSessionPool(identity_extractor=extract_user_id) 

274 """ 

275 

276 # Headers that contribute to session identity (case-insensitive) 

277 DEFAULT_IDENTITY_HEADERS: frozenset[str] = frozenset( 

278 [ 

279 "authorization", 

280 "x-tenant-id", 

281 "x-user-id", 

282 "x-api-key", 

283 "cookie", 

284 "x-mcp-session-id", 

285 ] 

286 ) 

287 

288 def __init__( 

289 self, 

290 max_sessions_per_key: int = 10, 

291 session_ttl_seconds: float = 300.0, 

292 health_check_interval_seconds: float = 60.0, 

293 acquire_timeout_seconds: float = 30.0, 

294 session_create_timeout_seconds: float = 30.0, 

295 circuit_breaker_threshold: int = 5, 

296 circuit_breaker_reset_seconds: float = 60.0, 

297 identity_headers: Optional[frozenset[str]] = None, 

298 identity_extractor: Optional[IdentityExtractor] = None, 

299 idle_pool_eviction_seconds: float = 600.0, 

300 default_transport_timeout_seconds: float = 30.0, 

301 health_check_methods: Optional[list[str]] = None, 

302 health_check_timeout_seconds: float = 5.0, 

303 message_handler_factory: Optional[MessageHandlerFactory] = None, 

304 ): 

305 """ 

306 Initialize the session pool. 

307 

308 Args: 

309 max_sessions_per_key: Maximum pooled sessions per (URL, identity, transport). 

310 session_ttl_seconds: Session TTL in seconds before forced expiration. 

311 health_check_interval_seconds: Seconds of idle time before health check. 

312 acquire_timeout_seconds: Timeout for waiting when pool is exhausted. 

313 session_create_timeout_seconds: Timeout for creating new sessions. 

314 circuit_breaker_threshold: Consecutive failures before circuit opens. 

315 circuit_breaker_reset_seconds: Seconds before circuit breaker resets. 

316 identity_headers: Headers that contribute to identity hash. 

317 identity_extractor: Optional callback to extract stable identity from headers. 

318 Use this when tokens rotate frequently (e.g., short-lived JWTs). 

319 Should return a stable user/tenant ID string. 

320 idle_pool_eviction_seconds: Evict empty pool keys after this many seconds of no use. 

321 default_transport_timeout_seconds: Default timeout for transport connections. 

322 health_check_methods: Ordered list of health check methods to try. 

323 Options: ping, list_tools, list_prompts, list_resources, skip. 

324 Default: ["ping", "skip"] (try ping, skip if unsupported). 

325 health_check_timeout_seconds: Timeout for each health check attempt. 

326 message_handler_factory: Optional factory for creating message handlers. 

327 Called with (url, gateway_id) to create handlers for 

328 each new session. Enables notification handling. 

329 """ 

330 # Configuration 

331 self._max_sessions = max_sessions_per_key 

332 self._session_ttl = session_ttl_seconds 

333 self._health_check_interval = health_check_interval_seconds 

334 self._acquire_timeout = acquire_timeout_seconds 

335 self._session_create_timeout = session_create_timeout_seconds 

336 self._circuit_breaker_threshold = circuit_breaker_threshold 

337 self._circuit_breaker_reset = circuit_breaker_reset_seconds 

338 self._identity_headers = identity_headers or self.DEFAULT_IDENTITY_HEADERS 

339 self._identity_extractor = identity_extractor 

340 self._idle_pool_eviction = idle_pool_eviction_seconds 

341 self._default_transport_timeout = default_transport_timeout_seconds 

342 self._health_check_methods = health_check_methods or ["ping", "skip"] 

343 self._health_check_timeout = health_check_timeout_seconds 

344 self._message_handler_factory = message_handler_factory 

345 

346 # State - protected by _global_lock for creation, per-key locks for access 

347 self._global_lock = asyncio.Lock() 

348 self._pools: Dict[PoolKey, asyncio.Queue[PooledSession]] = {} 

349 self._active: Dict[PoolKey, Set[PooledSession]] = {} 

350 self._locks: Dict[PoolKey, asyncio.Lock] = {} 

351 self._semaphores: Dict[PoolKey, asyncio.Semaphore] = {} 

352 self._pool_last_used: Dict[PoolKey, float] = {} # Track last use time per pool key 

353 

354 # Circuit breaker state 

355 self._failures: Dict[str, int] = {} # url -> consecutive failure count 

356 self._circuit_open_until: Dict[str, float] = {} # url -> timestamp 

357 

358 # Eviction throttling - only run eviction once per interval 

359 self._last_eviction_run: float = 0.0 

360 self._eviction_run_interval: float = 60.0 # Run eviction at most every 60 seconds 

361 

362 # Metrics 

363 self._hits = 0 

364 self._misses = 0 

365 self._evictions = 0 

366 self._health_check_failures = 0 

367 self._circuit_breaker_trips = 0 

368 self._pool_keys_evicted = 0 

369 self._sessions_reaped = 0 # Sessions closed during background eviction 

370 self._anonymous_identity_count = 0 # Count of requests with no identity headers 

371 

372 # Lifecycle 

373 self._closed = False 

374 

375 # Pre-registered session mappings for session affinity 

376 # Mapping from (mcp_session_id, url, transport_type, gateway_id) -> pool_key 

377 # Set by broadcast() before acquire() is called to enable session affinity lookup 

378 self._mcp_session_mapping: Dict[SessionMappingKey, PoolKey] = {} 

379 self._mcp_session_mapping_lock = asyncio.Lock() 

380 

381 # Multi-worker session affinity via Redis pub/sub 

382 # Track pending responses for forwarded RPC requests 

383 self._rpc_listener_task: Optional[asyncio.Task[None]] = None 

384 self._heartbeat_task: Optional[asyncio.Task[None]] = None 

385 

386 # Session affinity metrics 

387 self._session_affinity_local_hits = 0 

388 self._session_affinity_redis_hits = 0 

389 self._session_affinity_misses = 0 

390 self._forwarded_requests = 0 

391 self._forwarded_request_failures = 0 

392 self._forwarded_request_timeouts = 0 

393 

394 async def __aenter__(self) -> "MCPSessionPool": 

395 """Async context manager entry. 

396 

397 Returns: 

398 MCPSessionPool: This pool instance. 

399 """ 

400 self.start_heartbeat() 

401 return self 

402 

403 async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

404 """Async context manager exit - closes all sessions. 

405 

406 Args: 

407 exc_type: Exception type if an exception was raised. 

408 exc_val: Exception value if an exception was raised. 

409 exc_tb: Exception traceback if an exception was raised. 

410 """ 

411 await self.close_all() 

412 

413 def _compute_identity_hash(self, headers: Optional[Dict[str, str]]) -> str: 

414 """ 

415 Compute a hash of identity-relevant headers. 

416 

417 This ensures sessions are isolated per user/tenant. Different users 

418 with different Authorization headers will have different identity hashes 

419 and thus separate session pools. 

420 

421 Identity resolution order: 

422 1. Custom identity_extractor (if configured) - for rotating tokens like JWTs 

423 2. x-mcp-session-id header (if present) - for session affinity, ensures 

424 requests with the same downstream session ID get the same upstream 

425 session even when JWT tokens rotate (different jti values) 

426 3. Configured identity headers - fallback to hashing all identity headers 

427 

428 Args: 

429 headers: Request headers dict. 

430 

431 Returns: 

432 Identity hash string, or "anonymous" if no identity headers present. 

433 """ 

434 if not headers: 

435 self._anonymous_identity_count += 1 

436 logger.debug("Session pool identity collapsed to 'anonymous' (no headers provided). " + "Sessions will be shared. Ensure this is intentional for stateless MCP servers.") 

437 return "anonymous" 

438 

439 # Try custom identity extractor first (for rotating tokens like JWTs) 

440 if self._identity_extractor: 

441 try: 

442 extracted = self._identity_extractor(headers) 

443 if extracted: 

444 return hashlib.sha256(extracted.encode()).hexdigest() 

445 except Exception as e: 

446 logger.debug(f"Identity extractor failed, falling back to header hash: {e}") 

447 

448 # Normalize headers for case-insensitive lookup 

449 headers_lower = {k.lower(): v for k, v in headers.items()} 

450 

451 # Session affinity: prioritize x-mcp-session-id for stable identity 

452 # When present, use ONLY the session ID for identity hash. This ensures 

453 # requests with the same downstream session ID get the same upstream session, 

454 # even when JWT tokens rotate (different jti values per request). 

455 if settings.mcpgateway_session_affinity_enabled: 

456 session_id = headers_lower.get("x-mcp-session-id") 

457 if session_id: 

458 logger.debug(f"Using x-mcp-session-id for session affinity: {session_id[:8]}...") 

459 return hashlib.sha256(session_id.encode()).hexdigest() 

460 

461 # Fallback: extract identity from configured headers 

462 identity_parts = [] 

463 

464 for header in sorted(self._identity_headers): 

465 if header in headers_lower: 

466 identity_parts.append(f"{header}:{headers_lower[header]}") 

467 

468 if not identity_parts: 

469 self._anonymous_identity_count += 1 

470 logger.debug( 

471 "Session pool identity collapsed to 'anonymous' (no identity headers found). " + "Expected headers: %s. Sessions will be shared.", 

472 list(self._identity_headers), 

473 ) 

474 return "anonymous" 

475 

476 # Create a stable, deterministic hash using JSON serialization 

477 # Prevents delimiter-collision or injection issues present in string joining 

478 serialized_identity = orjson.dumps(identity_parts) 

479 return hashlib.sha256(serialized_identity).hexdigest() 

480 

481 def _make_pool_key( 

482 self, 

483 url: str, 

484 headers: Optional[Dict[str, str]], 

485 transport_type: TransportType, 

486 user_identity: str, 

487 gateway_id: Optional[str] = None, 

488 ) -> PoolKey: 

489 """Create composite pool key from URL, identity, transport type, user identity, and gateway_id. 

490 

491 Including gateway_id ensures correct notification attribution when multiple gateways 

492 share the same URL/auth. Sessions are isolated per gateway for proper event routing. 

493 """ 

494 identity_hash = self._compute_identity_hash(headers) 

495 

496 # Anonymize user identity by hashing it (unless it's commonly "anonymous") 

497 # Use full hash for collision resistance - truncate only for display in logs/metrics 

498 if user_identity == "anonymous": 

499 user_hash = "anonymous" 

500 else: 

501 user_hash = hashlib.sha256(user_identity.encode()).hexdigest() 

502 

503 # Use empty string for None gateway_id to maintain consistent key type 

504 gw_id = gateway_id or "" 

505 

506 return (user_hash, url, identity_hash, transport_type.value, gw_id) 

507 

508 async def _get_or_create_lock(self, pool_key: PoolKey) -> asyncio.Lock: 

509 """Get or create a lock for the given pool key (thread-safe).""" 

510 async with self._global_lock: 

511 if pool_key not in self._locks: 

512 self._locks[pool_key] = asyncio.Lock() 

513 return self._locks[pool_key] 

514 

515 async def _get_or_create_pool(self, pool_key: PoolKey) -> asyncio.Queue[PooledSession]: 

516 """Get or create a pool queue for the given key (thread-safe).""" 

517 async with self._global_lock: 

518 if pool_key not in self._pools: 

519 self._pools[pool_key] = asyncio.Queue(maxsize=self._max_sessions) 

520 self._active[pool_key] = set() 

521 self._semaphores[pool_key] = asyncio.Semaphore(self._max_sessions) 

522 return self._pools[pool_key] 

523 

524 def _is_circuit_open(self, url: str) -> bool: 

525 """Check if circuit breaker is open for a URL.""" 

526 if url not in self._circuit_open_until: 

527 return False 

528 if time.time() >= self._circuit_open_until[url]: 

529 # Circuit breaker reset 

530 del self._circuit_open_until[url] 

531 self._failures[url] = 0 

532 logger.info(f"Circuit breaker reset for {sanitize_url_for_logging(url)}") 

533 return False 

534 return True 

535 

536 def _record_failure(self, url: str) -> None: 

537 """Record a failure and potentially trip circuit breaker.""" 

538 self._failures[url] = self._failures.get(url, 0) + 1 

539 if self._failures[url] >= self._circuit_breaker_threshold: 

540 self._circuit_open_until[url] = time.time() + self._circuit_breaker_reset 

541 self._circuit_breaker_trips += 1 

542 logger.warning(f"Circuit breaker opened for {sanitize_url_for_logging(url)} after {self._failures[url]} failures. " f"Will reset in {self._circuit_breaker_reset}s") 

543 

544 def _record_success(self, url: str) -> None: 

545 """Record a success, resetting failure count.""" 

546 self._failures[url] = 0 

547 

548 @staticmethod 

549 def is_valid_mcp_session_id(session_id: str) -> bool: 

550 """Validate downstream MCP session ID format for affinity. 

551 

552 Used for: 

553 - Redis key construction (ownership + mapping) 

554 - Pub/Sub channel naming 

555 - Avoiding log spam / injection 

556 """ 

557 if not session_id: 

558 return False 

559 return bool(_MCP_SESSION_ID_PATTERN.match(session_id)) 

560 

561 def _sanitize_redis_key_component(self, value: str) -> str: 

562 """Sanitize a value for use in Redis key construction. 

563 

564 Replaces any characters that could cause key collision or injection. 

565 

566 Args: 

567 value: The value to sanitize. 

568 

569 Returns: 

570 Sanitized value safe for Redis key construction. 

571 """ 

572 if not value: 

573 return "" 

574 

575 # Replace problematic characters with underscores 

576 return re.sub(r"[^a-zA-Z0-9_-]", "_", value) 

577 

578 def _session_mapping_redis_key(self, mcp_session_id: str, url: str, transport_type: str, gateway_id: str) -> str: 

579 """Compute a bounded Redis key for session mapping. 

580 

581 The URL is hashed to keep keys small and avoid special character issues. 

582 """ 

583 sanitized_session_id = self._sanitize_redis_key_component(mcp_session_id) 

584 url_hash = hashlib.sha256(url.encode()).hexdigest()[:16] 

585 return f"mcpgw:session_mapping:{sanitized_session_id}:{url_hash}:{transport_type}:{gateway_id}" 

586 

587 @staticmethod 

588 def _pool_owner_key(mcp_session_id: str) -> str: 

589 """Return Redis key for session ownership tracking.""" 

590 return f"mcpgw:pool_owner:{mcp_session_id}" 

591 

592 def _worker_heartbeat_key(self) -> str: 

593 """Redis key for this worker's heartbeat.""" 

594 return f"mcpgw:worker_heartbeat:{WORKER_ID}" 

595 

596 def start_heartbeat(self) -> None: 

597 """Start the worker heartbeat background task. 

598 

599 Must be called from an async context. Safe to call multiple times; 

600 subsequent calls are no-ops if the heartbeat is already running. 

601 """ 

602 if not settings.mcpgateway_session_affinity_enabled: 

603 return 

604 if self._heartbeat_task is None or self._heartbeat_task.done(): 

605 self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop()) 

606 

607 async def _run_heartbeat_loop(self) -> None: 

608 """Maintain worker heartbeat in Redis.""" 

609 # First-Party 

610 from mcpgateway.utils.redis_client import get_redis_client 

611 

612 while not self._closed: 

613 try: 

614 redis = await get_redis_client() 

615 if redis: 

616 # Refresh heartbeat with 30s TTL (much shorter than session TTL) 

617 await redis.setex(self._worker_heartbeat_key(), 30, "alive") 

618 except Exception as e: 

619 logger.debug(f"Heartbeat update failed: {e}") 

620 

621 await asyncio.sleep(10) # Refresh every 10s 

622 

623 async def _is_worker_alive(self, worker_id: str) -> bool: 

624 """Check if a worker is alive via heartbeat.""" 

625 try: 

626 # First-Party 

627 from mcpgateway.utils.redis_client import get_redis_client 

628 

629 redis = await get_redis_client() 

630 if not redis: 

631 return True # Assume alive if Redis unavailable 

632 

633 heartbeat_key = f"mcpgw:worker_heartbeat:{worker_id}" 

634 return await redis.exists(heartbeat_key) > 0 

635 except Exception: 

636 return True # Fail open 

637 

638 async def register_session_mapping( 

639 self, 

640 mcp_session_id: str, 

641 url: str, 

642 gateway_id: str, 

643 transport_type: str, 

644 user_email: Optional[str] = None, 

645 ) -> None: 

646 """Pre-register session mapping for session affinity. 

647 

648 Called from respond() to set up mapping BEFORE acquire() is called. 

649 This ensures acquire() can find the correct pool key for session affinity. 

650 

651 The mapping stores the relationship between an incoming MCP session ID 

652 and the pool key that should be used for upstream connections. This 

653 enables session affinity even when JWT tokens rotate (different jti values 

654 per request). 

655 

656 For multi-worker deployments, the mapping is also stored in Redis with TTL 

657 so that any worker can look it up during acquire(). 

658 

659 Args: 

660 mcp_session_id: The downstream MCP session ID from x-mcp-session-id header. 

661 url: The upstream MCP server URL. 

662 gateway_id: The gateway ID. 

663 transport_type: The transport type (sse, streamablehttp). 

664 user_email: The email of the authenticated user (or "system" for unauthenticated). 

665 """ 

666 if not settings.mcpgateway_session_affinity_enabled: 

667 return 

668 

669 # Validate mcp_session_id to prevent Redis key injection 

670 if not self.is_valid_mcp_session_id(mcp_session_id): 

671 logger.warning(f"Invalid mcp_session_id format, skipping session mapping: {mcp_session_id[:20]}...") 

672 return 

673 

674 # Use user email for user_identity, or "anonymous" if not provided 

675 user_identity = user_email or "anonymous" 

676 

677 # Normalize gateway_id to empty string if None for consistent key matching 

678 normalized_gateway_id = gateway_id or "" 

679 

680 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type, normalized_gateway_id) 

681 

682 # Compute what the pool_key will be for this session 

683 # Use mcp_session_id as the identity basis for affinity 

684 identity_hash = hashlib.sha256(mcp_session_id.encode()).hexdigest() 

685 

686 # Hash user identity for privacy (unless it's "anonymous") 

687 if user_identity == "anonymous": 

688 user_hash = "anonymous" 

689 else: 

690 user_hash = hashlib.sha256(user_identity.encode()).hexdigest() 

691 

692 pool_key: PoolKey = (user_hash, url, identity_hash, transport_type, normalized_gateway_id) 

693 

694 # Store in local memory 

695 async with self._mcp_session_mapping_lock: 

696 self._mcp_session_mapping[mapping_key] = pool_key 

697 logger.debug(f"Session affinity pre-registered (local): {mcp_session_id[:8]}... → {url}, user={SecurityValidator.sanitize_log_message(user_identity)}") 

698 

699 # Store in Redis for multi-worker support AND register ownership atomically 

700 # Registering ownership HERE (during mapping) instead of in acquire() prevents 

701 # a race condition where two workers could both start creating sessions before 

702 # either registers ownership 

703 try: 

704 # First-Party 

705 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

706 

707 redis = await get_redis_client() 

708 if redis: 

709 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type, normalized_gateway_id) 

710 

711 # Store pool_key as JSON for easy deserialization 

712 pool_key_data = { 

713 "user_hash": user_hash, 

714 "url": url, 

715 "identity_hash": identity_hash, 

716 "transport_type": transport_type, 

717 "gateway_id": normalized_gateway_id, 

718 } 

719 await redis.setex(redis_key, settings.mcpgateway_session_affinity_ttl, orjson.dumps(pool_key_data)) # TTL from config 

720 

721 # CRITICAL: Register ownership atomically with mapping. 

722 # This claims ownership BEFORE any session creation attempt, preventing 

723 # the race condition where two workers both start creating sessions 

724 owner_key = self._pool_owner_key(mcp_session_id) 

725 # Atomic claim with TTL (avoids the SETNX/EXPIRE crash window). 

726 was_set = await redis.set(owner_key, WORKER_ID, nx=True, ex=settings.mcpgateway_session_affinity_ttl) 

727 if was_set: 

728 logger.debug(f"Session ownership claimed (SET NX): {mcp_session_id[:8]}... → worker {WORKER_ID}") 

729 else: 

730 # Another worker already claimed ownership 

731 existing_owner = await redis.get(owner_key) 

732 owner_id = existing_owner.decode() if isinstance(existing_owner, bytes) else existing_owner 

733 logger.debug(f"Session ownership already claimed by {owner_id}: {mcp_session_id[:8]}...") 

734 

735 logger.debug(f"Session affinity pre-registered (Redis): {mcp_session_id[:8]}... TTL={settings.mcpgateway_session_affinity_ttl}s") 

736 except Exception as e: 

737 # Redis failure is non-fatal - local mapping still works for same-worker requests 

738 logger.debug(f"Failed to store session mapping in Redis: {e}") 

739 

740 async def acquire( 

741 self, 

742 url: str, 

743 headers: Optional[Dict[str, str]] = None, 

744 transport_type: TransportType = TransportType.STREAMABLE_HTTP, 

745 httpx_client_factory: Optional[HttpxClientFactory] = None, 

746 timeout: Optional[float] = None, 

747 user_identity: Optional[str] = None, 

748 gateway_id: Optional[str] = None, 

749 ) -> PooledSession: 

750 """ 

751 Acquire a session for the given URL, identity, and transport type. 

752 

753 Sessions are isolated by identity (derived from auth headers) AND 

754 transport type. Returns an initialized, healthy session ready for tool calls. 

755 

756 Args: 

757 url: The MCP server URL. 

758 headers: Request headers (used for identity hashing and passed to server). 

759 transport_type: The transport type (SSE or STREAMABLE_HTTP). 

760 httpx_client_factory: Optional factory for creating httpx clients 

761 (for custom SSL/timeout configuration). 

762 timeout: Optional timeout in seconds for transport connection. 

763 gateway_id: Optional gateway ID for notification handler context. 

764 

765 Returns: 

766 PooledSession ready for use. 

767 

768 Raises: 

769 asyncio.TimeoutError: If acquire times out waiting for available session. 

770 RuntimeError: If pool is closed or circuit breaker is open. 

771 Exception: If session creation fails. 

772 """ 

773 if self._closed: 

774 raise RuntimeError("Session pool is closed") 

775 

776 if self._is_circuit_open(url): 

777 raise RuntimeError(f"Circuit breaker open for {url}") 

778 

779 # Use default timeout if not provided 

780 effective_timeout = timeout if timeout is not None else self._default_transport_timeout 

781 

782 user_id = user_identity or "anonymous" 

783 pool_key: Optional[PoolKey] = None 

784 

785 # Check pre-registered mapping first (set by respond() for session affinity) 

786 if settings.mcpgateway_session_affinity_enabled and headers: 

787 headers_lower = {k.lower(): v for k, v in headers.items()} 

788 mcp_session_id = headers_lower.get("x-mcp-session-id") 

789 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

790 normalized_gateway_id = gateway_id or "" 

791 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type.value, normalized_gateway_id) 

792 

793 # Check local memory first (fast path - same worker) 

794 async with self._mcp_session_mapping_lock: 

795 pool_key = self._mcp_session_mapping.get(mapping_key) 

796 if pool_key: 

797 self._session_affinity_local_hits += 1 

798 logger.debug(f"Session affinity hit (local): {mcp_session_id[:8]}...") 

799 

800 # If not in local memory, check Redis (multi-worker support) 

801 if pool_key is None: 

802 try: 

803 # First-Party 

804 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

805 

806 redis = await get_redis_client() 

807 if redis: 

808 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type.value, normalized_gateway_id) 

809 pool_key_data = await redis.get(redis_key) 

810 if pool_key_data: 

811 # Deserialize pool_key from JSON 

812 data = orjson.loads(pool_key_data) 

813 pool_key = ( 

814 data["user_hash"], 

815 data["url"], 

816 data["identity_hash"], 

817 data["transport_type"], 

818 data["gateway_id"], 

819 ) 

820 # Cache in local memory for future requests 

821 async with self._mcp_session_mapping_lock: 

822 self._mcp_session_mapping[mapping_key] = pool_key 

823 self._session_affinity_redis_hits += 1 

824 logger.debug(f"Session affinity hit (Redis): {mcp_session_id[:8]}...") 

825 except Exception as e: 

826 logger.debug(f"Failed to check Redis for session mapping: {e}") 

827 

828 # Fallback to normal pool key computation 

829 if pool_key is None: 

830 self._session_affinity_misses += 1 

831 pool_key = self._make_pool_key(url, headers, transport_type, user_id, gateway_id) 

832 

833 pool = await self._get_or_create_pool(pool_key) 

834 

835 # Update pool key last used time IMMEDIATELY after getting pool 

836 # This prevents race with eviction removing keys between awaits 

837 self._pool_last_used[pool_key] = time.time() 

838 

839 lock = await self._get_or_create_lock(pool_key) 

840 

841 # Guard semaphore access - eviction may have removed it between awaits 

842 # If so, re-create the pool structures 

843 if pool_key not in self._semaphores: 

844 pool = await self._get_or_create_pool(pool_key) 

845 self._pool_last_used[pool_key] = time.time() 

846 

847 semaphore = self._semaphores[pool_key] 

848 

849 # Throttled eviction - only run if enough time has passed (inline, not spawned) 

850 await self._maybe_evict_idle_pool_keys() 

851 

852 # Try to get from pool first (quick path, no lock needed for queue get) 

853 while True: 

854 try: 

855 pooled = pool.get_nowait() 

856 except asyncio.QueueEmpty: 

857 break 

858 

859 # Validate the session outside the lock 

860 if await self._validate_session(pooled): 

861 pooled.last_used = time.time() 

862 pooled.use_count += 1 

863 self._hits += 1 

864 async with lock: 

865 self._active[pool_key].add(pooled) 

866 logger.debug(f"Pool hit for {sanitize_url_for_logging(url)} (identity={pool_key[2][:8]}, transport={transport_type.value})") 

867 return pooled 

868 

869 # Session invalid, close it 

870 await self._close_session(pooled) 

871 self._evictions += 1 

872 semaphore.release() # Free up a slot 

873 

874 # No valid session in pool - try to create one or wait 

875 try: 

876 # Use semaphore with timeout to limit concurrent sessions 

877 acquired = await asyncio.wait_for(semaphore.acquire(), timeout=self._acquire_timeout) 

878 if not acquired: 

879 raise asyncio.TimeoutError("Failed to acquire session slot") 

880 except asyncio.TimeoutError: 

881 raise asyncio.TimeoutError(f"Timeout waiting for available session for {sanitize_url_for_logging(url)}") from None 

882 

883 # Create new session (semaphore acquired) 

884 try: 

885 # Verify we own this session before creating (prevents race condition) 

886 # If another worker already claimed ownership, we should not create a new session 

887 # Note: Ownership is registered atomically in register_session_mapping() using SETNX 

888 if settings.mcpgateway_session_affinity_enabled and headers: 

889 headers_lower = {k.lower(): v for k, v in headers.items()} 

890 mcp_session_id = headers_lower.get("x-mcp-session-id") 

891 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

892 owner = await self._get_pool_session_owner(mcp_session_id) 

893 if owner and owner != WORKER_ID: 

894 # Check if owner is still alive 

895 if not await self._is_worker_alive(owner): 

896 # Owner is dead - reclaim ownership with compare-and-swap 

897 # First-Party 

898 from mcpgateway.utils.redis_client import get_redis_client 

899 

900 redis = await get_redis_client() 

901 if redis: 

902 owner_key = self._pool_owner_key(mcp_session_id) 

903 # Lua CAS: only reclaim if still owned by the dead worker 

904 cas_script = """ 

905 local cur = redis.call('GET', KEYS[1]) 

906 if cur == ARGV[1] then 

907 redis.call('SET', KEYS[1], ARGV[2], 'EX', ARGV[3]) 

908 return 1 

909 end 

910 return 0 

911 """ 

912 ttl = int(settings.mcpgateway_session_affinity_ttl) 

913 reclaimed = await redis.eval(cas_script, 1, owner_key, owner, WORKER_ID, ttl) 

914 if reclaimed == 1: 

915 logger.info(f"Reclaimed ownership from dead worker {owner}: {mcp_session_id[:8]}...") 

916 else: 

917 # Another worker already reclaimed - let it handle 

918 # (outer except BaseException releases the semaphore) 

919 raise RuntimeError(f"Session reclaimed by another worker: {mcp_session_id[:8]}...") 

920 else: 

921 # Owner is alive - should have been forwarded 

922 # (outer except BaseException releases the semaphore) 

923 logger.warning(f"Session {mcp_session_id[:8]}... owned by worker {owner}, not us ({WORKER_ID})") 

924 raise RuntimeError(f"Session owned by another worker: {owner}") 

925 

926 pooled = await asyncio.wait_for( 

927 self._create_session(url, headers, transport_type, httpx_client_factory, effective_timeout, gateway_id), 

928 timeout=self._session_create_timeout, 

929 ) 

930 # Store identity components for key reconstruction 

931 pooled.identity_key = pool_key[2] 

932 pooled.user_identity = user_id 

933 

934 # Note: Ownership is now registered atomically in register_session_mapping() 

935 # before acquire() is called, so we don't need to register it here 

936 

937 self._misses += 1 

938 self._record_success(url) 

939 async with lock: 

940 self._active[pool_key].add(pooled) 

941 logger.debug(f"Pool miss for {sanitize_url_for_logging(url)} - created new session (transport={transport_type.value})") 

942 return pooled 

943 except BaseException as e: 

944 # Release semaphore on ANY failure (including CancelledError) 

945 semaphore.release() 

946 if not isinstance(e, asyncio.CancelledError): 

947 self._record_failure(url) 

948 logger.warning(f"Failed to create session for {sanitize_url_for_logging(url)}: {e}") 

949 raise 

950 

951 async def release(self, pooled: PooledSession, *, discard: bool = False) -> None: 

952 """ 

953 Return a session to the pool for reuse, or discard it. 

954 

955 Args: 

956 pooled: The session to release. 

957 discard: If True, close the session instead of returning it to the 

958 pool. Used when the caller detected a transport error 

959 (e.g. ``ClosedResourceError``) to prevent recycling a 

960 broken session. 

961 """ 

962 # Treat already-closed sessions (e.g. dead owner task) as discards — 

963 # still need to remove from _active and release the semaphore slot. 

964 if pooled.is_closed: 

965 discard = True 

966 

967 # Pool key includes transport type, user identity, and gateway_id 

968 # Re-compute user hash from stored raw identity (full hash for collision resistance) 

969 user_hash = "anonymous" 

970 if pooled.user_identity != "anonymous": 

971 user_hash = hashlib.sha256(pooled.user_identity.encode()).hexdigest() 

972 

973 pool_key = (user_hash, pooled.url, pooled.identity_key, pooled.transport_type.value, pooled.gateway_id) 

974 lock = await self._get_or_create_lock(pool_key) 

975 pool = await self._get_or_create_pool(pool_key) 

976 

977 async with lock: 

978 # Update last-used FIRST to prevent eviction race: 

979 # If eviction runs between removing from _active and putting back in pool, 

980 # it would see key as idle + inactive and evict it. By updating last-used 

981 # while still holding the lock and before removing from _active, we ensure 

982 # eviction sees recent activity. 

983 self._pool_last_used[pool_key] = time.time() 

984 self._active.get(pool_key, set()).discard(pooled) 

985 

986 # Discard broken sessions instead of recycling them 

987 if discard: 

988 logger.debug(f"Discarding broken session for {sanitize_url_for_logging(pooled.url)}") 

989 await self._close_session(pooled) 

990 if pool_key in self._semaphores: 

991 self._semaphores[pool_key].release() 

992 self._evictions += 1 

993 return 

994 

995 # Check if session should be returned to pool 

996 if self._closed or pooled.age_seconds > self._session_ttl: 

997 await self._close_session(pooled) 

998 if pool_key in self._semaphores: 

999 self._semaphores[pool_key].release() 

1000 if pooled.age_seconds > self._session_ttl: 

1001 self._evictions += 1 

1002 return 

1003 

1004 # Return to pool (pool may have been evicted in edge case, recreate if needed) 

1005 if pool_key not in self._pools: 

1006 pool = await self._get_or_create_pool(pool_key) 

1007 self._pool_last_used[pool_key] = time.time() 

1008 

1009 try: 

1010 pool.put_nowait(pooled) 

1011 logger.debug(f"Session returned to pool for {sanitize_url_for_logging(pooled.url)}") 

1012 except asyncio.QueueFull: 

1013 # Pool full (shouldn't happen with semaphore), close session 

1014 await self._close_session(pooled) 

1015 if pool_key in self._semaphores: 

1016 self._semaphores[pool_key].release() 

1017 

1018 async def _maybe_evict_idle_pool_keys(self) -> None: 

1019 """ 

1020 Reap stale sessions and evict idle pool keys. 

1021 

1022 This method is throttled - it only runs eviction if enough time has 

1023 passed since the last run (default: 60 seconds). This prevents: 

1024 - Unbounded task spawning on every acquire 

1025 - Lock contention under high load 

1026 

1027 Two-phase cleanup: 

1028 1. Close expired/stale sessions parked in idle pools (frees connections) 

1029 2. Evict pool keys that are now empty and have no active sessions 

1030 

1031 This prevents unbounded connection and pool key growth when using 

1032 rotating tokens (e.g., short-lived JWTs with unique identifiers). 

1033 """ 

1034 if self._closed: 

1035 return 

1036 

1037 now = time.time() 

1038 

1039 # Throttle: only run eviction once per interval 

1040 if now - self._last_eviction_run < self._eviction_run_interval: 

1041 return 

1042 

1043 self._last_eviction_run = now 

1044 

1045 # Collect sessions to close and keys to evict (minimize time holding lock) 

1046 sessions_to_close: list[PooledSession] = [] 

1047 keys_to_evict: list[PoolKey] = [] 

1048 

1049 async with self._global_lock: 

1050 for pool_key, last_used in list(self._pool_last_used.items()): 

1051 # Skip recently-used pools 

1052 if now - last_used < self._idle_pool_eviction: 

1053 continue 

1054 

1055 pool = self._pools.get(pool_key) 

1056 active = self._active.get(pool_key, set()) 

1057 

1058 # Skip if there are active sessions (in use) 

1059 if active: 

1060 continue 

1061 

1062 if pool: 

1063 # Phase 1: Drain and collect expired/stale sessions from idle pools 

1064 while not pool.empty(): 

1065 try: 

1066 session = pool.get_nowait() 

1067 # Close if expired OR idle too long (defense in depth) 

1068 if session.age_seconds > self._session_ttl or session.idle_seconds > self._idle_pool_eviction: 

1069 sessions_to_close.append(session) 

1070 # Release semaphore slot for this session 

1071 if pool_key in self._semaphores: 

1072 self._semaphores[pool_key].release() 

1073 else: 

1074 # Session still valid, put it back 

1075 pool.put_nowait(session) 

1076 break # Stop draining if we find a valid session 

1077 except asyncio.QueueEmpty: 

1078 break 

1079 

1080 # Phase 2: Evict pool key if now empty 

1081 if pool.empty(): 

1082 keys_to_evict.append(pool_key) 

1083 

1084 # Remove evicted keys from all tracking dicts 

1085 for pool_key in keys_to_evict: 

1086 self._pools.pop(pool_key, None) 

1087 self._active.pop(pool_key, None) 

1088 self._locks.pop(pool_key, None) 

1089 self._semaphores.pop(pool_key, None) 

1090 self._pool_last_used.pop(pool_key, None) 

1091 

1092 # Clean up session mappings pointing to this evicted pool key 

1093 async with self._mcp_session_mapping_lock: 

1094 stale_mappings = [k for k, v in self._mcp_session_mapping.items() if v == pool_key] 

1095 for mapping_key in stale_mappings: 

1096 self._mcp_session_mapping.pop(mapping_key, None) 

1097 

1098 self._pool_keys_evicted += 1 

1099 logger.debug(f"Evicted idle pool key: {pool_key[0][:8]}|{pool_key[1]}|{pool_key[2][:8]}") 

1100 

1101 # Close sessions outside the lock (I/O operations) 

1102 for session in sessions_to_close: 

1103 await self._close_session(session) 

1104 self._sessions_reaped += 1 

1105 logger.debug(f"Reaped stale session for {sanitize_url_for_logging(session.url)} (age={session.age_seconds:.1f}s)") 

1106 

1107 async def _validate_session(self, pooled: PooledSession) -> bool: 

1108 """ 

1109 Validate a session is still usable. 

1110 

1111 Checks TTL and performs health check if session is stale. 

1112 

1113 Args: 

1114 pooled: The session to validate. 

1115 

1116 Returns: 

1117 True if session is valid, False otherwise. 

1118 """ 

1119 if pooled.is_closed: 

1120 return False 

1121 

1122 # Check TTL 

1123 if pooled.age_seconds > self._session_ttl: 

1124 logger.debug(f"Session expired (age={pooled.age_seconds:.1f}s)") 

1125 return False 

1126 

1127 # Health check if stale 

1128 if pooled.idle_seconds > self._health_check_interval: 

1129 return await self._run_health_check_chain(pooled) 

1130 

1131 return True 

1132 

1133 async def _run_health_check_chain(self, pooled: PooledSession) -> bool: 

1134 """ 

1135 Run health check methods in configured order until one succeeds. 

1136 

1137 The health check chain allows configuring which methods to try and in what order. 

1138 This supports both modern servers (with ping support) and legacy servers 

1139 (that may only support list_tools or no health check at all). 

1140 

1141 Args: 

1142 pooled: The session to health check. 

1143 

1144 Returns: 

1145 True if any health check method succeeds, False if all fail. 

1146 """ 

1147 for method in self._health_check_methods: 

1148 try: 

1149 if method == "ping": 

1150 with anyio.fail_after(self._health_check_timeout): 

1151 await pooled.session.send_ping() 

1152 logger.debug(f"Health check passed: ping (url={sanitize_url_for_logging(pooled.url)})") 

1153 return True 

1154 if method == "list_tools": 

1155 with anyio.fail_after(self._health_check_timeout): 

1156 await pooled.session.list_tools() 

1157 logger.debug(f"Health check passed: list_tools (url={sanitize_url_for_logging(pooled.url)})") 

1158 return True 

1159 if method == "list_prompts": 

1160 with anyio.fail_after(self._health_check_timeout): 

1161 await pooled.session.list_prompts() 

1162 logger.debug(f"Health check passed: list_prompts (url={sanitize_url_for_logging(pooled.url)})") 

1163 return True 

1164 if method == "list_resources": 

1165 with anyio.fail_after(self._health_check_timeout): 

1166 await pooled.session.list_resources() 

1167 logger.debug(f"Health check passed: list_resources (url={sanitize_url_for_logging(pooled.url)})") 

1168 return True 

1169 if method == "skip": 

1170 logger.debug(f"Health check skipped per configuration (url={sanitize_url_for_logging(pooled.url)})") 

1171 return True 

1172 logger.warning(f"Unknown health check method '{method}', skipping") 

1173 continue 

1174 

1175 except McpError as e: 

1176 # METHOD_NOT_FOUND (-32601) means the method isn't supported - try next 

1177 if e.error.code == METHOD_NOT_FOUND: 

1178 logger.debug(f"Health check method '{method}' not supported by server, trying next") 

1179 continue 

1180 # Other MCP errors are real failures 

1181 logger.debug(f"Health check '{method}' failed with MCP error: {e}") 

1182 self._health_check_failures += 1 

1183 return False 

1184 

1185 except TimeoutError: 

1186 logger.debug(f"Health check '{method}' timed out after {self._health_check_timeout}s, trying next") 

1187 continue 

1188 

1189 except Exception as e: 

1190 logger.debug(f"Health check '{method}' failed: {e}") 

1191 self._health_check_failures += 1 

1192 return False 

1193 

1194 # All methods failed or were unsupported 

1195 logger.warning(f"All health check methods failed or unsupported (methods={self._health_check_methods})") 

1196 self._health_check_failures += 1 

1197 return False 

1198 

1199 async def _session_owner_coro( 

1200 self, 

1201 url: str, 

1202 transport_type: TransportType, 

1203 merged_headers: Dict[str, str], 

1204 httpx_client_factory: Optional[HttpxClientFactory], 

1205 timeout: Optional[float], 

1206 gateway_id: Optional[str], 

1207 ready_future: "asyncio.Future[Tuple[ClientSession, Any]]", 

1208 shutdown_event: asyncio.Event, 

1209 ) -> None: 

1210 """Background task that owns the transport and session lifecycle. 

1211 

1212 Runs transport and session inside proper ``async with`` blocks so that 

1213 anyio cancel scopes are bound to THIS task, not the request handler. 

1214 Signals readiness via *ready_future*, then blocks on *shutdown_event* 

1215 until the pool requests cleanup. 

1216 """ 

1217 try: 

1218 # Build transport context 

1219 if transport_type == TransportType.SSE: 

1220 if httpx_client_factory: 

1221 transport_ctx = sse_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout) 

1222 else: 

1223 transport_ctx = sse_client(url=url, headers=merged_headers, timeout=timeout) 

1224 else: # STREAMABLE_HTTP 

1225 if httpx_client_factory: 

1226 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout) 

1227 else: 

1228 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, timeout=timeout) 

1229 

1230 async with transport_ctx as streams: 

1231 if transport_type == TransportType.SSE: 

1232 read_stream, write_stream = streams[0], streams[1] 

1233 else: 

1234 read_stream, write_stream = streams[0], streams[1] 

1235 

1236 # Create message handler if factory is configured 

1237 message_handler = None 

1238 if self._message_handler_factory: 

1239 try: 

1240 message_handler = self._message_handler_factory(url, gateway_id) 

1241 logger.debug(f"Created message handler for session {sanitize_url_for_logging(url)} (gateway={SecurityValidator.sanitize_log_message(gateway_id)})") 

1242 except Exception as e: 

1243 logger.warning(f"Failed to create message handler for {sanitize_url_for_logging(url)}: {e}") 

1244 

1245 async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: 

1246 await session.initialize() 

1247 # Signal the session is ready for use 

1248 if not ready_future.done(): 

1249 ready_future.set_result((session, transport_ctx)) 

1250 logger.info(f"Created new MCP session for {sanitize_url_for_logging(url)} (transport={transport_type.value})") 

1251 # Block here until the pool asks us to shut down 

1252 await shutdown_event.wait() 

1253 # async with ClientSession exits: session.__aexit__ unwinds properly 

1254 # async with transport_ctx exits: transport.__aexit__ unwinds properly 

1255 

1256 except BaseException as exc: 

1257 if not ready_future.done(): 

1258 ready_future.set_exception(RuntimeError(f"Failed to create MCP session for {url}: {exc}")) 

1259 # Let the task finish; the pool will detect _owner_task.done() 

1260 

1261 async def _create_session( 

1262 self, 

1263 url: str, 

1264 headers: Optional[Dict[str, str]], 

1265 transport_type: TransportType, 

1266 httpx_client_factory: Optional[HttpxClientFactory], 

1267 timeout: Optional[float] = None, 

1268 gateway_id: Optional[str] = None, 

1269 ) -> PooledSession: 

1270 """ 

1271 Create a new initialized MCP session via a dedicated background task. 

1272 

1273 The transport and session contexts are entered inside a background 

1274 ``asyncio.Task`` so that their anyio cancel scopes are bound to that 

1275 task, not to the HTTP request handler. This prevents child-task 

1276 failures in the transport's TaskGroup from cancelling the request. 

1277 

1278 Args: 

1279 url: Server URL. 

1280 headers: Request headers. 

1281 transport_type: Transport type to use. 

1282 httpx_client_factory: Optional factory for httpx clients. 

1283 timeout: Optional timeout in seconds for transport connection. 

1284 gateway_id: Optional gateway ID for notification handler context. 

1285 

1286 Returns: 

1287 Initialized PooledSession. 

1288 

1289 Raises: 

1290 RuntimeError: If session creation or initialization fails. 

1291 asyncio.CancelledError: If cancelled during creation. 

1292 """ 

1293 # Merge headers with defaults 

1294 merged_headers = {"Accept": "application/json, text/event-stream"} 

1295 if headers: 

1296 merged_headers.update(headers) 

1297 

1298 # Strip gateway-internal session affinity headers before sending to upstream 

1299 keys_to_remove = [k for k in merged_headers if k.lower() in ("x-mcp-session-id", "mcp-session-id")] 

1300 for k in keys_to_remove: 

1301 del merged_headers[k] 

1302 

1303 identity_key = self._compute_identity_hash(headers) 

1304 shutdown_event = asyncio.Event() 

1305 loop = asyncio.get_running_loop() 

1306 ready_future: asyncio.Future[Tuple[ClientSession, Any]] = loop.create_future() 

1307 

1308 owner_task = asyncio.create_task( 

1309 self._session_owner_coro(url, transport_type, merged_headers, httpx_client_factory, timeout, gateway_id, ready_future, shutdown_event), 

1310 name=f"mcp-session-owner-{sanitize_url_for_logging(url)}", 

1311 ) 

1312 

1313 success = False 

1314 try: 

1315 session, transport_ctx = await asyncio.wait_for(ready_future, timeout=self._session_create_timeout) 

1316 success = True 

1317 finally: 

1318 # Clean up owner task on ANY failure (TimeoutError, Exception, CancelledError) 

1319 if not success: 

1320 shutdown_event.set() 

1321 owner_task.cancel() 

1322 cleanup_timeout = _get_cleanup_timeout() 

1323 with anyio.move_on_after(cleanup_timeout): 

1324 try: 

1325 await owner_task 

1326 except BaseException: # nosec B110 - Best effort cleanup 

1327 pass 

1328 

1329 return PooledSession( 

1330 session=session, 

1331 transport_context=transport_ctx, 

1332 url=url, 

1333 transport_type=transport_type, 

1334 headers=merged_headers, 

1335 identity_key=identity_key, 

1336 gateway_id=gateway_id or "", 

1337 _owner_task=owner_task, 

1338 _shutdown_event=shutdown_event, 

1339 ) 

1340 

1341 async def _close_session(self, pooled: PooledSession) -> None: 

1342 """ 

1343 Close a session and its transport. 

1344 

1345 For sessions with a background owner task, signals the task to shut down 

1346 and waits for it to complete (which unwinds the ``async with`` contexts 

1347 naturally). Falls back to manual ``__aexit__`` for legacy sessions 

1348 without an owner task. 

1349 

1350 Args: 

1351 pooled: The session to close. 

1352 """ 

1353 if pooled.is_closed and pooled.shutdown_event is None: 

1354 # Truly closed (legacy path, no owner task) — nothing to clean up 

1355 return 

1356 

1357 pooled.mark_closed() 

1358 cleanup_timeout = _get_cleanup_timeout() 

1359 

1360 if pooled.shutdown_event is not None and pooled.owner_task is not None: 

1361 # Signal the owner task to shut down gracefully 

1362 pooled.shutdown_event.set() 

1363 

1364 if not pooled.owner_task.done(): 

1365 # Wait for graceful exit 

1366 with anyio.move_on_after(cleanup_timeout) as scope: 

1367 try: 

1368 await pooled.owner_task 

1369 except (asyncio.CancelledError, Exception): # nosec B110 

1370 pass 

1371 if scope.cancelled_caught: 

1372 logger.warning(f"Session owner cleanup timed out for {sanitize_url_for_logging(pooled.url)} - force cancelling") 

1373 pooled.owner_task.cancel() 

1374 try: 

1375 await pooled.owner_task 

1376 except (asyncio.CancelledError, Exception): # nosec B110 

1377 pass 

1378 else: 

1379 # Legacy path: manual __aexit__ for sessions without owner task 

1380 with anyio.move_on_after(cleanup_timeout) as session_scope: 

1381 try: 

1382 await pooled.session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1383 except Exception as e: 

1384 logger.debug(f"Error closing session: {e}") 

1385 if session_scope.cancelled_caught: 

1386 logger.warning(f"Session cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway") 

1387 

1388 if pooled.transport_context is not None: 

1389 with anyio.move_on_after(cleanup_timeout) as transport_scope: 

1390 try: 

1391 await pooled.transport_context.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1392 except Exception as e: 

1393 logger.debug(f"Error closing transport: {e}") 

1394 if transport_scope.cancelled_caught: 

1395 logger.warning(f"Transport cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway") 

1396 

1397 logger.debug(f"Closed session for {sanitize_url_for_logging(pooled.url)} (uses={pooled.use_count})") 

1398 

1399 # Clean up pool_owner key in Redis for session affinity 

1400 if settings.mcpgateway_session_affinity_enabled and pooled.headers: 

1401 headers_lower = {k.lower(): v for k, v in pooled.headers.items()} 

1402 mcp_session_id = headers_lower.get("x-mcp-session-id") 

1403 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

1404 await self._cleanup_pool_session_owner(mcp_session_id) 

1405 

1406 async def _cleanup_pool_session_owner(self, mcp_session_id: str) -> None: 

1407 """Clean up pool_owner key in Redis when session is closed. 

1408 

1409 Only deletes the key if this worker owns it (to prevent removing other workers' ownership). 

1410 

1411 Args: 

1412 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1413 """ 

1414 try: 

1415 # First-Party 

1416 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1417 

1418 redis = await get_redis_client() 

1419 if redis: 

1420 key = self._pool_owner_key(mcp_session_id) 

1421 # Only delete if we own it 

1422 owner = await redis.get(key) 

1423 if owner: 

1424 owner_id = owner.decode() if isinstance(owner, bytes) else owner 

1425 if owner_id == WORKER_ID: 

1426 await redis.delete(key) 

1427 logger.debug(f"Cleaned up pool session owner: {mcp_session_id[:8]}...") 

1428 except Exception as e: 

1429 # Cleanup failure is non-fatal 

1430 logger.debug(f"Failed to cleanup pool session owner in Redis: {e}") 

1431 

1432 async def cleanup_streamable_http_session_owner(self, mcp_session_id: str) -> None: 

1433 """Public wrapper for cleaning up Streamable HTTP session ownership. 

1434 

1435 This is used by trusted internal MCP session teardown paths that need to 

1436 remove affinity ownership without reaching into private helpers. 

1437 """ 

1438 if not self.is_valid_mcp_session_id(mcp_session_id): 

1439 logger.debug("Invalid mcp_session_id for owner cleanup, skipping") 

1440 return 

1441 await self._cleanup_pool_session_owner(mcp_session_id) 

1442 

1443 async def close_all(self) -> None: 

1444 """ 

1445 Gracefully close all pooled and active sessions. 

1446 

1447 Should be called during application shutdown. 

1448 """ 

1449 self._closed = True 

1450 logger.info("Closing all pooled sessions...") 

1451 

1452 async with self._global_lock: 

1453 # Close all pooled sessions 

1454 for _pool_key, pool in list(self._pools.items()): 

1455 while not pool.empty(): 

1456 try: 

1457 pooled = pool.get_nowait() 

1458 await self._close_session(pooled) 

1459 except asyncio.QueueEmpty: 

1460 break 

1461 

1462 # Close all active sessions 

1463 for _pool_key, active_set in list(self._active.items()): 

1464 for pooled in list(active_set): 

1465 await self._close_session(pooled) 

1466 

1467 self._pools.clear() 

1468 self._active.clear() 

1469 self._locks.clear() 

1470 self._semaphores.clear() 

1471 self._mcp_session_mapping.clear() 

1472 

1473 # Stop RPC listener if running 

1474 if self._rpc_listener_task and not self._rpc_listener_task.done(): 

1475 self._rpc_listener_task.cancel() 

1476 try: 

1477 await self._rpc_listener_task 

1478 except asyncio.CancelledError: 

1479 pass 

1480 self._rpc_listener_task = None 

1481 

1482 # Stop heartbeat if running 

1483 if self._heartbeat_task and not self._heartbeat_task.done(): 

1484 self._heartbeat_task.cancel() 

1485 try: 

1486 await self._heartbeat_task 

1487 except asyncio.CancelledError: 

1488 pass 

1489 self._heartbeat_task = None 

1490 

1491 logger.info("All sessions closed") 

1492 

1493 async def drain_all(self) -> None: 

1494 """Close all pooled and active sessions without marking the pool as closed. 

1495 

1496 Unlike ``close_all()``, the pool remains operational after draining. 

1497 New sessions will be created on demand with fresh TLS state. 

1498 Use this for certificate rotation (SIGHUP). 

1499 """ 

1500 logger.info("Draining all pooled sessions for TLS rotation...") 

1501 

1502 async with self._global_lock: 

1503 for _pool_key, pool in list(self._pools.items()): 

1504 while not pool.empty(): 

1505 try: 

1506 pooled = pool.get_nowait() 

1507 await self._close_session(pooled) 

1508 except asyncio.QueueEmpty: 

1509 break 

1510 

1511 for _pool_key, active_set in list(self._active.items()): 

1512 for pooled in list(active_set): 

1513 await self._close_session(pooled) 

1514 

1515 self._pools.clear() 

1516 self._active.clear() 

1517 self._mcp_session_mapping.clear() 

1518 

1519 logger.info("All pooled sessions drained; pool remains operational") 

1520 

1521 async def register_pool_session_owner(self, mcp_session_id: str) -> None: 

1522 """Register this worker as owner of a pool session in Redis. 

1523 

1524 This enables multi-worker session affinity by tracking which worker owns 

1525 which pool session. When a request with x-mcp-session-id arrives at a 

1526 different worker, it can forward the request to the owner worker. 

1527 

1528 Note: This method is now primarily used for refreshing TTL on existing ownership. 

1529 Initial ownership is claimed atomically in register_session_mapping() using SETNX. 

1530 

1531 Args: 

1532 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1533 """ 

1534 if not settings.mcpgateway_session_affinity_enabled: 

1535 return 

1536 

1537 if not self.is_valid_mcp_session_id(mcp_session_id): 

1538 logger.debug("Invalid mcp_session_id for owner registration, skipping") 

1539 return 

1540 

1541 try: 

1542 # First-Party 

1543 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1544 

1545 redis = await get_redis_client() 

1546 if redis: 

1547 key = self._pool_owner_key(mcp_session_id) 

1548 

1549 # Do not steal ownership: only claim if missing, or refresh TTL if we already own. 

1550 # Lua keeps this atomic. 

1551 script = """ 

1552 local cur = redis.call('GET', KEYS[1]) 

1553 if not cur then 

1554 redis.call('SET', KEYS[1], ARGV[1], 'EX', ARGV[2]) 

1555 return 1 

1556 end 

1557 if cur == ARGV[1] then 

1558 redis.call('EXPIRE', KEYS[1], ARGV[2]) 

1559 return 2 

1560 end 

1561 return 0 

1562 """ 

1563 ttl = int(settings.mcpgateway_session_affinity_ttl) 

1564 outcome = await redis.eval(script, 1, key, WORKER_ID, ttl) 

1565 logger.debug(f"Owner registration outcome={outcome} for session {mcp_session_id[:8]}...") 

1566 except Exception as e: 

1567 # Redis failure is non-fatal - single worker mode still works 

1568 logger.debug(f"Failed to register pool session owner in Redis: {e}") 

1569 

1570 async def _get_pool_session_owner(self, mcp_session_id: str) -> Optional[str]: 

1571 """Get the worker ID that owns a pool session. 

1572 

1573 Args: 

1574 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1575 

1576 Returns: 

1577 The worker ID that owns this session, or None if not found or Redis unavailable. 

1578 """ 

1579 if not settings.mcpgateway_session_affinity_enabled: 

1580 return None 

1581 

1582 if not self.is_valid_mcp_session_id(mcp_session_id): 

1583 return None 

1584 

1585 try: 

1586 # First-Party 

1587 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1588 

1589 redis = await get_redis_client() 

1590 if redis: 

1591 key = self._pool_owner_key(mcp_session_id) 

1592 owner = await redis.get(key) 

1593 if owner: 

1594 decoded = owner.decode() if isinstance(owner, bytes) else owner 

1595 return decoded 

1596 except Exception as e: 

1597 logger.debug(f"Failed to get pool session owner from Redis: {e}") 

1598 return None 

1599 

1600 async def forward_request_to_owner( 

1601 self, 

1602 mcp_session_id: str, 

1603 request_data: Dict[str, Any], 

1604 timeout: Optional[float] = None, 

1605 ) -> Optional[Dict[str, Any]]: 

1606 """Forward RPC request to the worker that owns the pool session. 

1607 

1608 This method checks Redis to find which worker owns the pool session for 

1609 the given mcp_session_id. If owned by another worker, it forwards the 

1610 request via Redis pub/sub and waits for the response. 

1611 

1612 Args: 

1613 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1614 request_data: The RPC request data to forward. 

1615 timeout: Optional timeout in seconds (default from config). 

1616 

1617 Returns: 

1618 The response from the owner worker, or None if we own the session 

1619 (caller should execute locally) or if Redis is unavailable. 

1620 

1621 Raises: 

1622 asyncio.TimeoutError: If the forwarded request times out. 

1623 """ 

1624 if not settings.mcpgateway_session_affinity_enabled: 

1625 return None 

1626 

1627 if not self.is_valid_mcp_session_id(mcp_session_id): 

1628 return None 

1629 

1630 effective_timeout = timeout if timeout is not None else settings.mcpgateway_pool_rpc_forward_timeout 

1631 

1632 try: 

1633 # First-Party 

1634 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1635 

1636 redis = await get_redis_client() 

1637 if not redis: 

1638 return None # Execute locally - no Redis 

1639 

1640 # Check who owns this session 

1641 owner = await redis.get(self._pool_owner_key(mcp_session_id)) 

1642 method = request_data.get("method", "unknown") 

1643 if not owner: 

1644 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | No owner → execute locally (new session)") 

1645 return None # No owner registered - execute locally (new session) 

1646 

1647 owner_id = owner.decode() if isinstance(owner, bytes) else owner 

1648 if owner_id == WORKER_ID: 

1649 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | We own it → execute locally") 

1650 return None # We own it - execute locally 

1651 

1652 if not await self._is_worker_alive(owner_id): 

1653 logger.warning(f"[AFFINITY] Owner {owner_id} is dead for session {mcp_session_id[:8]}...") 

1654 # CAS: reclaim only if still owned by the dead worker 

1655 cas_script = """ 

1656 local cur = redis.call('GET', KEYS[1]) 

1657 if cur == ARGV[1] then 

1658 redis.call('SET', KEYS[1], ARGV[2], 'EX', ARGV[3]) 

1659 return 1 

1660 end 

1661 return 0 

1662 """ 

1663 ttl = int(settings.mcpgateway_session_affinity_ttl) 

1664 reclaimed = await redis.eval(cas_script, 1, self._pool_owner_key(mcp_session_id), owner_id, WORKER_ID, ttl) 

1665 if reclaimed == 1: 

1666 logger.info(f"[AFFINITY] Reclaimed session {mcp_session_id[:8]}... from dead worker {owner_id} → execute locally") 

1667 return None # We won the reclaim - execute locally 

1668 # Another worker already reclaimed; re-read the new owner and forward 

1669 new_owner = await redis.get(self._pool_owner_key(mcp_session_id)) 

1670 if not new_owner: 

1671 return None # Key vanished - execute locally 

1672 owner_id = new_owner.decode() if isinstance(new_owner, bytes) else new_owner 

1673 if owner_id == WORKER_ID: 

1674 return None # We ended up as owner 

1675 logger.info(f"[AFFINITY] Session {mcp_session_id[:8]}... reclaimed by {owner_id} → forwarding to new owner") 

1676 

1677 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Owner: {owner_id} → forwarding") 

1678 

1679 # Forward to owner worker via pub/sub 

1680 response_id = str(uuid.uuid4()) 

1681 response_channel = f"mcpgw:pool_rpc_response:{response_id}" 

1682 

1683 # Subscribe to response channel 

1684 pubsub = redis.pubsub() 

1685 await pubsub.subscribe(response_channel) 

1686 

1687 try: 

1688 # Prepare request with response channel 

1689 forward_data = { 

1690 "type": "rpc_forward", 

1691 **request_data, 

1692 "response_channel": response_channel, 

1693 "mcp_session_id": mcp_session_id, 

1694 } 

1695 

1696 # Publish request to owner's channel 

1697 await redis.publish(f"mcpgw:pool_rpc:{owner_id}", orjson.dumps(forward_data)) 

1698 self._forwarded_requests += 1 

1699 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Published to worker {owner_id}") 

1700 

1701 # Wait for response 

1702 async with asyncio.timeout(effective_timeout): 

1703 async for msg in pubsub.listen(): 

1704 if msg["type"] == "message": 

1705 return orjson.loads(msg["data"]) 

1706 finally: 

1707 await pubsub.unsubscribe(response_channel) 

1708 

1709 except asyncio.TimeoutError: 

1710 self._forwarded_request_timeouts += 1 

1711 logger.warning(f"Timeout forwarding request to owner for session {mcp_session_id[:8]}...") 

1712 raise 

1713 except Exception as e: 

1714 self._forwarded_request_failures += 1 

1715 logger.debug(f"Error forwarding request to owner: {e}") 

1716 return None # Execute locally on error 

1717 

1718 async def start_rpc_listener(self) -> None: 

1719 """Start listening for forwarded RPC and HTTP requests on this worker's channels. 

1720 

1721 This method subscribes to Redis pub/sub channels specific to this worker 

1722 and processes incoming forwarded requests from other workers: 

1723 - mcpgw:pool_rpc:{WORKER_ID} - for SSE transport JSON-RPC forwards 

1724 - mcpgw:pool_http:{WORKER_ID} - for Streamable HTTP request forwards 

1725 """ 

1726 if not settings.mcpgateway_session_affinity_enabled: 

1727 return 

1728 

1729 try: 

1730 # First-Party 

1731 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1732 

1733 redis = await get_redis_client() 

1734 if not redis: 

1735 logger.debug("Redis not available, RPC listener not started") 

1736 return 

1737 

1738 rpc_channel = f"mcpgw:pool_rpc:{WORKER_ID}" 

1739 http_channel = f"mcpgw:pool_http:{WORKER_ID}" 

1740 pubsub = redis.pubsub() 

1741 await pubsub.subscribe(rpc_channel, http_channel) 

1742 logger.info(f"RPC/HTTP listener started for worker {WORKER_ID} on channels: {rpc_channel}, {http_channel}") 

1743 

1744 try: 

1745 while not self._closed: 

1746 try: 

1747 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) 

1748 if msg and msg["type"] == "message": 

1749 request = orjson.loads(msg["data"]) 

1750 forward_type = request.get("type") 

1751 response_channel = request.get("response_channel") 

1752 

1753 if response_channel: 

1754 if forward_type == "rpc_forward": 

1755 # Execute forwarded RPC request for SSE transport 

1756 response = await self._execute_forwarded_request(request) 

1757 await redis.publish(response_channel, orjson.dumps(response)) 

1758 logger.debug(f"Processed forwarded RPC request, response sent to {response_channel}") 

1759 elif forward_type == "http_forward": 

1760 # Execute forwarded HTTP request for Streamable HTTP transport 

1761 await self._execute_forwarded_http_request(request, redis) 

1762 else: 

1763 logger.warning(f"Unknown forward type: {forward_type}") 

1764 except Exception as e: 

1765 logger.warning(f"Error processing forwarded request: {e}") 

1766 finally: 

1767 await pubsub.unsubscribe(rpc_channel, http_channel) 

1768 logger.info(f"RPC/HTTP listener stopped for worker {WORKER_ID}") 

1769 

1770 except Exception as e: 

1771 logger.warning(f"RPC/HTTP listener failed: {e}") 

1772 

1773 async def _execute_forwarded_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 

1774 """Execute a forwarded RPC request locally via internal HTTP call. 

1775 

1776 This method handles RPC requests that were forwarded from another worker. 

1777 Instead of handling specific methods here, we make an internal HTTP call 

1778 to the local /rpc endpoint which reuses ALL existing method handling logic. 

1779 

1780 The x-forwarded-internally header prevents infinite forwarding loops. 

1781 

1782 Args: 

1783 request: The forwarded RPC request containing method, params, headers, req_id, etc. 

1784 

1785 Returns: 

1786 The JSON-RPC response from the local endpoint. 

1787 """ 

1788 try: 

1789 method = request.get("method") 

1790 params = request.get("params", {}) 

1791 headers = request.get("headers", {}) 

1792 req_id = request.get("req_id", 1) 

1793 mcp_session_id = request.get("mcp_session_id", "unknown") 

1794 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

1795 

1796 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Received forwarded request, executing locally") 

1797 

1798 # Make internal HTTP/HTTPS call to local /rpc endpoint. 

1799 # This reuses ALL existing method handling logic without duplication. 

1800 internal_base_url = internal_loopback_base_url() 

1801 async with httpx.AsyncClient(verify=internal_loopback_verify()) as client: 

1802 # Build headers for internal request - forward original headers 

1803 # but add x-forwarded-internally to prevent infinite loops. 

1804 # Relies on the originating transport having already filtered 

1805 # passthrough headers via extract_headers_for_loopback (#3640). 

1806 internal_headers = dict(headers) 

1807 internal_headers["x-forwarded-internally"] = "true" 

1808 # Ensure content-type is set 

1809 internal_headers["content-type"] = "application/json" 

1810 

1811 response = await client.post( 

1812 f"{internal_base_url}/rpc", 

1813 json={"jsonrpc": "2.0", "method": method, "params": params, "id": req_id}, 

1814 headers=internal_headers, 

1815 timeout=settings.mcpgateway_pool_rpc_forward_timeout, 

1816 ) 

1817 

1818 # Gate on HTTP status first: non-2xx responses are errors 

1819 # even if the body parses as JSON. 

1820 if not response.is_success: 

1821 try: 

1822 response_data = response.json() 

1823 except ValueError: 

1824 response_data = {} 

1825 if not isinstance(response_data, dict): 

1826 response_data = {} 

1827 

1828 # If body is a JSON-RPC error ({"error": {...}}), propagate it 

1829 if "error" in response_data and isinstance(response_data["error"], dict): 

1830 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error (HTTP {response.status_code})") 

1831 return {"error": response_data["error"]} 

1832 

1833 # Non-JSON-RPC error body (e.g. {"detail": "..."}): map to JSON-RPC error 

1834 detail = response_data.get("detail", response.text[:200] or "Unknown error") 

1835 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution failed with HTTP {response.status_code}") 

1836 return {"error": {"code": -32603, "message": f"Forwarded request failed (HTTP {response.status_code}): {detail}"}} 

1837 

1838 # Parse successful response 

1839 response_data = response.json() 

1840 

1841 # Extract result or error from JSON-RPC response 

1842 if "error" in response_data: 

1843 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error") 

1844 return {"error": response_data["error"]} 

1845 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed successfully") 

1846 return {"result": response_data.get("result", {})} 

1847 

1848 except httpx.TimeoutException: 

1849 logger.warning(f"Timeout executing forwarded request: {request.get('method')}") 

1850 return {"error": {"code": -32603, "message": "Internal request timeout"}} 

1851 except Exception as e: 

1852 logger.warning(f"Error executing forwarded request: {e}") 

1853 return {"error": {"code": -32603, "message": str(e)}} 

1854 

1855 async def _execute_forwarded_http_request(self, request: Dict[str, Any], redis: Any) -> None: 

1856 """Execute a forwarded HTTP request locally and return response via Redis. 

1857 

1858 This method handles full HTTP requests forwarded from other workers for 

1859 Streamable HTTP transport session affinity. It reconstructs the HTTP request, 

1860 makes an internal call to the appropriate endpoint, and publishes the response 

1861 back through Redis. 

1862 

1863 Args: 

1864 request: Serialized HTTP request data from Redis Pub/Sub containing: 

1865 - type: "http_forward" 

1866 - response_channel: Redis channel to publish response to 

1867 - mcp_session_id: Session identifier 

1868 - method: HTTP method (GET, POST, DELETE) 

1869 - path: Request path (e.g., /mcp) 

1870 - query_string: Query parameters 

1871 - headers: Request headers dict 

1872 - body: Hex-encoded request body 

1873 redis: Redis client for publishing response 

1874 """ 

1875 response_channel = None 

1876 try: 

1877 response_channel = request.get("response_channel") 

1878 method = request.get("method") 

1879 path = request.get("path") 

1880 query_string = request.get("query_string", "") 

1881 headers = request.get("headers", {}) 

1882 body_hex = request.get("body", "") 

1883 mcp_session_id = request.get("mcp_session_id") 

1884 

1885 # Decode hex body back to bytes 

1886 body = bytes.fromhex(body_hex) if body_hex else b"" 

1887 

1888 session_short = mcp_session_id[:8] if mcp_session_id and len(mcp_session_id) >= 8 else "unknown" 

1889 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Received forwarded HTTP request: {method} {path}") 

1890 

1891 # Add internal forwarding headers to prevent loops. 

1892 # Relies on the originating transport having already filtered 

1893 # passthrough headers via extract_headers_for_loopback (#3640). 

1894 internal_headers = dict(headers) 

1895 internal_headers["x-forwarded-internally"] = "true" 

1896 internal_headers["x-original-worker"] = request.get("original_worker", "unknown") 

1897 

1898 # Make internal HTTP/HTTPS request to local endpoint 

1899 url = f"{internal_loopback_base_url()}{path}" 

1900 if query_string: 

1901 url = f"{url}?{query_string}" 

1902 

1903 async with httpx.AsyncClient(verify=internal_loopback_verify()) as client: 

1904 response = await client.request( 

1905 method=method, 

1906 url=url, 

1907 headers=internal_headers, 

1908 content=body, 

1909 timeout=settings.mcpgateway_pool_rpc_forward_timeout, 

1910 ) 

1911 

1912 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Executed locally: {response.status_code}") 

1913 

1914 # Serialize response for Redis transport 

1915 response_data = { 

1916 "status": response.status_code, 

1917 "headers": dict(response.headers), 

1918 "body": response.content.hex(), # Hex encode binary response 

1919 } 

1920 

1921 # Publish response back to requesting worker 

1922 if redis and response_channel: 

1923 await redis.publish(response_channel, orjson.dumps(response_data)) 

1924 logger.debug(f"[HTTP_AFFINITY] Published HTTP response to Redis channel: {response_channel}") 

1925 

1926 except Exception as e: 

1927 logger.error(f"Error executing forwarded HTTP request: {e}") 

1928 # Try to send error response if possible 

1929 if redis and response_channel: 

1930 error_response = { 

1931 "status": 500, 

1932 "headers": {"content-type": "application/json"}, 

1933 "body": orjson.dumps({"error": "Internal forwarding error"}).hex(), 

1934 } 

1935 try: 

1936 await redis.publish(response_channel, orjson.dumps(error_response)) 

1937 except Exception as publish_error: 

1938 logger.debug(f"Failed to publish error response via Redis: {publish_error}") 

1939 

1940 async def get_streamable_http_session_owner(self, mcp_session_id: str) -> Optional[str]: 

1941 """Get the worker ID that owns a Streamable HTTP session. 

1942 

1943 This is a public wrapper around _get_pool_session_owner for use by 

1944 streamablehttp_transport to check session ownership before handling requests. 

1945 

1946 Args: 

1947 mcp_session_id: The MCP session ID from mcp-session-id header. 

1948 

1949 Returns: 

1950 Worker ID if found, None otherwise. 

1951 """ 

1952 return await self._get_pool_session_owner(mcp_session_id) 

1953 

1954 async def forward_streamable_http_to_owner( 

1955 self, 

1956 owner_worker_id: str, 

1957 mcp_session_id: str, 

1958 method: str, 

1959 path: str, 

1960 headers: Dict[str, str], 

1961 body: bytes, 

1962 query_string: str = "", 

1963 ) -> Optional[Dict[str, Any]]: 

1964 """Forward a Streamable HTTP request to the worker that owns the session via Redis Pub/Sub. 

1965 

1966 This method forwards the entire HTTP request to another worker using Redis 

1967 Pub/Sub channels, similar to forward_request_to_owner() for SSE transport. 

1968 This ensures session affinity works correctly in single-host multi-worker 

1969 deployments where hostname-based routing fails. 

1970 

1971 Args: 

1972 owner_worker_id: The worker ID that owns the session. 

1973 mcp_session_id: The MCP session ID. 

1974 method: HTTP method (GET, POST, DELETE). 

1975 path: Request path (e.g., /mcp). 

1976 headers: Request headers. 

1977 body: Request body bytes. 

1978 query_string: Query string if any. 

1979 

1980 Returns: 

1981 Dict with 'status', 'headers', and 'body' from the owner worker's response, 

1982 or None if forwarding fails. 

1983 """ 

1984 if not settings.mcpgateway_session_affinity_enabled: 

1985 return None 

1986 

1987 if not self.is_valid_mcp_session_id(mcp_session_id): 

1988 return None 

1989 

1990 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

1991 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | {method} {path} | Forwarding to worker {owner_worker_id}") 

1992 

1993 try: 

1994 # First-Party 

1995 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1996 

1997 redis = await get_redis_client() 

1998 if not redis: 

1999 logger.warning("Redis unavailable for HTTP forwarding, executing locally") 

2000 return None # Fall back to local execution 

2001 

2002 # Generate unique response channel for this request 

2003 response_uuid = uuid.uuid4().hex 

2004 response_channel = f"mcpgw:pool_http_response:{response_uuid}" 

2005 

2006 # Serialize HTTP request for Redis transport 

2007 forward_data = { 

2008 "type": "http_forward", 

2009 "response_channel": response_channel, 

2010 "mcp_session_id": mcp_session_id, 

2011 "method": method, 

2012 "path": path, 

2013 "query_string": query_string, 

2014 "headers": headers, 

2015 "body": body.hex() if body else "", # Hex encode binary body 

2016 "original_worker": WORKER_ID, 

2017 "timestamp": time.time(), 

2018 } 

2019 

2020 # Subscribe to response channel BEFORE publishing request (prevent race) 

2021 pubsub = redis.pubsub() 

2022 await pubsub.subscribe(response_channel) 

2023 

2024 try: 

2025 # Publish forwarded request to owner worker's HTTP channel 

2026 owner_channel = f"mcpgw:pool_http:{owner_worker_id}" 

2027 await redis.publish(owner_channel, orjson.dumps(forward_data)) 

2028 logger.debug(f"[HTTP_AFFINITY] Published HTTP request to Redis channel: {owner_channel}") 

2029 

2030 # Wait for response with timeout 

2031 timeout = settings.mcpgateway_pool_rpc_forward_timeout 

2032 async with asyncio.timeout(timeout): 

2033 while True: 

2034 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) 

2035 if msg and msg["type"] == "message": 

2036 response_data = orjson.loads(msg["data"]) 

2037 logger.debug(f"[HTTP_AFFINITY] Received HTTP response via Redis: status={response_data.get('status')}") 

2038 

2039 # Decode hex body back to bytes 

2040 body_hex = response_data.get("body", "") 

2041 response_data["body"] = bytes.fromhex(body_hex) if body_hex else b"" 

2042 

2043 self._forwarded_requests += 1 

2044 return response_data 

2045 

2046 finally: 

2047 await pubsub.unsubscribe(response_channel) 

2048 

2049 except asyncio.TimeoutError: 

2050 self._forwarded_request_timeouts += 1 

2051 logger.warning(f"Timeout forwarding HTTP request to owner {owner_worker_id}") 

2052 return None 

2053 except Exception as e: 

2054 self._forwarded_request_failures += 1 

2055 logger.warning(f"Error forwarding HTTP request via Redis: {e}") 

2056 return None 

2057 

2058 def get_metrics(self) -> Dict[str, Any]: 

2059 """ 

2060 Return pool metrics for monitoring. 

2061 

2062 Returns: 

2063 Dict with hits, misses, evictions, hit_rate, and per-pool stats. 

2064 """ 

2065 total_requests = self._hits + self._misses 

2066 total_affinity_requests = self._session_affinity_local_hits + self._session_affinity_redis_hits + self._session_affinity_misses 

2067 return { 

2068 "hits": self._hits, 

2069 "misses": self._misses, 

2070 "evictions": self._evictions, 

2071 "health_check_failures": self._health_check_failures, 

2072 "circuit_breaker_trips": self._circuit_breaker_trips, 

2073 "pool_keys_evicted": self._pool_keys_evicted, 

2074 "sessions_reaped": self._sessions_reaped, 

2075 "anonymous_identity_count": self._anonymous_identity_count, 

2076 "hit_rate": self._hits / total_requests if total_requests > 0 else 0.0, 

2077 "pool_key_count": len(self._pools), 

2078 # Session affinity metrics 

2079 "session_affinity": { 

2080 "local_hits": self._session_affinity_local_hits, 

2081 "redis_hits": self._session_affinity_redis_hits, 

2082 "misses": self._session_affinity_misses, 

2083 "hit_rate": (self._session_affinity_local_hits + self._session_affinity_redis_hits) / total_affinity_requests if total_affinity_requests > 0 else 0.0, 

2084 "forwarded_requests": self._forwarded_requests, 

2085 "forwarded_failures": self._forwarded_request_failures, 

2086 "forwarded_timeouts": self._forwarded_request_timeouts, 

2087 }, 

2088 "pools": { 

2089 f"{url}|{identity[:8]}|{transport}|{user}|{gw_id[:8] if gw_id else 'none'}": { 

2090 "available": pool.qsize(), 

2091 "active": len(self._active.get((user, url, identity, transport, gw_id), set())), 

2092 "max": self._max_sessions, 

2093 } 

2094 for (user, url, identity, transport, gw_id), pool in self._pools.items() 

2095 }, 

2096 "circuit_breakers": { 

2097 url: { 

2098 "failures": self._failures.get(url, 0), 

2099 "open_until": self._circuit_open_until.get(url), 

2100 } 

2101 for url in set(self._failures.keys()) | set(self._circuit_open_until.keys()) 

2102 }, 

2103 } 

2104 

2105 @asynccontextmanager 

2106 async def session( 

2107 self, 

2108 url: str, 

2109 headers: Optional[Dict[str, str]] = None, 

2110 transport_type: TransportType = TransportType.STREAMABLE_HTTP, 

2111 httpx_client_factory: Optional[HttpxClientFactory] = None, 

2112 timeout: Optional[float] = None, 

2113 user_identity: Optional[str] = None, 

2114 gateway_id: Optional[str] = None, 

2115 ) -> "AsyncIterator[PooledSession]": 

2116 """ 

2117 Context manager for acquiring and releasing a session. 

2118 

2119 Usage: 

2120 async with pool.session(url, headers) as pooled: 

2121 result = await pooled.session.call_tool("my_tool", {}) 

2122 

2123 Args: 

2124 url: The MCP server URL. 

2125 headers: Request headers. 

2126 transport_type: Transport type to use. 

2127 httpx_client_factory: Optional factory for httpx clients. 

2128 timeout: Optional timeout in seconds for transport connection. 

2129 user_identity: Optional user identity for strict isolation. 

2130 gateway_id: Optional gateway ID for notification handler context. 

2131 

2132 Yields: 

2133 PooledSession ready for use. 

2134 """ 

2135 pooled = await self.acquire(url, headers, transport_type, httpx_client_factory, timeout, user_identity, gateway_id) 

2136 failed = False 

2137 try: 

2138 yield pooled 

2139 except BaseException: 

2140 # Session encountered an error (e.g. ClosedResourceError) — evict it 

2141 # instead of returning a broken session to the pool. 

2142 failed = True 

2143 raise 

2144 finally: 

2145 await self.release(pooled, discard=failed) 

2146 

2147 

2148# Global pool instance - initialized by FastAPI lifespan 

2149_mcp_session_pool: Optional[MCPSessionPool] = None 

2150 

2151 

2152def get_mcp_session_pool() -> MCPSessionPool: 

2153 """Get the global MCP session pool instance. 

2154 

2155 Returns: 

2156 The global MCPSessionPool instance. 

2157 

2158 Raises: 

2159 RuntimeError: If pool has not been initialized. 

2160 """ 

2161 if _mcp_session_pool is None: 

2162 raise RuntimeError("MCP session pool not initialized. Call init_mcp_session_pool() first.") 

2163 return _mcp_session_pool 

2164 

2165 

2166def init_mcp_session_pool( 

2167 max_sessions_per_key: int = 10, 

2168 session_ttl_seconds: float = 300.0, 

2169 health_check_interval_seconds: float = 60.0, 

2170 acquire_timeout_seconds: float = 30.0, 

2171 session_create_timeout_seconds: float = 30.0, 

2172 circuit_breaker_threshold: int = 5, 

2173 circuit_breaker_reset_seconds: float = 60.0, 

2174 identity_headers: Optional[frozenset[str]] = None, 

2175 identity_extractor: Optional[IdentityExtractor] = None, 

2176 idle_pool_eviction_seconds: float = 600.0, 

2177 default_transport_timeout_seconds: float = 30.0, 

2178 health_check_methods: Optional[list[str]] = None, 

2179 health_check_timeout_seconds: float = 5.0, 

2180 message_handler_factory: Optional[MessageHandlerFactory] = None, 

2181 enable_notifications: bool = True, 

2182 notification_debounce_seconds: float = 5.0, 

2183) -> MCPSessionPool: 

2184 """Initialize the global MCP session pool. 

2185 

2186 Args: 

2187 See MCPSessionPool.__init__ for argument descriptions. 

2188 enable_notifications: Enable automatic notification service for list_changed events. 

2189 notification_debounce_seconds: Debounce interval for notification-triggered refreshes. 

2190 

2191 Returns: 

2192 The initialized MCPSessionPool instance. 

2193 """ 

2194 global _mcp_session_pool # pylint: disable=global-statement 

2195 

2196 # Auto-create notification service if enabled and no custom handler provided 

2197 effective_handler_factory = message_handler_factory 

2198 if enable_notifications and message_handler_factory is None: 

2199 # First-Party 

2200 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2201 init_notification_service, 

2202 ) 

2203 

2204 # Initialize notification service (will be started during acquire with gateway context) 

2205 notification_svc = init_notification_service(debounce_seconds=notification_debounce_seconds) 

2206 

2207 # Create default handler factory that uses notification service 

2208 def default_handler_factory(url: str, gateway_id: Optional[str]): 

2209 """Create a message handler for MCP session notifications. 

2210 

2211 Args: 

2212 url: The MCP server URL for the session. 

2213 gateway_id: Optional gateway ID for attribution, falls back to URL if not provided. 

2214 

2215 Returns: 

2216 A message handler that forwards notifications to the notification service. 

2217 """ 

2218 return notification_svc.create_message_handler(gateway_id or url, url) 

2219 

2220 effective_handler_factory = default_handler_factory 

2221 logger.info("MCP notification service created (debounce=%ss)", notification_debounce_seconds) 

2222 

2223 _mcp_session_pool = MCPSessionPool( 

2224 max_sessions_per_key=max_sessions_per_key, 

2225 session_ttl_seconds=session_ttl_seconds, 

2226 health_check_interval_seconds=health_check_interval_seconds, 

2227 acquire_timeout_seconds=acquire_timeout_seconds, 

2228 session_create_timeout_seconds=session_create_timeout_seconds, 

2229 circuit_breaker_threshold=circuit_breaker_threshold, 

2230 circuit_breaker_reset_seconds=circuit_breaker_reset_seconds, 

2231 identity_headers=identity_headers, 

2232 identity_extractor=identity_extractor, 

2233 idle_pool_eviction_seconds=idle_pool_eviction_seconds, 

2234 default_transport_timeout_seconds=default_transport_timeout_seconds, 

2235 health_check_methods=health_check_methods, 

2236 health_check_timeout_seconds=health_check_timeout_seconds, 

2237 message_handler_factory=effective_handler_factory, 

2238 ) 

2239 logger.info("MCP session pool initialized") 

2240 return _mcp_session_pool 

2241 

2242 

2243async def close_mcp_session_pool() -> None: 

2244 """Close the global MCP session pool and notification service.""" 

2245 global _mcp_session_pool # pylint: disable=global-statement 

2246 if _mcp_session_pool is not None: 

2247 await _mcp_session_pool.close_all() 

2248 _mcp_session_pool = None 

2249 logger.info("MCP session pool closed") 

2250 

2251 # Close notification service if it was initialized 

2252 try: 

2253 # First-Party 

2254 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2255 close_notification_service, 

2256 ) 

2257 

2258 await close_notification_service() 

2259 except (ImportError, RuntimeError): 

2260 pass # Notification service not initialized 

2261 

2262 

2263async def drain_mcp_session_pool() -> None: 

2264 """Drain all sessions from the global pool without destroying the pool. 

2265 

2266 Sessions are closed so new ones reconnect with fresh TLS state. 

2267 The pool remains operational — unlike ``close_mcp_session_pool()``, 

2268 which shuts it down permanently. 

2269 """ 

2270 if _mcp_session_pool is not None: 

2271 await _mcp_session_pool.drain_all() 

2272 

2273 

2274async def start_pool_notification_service(gateway_service: Any = None) -> None: 

2275 """Start the notification service background worker. 

2276 

2277 Call this after gateway_service is initialized to enable event-driven refresh. 

2278 

2279 Args: 

2280 gateway_service: Optional GatewayService instance for triggering refreshes. 

2281 """ 

2282 try: 

2283 # First-Party 

2284 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2285 get_notification_service, 

2286 ) 

2287 

2288 notification_svc = get_notification_service() 

2289 await notification_svc.initialize(gateway_service) 

2290 logger.info("MCP notification service started") 

2291 except RuntimeError: 

2292 logger.debug("Notification service not configured, skipping start") 

2293 

2294 

2295def register_gateway_capabilities_for_notifications(gateway_id: str, capabilities: Dict[str, Any]) -> None: 

2296 """Register gateway capabilities for notification handling. 

2297 

2298 Call this after gateway initialization to enable list_changed notifications. 

2299 

2300 Args: 

2301 gateway_id: The gateway ID. 

2302 capabilities: Server capabilities from initialization response. 

2303 """ 

2304 try: 

2305 # First-Party 

2306 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2307 get_notification_service, 

2308 ) 

2309 

2310 notification_svc = get_notification_service() 

2311 notification_svc.register_gateway_capabilities(gateway_id, capabilities) 

2312 except RuntimeError: 

2313 pass # Notification service not initialized 

2314 

2315 

2316def unregister_gateway_from_notifications(gateway_id: str) -> None: 

2317 """Unregister a gateway from notification handling. 

2318 

2319 Call this when a gateway is deleted. 

2320 

2321 Args: 

2322 gateway_id: The gateway ID to unregister. 

2323 """ 

2324 try: 

2325 # First-Party 

2326 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2327 get_notification_service, 

2328 ) 

2329 

2330 notification_svc = get_notification_service() 

2331 notification_svc.unregister_gateway(gateway_id) 

2332 except RuntimeError: 

2333 pass # Notification service not initialized