Coverage for mcpgateway / services / mcp_session_pool.py: 100%

842 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2""" 

3MCP Session Pool Implementation. 

4 

5Provides session pooling for MCP ClientSessions to reduce per-request overhead. 

6Sessions are isolated per user/tenant via identity hashing to prevent session collision. 

7 

8Performance Impact: 

9 - Without pooling: 20-23ms per tool call (new session each time) 

10 - With pooling: 1-2ms per tool call (10-20x improvement) 

11 

12Security: 

13 Sessions are isolated by (url, identity_hash, transport_type) to prevent: 

14 - Cross-user session sharing 

15 - Cross-tenant data leakage 

16 - Authentication bypass 

17 

18Copyright 2026 

19SPDX-License-Identifier: Apache-2.0 

20Authors: Mihai Criveti 

21""" 

22 

23# flake8: noqa: DAR101, DAR201, DAR401 

24 

25# Future 

26from __future__ import annotations 

27 

28# Standard 

29import asyncio 

30from contextlib import asynccontextmanager 

31from dataclasses import dataclass, field 

32from enum import Enum 

33import hashlib 

34import logging 

35import os 

36import re 

37import socket 

38import time 

39from typing import Any, Callable, Dict, Optional, Set, Tuple, TYPE_CHECKING 

40import uuid 

41 

42# Third-Party 

43import anyio 

44import httpx 

45from mcp import ClientSession, McpError 

46from mcp.client.sse import sse_client 

47from mcp.client.streamable_http import streamablehttp_client 

48from mcp.shared.session import RequestResponder 

49import mcp.types as mcp_types 

50import orjson 

51 

52# First-Party 

53from mcpgateway.config import settings 

54from mcpgateway.utils.url_auth import sanitize_url_for_logging 

55 

56# JSON-RPC standard error code for method not found 

57METHOD_NOT_FOUND = -32601 

58 

59# Shared session-id validation (downstream MCP session IDs used for affinity). 

60# Intentionally strict: protects Redis key/channel construction and log lines. 

61_MCP_SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") 

62 

63# Worker ID for multi-worker session affinity 

64# Uses hostname + PID to be unique across Docker containers (each container has PID 1) 

65# and across gunicorn workers within the same container 

66WORKER_ID = f"{socket.gethostname()}:{os.getpid()}" 

67 

68 

69def _get_cleanup_timeout() -> float: 

70 """Get session cleanup timeout from config (lazy import to avoid circular deps). 

71 

72 This timeout controls how long to wait for session/transport __aexit__ calls 

73 when closing sessions. It prevents CPU spin loops when internal tasks don't 

74 respond to cancellation (anyio's _deliver_cancellation issue). 

75 

76 Returns: 

77 Cleanup timeout in seconds (default: 5.0) 

78 """ 

79 try: 

80 # Lazy import to avoid circular dependency during startup 

81 return settings.mcp_session_pool_cleanup_timeout 

82 except Exception: 

83 return 5.0 # Fallback default 

84 

85 

86if TYPE_CHECKING: 

87 # Standard 

88 from collections.abc import AsyncIterator # pragma: no cover 

89 

90logger = logging.getLogger(__name__) 

91 

92 

93class TransportType(Enum): 

94 """Supported MCP transport types.""" 

95 

96 SSE = "sse" 

97 STREAMABLE_HTTP = "streamablehttp" 

98 

99 

100@dataclass(eq=False) # eq=False makes instances hashable by object identity 

101class PooledSession: 

102 """A pooled MCP session with metadata for lifecycle management. 

103 

104 Note: eq=False is required because we store these in Sets for active session 

105 tracking. This makes instances hashable by their object id (identity). 

106 """ 

107 

108 session: ClientSession 

109 transport_context: Any # The transport context manager (kept open) 

110 url: str 

111 transport_type: TransportType 

112 headers: Dict[str, str] # Original headers (for reconnection) 

113 identity_key: str # Identity hash component for headers 

114 user_identity: str = "anonymous" # for user isolation 

115 gateway_id: str = "" # Gateway ID for notification attribution 

116 created_at: float = field(default_factory=time.time) 

117 last_used: float = field(default_factory=time.time) 

118 use_count: int = 0 

119 _closed: bool = field(default=False, repr=False) 

120 

121 @property 

122 def age_seconds(self) -> float: 

123 """Return session age in seconds. 

124 

125 Returns: 

126 float: Session age in seconds since creation. 

127 """ 

128 return time.time() - self.created_at 

129 

130 @property 

131 def idle_seconds(self) -> float: 

132 """Return seconds since last use. 

133 

134 Returns: 

135 float: Seconds since last use of this session. 

136 """ 

137 return time.time() - self.last_used 

138 

139 @property 

140 def is_closed(self) -> bool: 

141 """Return whether this session has been closed. 

142 

143 Returns: 

144 bool: True if session is closed, False otherwise. 

145 """ 

146 return self._closed 

147 

148 def mark_closed(self) -> None: 

149 """Mark this session as closed.""" 

150 self._closed = True 

151 

152 

153# Type aliases 

154# Pool key includes transport type and gateway_id to prevent returning wrong transport for same URL 

155# and to ensure correct notification attribution when notifications are enabled 

156PoolKey = Tuple[str, str, str, str, str] # (user_identity_hash, url, identity_hash, transport_type, gateway_id) 

157 

158# Session affinity mapping key: (mcp_session_id, url, transport_type, gateway_id) 

159SessionMappingKey = Tuple[str, str, str, str] 

160HttpxClientFactory = Callable[ 

161 [Optional[Dict[str, str]], Optional[httpx.Timeout], Optional[httpx.Auth]], 

162 httpx.AsyncClient, 

163] 

164 

165 

166# Type alias for identity extractor callback 

167# Extracts stable identity from headers (e.g., decode JWT to get user_id) 

168IdentityExtractor = Callable[[Dict[str, str]], Optional[str]] 

169 

170# Type alias for message handler factory 

171# Factory that creates message handlers given URL and optional gateway_id 

172# The handler receives ServerNotification, ServerRequest responders, or Exceptions 

173MessageHandlerFactory = Callable[ 

174 [str, Optional[str]], # (url, gateway_id) 

175 Callable[ 

176 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception], 

177 Any, # Coroutine 

178 ], 

179] 

180 

181 

182class MCPSessionPool: # pylint: disable=too-many-instance-attributes 

183 """ 

184 Pool of MCP ClientSessions keyed by (user_identity, server URL, identity hash, transport type, gateway_id). 

185 

186 Thread-Safety: 

187 This pool is designed for asyncio concurrency. It uses asyncio.Lock 

188 for synchronization, which is safe for coroutine-based concurrency 

189 but NOT for multi-threaded access. 

190 

191 Session Isolation: 

192 Sessions are isolated per user/tenant to prevent session collision. 

193 The identity hash is derived from authentication headers ensuring 

194 that different users never share MCP sessions. 

195 

196 Transport Isolation: 

197 Sessions are also isolated by transport type (SSE vs STREAMABLE_HTTP). 

198 The same URL with different transports will use separate pools. 

199 

200 Gateway Isolation: 

201 Sessions are isolated by gateway_id for correct notification attribution. 

202 When notifications are enabled, each gateway gets its own pooled sessions 

203 even if they share the same URL and authentication. 

204 

205 Features: 

206 - Session reuse across requests (10-20x latency improvement) 

207 - Per-user/tenant session isolation (prevents session collision) 

208 - Per-transport session isolation (prevents transport mismatch) 

209 - TTL-based expiration with configurable lifetime 

210 - Health checks on acquire for stale sessions 

211 - Configurable pool size per URL+identity+transport 

212 - Circuit breaker for failing endpoints 

213 - Idle pool key eviction to prevent unbounded growth 

214 - Custom identity extractor for rotating tokens (e.g., JWT decode) 

215 - Metrics for monitoring (hits, misses, evictions) 

216 - Graceful shutdown with close_all() 

217 

218 Usage: 

219 pool = MCPSessionPool() 

220 

221 # Use as context manager for lifecycle management 

222 async with pool: 

223 pooled = await pool.acquire(url, headers) 

224 try: 

225 result = await pooled.session.call_tool("my_tool", {}) 

226 finally: 

227 await pool.release(pooled) 

228 

229 # With custom identity extractor for JWT tokens: 

230 def extract_user_id(headers: dict) -> str: 

231 token = headers.get("Authorization", "").replace("Bearer ", "") 

232 claims = jwt.decode(token, options={"verify_signature": False}) 

233 return claims.get("sub") or claims.get("user_id") 

234 

235 pool = MCPSessionPool(identity_extractor=extract_user_id) 

236 """ 

237 

238 # Headers that contribute to session identity (case-insensitive) 

239 DEFAULT_IDENTITY_HEADERS: frozenset[str] = frozenset( 

240 [ 

241 "authorization", 

242 "x-tenant-id", 

243 "x-user-id", 

244 "x-api-key", 

245 "cookie", 

246 "x-mcp-session-id", 

247 ] 

248 ) 

249 

250 def __init__( 

251 self, 

252 max_sessions_per_key: int = 10, 

253 session_ttl_seconds: float = 300.0, 

254 health_check_interval_seconds: float = 60.0, 

255 acquire_timeout_seconds: float = 30.0, 

256 session_create_timeout_seconds: float = 30.0, 

257 circuit_breaker_threshold: int = 5, 

258 circuit_breaker_reset_seconds: float = 60.0, 

259 identity_headers: Optional[frozenset[str]] = None, 

260 identity_extractor: Optional[IdentityExtractor] = None, 

261 idle_pool_eviction_seconds: float = 600.0, 

262 default_transport_timeout_seconds: float = 30.0, 

263 health_check_methods: Optional[list[str]] = None, 

264 health_check_timeout_seconds: float = 5.0, 

265 message_handler_factory: Optional[MessageHandlerFactory] = None, 

266 ): 

267 """ 

268 Initialize the session pool. 

269 

270 Args: 

271 max_sessions_per_key: Maximum pooled sessions per (URL, identity, transport). 

272 session_ttl_seconds: Session TTL in seconds before forced expiration. 

273 health_check_interval_seconds: Seconds of idle time before health check. 

274 acquire_timeout_seconds: Timeout for waiting when pool is exhausted. 

275 session_create_timeout_seconds: Timeout for creating new sessions. 

276 circuit_breaker_threshold: Consecutive failures before circuit opens. 

277 circuit_breaker_reset_seconds: Seconds before circuit breaker resets. 

278 identity_headers: Headers that contribute to identity hash. 

279 identity_extractor: Optional callback to extract stable identity from headers. 

280 Use this when tokens rotate frequently (e.g., short-lived JWTs). 

281 Should return a stable user/tenant ID string. 

282 idle_pool_eviction_seconds: Evict empty pool keys after this many seconds of no use. 

283 default_transport_timeout_seconds: Default timeout for transport connections. 

284 health_check_methods: Ordered list of health check methods to try. 

285 Options: ping, list_tools, list_prompts, list_resources, skip. 

286 Default: ["ping", "skip"] (try ping, skip if unsupported). 

287 health_check_timeout_seconds: Timeout for each health check attempt. 

288 message_handler_factory: Optional factory for creating message handlers. 

289 Called with (url, gateway_id) to create handlers for 

290 each new session. Enables notification handling. 

291 """ 

292 # Configuration 

293 self._max_sessions = max_sessions_per_key 

294 self._session_ttl = session_ttl_seconds 

295 self._health_check_interval = health_check_interval_seconds 

296 self._acquire_timeout = acquire_timeout_seconds 

297 self._session_create_timeout = session_create_timeout_seconds 

298 self._circuit_breaker_threshold = circuit_breaker_threshold 

299 self._circuit_breaker_reset = circuit_breaker_reset_seconds 

300 self._identity_headers = identity_headers or self.DEFAULT_IDENTITY_HEADERS 

301 self._identity_extractor = identity_extractor 

302 self._idle_pool_eviction = idle_pool_eviction_seconds 

303 self._default_transport_timeout = default_transport_timeout_seconds 

304 self._health_check_methods = health_check_methods or ["ping", "skip"] 

305 self._health_check_timeout = health_check_timeout_seconds 

306 self._message_handler_factory = message_handler_factory 

307 

308 # State - protected by _global_lock for creation, per-key locks for access 

309 self._global_lock = asyncio.Lock() 

310 self._pools: Dict[PoolKey, asyncio.Queue[PooledSession]] = {} 

311 self._active: Dict[PoolKey, Set[PooledSession]] = {} 

312 self._locks: Dict[PoolKey, asyncio.Lock] = {} 

313 self._semaphores: Dict[PoolKey, asyncio.Semaphore] = {} 

314 self._pool_last_used: Dict[PoolKey, float] = {} # Track last use time per pool key 

315 

316 # Circuit breaker state 

317 self._failures: Dict[str, int] = {} # url -> consecutive failure count 

318 self._circuit_open_until: Dict[str, float] = {} # url -> timestamp 

319 

320 # Eviction throttling - only run eviction once per interval 

321 self._last_eviction_run: float = 0.0 

322 self._eviction_run_interval: float = 60.0 # Run eviction at most every 60 seconds 

323 

324 # Metrics 

325 self._hits = 0 

326 self._misses = 0 

327 self._evictions = 0 

328 self._health_check_failures = 0 

329 self._circuit_breaker_trips = 0 

330 self._pool_keys_evicted = 0 

331 self._sessions_reaped = 0 # Sessions closed during background eviction 

332 self._anonymous_identity_count = 0 # Count of requests with no identity headers 

333 

334 # Lifecycle 

335 self._closed = False 

336 

337 # Pre-registered session mappings for session affinity 

338 # Mapping from (mcp_session_id, url, transport_type, gateway_id) -> pool_key 

339 # Set by broadcast() before acquire() is called to enable session affinity lookup 

340 self._mcp_session_mapping: Dict[SessionMappingKey, PoolKey] = {} 

341 self._mcp_session_mapping_lock = asyncio.Lock() 

342 

343 # Multi-worker session affinity via Redis pub/sub 

344 # Track pending responses for forwarded RPC requests 

345 self._rpc_listener_task: Optional[asyncio.Task[None]] = None 

346 self._pending_responses: Dict[str, asyncio.Future[Dict[str, Any]]] = {} 

347 

348 # Session affinity metrics 

349 self._session_affinity_local_hits = 0 

350 self._session_affinity_redis_hits = 0 

351 self._session_affinity_misses = 0 

352 self._forwarded_requests = 0 

353 self._forwarded_request_failures = 0 

354 self._forwarded_request_timeouts = 0 

355 

356 async def __aenter__(self) -> "MCPSessionPool": 

357 """Async context manager entry. 

358 

359 Returns: 

360 MCPSessionPool: This pool instance. 

361 """ 

362 return self 

363 

364 async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

365 """Async context manager exit - closes all sessions. 

366 

367 Args: 

368 exc_type: Exception type if an exception was raised. 

369 exc_val: Exception value if an exception was raised. 

370 exc_tb: Exception traceback if an exception was raised. 

371 """ 

372 await self.close_all() 

373 

374 def _compute_identity_hash(self, headers: Optional[Dict[str, str]]) -> str: 

375 """ 

376 Compute a hash of identity-relevant headers. 

377 

378 This ensures sessions are isolated per user/tenant. Different users 

379 with different Authorization headers will have different identity hashes 

380 and thus separate session pools. 

381 

382 Identity resolution order: 

383 1. Custom identity_extractor (if configured) - for rotating tokens like JWTs 

384 2. x-mcp-session-id header (if present) - for session affinity, ensures 

385 requests with the same downstream session ID get the same upstream 

386 session even when JWT tokens rotate (different jti values) 

387 3. Configured identity headers - fallback to hashing all identity headers 

388 

389 Args: 

390 headers: Request headers dict. 

391 

392 Returns: 

393 Identity hash string, or "anonymous" if no identity headers present. 

394 """ 

395 if not headers: 

396 self._anonymous_identity_count += 1 

397 logger.debug("Session pool identity collapsed to 'anonymous' (no headers provided). " + "Sessions will be shared. Ensure this is intentional for stateless MCP servers.") 

398 return "anonymous" 

399 

400 # Try custom identity extractor first (for rotating tokens like JWTs) 

401 if self._identity_extractor: 

402 try: 

403 extracted = self._identity_extractor(headers) 

404 if extracted: 

405 return hashlib.sha256(extracted.encode()).hexdigest() 

406 except Exception as e: 

407 logger.debug(f"Identity extractor failed, falling back to header hash: {e}") 

408 

409 # Normalize headers for case-insensitive lookup 

410 headers_lower = {k.lower(): v for k, v in headers.items()} 

411 

412 # Session affinity: prioritize x-mcp-session-id for stable identity 

413 # When present, use ONLY the session ID for identity hash. This ensures 

414 # requests with the same downstream session ID get the same upstream session, 

415 # even when JWT tokens rotate (different jti values per request). 

416 if settings.mcpgateway_session_affinity_enabled: 

417 session_id = headers_lower.get("x-mcp-session-id") 

418 if session_id: 

419 logger.debug(f"Using x-mcp-session-id for session affinity: {session_id[:8]}...") 

420 return hashlib.sha256(session_id.encode()).hexdigest() 

421 

422 # Fallback: extract identity from configured headers 

423 identity_parts = [] 

424 

425 for header in sorted(self._identity_headers): 

426 if header in headers_lower: 

427 identity_parts.append(f"{header}:{headers_lower[header]}") 

428 

429 if not identity_parts: 

430 self._anonymous_identity_count += 1 

431 logger.debug( 

432 "Session pool identity collapsed to 'anonymous' (no identity headers found). " + "Expected headers: %s. Sessions will be shared.", 

433 list(self._identity_headers), 

434 ) 

435 return "anonymous" 

436 

437 # Create a stable, deterministic hash using JSON serialization 

438 # Prevents delimiter-collision or injection issues present in string joining 

439 serialized_identity = orjson.dumps(identity_parts) 

440 return hashlib.sha256(serialized_identity).hexdigest() 

441 

442 def _make_pool_key( 

443 self, 

444 url: str, 

445 headers: Optional[Dict[str, str]], 

446 transport_type: TransportType, 

447 user_identity: str, 

448 gateway_id: Optional[str] = None, 

449 ) -> PoolKey: 

450 """Create composite pool key from URL, identity, transport type, user identity, and gateway_id. 

451 

452 Including gateway_id ensures correct notification attribution when multiple gateways 

453 share the same URL/auth. Sessions are isolated per gateway for proper event routing. 

454 """ 

455 identity_hash = self._compute_identity_hash(headers) 

456 

457 # Anonymize user identity by hashing it (unless it's commonly "anonymous") 

458 # Use full hash for collision resistance - truncate only for display in logs/metrics 

459 if user_identity == "anonymous": 

460 user_hash = "anonymous" 

461 else: 

462 user_hash = hashlib.sha256(user_identity.encode()).hexdigest() 

463 

464 # Use empty string for None gateway_id to maintain consistent key type 

465 gw_id = gateway_id or "" 

466 

467 return (user_hash, url, identity_hash, transport_type.value, gw_id) 

468 

469 async def _get_or_create_lock(self, pool_key: PoolKey) -> asyncio.Lock: 

470 """Get or create a lock for the given pool key (thread-safe).""" 

471 async with self._global_lock: 

472 if pool_key not in self._locks: 

473 self._locks[pool_key] = asyncio.Lock() 

474 return self._locks[pool_key] 

475 

476 async def _get_or_create_pool(self, pool_key: PoolKey) -> asyncio.Queue[PooledSession]: 

477 """Get or create a pool queue for the given key (thread-safe).""" 

478 async with self._global_lock: 

479 if pool_key not in self._pools: 

480 self._pools[pool_key] = asyncio.Queue(maxsize=self._max_sessions) 

481 self._active[pool_key] = set() 

482 self._semaphores[pool_key] = asyncio.Semaphore(self._max_sessions) 

483 return self._pools[pool_key] 

484 

485 def _is_circuit_open(self, url: str) -> bool: 

486 """Check if circuit breaker is open for a URL.""" 

487 if url not in self._circuit_open_until: 

488 return False 

489 if time.time() >= self._circuit_open_until[url]: 

490 # Circuit breaker reset 

491 del self._circuit_open_until[url] 

492 self._failures[url] = 0 

493 logger.info(f"Circuit breaker reset for {sanitize_url_for_logging(url)}") 

494 return False 

495 return True 

496 

497 def _record_failure(self, url: str) -> None: 

498 """Record a failure and potentially trip circuit breaker.""" 

499 self._failures[url] = self._failures.get(url, 0) + 1 

500 if self._failures[url] >= self._circuit_breaker_threshold: 

501 self._circuit_open_until[url] = time.time() + self._circuit_breaker_reset 

502 self._circuit_breaker_trips += 1 

503 logger.warning(f"Circuit breaker opened for {sanitize_url_for_logging(url)} after {self._failures[url]} failures. " f"Will reset in {self._circuit_breaker_reset}s") 

504 

505 def _record_success(self, url: str) -> None: 

506 """Record a success, resetting failure count.""" 

507 self._failures[url] = 0 

508 

509 @staticmethod 

510 def is_valid_mcp_session_id(session_id: str) -> bool: 

511 """Validate downstream MCP session ID format for affinity. 

512 

513 Used for: 

514 - Redis key construction (ownership + mapping) 

515 - Pub/Sub channel naming 

516 - Avoiding log spam / injection 

517 """ 

518 if not session_id: 

519 return False 

520 return bool(_MCP_SESSION_ID_PATTERN.match(session_id)) 

521 

522 def _sanitize_redis_key_component(self, value: str) -> str: 

523 """Sanitize a value for use in Redis key construction. 

524 

525 Replaces any characters that could cause key collision or injection. 

526 

527 Args: 

528 value: The value to sanitize. 

529 

530 Returns: 

531 Sanitized value safe for Redis key construction. 

532 """ 

533 if not value: 

534 return "" 

535 

536 # Replace problematic characters with underscores 

537 return re.sub(r"[^a-zA-Z0-9_-]", "_", value) 

538 

539 def _session_mapping_redis_key(self, mcp_session_id: str, url: str, transport_type: str, gateway_id: str) -> str: 

540 """Compute a bounded Redis key for session mapping. 

541 

542 The URL is hashed to keep keys small and avoid special character issues. 

543 """ 

544 sanitized_session_id = self._sanitize_redis_key_component(mcp_session_id) 

545 url_hash = hashlib.sha256(url.encode()).hexdigest()[:16] 

546 return f"mcpgw:session_mapping:{sanitized_session_id}:{url_hash}:{transport_type}:{gateway_id}" 

547 

548 @staticmethod 

549 def _pool_owner_key(mcp_session_id: str) -> str: 

550 """Return Redis key for session ownership tracking.""" 

551 return f"mcpgw:pool_owner:{mcp_session_id}" 

552 

553 async def register_session_mapping( 

554 self, 

555 mcp_session_id: str, 

556 url: str, 

557 gateway_id: str, 

558 transport_type: str, 

559 user_email: Optional[str] = None, 

560 ) -> None: 

561 """Pre-register session mapping for session affinity. 

562 

563 Called from respond() to set up mapping BEFORE acquire() is called. 

564 This ensures acquire() can find the correct pool key for session affinity. 

565 

566 The mapping stores the relationship between an incoming MCP session ID 

567 and the pool key that should be used for upstream connections. This 

568 enables session affinity even when JWT tokens rotate (different jti values 

569 per request). 

570 

571 For multi-worker deployments, the mapping is also stored in Redis with TTL 

572 so that any worker can look it up during acquire(). 

573 

574 Args: 

575 mcp_session_id: The downstream MCP session ID from x-mcp-session-id header. 

576 url: The upstream MCP server URL. 

577 gateway_id: The gateway ID. 

578 transport_type: The transport type (sse, streamablehttp). 

579 user_email: The email of the authenticated user (or "system" for unauthenticated). 

580 """ 

581 if not settings.mcpgateway_session_affinity_enabled: 

582 return 

583 

584 # Validate mcp_session_id to prevent Redis key injection 

585 if not self.is_valid_mcp_session_id(mcp_session_id): 

586 logger.warning(f"Invalid mcp_session_id format, skipping session mapping: {mcp_session_id[:20]}...") 

587 return 

588 

589 # Use user email for user_identity, or "anonymous" if not provided 

590 user_identity = user_email or "anonymous" 

591 

592 # Normalize gateway_id to empty string if None for consistent key matching 

593 normalized_gateway_id = gateway_id or "" 

594 

595 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type, normalized_gateway_id) 

596 

597 # Compute what the pool_key will be for this session 

598 # Use mcp_session_id as the identity basis for affinity 

599 identity_hash = hashlib.sha256(mcp_session_id.encode()).hexdigest() 

600 

601 # Hash user identity for privacy (unless it's "anonymous") 

602 if user_identity == "anonymous": 

603 user_hash = "anonymous" 

604 else: 

605 user_hash = hashlib.sha256(user_identity.encode()).hexdigest() 

606 

607 pool_key: PoolKey = (user_hash, url, identity_hash, transport_type, normalized_gateway_id) 

608 

609 # Store in local memory 

610 async with self._mcp_session_mapping_lock: 

611 self._mcp_session_mapping[mapping_key] = pool_key 

612 logger.debug(f"Session affinity pre-registered (local): {mcp_session_id[:8]}... → {url}, user={user_identity}") 

613 

614 # Store in Redis for multi-worker support AND register ownership atomically 

615 # Registering ownership HERE (during mapping) instead of in acquire() prevents 

616 # a race condition where two workers could both start creating sessions before 

617 # either registers ownership 

618 try: 

619 # First-Party 

620 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

621 

622 redis = await get_redis_client() 

623 if redis: 

624 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type, normalized_gateway_id) 

625 

626 # Store pool_key as JSON for easy deserialization 

627 pool_key_data = { 

628 "user_hash": user_hash, 

629 "url": url, 

630 "identity_hash": identity_hash, 

631 "transport_type": transport_type, 

632 "gateway_id": normalized_gateway_id, 

633 } 

634 await redis.setex(redis_key, settings.mcpgateway_session_affinity_ttl, orjson.dumps(pool_key_data)) # TTL from config 

635 

636 # CRITICAL: Register ownership atomically with mapping. 

637 # This claims ownership BEFORE any session creation attempt, preventing 

638 # the race condition where two workers both start creating sessions 

639 owner_key = self._pool_owner_key(mcp_session_id) 

640 # Atomic claim with TTL (avoids the SETNX/EXPIRE crash window). 

641 was_set = await redis.set(owner_key, WORKER_ID, nx=True, ex=settings.mcpgateway_session_affinity_ttl) 

642 if was_set: 

643 logger.debug(f"Session ownership claimed (SET NX): {mcp_session_id[:8]}... → worker {WORKER_ID}") 

644 else: 

645 # Another worker already claimed ownership 

646 existing_owner = await redis.get(owner_key) 

647 owner_id = existing_owner.decode() if isinstance(existing_owner, bytes) else existing_owner 

648 logger.debug(f"Session ownership already claimed by {owner_id}: {mcp_session_id[:8]}...") 

649 

650 logger.debug(f"Session affinity pre-registered (Redis): {mcp_session_id[:8]}... TTL={settings.mcpgateway_session_affinity_ttl}s") 

651 except Exception as e: 

652 # Redis failure is non-fatal - local mapping still works for same-worker requests 

653 logger.debug(f"Failed to store session mapping in Redis: {e}") 

654 

655 async def acquire( 

656 self, 

657 url: str, 

658 headers: Optional[Dict[str, str]] = None, 

659 transport_type: TransportType = TransportType.STREAMABLE_HTTP, 

660 httpx_client_factory: Optional[HttpxClientFactory] = None, 

661 timeout: Optional[float] = None, 

662 user_identity: Optional[str] = None, 

663 gateway_id: Optional[str] = None, 

664 ) -> PooledSession: 

665 """ 

666 Acquire a session for the given URL, identity, and transport type. 

667 

668 Sessions are isolated by identity (derived from auth headers) AND 

669 transport type. Returns an initialized, healthy session ready for tool calls. 

670 

671 Args: 

672 url: The MCP server URL. 

673 headers: Request headers (used for identity hashing and passed to server). 

674 transport_type: The transport type (SSE or STREAMABLE_HTTP). 

675 httpx_client_factory: Optional factory for creating httpx clients 

676 (for custom SSL/timeout configuration). 

677 timeout: Optional timeout in seconds for transport connection. 

678 gateway_id: Optional gateway ID for notification handler context. 

679 

680 Returns: 

681 PooledSession ready for use. 

682 

683 Raises: 

684 asyncio.TimeoutError: If acquire times out waiting for available session. 

685 RuntimeError: If pool is closed or circuit breaker is open. 

686 Exception: If session creation fails. 

687 """ 

688 if self._closed: 

689 raise RuntimeError("Session pool is closed") 

690 

691 if self._is_circuit_open(url): 

692 raise RuntimeError(f"Circuit breaker open for {url}") 

693 

694 # Use default timeout if not provided 

695 effective_timeout = timeout if timeout is not None else self._default_transport_timeout 

696 

697 user_id = user_identity or "anonymous" 

698 pool_key: Optional[PoolKey] = None 

699 

700 # Check pre-registered mapping first (set by respond() for session affinity) 

701 if settings.mcpgateway_session_affinity_enabled and headers: 

702 headers_lower = {k.lower(): v for k, v in headers.items()} 

703 mcp_session_id = headers_lower.get("x-mcp-session-id") 

704 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

705 normalized_gateway_id = gateway_id or "" 

706 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type.value, normalized_gateway_id) 

707 

708 # Check local memory first (fast path - same worker) 

709 async with self._mcp_session_mapping_lock: 

710 pool_key = self._mcp_session_mapping.get(mapping_key) 

711 if pool_key: 

712 self._session_affinity_local_hits += 1 

713 logger.debug(f"Session affinity hit (local): {mcp_session_id[:8]}...") 

714 

715 # If not in local memory, check Redis (multi-worker support) 

716 if pool_key is None: 

717 try: 

718 # First-Party 

719 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

720 

721 redis = await get_redis_client() 

722 if redis: 

723 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type.value, normalized_gateway_id) 

724 pool_key_data = await redis.get(redis_key) 

725 if pool_key_data: 

726 # Deserialize pool_key from JSON 

727 data = orjson.loads(pool_key_data) 

728 pool_key = ( 

729 data["user_hash"], 

730 data["url"], 

731 data["identity_hash"], 

732 data["transport_type"], 

733 data["gateway_id"], 

734 ) 

735 # Cache in local memory for future requests 

736 async with self._mcp_session_mapping_lock: 

737 self._mcp_session_mapping[mapping_key] = pool_key 

738 self._session_affinity_redis_hits += 1 

739 logger.debug(f"Session affinity hit (Redis): {mcp_session_id[:8]}...") 

740 except Exception as e: 

741 logger.debug(f"Failed to check Redis for session mapping: {e}") 

742 

743 # Fallback to normal pool key computation 

744 if pool_key is None: 

745 self._session_affinity_misses += 1 

746 pool_key = self._make_pool_key(url, headers, transport_type, user_id, gateway_id) 

747 

748 pool = await self._get_or_create_pool(pool_key) 

749 

750 # Update pool key last used time IMMEDIATELY after getting pool 

751 # This prevents race with eviction removing keys between awaits 

752 self._pool_last_used[pool_key] = time.time() 

753 

754 lock = await self._get_or_create_lock(pool_key) 

755 

756 # Guard semaphore access - eviction may have removed it between awaits 

757 # If so, re-create the pool structures 

758 if pool_key not in self._semaphores: 

759 pool = await self._get_or_create_pool(pool_key) 

760 self._pool_last_used[pool_key] = time.time() 

761 

762 semaphore = self._semaphores[pool_key] 

763 

764 # Throttled eviction - only run if enough time has passed (inline, not spawned) 

765 await self._maybe_evict_idle_pool_keys() 

766 

767 # Try to get from pool first (quick path, no lock needed for queue get) 

768 while True: 

769 try: 

770 pooled = pool.get_nowait() 

771 except asyncio.QueueEmpty: 

772 break 

773 

774 # Validate the session outside the lock 

775 if await self._validate_session(pooled): 

776 pooled.last_used = time.time() 

777 pooled.use_count += 1 

778 self._hits += 1 

779 async with lock: 

780 self._active[pool_key].add(pooled) 

781 logger.debug(f"Pool hit for {sanitize_url_for_logging(url)} (identity={pool_key[2][:8]}, transport={transport_type.value})") 

782 return pooled 

783 

784 # Session invalid, close it 

785 await self._close_session(pooled) 

786 self._evictions += 1 

787 semaphore.release() # Free up a slot 

788 

789 # No valid session in pool - try to create one or wait 

790 try: 

791 # Use semaphore with timeout to limit concurrent sessions 

792 acquired = await asyncio.wait_for(semaphore.acquire(), timeout=self._acquire_timeout) 

793 if not acquired: 

794 raise asyncio.TimeoutError("Failed to acquire session slot") 

795 except asyncio.TimeoutError: 

796 raise asyncio.TimeoutError(f"Timeout waiting for available session for {sanitize_url_for_logging(url)}") from None 

797 

798 # Create new session (semaphore acquired) 

799 try: 

800 # Verify we own this session before creating (prevents race condition) 

801 # If another worker already claimed ownership, we should not create a new session 

802 # Note: Ownership is registered atomically in register_session_mapping() using SETNX 

803 if settings.mcpgateway_session_affinity_enabled and headers: 

804 headers_lower = {k.lower(): v for k, v in headers.items()} 

805 mcp_session_id = headers_lower.get("x-mcp-session-id") 

806 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

807 owner = await self._get_pool_session_owner(mcp_session_id) 

808 if owner and owner != WORKER_ID: 

809 # Another worker claimed ownership - should have been forwarded 

810 # Release semaphore and raise to trigger forwarding 

811 semaphore.release() 

812 logger.warning(f"Session {mcp_session_id[:8]}... owned by worker {owner}, not us ({WORKER_ID})") 

813 raise RuntimeError(f"Session owned by another worker: {owner}") 

814 

815 pooled = await asyncio.wait_for( 

816 self._create_session(url, headers, transport_type, httpx_client_factory, effective_timeout, gateway_id), 

817 timeout=self._session_create_timeout, 

818 ) 

819 # Store identity components for key reconstruction 

820 pooled.identity_key = pool_key[2] 

821 pooled.user_identity = user_id 

822 

823 # Note: Ownership is now registered atomically in register_session_mapping() 

824 # before acquire() is called, so we don't need to register it here 

825 

826 self._misses += 1 

827 self._record_success(url) 

828 async with lock: 

829 self._active[pool_key].add(pooled) 

830 logger.debug(f"Pool miss for {sanitize_url_for_logging(url)} - created new session (transport={transport_type.value})") 

831 return pooled 

832 except BaseException as e: 

833 # Release semaphore on ANY failure (including CancelledError) 

834 semaphore.release() 

835 if not isinstance(e, asyncio.CancelledError): 

836 self._record_failure(url) 

837 logger.warning(f"Failed to create session for {sanitize_url_for_logging(url)}: {e}") 

838 raise 

839 

840 async def release(self, pooled: PooledSession) -> None: 

841 """ 

842 Return a session to the pool for reuse. 

843 

844 Args: 

845 pooled: The session to release. 

846 """ 

847 if pooled.is_closed: 

848 logger.warning("Attempted to release already-closed session") 

849 return 

850 

851 # Pool key includes transport type, user identity, and gateway_id 

852 # Re-compute user hash from stored raw identity (full hash for collision resistance) 

853 user_hash = "anonymous" 

854 if pooled.user_identity != "anonymous": 

855 user_hash = hashlib.sha256(pooled.user_identity.encode()).hexdigest() 

856 

857 pool_key = (user_hash, pooled.url, pooled.identity_key, pooled.transport_type.value, pooled.gateway_id) 

858 lock = await self._get_or_create_lock(pool_key) 

859 pool = await self._get_or_create_pool(pool_key) 

860 

861 async with lock: 

862 # Update last-used FIRST to prevent eviction race: 

863 # If eviction runs between removing from _active and putting back in pool, 

864 # it would see key as idle + inactive and evict it. By updating last-used 

865 # while still holding the lock and before removing from _active, we ensure 

866 # eviction sees recent activity. 

867 self._pool_last_used[pool_key] = time.time() 

868 self._active.get(pool_key, set()).discard(pooled) 

869 

870 # Check if session should be returned to pool 

871 if self._closed or pooled.age_seconds > self._session_ttl: 

872 await self._close_session(pooled) 

873 if pool_key in self._semaphores: 

874 self._semaphores[pool_key].release() 

875 if pooled.age_seconds > self._session_ttl: 

876 self._evictions += 1 

877 return 

878 

879 # Return to pool (pool may have been evicted in edge case, recreate if needed) 

880 if pool_key not in self._pools: 

881 pool = await self._get_or_create_pool(pool_key) 

882 self._pool_last_used[pool_key] = time.time() 

883 

884 try: 

885 pool.put_nowait(pooled) 

886 logger.debug(f"Session returned to pool for {sanitize_url_for_logging(pooled.url)}") 

887 except asyncio.QueueFull: 

888 # Pool full (shouldn't happen with semaphore), close session 

889 await self._close_session(pooled) 

890 if pool_key in self._semaphores: 

891 self._semaphores[pool_key].release() 

892 

893 async def _maybe_evict_idle_pool_keys(self) -> None: 

894 """ 

895 Reap stale sessions and evict idle pool keys. 

896 

897 This method is throttled - it only runs eviction if enough time has 

898 passed since the last run (default: 60 seconds). This prevents: 

899 - Unbounded task spawning on every acquire 

900 - Lock contention under high load 

901 

902 Two-phase cleanup: 

903 1. Close expired/stale sessions parked in idle pools (frees connections) 

904 2. Evict pool keys that are now empty and have no active sessions 

905 

906 This prevents unbounded connection and pool key growth when using 

907 rotating tokens (e.g., short-lived JWTs with unique identifiers). 

908 """ 

909 if self._closed: 

910 return 

911 

912 now = time.time() 

913 

914 # Throttle: only run eviction once per interval 

915 if now - self._last_eviction_run < self._eviction_run_interval: 

916 return 

917 

918 self._last_eviction_run = now 

919 

920 # Collect sessions to close and keys to evict (minimize time holding lock) 

921 sessions_to_close: list[PooledSession] = [] 

922 keys_to_evict: list[PoolKey] = [] 

923 

924 async with self._global_lock: 

925 for pool_key, last_used in list(self._pool_last_used.items()): 

926 # Skip recently-used pools 

927 if now - last_used < self._idle_pool_eviction: 

928 continue 

929 

930 pool = self._pools.get(pool_key) 

931 active = self._active.get(pool_key, set()) 

932 

933 # Skip if there are active sessions (in use) 

934 if active: 

935 continue 

936 

937 if pool: 

938 # Phase 1: Drain and collect expired/stale sessions from idle pools 

939 while not pool.empty(): 

940 try: 

941 session = pool.get_nowait() 

942 # Close if expired OR idle too long (defense in depth) 

943 if session.age_seconds > self._session_ttl or session.idle_seconds > self._idle_pool_eviction: 

944 sessions_to_close.append(session) 

945 # Release semaphore slot for this session 

946 if pool_key in self._semaphores: 

947 self._semaphores[pool_key].release() 

948 else: 

949 # Session still valid, put it back 

950 pool.put_nowait(session) 

951 break # Stop draining if we find a valid session 

952 except asyncio.QueueEmpty: 

953 break 

954 

955 # Phase 2: Evict pool key if now empty 

956 if pool.empty(): 

957 keys_to_evict.append(pool_key) 

958 

959 # Remove evicted keys from all tracking dicts 

960 for pool_key in keys_to_evict: 

961 self._pools.pop(pool_key, None) 

962 self._active.pop(pool_key, None) 

963 self._locks.pop(pool_key, None) 

964 self._semaphores.pop(pool_key, None) 

965 self._pool_last_used.pop(pool_key, None) 

966 self._pool_keys_evicted += 1 

967 logger.debug(f"Evicted idle pool key: {pool_key[0][:8]}|{pool_key[1]}|{pool_key[2][:8]}") 

968 

969 # Close sessions outside the lock (I/O operations) 

970 for session in sessions_to_close: 

971 await self._close_session(session) 

972 self._sessions_reaped += 1 

973 logger.debug(f"Reaped stale session for {sanitize_url_for_logging(session.url)} (age={session.age_seconds:.1f}s)") 

974 

975 async def _validate_session(self, pooled: PooledSession) -> bool: 

976 """ 

977 Validate a session is still usable. 

978 

979 Checks TTL and performs health check if session is stale. 

980 

981 Args: 

982 pooled: The session to validate. 

983 

984 Returns: 

985 True if session is valid, False otherwise. 

986 """ 

987 if pooled.is_closed: 

988 return False 

989 

990 # Check TTL 

991 if pooled.age_seconds > self._session_ttl: 

992 logger.debug(f"Session expired (age={pooled.age_seconds:.1f}s)") 

993 return False 

994 

995 # Health check if stale 

996 if pooled.idle_seconds > self._health_check_interval: 

997 return await self._run_health_check_chain(pooled) 

998 

999 return True 

1000 

1001 async def _run_health_check_chain(self, pooled: PooledSession) -> bool: 

1002 """ 

1003 Run health check methods in configured order until one succeeds. 

1004 

1005 The health check chain allows configuring which methods to try and in what order. 

1006 This supports both modern servers (with ping support) and legacy servers 

1007 (that may only support list_tools or no health check at all). 

1008 

1009 Args: 

1010 pooled: The session to health check. 

1011 

1012 Returns: 

1013 True if any health check method succeeds, False if all fail. 

1014 """ 

1015 for method in self._health_check_methods: 

1016 try: 

1017 if method == "ping": 

1018 await asyncio.wait_for(pooled.session.send_ping(), timeout=self._health_check_timeout) 

1019 logger.debug(f"Health check passed: ping (url={sanitize_url_for_logging(pooled.url)})") 

1020 return True 

1021 if method == "list_tools": 

1022 await asyncio.wait_for(pooled.session.list_tools(), timeout=self._health_check_timeout) 

1023 logger.debug(f"Health check passed: list_tools (url={sanitize_url_for_logging(pooled.url)})") 

1024 return True 

1025 if method == "list_prompts": 

1026 await asyncio.wait_for(pooled.session.list_prompts(), timeout=self._health_check_timeout) 

1027 logger.debug(f"Health check passed: list_prompts (url={sanitize_url_for_logging(pooled.url)})") 

1028 return True 

1029 if method == "list_resources": 

1030 await asyncio.wait_for(pooled.session.list_resources(), timeout=self._health_check_timeout) 

1031 logger.debug(f"Health check passed: list_resources (url={sanitize_url_for_logging(pooled.url)})") 

1032 return True 

1033 if method == "skip": 

1034 logger.debug(f"Health check skipped per configuration (url={sanitize_url_for_logging(pooled.url)})") 

1035 return True 

1036 logger.warning(f"Unknown health check method '{method}', skipping") 

1037 continue 

1038 

1039 except McpError as e: 

1040 # METHOD_NOT_FOUND (-32601) means the method isn't supported - try next 

1041 if e.error.code == METHOD_NOT_FOUND: 

1042 logger.debug(f"Health check method '{method}' not supported by server, trying next") 

1043 continue 

1044 # Other MCP errors are real failures 

1045 logger.debug(f"Health check '{method}' failed with MCP error: {e}") 

1046 self._health_check_failures += 1 

1047 return False 

1048 

1049 except asyncio.TimeoutError: 

1050 logger.debug(f"Health check '{method}' timed out after {self._health_check_timeout}s, trying next") 

1051 continue 

1052 

1053 except Exception as e: 

1054 logger.debug(f"Health check '{method}' failed: {e}") 

1055 self._health_check_failures += 1 

1056 return False 

1057 

1058 # All methods failed or were unsupported 

1059 logger.warning(f"All health check methods failed or unsupported (methods={self._health_check_methods})") 

1060 self._health_check_failures += 1 

1061 return False 

1062 

1063 async def _create_session( 

1064 self, 

1065 url: str, 

1066 headers: Optional[Dict[str, str]], 

1067 transport_type: TransportType, 

1068 httpx_client_factory: Optional[HttpxClientFactory], 

1069 timeout: Optional[float] = None, 

1070 gateway_id: Optional[str] = None, 

1071 ) -> PooledSession: 

1072 """ 

1073 Create a new initialized MCP session. 

1074 

1075 Args: 

1076 url: Server URL. 

1077 headers: Request headers. 

1078 transport_type: Transport type to use. 

1079 httpx_client_factory: Optional factory for httpx clients. 

1080 timeout: Optional timeout in seconds for transport connection. 

1081 gateway_id: Optional gateway ID for notification handler context. 

1082 

1083 Returns: 

1084 Initialized PooledSession. 

1085 

1086 Raises: 

1087 RuntimeError: If session creation or initialization fails. 

1088 asyncio.CancelledError: If cancelled during creation. 

1089 """ 

1090 # Merge headers with defaults 

1091 merged_headers = {"Accept": "application/json, text/event-stream"} 

1092 if headers: 

1093 merged_headers.update(headers) 

1094 

1095 # Strip gateway-internal session affinity headers before sending to upstream 

1096 # x-mcp-session-id is our internal representation, mcp-session-id is the MCP protocol header 

1097 # Neither should be forwarded to upstream servers 

1098 keys_to_remove = [k for k in merged_headers if k.lower() in ("x-mcp-session-id", "mcp-session-id")] 

1099 for k in keys_to_remove: 

1100 del merged_headers[k] 

1101 

1102 identity_key = self._compute_identity_hash(headers) 

1103 transport_ctx = None 

1104 session = None 

1105 success = False 

1106 

1107 try: 

1108 # Create transport context 

1109 if transport_type == TransportType.SSE: 

1110 if httpx_client_factory: 

1111 transport_ctx = sse_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout) 

1112 else: 

1113 transport_ctx = sse_client(url=url, headers=merged_headers, timeout=timeout) 

1114 # pylint: disable=unnecessary-dunder-call,no-member 

1115 streams = await transport_ctx.__aenter__() # Must call directly for manual lifecycle management 

1116 read_stream, write_stream = streams[0], streams[1] 

1117 else: # STREAMABLE_HTTP 

1118 if httpx_client_factory: 

1119 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout) 

1120 else: 

1121 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, timeout=timeout) 

1122 # pylint: disable=unnecessary-dunder-call,no-member 

1123 read_stream, write_stream, _ = await transport_ctx.__aenter__() # Must call directly for manual lifecycle management 

1124 

1125 # Create message handler if factory is configured 

1126 message_handler = None 

1127 if self._message_handler_factory: 

1128 try: 

1129 message_handler = self._message_handler_factory(url, gateway_id) 

1130 logger.debug(f"Created message handler for session {sanitize_url_for_logging(url)} (gateway={gateway_id})") 

1131 except Exception as e: 

1132 logger.warning(f"Failed to create message handler for {sanitize_url_for_logging(url)}: {e}") 

1133 

1134 # Create and initialize session 

1135 session = ClientSession(read_stream, write_stream, message_handler=message_handler) 

1136 # pylint: disable=unnecessary-dunder-call 

1137 await session.__aenter__() # Must call directly for manual lifecycle management 

1138 await session.initialize() 

1139 

1140 logger.info(f"Created new MCP session for {sanitize_url_for_logging(url)} (transport={transport_type.value})") 

1141 success = True 

1142 

1143 return PooledSession( 

1144 session=session, 

1145 transport_context=transport_ctx, 

1146 url=url, 

1147 transport_type=transport_type, 

1148 headers=merged_headers, 

1149 identity_key=identity_key, 

1150 gateway_id=gateway_id or "", 

1151 ) 

1152 

1153 except asyncio.CancelledError: # pylint: disable=try-except-raise 

1154 # Re-raise CancelledError after cleanup (handled in finally) 

1155 raise 

1156 

1157 except Exception as e: 

1158 raise RuntimeError(f"Failed to create MCP session for {url}: {e}") from e 

1159 

1160 finally: 

1161 # Clean up on ANY failure (Exception, CancelledError, etc.) 

1162 # Only clean up if we didn't succeed 

1163 # Use anyio.move_on_after instead of asyncio.wait_for to properly propagate 

1164 # cancellation through anyio's cancel scope system (prevents orphaned spinning tasks) 

1165 if not success: 

1166 cleanup_timeout = _get_cleanup_timeout() 

1167 if session is not None: 

1168 with anyio.move_on_after(cleanup_timeout): 

1169 try: 

1170 await session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1171 except Exception: # nosec B110 - Best effort cleanup on connection failure 

1172 pass 

1173 if transport_ctx is not None: 

1174 with anyio.move_on_after(cleanup_timeout): 

1175 try: 

1176 await transport_ctx.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1177 except Exception: # nosec B110 - Best effort cleanup on connection failure 

1178 pass 

1179 

1180 async def _close_session(self, pooled: PooledSession) -> None: 

1181 """ 

1182 Close a session and its transport. 

1183 

1184 Uses timeouts to prevent indefinite blocking if session/transport tasks 

1185 don't respond to cancellation. This prevents CPU spin loops in anyio's 

1186 _deliver_cancellation which can occur when async iterators or blocking 

1187 operations don't properly handle CancelledError. 

1188 

1189 Args: 

1190 pooled: The session to close. 

1191 """ 

1192 if pooled.is_closed: 

1193 return 

1194 

1195 pooled.mark_closed() 

1196 

1197 # Use anyio's move_on_after instead of asyncio.wait_for to properly propagate 

1198 # cancellation through anyio's cancel scope system. asyncio.wait_for() creates 

1199 # orphaned anyio tasks that keep spinning in _deliver_cancellation. 

1200 cleanup_timeout = _get_cleanup_timeout() 

1201 

1202 # Close session with anyio timeout 

1203 with anyio.move_on_after(cleanup_timeout) as session_scope: 

1204 try: 

1205 await pooled.session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1206 except Exception as e: 

1207 logger.debug(f"Error closing session: {e}") 

1208 if session_scope.cancelled_caught: 

1209 logger.warning(f"Session cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway") 

1210 

1211 # Close transport with anyio timeout 

1212 with anyio.move_on_after(cleanup_timeout) as transport_scope: 

1213 try: 

1214 await pooled.transport_context.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1215 except Exception as e: 

1216 logger.debug(f"Error closing transport: {e}") 

1217 if transport_scope.cancelled_caught: 

1218 logger.warning(f"Transport cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway") 

1219 

1220 logger.debug(f"Closed session for {sanitize_url_for_logging(pooled.url)} (uses={pooled.use_count})") 

1221 

1222 # Clean up pool_owner key in Redis for session affinity 

1223 if settings.mcpgateway_session_affinity_enabled and pooled.headers: 

1224 headers_lower = {k.lower(): v for k, v in pooled.headers.items()} 

1225 mcp_session_id = headers_lower.get("x-mcp-session-id") 

1226 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

1227 await self._cleanup_pool_session_owner(mcp_session_id) 

1228 

1229 async def _cleanup_pool_session_owner(self, mcp_session_id: str) -> None: 

1230 """Clean up pool_owner key in Redis when session is closed. 

1231 

1232 Only deletes the key if this worker owns it (to prevent removing other workers' ownership). 

1233 

1234 Args: 

1235 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1236 """ 

1237 try: 

1238 # First-Party 

1239 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1240 

1241 redis = await get_redis_client() 

1242 if redis: 

1243 key = self._pool_owner_key(mcp_session_id) 

1244 # Only delete if we own it 

1245 owner = await redis.get(key) 

1246 if owner: 

1247 owner_id = owner.decode() if isinstance(owner, bytes) else owner 

1248 if owner_id == WORKER_ID: 

1249 await redis.delete(key) 

1250 logger.debug(f"Cleaned up pool session owner: {mcp_session_id[:8]}...") 

1251 except Exception as e: 

1252 # Cleanup failure is non-fatal 

1253 logger.debug(f"Failed to cleanup pool session owner in Redis: {e}") 

1254 

1255 async def close_all(self) -> None: 

1256 """ 

1257 Gracefully close all pooled and active sessions. 

1258 

1259 Should be called during application shutdown. 

1260 """ 

1261 self._closed = True 

1262 logger.info("Closing all pooled sessions...") 

1263 

1264 async with self._global_lock: 

1265 # Close all pooled sessions 

1266 for _pool_key, pool in list(self._pools.items()): 

1267 while not pool.empty(): 

1268 try: 

1269 pooled = pool.get_nowait() 

1270 await self._close_session(pooled) 

1271 except asyncio.QueueEmpty: 

1272 break 

1273 

1274 # Close all active sessions 

1275 for _pool_key, active_set in list(self._active.items()): 

1276 for pooled in list(active_set): 

1277 await self._close_session(pooled) 

1278 

1279 self._pools.clear() 

1280 self._active.clear() 

1281 self._locks.clear() 

1282 self._semaphores.clear() 

1283 

1284 # Stop RPC listener if running 

1285 if self._rpc_listener_task and not self._rpc_listener_task.done(): 

1286 self._rpc_listener_task.cancel() 

1287 try: 

1288 await self._rpc_listener_task 

1289 except asyncio.CancelledError: 

1290 pass 

1291 self._rpc_listener_task = None 

1292 

1293 logger.info("All sessions closed") 

1294 

1295 async def register_pool_session_owner(self, mcp_session_id: str) -> None: 

1296 """Register this worker as owner of a pool session in Redis. 

1297 

1298 This enables multi-worker session affinity by tracking which worker owns 

1299 which pool session. When a request with x-mcp-session-id arrives at a 

1300 different worker, it can forward the request to the owner worker. 

1301 

1302 Note: This method is now primarily used for refreshing TTL on existing ownership. 

1303 Initial ownership is claimed atomically in register_session_mapping() using SETNX. 

1304 

1305 Args: 

1306 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1307 """ 

1308 if not settings.mcpgateway_session_affinity_enabled: 

1309 return 

1310 

1311 if not self.is_valid_mcp_session_id(mcp_session_id): 

1312 logger.debug("Invalid mcp_session_id for owner registration, skipping") 

1313 return 

1314 

1315 try: 

1316 # First-Party 

1317 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1318 

1319 redis = await get_redis_client() 

1320 if redis: 

1321 key = self._pool_owner_key(mcp_session_id) 

1322 

1323 # Do not steal ownership: only claim if missing, or refresh TTL if we already own. 

1324 # Lua keeps this atomic. 

1325 script = """ 

1326 local cur = redis.call('GET', KEYS[1]) 

1327 if not cur then 

1328 redis.call('SET', KEYS[1], ARGV[1], 'EX', ARGV[2]) 

1329 return 1 

1330 end 

1331 if cur == ARGV[1] then 

1332 redis.call('EXPIRE', KEYS[1], ARGV[2]) 

1333 return 2 

1334 end 

1335 return 0 

1336 """ 

1337 ttl = int(settings.mcpgateway_session_affinity_ttl) 

1338 outcome = await redis.eval(script, 1, key, WORKER_ID, ttl) 

1339 logger.debug(f"Owner registration outcome={outcome} for session {mcp_session_id[:8]}...") 

1340 except Exception as e: 

1341 # Redis failure is non-fatal - single worker mode still works 

1342 logger.debug(f"Failed to register pool session owner in Redis: {e}") 

1343 

1344 async def _get_pool_session_owner(self, mcp_session_id: str) -> Optional[str]: 

1345 """Get the worker ID that owns a pool session. 

1346 

1347 Args: 

1348 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1349 

1350 Returns: 

1351 The worker ID that owns this session, or None if not found or Redis unavailable. 

1352 """ 

1353 if not settings.mcpgateway_session_affinity_enabled: 

1354 return None 

1355 

1356 if not self.is_valid_mcp_session_id(mcp_session_id): 

1357 return None 

1358 

1359 try: 

1360 # First-Party 

1361 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1362 

1363 redis = await get_redis_client() 

1364 if redis: 

1365 key = self._pool_owner_key(mcp_session_id) 

1366 owner = await redis.get(key) 

1367 if owner: 

1368 decoded = owner.decode() if isinstance(owner, bytes) else owner 

1369 return decoded 

1370 except Exception as e: 

1371 logger.debug(f"Failed to get pool session owner from Redis: {e}") 

1372 return None 

1373 

1374 async def forward_request_to_owner( 

1375 self, 

1376 mcp_session_id: str, 

1377 request_data: Dict[str, Any], 

1378 timeout: Optional[float] = None, 

1379 ) -> Optional[Dict[str, Any]]: 

1380 """Forward RPC request to the worker that owns the pool session. 

1381 

1382 This method checks Redis to find which worker owns the pool session for 

1383 the given mcp_session_id. If owned by another worker, it forwards the 

1384 request via Redis pub/sub and waits for the response. 

1385 

1386 Args: 

1387 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1388 request_data: The RPC request data to forward. 

1389 timeout: Optional timeout in seconds (default from config). 

1390 

1391 Returns: 

1392 The response from the owner worker, or None if we own the session 

1393 (caller should execute locally) or if Redis is unavailable. 

1394 

1395 Raises: 

1396 asyncio.TimeoutError: If the forwarded request times out. 

1397 """ 

1398 if not settings.mcpgateway_session_affinity_enabled: 

1399 return None 

1400 

1401 if not self.is_valid_mcp_session_id(mcp_session_id): 

1402 return None 

1403 

1404 effective_timeout = timeout if timeout is not None else settings.mcpgateway_pool_rpc_forward_timeout 

1405 

1406 try: 

1407 # First-Party 

1408 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1409 

1410 redis = await get_redis_client() 

1411 if not redis: 

1412 return None # Execute locally - no Redis 

1413 

1414 # Check who owns this session 

1415 owner = await redis.get(self._pool_owner_key(mcp_session_id)) 

1416 method = request_data.get("method", "unknown") 

1417 if not owner: 

1418 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | No owner → execute locally (new session)") 

1419 return None # No owner registered - execute locally (new session) 

1420 

1421 owner_id = owner.decode() if isinstance(owner, bytes) else owner 

1422 if owner_id == WORKER_ID: 

1423 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | We own it → execute locally") 

1424 return None # We own it - execute locally 

1425 

1426 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Owner: {owner_id} → forwarding") 

1427 

1428 # Forward to owner worker via pub/sub 

1429 response_id = str(uuid.uuid4()) 

1430 response_channel = f"mcpgw:pool_rpc_response:{response_id}" 

1431 

1432 # Subscribe to response channel 

1433 pubsub = redis.pubsub() 

1434 await pubsub.subscribe(response_channel) 

1435 

1436 try: 

1437 # Prepare request with response channel 

1438 forward_data = { 

1439 "type": "rpc_forward", 

1440 **request_data, 

1441 "response_channel": response_channel, 

1442 "mcp_session_id": mcp_session_id, 

1443 } 

1444 

1445 # Publish request to owner's channel 

1446 await redis.publish(f"mcpgw:pool_rpc:{owner_id}", orjson.dumps(forward_data)) 

1447 self._forwarded_requests += 1 

1448 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Published to worker {owner_id}") 

1449 

1450 # Wait for response 

1451 async with asyncio.timeout(effective_timeout): 

1452 while True: 

1453 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) 

1454 if msg and msg["type"] == "message": 

1455 return orjson.loads(msg["data"]) 

1456 finally: 

1457 await pubsub.unsubscribe(response_channel) 

1458 

1459 except asyncio.TimeoutError: 

1460 self._forwarded_request_timeouts += 1 

1461 logger.warning(f"Timeout forwarding request to owner for session {mcp_session_id[:8]}...") 

1462 raise 

1463 except Exception as e: 

1464 self._forwarded_request_failures += 1 

1465 logger.debug(f"Error forwarding request to owner: {e}") 

1466 return None # Execute locally on error 

1467 

1468 async def start_rpc_listener(self) -> None: 

1469 """Start listening for forwarded RPC and HTTP requests on this worker's channels. 

1470 

1471 This method subscribes to Redis pub/sub channels specific to this worker 

1472 and processes incoming forwarded requests from other workers: 

1473 - mcpgw:pool_rpc:{WORKER_ID} - for SSE transport JSON-RPC forwards 

1474 - mcpgw:pool_http:{WORKER_ID} - for Streamable HTTP request forwards 

1475 """ 

1476 if not settings.mcpgateway_session_affinity_enabled: 

1477 return 

1478 

1479 try: 

1480 # First-Party 

1481 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1482 

1483 redis = await get_redis_client() 

1484 if not redis: 

1485 logger.debug("Redis not available, RPC listener not started") 

1486 return 

1487 

1488 rpc_channel = f"mcpgw:pool_rpc:{WORKER_ID}" 

1489 http_channel = f"mcpgw:pool_http:{WORKER_ID}" 

1490 pubsub = redis.pubsub() 

1491 await pubsub.subscribe(rpc_channel, http_channel) 

1492 logger.info(f"RPC/HTTP listener started for worker {WORKER_ID} on channels: {rpc_channel}, {http_channel}") 

1493 

1494 while not self._closed: 

1495 try: 

1496 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) 

1497 if msg and msg["type"] == "message": 

1498 request = orjson.loads(msg["data"]) 

1499 forward_type = request.get("type") 

1500 response_channel = request.get("response_channel") 

1501 

1502 if response_channel: 

1503 if forward_type == "rpc_forward": 

1504 # Execute forwarded RPC request for SSE transport 

1505 response = await self._execute_forwarded_request(request) 

1506 await redis.publish(response_channel, orjson.dumps(response)) 

1507 logger.debug(f"Processed forwarded RPC request, response sent to {response_channel}") 

1508 elif forward_type == "http_forward": 

1509 # Execute forwarded HTTP request for Streamable HTTP transport 

1510 await self._execute_forwarded_http_request(request, redis) 

1511 else: 

1512 logger.warning(f"Unknown forward type: {forward_type}") 

1513 except asyncio.CancelledError: 

1514 break 

1515 except Exception as e: 

1516 logger.warning(f"Error processing forwarded request: {e}") 

1517 

1518 await pubsub.unsubscribe(rpc_channel, http_channel) 

1519 logger.info(f"RPC/HTTP listener stopped for worker {WORKER_ID}") 

1520 

1521 except Exception as e: 

1522 logger.warning(f"RPC/HTTP listener failed: {e}") 

1523 

1524 async def _execute_forwarded_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 

1525 """Execute a forwarded RPC request locally via internal HTTP call. 

1526 

1527 This method handles RPC requests that were forwarded from another worker. 

1528 Instead of handling specific methods here, we make an internal HTTP call 

1529 to the local /rpc endpoint which reuses ALL existing method handling logic. 

1530 

1531 The x-forwarded-internally header prevents infinite forwarding loops. 

1532 

1533 Args: 

1534 request: The forwarded RPC request containing method, params, headers, req_id, etc. 

1535 

1536 Returns: 

1537 The JSON-RPC response from the local endpoint. 

1538 """ 

1539 try: 

1540 method = request.get("method") 

1541 params = request.get("params", {}) 

1542 headers = request.get("headers", {}) 

1543 req_id = request.get("req_id", 1) 

1544 mcp_session_id = request.get("mcp_session_id", "unknown") 

1545 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

1546 

1547 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Received forwarded request, executing locally") 

1548 

1549 # Make internal HTTP call to local /rpc endpoint 

1550 # This reuses ALL existing method handling logic without duplication 

1551 async with httpx.AsyncClient() as client: 

1552 # Build headers for internal request - forward original headers 

1553 # but add x-forwarded-internally to prevent infinite loops 

1554 internal_headers = dict(headers) 

1555 internal_headers["x-forwarded-internally"] = "true" 

1556 # Ensure content-type is set 

1557 internal_headers["content-type"] = "application/json" 

1558 

1559 response = await client.post( 

1560 f"http://127.0.0.1:{settings.port}/rpc", 

1561 json={"jsonrpc": "2.0", "method": method, "params": params, "id": req_id}, 

1562 headers=internal_headers, 

1563 timeout=settings.mcpgateway_pool_rpc_forward_timeout, 

1564 ) 

1565 

1566 # Parse response 

1567 response_data = response.json() 

1568 

1569 # Extract result or error from JSON-RPC response 

1570 if "error" in response_data: 

1571 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error") 

1572 return {"error": response_data["error"]} 

1573 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed successfully") 

1574 return {"result": response_data.get("result", {})} 

1575 

1576 except httpx.TimeoutException: 

1577 logger.warning(f"Timeout executing forwarded request: {request.get('method')}") 

1578 return {"error": {"code": -32603, "message": "Internal request timeout"}} 

1579 except Exception as e: 

1580 logger.warning(f"Error executing forwarded request: {e}") 

1581 return {"error": {"code": -32603, "message": str(e)}} 

1582 

1583 async def _execute_forwarded_http_request(self, request: Dict[str, Any], redis: Any) -> None: 

1584 """Execute a forwarded HTTP request locally and return response via Redis. 

1585 

1586 This method handles full HTTP requests forwarded from other workers for 

1587 Streamable HTTP transport session affinity. It reconstructs the HTTP request, 

1588 makes an internal call to the appropriate endpoint, and publishes the response 

1589 back through Redis. 

1590 

1591 Args: 

1592 request: Serialized HTTP request data from Redis Pub/Sub containing: 

1593 - type: "http_forward" 

1594 - response_channel: Redis channel to publish response to 

1595 - mcp_session_id: Session identifier 

1596 - method: HTTP method (GET, POST, DELETE) 

1597 - path: Request path (e.g., /mcp) 

1598 - query_string: Query parameters 

1599 - headers: Request headers dict 

1600 - body: Hex-encoded request body 

1601 redis: Redis client for publishing response 

1602 """ 

1603 response_channel = None 

1604 try: 

1605 response_channel = request.get("response_channel") 

1606 method = request.get("method") 

1607 path = request.get("path") 

1608 query_string = request.get("query_string", "") 

1609 headers = request.get("headers", {}) 

1610 body_hex = request.get("body", "") 

1611 mcp_session_id = request.get("mcp_session_id") 

1612 

1613 # Decode hex body back to bytes 

1614 body = bytes.fromhex(body_hex) if body_hex else b"" 

1615 

1616 session_short = mcp_session_id[:8] if mcp_session_id and len(mcp_session_id) >= 8 else "unknown" 

1617 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Received forwarded HTTP request: {method} {path}") 

1618 

1619 # Add internal forwarding headers to prevent loops 

1620 internal_headers = dict(headers) 

1621 internal_headers["x-forwarded-internally"] = "true" 

1622 internal_headers["x-original-worker"] = request.get("original_worker", "unknown") 

1623 

1624 # Make internal HTTP request to local endpoint 

1625 url = f"http://127.0.0.1:{settings.port}{path}" 

1626 if query_string: 

1627 url = f"{url}?{query_string}" 

1628 

1629 async with httpx.AsyncClient() as client: 

1630 response = await client.request( 

1631 method=method, 

1632 url=url, 

1633 headers=internal_headers, 

1634 content=body, 

1635 timeout=settings.mcpgateway_pool_rpc_forward_timeout, 

1636 ) 

1637 

1638 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Executed locally: {response.status_code}") 

1639 

1640 # Serialize response for Redis transport 

1641 response_data = { 

1642 "status": response.status_code, 

1643 "headers": dict(response.headers), 

1644 "body": response.content.hex(), # Hex encode binary response 

1645 } 

1646 

1647 # Publish response back to requesting worker 

1648 if redis and response_channel: 

1649 await redis.publish(response_channel, orjson.dumps(response_data)) 

1650 logger.debug(f"[HTTP_AFFINITY] Published HTTP response to Redis channel: {response_channel}") 

1651 

1652 except Exception as e: 

1653 logger.error(f"Error executing forwarded HTTP request: {e}") 

1654 # Try to send error response if possible 

1655 if redis and response_channel: 

1656 error_response = { 

1657 "status": 500, 

1658 "headers": {"content-type": "application/json"}, 

1659 "body": orjson.dumps({"error": "Internal forwarding error"}).hex(), 

1660 } 

1661 try: 

1662 await redis.publish(response_channel, orjson.dumps(error_response)) 

1663 except Exception as publish_error: 

1664 logger.debug(f"Failed to publish error response via Redis: {publish_error}") 

1665 

1666 async def get_streamable_http_session_owner(self, mcp_session_id: str) -> Optional[str]: 

1667 """Get the worker ID that owns a Streamable HTTP session. 

1668 

1669 This is a public wrapper around _get_pool_session_owner for use by 

1670 streamablehttp_transport to check session ownership before handling requests. 

1671 

1672 Args: 

1673 mcp_session_id: The MCP session ID from mcp-session-id header. 

1674 

1675 Returns: 

1676 Worker ID if found, None otherwise. 

1677 """ 

1678 return await self._get_pool_session_owner(mcp_session_id) 

1679 

1680 async def forward_streamable_http_to_owner( 

1681 self, 

1682 owner_worker_id: str, 

1683 mcp_session_id: str, 

1684 method: str, 

1685 path: str, 

1686 headers: Dict[str, str], 

1687 body: bytes, 

1688 query_string: str = "", 

1689 ) -> Optional[Dict[str, Any]]: 

1690 """Forward a Streamable HTTP request to the worker that owns the session via Redis Pub/Sub. 

1691 

1692 This method forwards the entire HTTP request to another worker using Redis 

1693 Pub/Sub channels, similar to forward_request_to_owner() for SSE transport. 

1694 This ensures session affinity works correctly in single-host multi-worker 

1695 deployments where hostname-based routing fails. 

1696 

1697 Args: 

1698 owner_worker_id: The worker ID that owns the session. 

1699 mcp_session_id: The MCP session ID. 

1700 method: HTTP method (GET, POST, DELETE). 

1701 path: Request path (e.g., /mcp). 

1702 headers: Request headers. 

1703 body: Request body bytes. 

1704 query_string: Query string if any. 

1705 

1706 Returns: 

1707 Dict with 'status', 'headers', and 'body' from the owner worker's response, 

1708 or None if forwarding fails. 

1709 """ 

1710 if not settings.mcpgateway_session_affinity_enabled: 

1711 return None 

1712 

1713 if not self.is_valid_mcp_session_id(mcp_session_id): 

1714 return None 

1715 

1716 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

1717 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | {method} {path} | Forwarding to worker {owner_worker_id}") 

1718 

1719 try: 

1720 # First-Party 

1721 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1722 

1723 redis = await get_redis_client() 

1724 if not redis: 

1725 logger.warning("Redis unavailable for HTTP forwarding, executing locally") 

1726 return None # Fall back to local execution 

1727 

1728 # Generate unique response channel for this request 

1729 response_uuid = uuid.uuid4().hex 

1730 response_channel = f"mcpgw:pool_http_response:{response_uuid}" 

1731 

1732 # Serialize HTTP request for Redis transport 

1733 forward_data = { 

1734 "type": "http_forward", 

1735 "response_channel": response_channel, 

1736 "mcp_session_id": mcp_session_id, 

1737 "method": method, 

1738 "path": path, 

1739 "query_string": query_string, 

1740 "headers": headers, 

1741 "body": body.hex() if body else "", # Hex encode binary body 

1742 "original_worker": WORKER_ID, 

1743 "timestamp": time.time(), 

1744 } 

1745 

1746 # Subscribe to response channel BEFORE publishing request (prevent race) 

1747 pubsub = redis.pubsub() 

1748 await pubsub.subscribe(response_channel) 

1749 

1750 try: 

1751 # Publish forwarded request to owner worker's HTTP channel 

1752 owner_channel = f"mcpgw:pool_http:{owner_worker_id}" 

1753 await redis.publish(owner_channel, orjson.dumps(forward_data)) 

1754 logger.debug(f"[HTTP_AFFINITY] Published HTTP request to Redis channel: {owner_channel}") 

1755 

1756 # Wait for response with timeout 

1757 timeout = settings.mcpgateway_pool_rpc_forward_timeout 

1758 async with asyncio.timeout(timeout): 

1759 while True: 

1760 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) 

1761 if msg and msg["type"] == "message": 

1762 response_data = orjson.loads(msg["data"]) 

1763 logger.debug(f"[HTTP_AFFINITY] Received HTTP response via Redis: status={response_data.get('status')}") 

1764 

1765 # Decode hex body back to bytes 

1766 body_hex = response_data.get("body", "") 

1767 response_data["body"] = bytes.fromhex(body_hex) if body_hex else b"" 

1768 

1769 self._forwarded_requests += 1 

1770 return response_data 

1771 

1772 finally: 

1773 await pubsub.unsubscribe(response_channel) 

1774 

1775 except asyncio.TimeoutError: 

1776 self._forwarded_request_timeouts += 1 

1777 logger.warning(f"Timeout forwarding HTTP request to owner {owner_worker_id}") 

1778 return None 

1779 except Exception as e: 

1780 self._forwarded_request_failures += 1 

1781 logger.warning(f"Error forwarding HTTP request via Redis: {e}") 

1782 return None 

1783 

1784 def get_metrics(self) -> Dict[str, Any]: 

1785 """ 

1786 Return pool metrics for monitoring. 

1787 

1788 Returns: 

1789 Dict with hits, misses, evictions, hit_rate, and per-pool stats. 

1790 """ 

1791 total_requests = self._hits + self._misses 

1792 total_affinity_requests = self._session_affinity_local_hits + self._session_affinity_redis_hits + self._session_affinity_misses 

1793 return { 

1794 "hits": self._hits, 

1795 "misses": self._misses, 

1796 "evictions": self._evictions, 

1797 "health_check_failures": self._health_check_failures, 

1798 "circuit_breaker_trips": self._circuit_breaker_trips, 

1799 "pool_keys_evicted": self._pool_keys_evicted, 

1800 "sessions_reaped": self._sessions_reaped, 

1801 "anonymous_identity_count": self._anonymous_identity_count, 

1802 "hit_rate": self._hits / total_requests if total_requests > 0 else 0.0, 

1803 "pool_key_count": len(self._pools), 

1804 # Session affinity metrics 

1805 "session_affinity": { 

1806 "local_hits": self._session_affinity_local_hits, 

1807 "redis_hits": self._session_affinity_redis_hits, 

1808 "misses": self._session_affinity_misses, 

1809 "hit_rate": (self._session_affinity_local_hits + self._session_affinity_redis_hits) / total_affinity_requests if total_affinity_requests > 0 else 0.0, 

1810 "forwarded_requests": self._forwarded_requests, 

1811 "forwarded_failures": self._forwarded_request_failures, 

1812 "forwarded_timeouts": self._forwarded_request_timeouts, 

1813 }, 

1814 "pools": { 

1815 f"{url}|{identity[:8]}|{transport}|{user}|{gw_id[:8] if gw_id else 'none'}": { 

1816 "available": pool.qsize(), 

1817 "active": len(self._active.get((user, url, identity, transport, gw_id), set())), 

1818 "max": self._max_sessions, 

1819 } 

1820 for (user, url, identity, transport, gw_id), pool in self._pools.items() 

1821 }, 

1822 "circuit_breakers": { 

1823 url: { 

1824 "failures": self._failures.get(url, 0), 

1825 "open_until": self._circuit_open_until.get(url), 

1826 } 

1827 for url in set(self._failures.keys()) | set(self._circuit_open_until.keys()) 

1828 }, 

1829 } 

1830 

1831 @asynccontextmanager 

1832 async def session( 

1833 self, 

1834 url: str, 

1835 headers: Optional[Dict[str, str]] = None, 

1836 transport_type: TransportType = TransportType.STREAMABLE_HTTP, 

1837 httpx_client_factory: Optional[HttpxClientFactory] = None, 

1838 timeout: Optional[float] = None, 

1839 user_identity: Optional[str] = None, 

1840 gateway_id: Optional[str] = None, 

1841 ) -> "AsyncIterator[PooledSession]": 

1842 """ 

1843 Context manager for acquiring and releasing a session. 

1844 

1845 Usage: 

1846 async with pool.session(url, headers) as pooled: 

1847 result = await pooled.session.call_tool("my_tool", {}) 

1848 

1849 Args: 

1850 url: The MCP server URL. 

1851 headers: Request headers. 

1852 transport_type: Transport type to use. 

1853 httpx_client_factory: Optional factory for httpx clients. 

1854 timeout: Optional timeout in seconds for transport connection. 

1855 user_identity: Optional user identity for strict isolation. 

1856 gateway_id: Optional gateway ID for notification handler context. 

1857 

1858 Yields: 

1859 PooledSession ready for use. 

1860 """ 

1861 pooled = await self.acquire(url, headers, transport_type, httpx_client_factory, timeout, user_identity, gateway_id) 

1862 try: 

1863 yield pooled 

1864 finally: 

1865 await self.release(pooled) 

1866 

1867 

1868# Global pool instance - initialized by FastAPI lifespan 

1869_mcp_session_pool: Optional[MCPSessionPool] = None 

1870 

1871 

1872def get_mcp_session_pool() -> MCPSessionPool: 

1873 """Get the global MCP session pool instance. 

1874 

1875 Returns: 

1876 The global MCPSessionPool instance. 

1877 

1878 Raises: 

1879 RuntimeError: If pool has not been initialized. 

1880 """ 

1881 if _mcp_session_pool is None: 

1882 raise RuntimeError("MCP session pool not initialized. Call init_mcp_session_pool() first.") 

1883 return _mcp_session_pool 

1884 

1885 

1886def init_mcp_session_pool( 

1887 max_sessions_per_key: int = 10, 

1888 session_ttl_seconds: float = 300.0, 

1889 health_check_interval_seconds: float = 60.0, 

1890 acquire_timeout_seconds: float = 30.0, 

1891 session_create_timeout_seconds: float = 30.0, 

1892 circuit_breaker_threshold: int = 5, 

1893 circuit_breaker_reset_seconds: float = 60.0, 

1894 identity_headers: Optional[frozenset[str]] = None, 

1895 identity_extractor: Optional[IdentityExtractor] = None, 

1896 idle_pool_eviction_seconds: float = 600.0, 

1897 default_transport_timeout_seconds: float = 30.0, 

1898 health_check_methods: Optional[list[str]] = None, 

1899 health_check_timeout_seconds: float = 5.0, 

1900 message_handler_factory: Optional[MessageHandlerFactory] = None, 

1901 enable_notifications: bool = True, 

1902 notification_debounce_seconds: float = 5.0, 

1903) -> MCPSessionPool: 

1904 """Initialize the global MCP session pool. 

1905 

1906 Args: 

1907 See MCPSessionPool.__init__ for argument descriptions. 

1908 enable_notifications: Enable automatic notification service for list_changed events. 

1909 notification_debounce_seconds: Debounce interval for notification-triggered refreshes. 

1910 

1911 Returns: 

1912 The initialized MCPSessionPool instance. 

1913 """ 

1914 global _mcp_session_pool # pylint: disable=global-statement 

1915 

1916 # Auto-create notification service if enabled and no custom handler provided 

1917 effective_handler_factory = message_handler_factory 

1918 if enable_notifications and message_handler_factory is None: 

1919 # First-Party 

1920 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

1921 init_notification_service, 

1922 ) 

1923 

1924 # Initialize notification service (will be started during acquire with gateway context) 

1925 notification_svc = init_notification_service(debounce_seconds=notification_debounce_seconds) 

1926 

1927 # Create default handler factory that uses notification service 

1928 def default_handler_factory(url: str, gateway_id: Optional[str]): 

1929 """Create a message handler for MCP session notifications. 

1930 

1931 Args: 

1932 url: The MCP server URL for the session. 

1933 gateway_id: Optional gateway ID for attribution, falls back to URL if not provided. 

1934 

1935 Returns: 

1936 A message handler that forwards notifications to the notification service. 

1937 """ 

1938 return notification_svc.create_message_handler(gateway_id or url, url) 

1939 

1940 effective_handler_factory = default_handler_factory 

1941 logger.info("MCP notification service created (debounce=%ss)", notification_debounce_seconds) 

1942 

1943 _mcp_session_pool = MCPSessionPool( 

1944 max_sessions_per_key=max_sessions_per_key, 

1945 session_ttl_seconds=session_ttl_seconds, 

1946 health_check_interval_seconds=health_check_interval_seconds, 

1947 acquire_timeout_seconds=acquire_timeout_seconds, 

1948 session_create_timeout_seconds=session_create_timeout_seconds, 

1949 circuit_breaker_threshold=circuit_breaker_threshold, 

1950 circuit_breaker_reset_seconds=circuit_breaker_reset_seconds, 

1951 identity_headers=identity_headers, 

1952 identity_extractor=identity_extractor, 

1953 idle_pool_eviction_seconds=idle_pool_eviction_seconds, 

1954 default_transport_timeout_seconds=default_transport_timeout_seconds, 

1955 health_check_methods=health_check_methods, 

1956 health_check_timeout_seconds=health_check_timeout_seconds, 

1957 message_handler_factory=effective_handler_factory, 

1958 ) 

1959 logger.info("MCP session pool initialized") 

1960 return _mcp_session_pool 

1961 

1962 

1963async def close_mcp_session_pool() -> None: 

1964 """Close the global MCP session pool and notification service.""" 

1965 global _mcp_session_pool # pylint: disable=global-statement 

1966 if _mcp_session_pool is not None: 

1967 await _mcp_session_pool.close_all() 

1968 _mcp_session_pool = None 

1969 logger.info("MCP session pool closed") 

1970 

1971 # Close notification service if it was initialized 

1972 try: 

1973 # First-Party 

1974 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

1975 close_notification_service, 

1976 ) 

1977 

1978 await close_notification_service() 

1979 except (ImportError, RuntimeError): 

1980 pass # Notification service not initialized 

1981 

1982 

1983async def start_pool_notification_service(gateway_service: Any = None) -> None: 

1984 """Start the notification service background worker. 

1985 

1986 Call this after gateway_service is initialized to enable event-driven refresh. 

1987 

1988 Args: 

1989 gateway_service: Optional GatewayService instance for triggering refreshes. 

1990 """ 

1991 try: 

1992 # First-Party 

1993 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

1994 get_notification_service, 

1995 ) 

1996 

1997 notification_svc = get_notification_service() 

1998 await notification_svc.initialize(gateway_service) 

1999 logger.info("MCP notification service started") 

2000 except RuntimeError: 

2001 logger.debug("Notification service not configured, skipping start") 

2002 

2003 

2004def register_gateway_capabilities_for_notifications(gateway_id: str, capabilities: Dict[str, Any]) -> None: 

2005 """Register gateway capabilities for notification handling. 

2006 

2007 Call this after gateway initialization to enable list_changed notifications. 

2008 

2009 Args: 

2010 gateway_id: The gateway ID. 

2011 capabilities: Server capabilities from initialization response. 

2012 """ 

2013 try: 

2014 # First-Party 

2015 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2016 get_notification_service, 

2017 ) 

2018 

2019 notification_svc = get_notification_service() 

2020 notification_svc.register_gateway_capabilities(gateway_id, capabilities) 

2021 except RuntimeError: 

2022 pass # Notification service not initialized 

2023 

2024 

2025def unregister_gateway_from_notifications(gateway_id: str) -> None: 

2026 """Unregister a gateway from notification handling. 

2027 

2028 Call this when a gateway is deleted. 

2029 

2030 Args: 

2031 gateway_id: The gateway ID to unregister. 

2032 """ 

2033 try: 

2034 # First-Party 

2035 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2036 get_notification_service, 

2037 ) 

2038 

2039 notification_svc = get_notification_service() 

2040 notification_svc.unregister_gateway(gateway_id) 

2041 except RuntimeError: 

2042 pass # Notification service not initialized