Coverage for mcpgateway / services / mcp_session_pool.py: 100%

840 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2""" 

3MCP Session Pool Implementation. 

4 

5Provides session pooling for MCP ClientSessions to reduce per-request overhead. 

6Sessions are isolated per user/tenant via identity hashing to prevent session collision. 

7 

8Performance Impact: 

9 - Without pooling: 20-23ms per tool call (new session each time) 

10 - With pooling: 1-2ms per tool call (10-20x improvement) 

11 

12Security: 

13 Sessions are isolated by (url, identity_hash, transport_type) to prevent: 

14 - Cross-user session sharing 

15 - Cross-tenant data leakage 

16 - Authentication bypass 

17 

18Copyright 2026 

19SPDX-License-Identifier: Apache-2.0 

20Authors: Mihai Criveti 

21""" 

22 

23# flake8: noqa: DAR101, DAR201, DAR401 

24 

25# Future 

26from __future__ import annotations 

27 

28# Standard 

29import asyncio 

30from contextlib import asynccontextmanager 

31from dataclasses import dataclass, field 

32from enum import Enum 

33import hashlib 

34import logging 

35import os 

36import re 

37import socket 

38import time 

39from typing import Any, Callable, Dict, Optional, Set, Tuple, TYPE_CHECKING 

40import uuid 

41 

42# Third-Party 

43import anyio 

44import httpx 

45from mcp import ClientSession, McpError 

46from mcp.client.sse import sse_client 

47from mcp.client.streamable_http import streamablehttp_client 

48from mcp.shared.session import RequestResponder 

49import mcp.types as mcp_types 

50import orjson 

51 

52# First-Party 

53from mcpgateway.config import settings 

54from mcpgateway.utils.url_auth import sanitize_url_for_logging 

55 

56# JSON-RPC standard error code for method not found 

57METHOD_NOT_FOUND = -32601 

58 

59# Shared session-id validation (downstream MCP session IDs used for affinity). 

60# Intentionally strict: protects Redis key/channel construction and log lines. 

61_MCP_SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") 

62 

63# Worker ID for multi-worker session affinity 

64# Uses hostname + PID to be unique across Docker containers (each container has PID 1) 

65# and across gunicorn workers within the same container 

66WORKER_ID = f"{socket.gethostname()}:{os.getpid()}" 

67 

68 

69def _get_cleanup_timeout() -> float: 

70 """Get session cleanup timeout from config (lazy import to avoid circular deps). 

71 

72 This timeout controls how long to wait for session/transport __aexit__ calls 

73 when closing sessions. It prevents CPU spin loops when internal tasks don't 

74 respond to cancellation (anyio's _deliver_cancellation issue). 

75 

76 Returns: 

77 Cleanup timeout in seconds (default: 5.0) 

78 """ 

79 try: 

80 # Lazy import to avoid circular dependency during startup 

81 return settings.mcp_session_pool_cleanup_timeout 

82 except Exception: 

83 return 5.0 # Fallback default 

84 

85 

86if TYPE_CHECKING: 

87 # Standard 

88 from collections.abc import AsyncIterator # pragma: no cover 

89 

90logger = logging.getLogger(__name__) 

91 

92 

93class TransportType(Enum): 

94 """Supported MCP transport types.""" 

95 

96 SSE = "sse" 

97 STREAMABLE_HTTP = "streamablehttp" 

98 

99 

100@dataclass(eq=False) # eq=False makes instances hashable by object identity 

101class PooledSession: 

102 """A pooled MCP session with metadata for lifecycle management. 

103 

104 Note: eq=False is required because we store these in Sets for active session 

105 tracking. This makes instances hashable by their object id (identity). 

106 """ 

107 

108 session: ClientSession 

109 transport_context: Any # The transport context manager (kept open) 

110 url: str 

111 transport_type: TransportType 

112 headers: Dict[str, str] # Original headers (for reconnection) 

113 identity_key: str # Identity hash component for headers 

114 user_identity: str = "anonymous" # for user isolation 

115 gateway_id: str = "" # Gateway ID for notification attribution 

116 created_at: float = field(default_factory=time.time) 

117 last_used: float = field(default_factory=time.time) 

118 use_count: int = 0 

119 _closed: bool = field(default=False, repr=False) 

120 

121 @property 

122 def age_seconds(self) -> float: 

123 """Return session age in seconds. 

124 

125 Returns: 

126 float: Session age in seconds since creation. 

127 """ 

128 return time.time() - self.created_at 

129 

130 @property 

131 def idle_seconds(self) -> float: 

132 """Return seconds since last use. 

133 

134 Returns: 

135 float: Seconds since last use of this session. 

136 """ 

137 return time.time() - self.last_used 

138 

139 @property 

140 def is_closed(self) -> bool: 

141 """Return whether this session has been closed. 

142 

143 Returns: 

144 bool: True if session is closed, False otherwise. 

145 """ 

146 return self._closed 

147 

148 def mark_closed(self) -> None: 

149 """Mark this session as closed.""" 

150 self._closed = True 

151 

152 

153# Type aliases 

154# Pool key includes transport type and gateway_id to prevent returning wrong transport for same URL 

155# and to ensure correct notification attribution when notifications are enabled 

156PoolKey = Tuple[str, str, str, str, str] # (user_identity_hash, url, identity_hash, transport_type, gateway_id) 

157 

158# Session affinity mapping key: (mcp_session_id, url, transport_type, gateway_id) 

159SessionMappingKey = Tuple[str, str, str, str] 

160HttpxClientFactory = Callable[ 

161 [Optional[Dict[str, str]], Optional[httpx.Timeout], Optional[httpx.Auth]], 

162 httpx.AsyncClient, 

163] 

164 

165 

166# Type alias for identity extractor callback 

167# Extracts stable identity from headers (e.g., decode JWT to get user_id) 

168IdentityExtractor = Callable[[Dict[str, str]], Optional[str]] 

169 

170# Type alias for message handler factory 

171# Factory that creates message handlers given URL and optional gateway_id 

172# The handler receives ServerNotification, ServerRequest responders, or Exceptions 

173MessageHandlerFactory = Callable[ 

174 [str, Optional[str]], # (url, gateway_id) 

175 Callable[ 

176 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception], 

177 Any, # Coroutine 

178 ], 

179] 

180 

181 

182class MCPSessionPool: # pylint: disable=too-many-instance-attributes 

183 """ 

184 Pool of MCP ClientSessions keyed by (user_identity, server URL, identity hash, transport type, gateway_id). 

185 

186 Thread-Safety: 

187 This pool is designed for asyncio concurrency. It uses asyncio.Lock 

188 for synchronization, which is safe for coroutine-based concurrency 

189 but NOT for multi-threaded access. 

190 

191 Session Isolation: 

192 Sessions are isolated per user/tenant to prevent session collision. 

193 The identity hash is derived from authentication headers ensuring 

194 that different users never share MCP sessions. 

195 

196 Transport Isolation: 

197 Sessions are also isolated by transport type (SSE vs STREAMABLE_HTTP). 

198 The same URL with different transports will use separate pools. 

199 

200 Gateway Isolation: 

201 Sessions are isolated by gateway_id for correct notification attribution. 

202 When notifications are enabled, each gateway gets its own pooled sessions 

203 even if they share the same URL and authentication. 

204 

205 Features: 

206 - Session reuse across requests (10-20x latency improvement) 

207 - Per-user/tenant session isolation (prevents session collision) 

208 - Per-transport session isolation (prevents transport mismatch) 

209 - TTL-based expiration with configurable lifetime 

210 - Health checks on acquire for stale sessions 

211 - Configurable pool size per URL+identity+transport 

212 - Circuit breaker for failing endpoints 

213 - Idle pool key eviction to prevent unbounded growth 

214 - Custom identity extractor for rotating tokens (e.g., JWT decode) 

215 - Metrics for monitoring (hits, misses, evictions) 

216 - Graceful shutdown with close_all() 

217 

218 Usage: 

219 pool = MCPSessionPool() 

220 

221 # Use as context manager for lifecycle management 

222 async with pool: 

223 pooled = await pool.acquire(url, headers) 

224 try: 

225 result = await pooled.session.call_tool("my_tool", {}) 

226 finally: 

227 await pool.release(pooled) 

228 

229 # With custom identity extractor for JWT tokens: 

230 def extract_user_id(headers: dict) -> str: 

231 token = headers.get("Authorization", "").replace("Bearer ", "") 

232 claims = jwt.decode(token, options={"verify_signature": False}) 

233 return claims.get("sub") or claims.get("user_id") 

234 

235 pool = MCPSessionPool(identity_extractor=extract_user_id) 

236 """ 

237 

238 # Headers that contribute to session identity (case-insensitive) 

239 DEFAULT_IDENTITY_HEADERS: frozenset[str] = frozenset( 

240 [ 

241 "authorization", 

242 "x-tenant-id", 

243 "x-user-id", 

244 "x-api-key", 

245 "cookie", 

246 "x-mcp-session-id", 

247 ] 

248 ) 

249 

250 def __init__( 

251 self, 

252 max_sessions_per_key: int = 10, 

253 session_ttl_seconds: float = 300.0, 

254 health_check_interval_seconds: float = 60.0, 

255 acquire_timeout_seconds: float = 30.0, 

256 session_create_timeout_seconds: float = 30.0, 

257 circuit_breaker_threshold: int = 5, 

258 circuit_breaker_reset_seconds: float = 60.0, 

259 identity_headers: Optional[frozenset[str]] = None, 

260 identity_extractor: Optional[IdentityExtractor] = None, 

261 idle_pool_eviction_seconds: float = 600.0, 

262 default_transport_timeout_seconds: float = 30.0, 

263 health_check_methods: Optional[list[str]] = None, 

264 health_check_timeout_seconds: float = 5.0, 

265 message_handler_factory: Optional[MessageHandlerFactory] = None, 

266 ): 

267 """ 

268 Initialize the session pool. 

269 

270 Args: 

271 max_sessions_per_key: Maximum pooled sessions per (URL, identity, transport). 

272 session_ttl_seconds: Session TTL in seconds before forced expiration. 

273 health_check_interval_seconds: Seconds of idle time before health check. 

274 acquire_timeout_seconds: Timeout for waiting when pool is exhausted. 

275 session_create_timeout_seconds: Timeout for creating new sessions. 

276 circuit_breaker_threshold: Consecutive failures before circuit opens. 

277 circuit_breaker_reset_seconds: Seconds before circuit breaker resets. 

278 identity_headers: Headers that contribute to identity hash. 

279 identity_extractor: Optional callback to extract stable identity from headers. 

280 Use this when tokens rotate frequently (e.g., short-lived JWTs). 

281 Should return a stable user/tenant ID string. 

282 idle_pool_eviction_seconds: Evict empty pool keys after this many seconds of no use. 

283 default_transport_timeout_seconds: Default timeout for transport connections. 

284 health_check_methods: Ordered list of health check methods to try. 

285 Options: ping, list_tools, list_prompts, list_resources, skip. 

286 Default: ["ping", "skip"] (try ping, skip if unsupported). 

287 health_check_timeout_seconds: Timeout for each health check attempt. 

288 message_handler_factory: Optional factory for creating message handlers. 

289 Called with (url, gateway_id) to create handlers for 

290 each new session. Enables notification handling. 

291 """ 

292 # Configuration 

293 self._max_sessions = max_sessions_per_key 

294 self._session_ttl = session_ttl_seconds 

295 self._health_check_interval = health_check_interval_seconds 

296 self._acquire_timeout = acquire_timeout_seconds 

297 self._session_create_timeout = session_create_timeout_seconds 

298 self._circuit_breaker_threshold = circuit_breaker_threshold 

299 self._circuit_breaker_reset = circuit_breaker_reset_seconds 

300 self._identity_headers = identity_headers or self.DEFAULT_IDENTITY_HEADERS 

301 self._identity_extractor = identity_extractor 

302 self._idle_pool_eviction = idle_pool_eviction_seconds 

303 self._default_transport_timeout = default_transport_timeout_seconds 

304 self._health_check_methods = health_check_methods or ["ping", "skip"] 

305 self._health_check_timeout = health_check_timeout_seconds 

306 self._message_handler_factory = message_handler_factory 

307 

308 # State - protected by _global_lock for creation, per-key locks for access 

309 self._global_lock = asyncio.Lock() 

310 self._pools: Dict[PoolKey, asyncio.Queue[PooledSession]] = {} 

311 self._active: Dict[PoolKey, Set[PooledSession]] = {} 

312 self._locks: Dict[PoolKey, asyncio.Lock] = {} 

313 self._semaphores: Dict[PoolKey, asyncio.Semaphore] = {} 

314 self._pool_last_used: Dict[PoolKey, float] = {} # Track last use time per pool key 

315 

316 # Circuit breaker state 

317 self._failures: Dict[str, int] = {} # url -> consecutive failure count 

318 self._circuit_open_until: Dict[str, float] = {} # url -> timestamp 

319 

320 # Eviction throttling - only run eviction once per interval 

321 self._last_eviction_run: float = 0.0 

322 self._eviction_run_interval: float = 60.0 # Run eviction at most every 60 seconds 

323 

324 # Metrics 

325 self._hits = 0 

326 self._misses = 0 

327 self._evictions = 0 

328 self._health_check_failures = 0 

329 self._circuit_breaker_trips = 0 

330 self._pool_keys_evicted = 0 

331 self._sessions_reaped = 0 # Sessions closed during background eviction 

332 self._anonymous_identity_count = 0 # Count of requests with no identity headers 

333 

334 # Lifecycle 

335 self._closed = False 

336 

337 # Pre-registered session mappings for session affinity 

338 # Mapping from (mcp_session_id, url, transport_type, gateway_id) -> pool_key 

339 # Set by broadcast() before acquire() is called to enable session affinity lookup 

340 self._mcp_session_mapping: Dict[SessionMappingKey, PoolKey] = {} 

341 self._mcp_session_mapping_lock = asyncio.Lock() 

342 

343 # Multi-worker session affinity via Redis pub/sub 

344 # Track pending responses for forwarded RPC requests 

345 self._rpc_listener_task: Optional[asyncio.Task[None]] = None 

346 self._pending_responses: Dict[str, asyncio.Future[Dict[str, Any]]] = {} 

347 

348 # Session affinity metrics 

349 self._session_affinity_local_hits = 0 

350 self._session_affinity_redis_hits = 0 

351 self._session_affinity_misses = 0 

352 self._forwarded_requests = 0 

353 self._forwarded_request_failures = 0 

354 self._forwarded_request_timeouts = 0 

355 

356 async def __aenter__(self) -> "MCPSessionPool": 

357 """Async context manager entry. 

358 

359 Returns: 

360 MCPSessionPool: This pool instance. 

361 """ 

362 return self 

363 

364 async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

365 """Async context manager exit - closes all sessions. 

366 

367 Args: 

368 exc_type: Exception type if an exception was raised. 

369 exc_val: Exception value if an exception was raised. 

370 exc_tb: Exception traceback if an exception was raised. 

371 """ 

372 await self.close_all() 

373 

374 def _compute_identity_hash(self, headers: Optional[Dict[str, str]]) -> str: 

375 """ 

376 Compute a hash of identity-relevant headers. 

377 

378 This ensures sessions are isolated per user/tenant. Different users 

379 with different Authorization headers will have different identity hashes 

380 and thus separate session pools. 

381 

382 Identity resolution order: 

383 1. Custom identity_extractor (if configured) - for rotating tokens like JWTs 

384 2. x-mcp-session-id header (if present) - for session affinity, ensures 

385 requests with the same downstream session ID get the same upstream 

386 session even when JWT tokens rotate (different jti values) 

387 3. Configured identity headers - fallback to hashing all identity headers 

388 

389 Args: 

390 headers: Request headers dict. 

391 

392 Returns: 

393 Identity hash string, or "anonymous" if no identity headers present. 

394 """ 

395 if not headers: 

396 self._anonymous_identity_count += 1 

397 logger.debug("Session pool identity collapsed to 'anonymous' (no headers provided). " + "Sessions will be shared. Ensure this is intentional for stateless MCP servers.") 

398 return "anonymous" 

399 

400 # Try custom identity extractor first (for rotating tokens like JWTs) 

401 if self._identity_extractor: 

402 try: 

403 extracted = self._identity_extractor(headers) 

404 if extracted: 

405 return hashlib.sha256(extracted.encode()).hexdigest() 

406 except Exception as e: 

407 logger.debug(f"Identity extractor failed, falling back to header hash: {e}") 

408 

409 # Normalize headers for case-insensitive lookup 

410 headers_lower = {k.lower(): v for k, v in headers.items()} 

411 

412 # Session affinity: prioritize x-mcp-session-id for stable identity 

413 # When present, use ONLY the session ID for identity hash. This ensures 

414 # requests with the same downstream session ID get the same upstream session, 

415 # even when JWT tokens rotate (different jti values per request). 

416 if settings.mcpgateway_session_affinity_enabled: 

417 session_id = headers_lower.get("x-mcp-session-id") 

418 if session_id: 

419 logger.debug(f"Using x-mcp-session-id for session affinity: {session_id[:8]}...") 

420 return hashlib.sha256(session_id.encode()).hexdigest() 

421 

422 # Fallback: extract identity from configured headers 

423 identity_parts = [] 

424 

425 for header in sorted(self._identity_headers): 

426 if header in headers_lower: 

427 identity_parts.append(f"{header}:{headers_lower[header]}") 

428 

429 if not identity_parts: 

430 self._anonymous_identity_count += 1 

431 logger.debug( 

432 "Session pool identity collapsed to 'anonymous' (no identity headers found). " + "Expected headers: %s. Sessions will be shared.", 

433 list(self._identity_headers), 

434 ) 

435 return "anonymous" 

436 

437 # Create a stable, deterministic hash using JSON serialization 

438 # Prevents delimiter-collision or injection issues present in string joining 

439 serialized_identity = orjson.dumps(identity_parts) 

440 return hashlib.sha256(serialized_identity).hexdigest() 

441 

442 def _make_pool_key( 

443 self, 

444 url: str, 

445 headers: Optional[Dict[str, str]], 

446 transport_type: TransportType, 

447 user_identity: str, 

448 gateway_id: Optional[str] = None, 

449 ) -> PoolKey: 

450 """Create composite pool key from URL, identity, transport type, user identity, and gateway_id. 

451 

452 Including gateway_id ensures correct notification attribution when multiple gateways 

453 share the same URL/auth. Sessions are isolated per gateway for proper event routing. 

454 """ 

455 identity_hash = self._compute_identity_hash(headers) 

456 

457 # Anonymize user identity by hashing it (unless it's commonly "anonymous") 

458 # Use full hash for collision resistance - truncate only for display in logs/metrics 

459 if user_identity == "anonymous": 

460 user_hash = "anonymous" 

461 else: 

462 user_hash = hashlib.sha256(user_identity.encode()).hexdigest() 

463 

464 # Use empty string for None gateway_id to maintain consistent key type 

465 gw_id = gateway_id or "" 

466 

467 return (user_hash, url, identity_hash, transport_type.value, gw_id) 

468 

469 async def _get_or_create_lock(self, pool_key: PoolKey) -> asyncio.Lock: 

470 """Get or create a lock for the given pool key (thread-safe).""" 

471 async with self._global_lock: 

472 if pool_key not in self._locks: 

473 self._locks[pool_key] = asyncio.Lock() 

474 return self._locks[pool_key] 

475 

476 async def _get_or_create_pool(self, pool_key: PoolKey) -> asyncio.Queue[PooledSession]: 

477 """Get or create a pool queue for the given key (thread-safe).""" 

478 async with self._global_lock: 

479 if pool_key not in self._pools: 

480 self._pools[pool_key] = asyncio.Queue(maxsize=self._max_sessions) 

481 self._active[pool_key] = set() 

482 self._semaphores[pool_key] = asyncio.Semaphore(self._max_sessions) 

483 return self._pools[pool_key] 

484 

485 def _is_circuit_open(self, url: str) -> bool: 

486 """Check if circuit breaker is open for a URL.""" 

487 if url not in self._circuit_open_until: 

488 return False 

489 if time.time() >= self._circuit_open_until[url]: 

490 # Circuit breaker reset 

491 del self._circuit_open_until[url] 

492 self._failures[url] = 0 

493 logger.info(f"Circuit breaker reset for {sanitize_url_for_logging(url)}") 

494 return False 

495 return True 

496 

497 def _record_failure(self, url: str) -> None: 

498 """Record a failure and potentially trip circuit breaker.""" 

499 self._failures[url] = self._failures.get(url, 0) + 1 

500 if self._failures[url] >= self._circuit_breaker_threshold: 

501 self._circuit_open_until[url] = time.time() + self._circuit_breaker_reset 

502 self._circuit_breaker_trips += 1 

503 logger.warning(f"Circuit breaker opened for {sanitize_url_for_logging(url)} after {self._failures[url]} failures. " f"Will reset in {self._circuit_breaker_reset}s") 

504 

505 def _record_success(self, url: str) -> None: 

506 """Record a success, resetting failure count.""" 

507 self._failures[url] = 0 

508 

509 @staticmethod 

510 def is_valid_mcp_session_id(session_id: str) -> bool: 

511 """Validate downstream MCP session ID format for affinity. 

512 

513 Used for: 

514 - Redis key construction (ownership + mapping) 

515 - Pub/Sub channel naming 

516 - Avoiding log spam / injection 

517 """ 

518 if not session_id: 

519 return False 

520 return bool(_MCP_SESSION_ID_PATTERN.match(session_id)) 

521 

522 def _sanitize_redis_key_component(self, value: str) -> str: 

523 """Sanitize a value for use in Redis key construction. 

524 

525 Replaces any characters that could cause key collision or injection. 

526 

527 Args: 

528 value: The value to sanitize. 

529 

530 Returns: 

531 Sanitized value safe for Redis key construction. 

532 """ 

533 if not value: 

534 return "" 

535 

536 # Replace problematic characters with underscores 

537 return re.sub(r"[^a-zA-Z0-9_-]", "_", value) 

538 

539 def _session_mapping_redis_key(self, mcp_session_id: str, url: str, transport_type: str, gateway_id: str) -> str: 

540 """Compute a bounded Redis key for session mapping. 

541 

542 The URL is hashed to keep keys small and avoid special character issues. 

543 """ 

544 sanitized_session_id = self._sanitize_redis_key_component(mcp_session_id) 

545 url_hash = hashlib.sha256(url.encode()).hexdigest()[:16] 

546 return f"mcpgw:session_mapping:{sanitized_session_id}:{url_hash}:{transport_type}:{gateway_id}" 

547 

548 @staticmethod 

549 def _pool_owner_key(mcp_session_id: str) -> str: 

550 """Return Redis key for session ownership tracking.""" 

551 return f"mcpgw:pool_owner:{mcp_session_id}" 

552 

553 async def register_session_mapping( 

554 self, 

555 mcp_session_id: str, 

556 url: str, 

557 gateway_id: str, 

558 transport_type: str, 

559 user_email: Optional[str] = None, 

560 ) -> None: 

561 """Pre-register session mapping for session affinity. 

562 

563 Called from respond() to set up mapping BEFORE acquire() is called. 

564 This ensures acquire() can find the correct pool key for session affinity. 

565 

566 The mapping stores the relationship between an incoming MCP session ID 

567 and the pool key that should be used for upstream connections. This 

568 enables session affinity even when JWT tokens rotate (different jti values 

569 per request). 

570 

571 For multi-worker deployments, the mapping is also stored in Redis with TTL 

572 so that any worker can look it up during acquire(). 

573 

574 Args: 

575 mcp_session_id: The downstream MCP session ID from x-mcp-session-id header. 

576 url: The upstream MCP server URL. 

577 gateway_id: The gateway ID. 

578 transport_type: The transport type (sse, streamablehttp). 

579 user_email: The email of the authenticated user (or "system" for unauthenticated). 

580 """ 

581 if not settings.mcpgateway_session_affinity_enabled: 

582 return 

583 

584 # Validate mcp_session_id to prevent Redis key injection 

585 if not self.is_valid_mcp_session_id(mcp_session_id): 

586 logger.warning(f"Invalid mcp_session_id format, skipping session mapping: {mcp_session_id[:20]}...") 

587 return 

588 

589 # Use user email for user_identity, or "anonymous" if not provided 

590 user_identity = user_email or "anonymous" 

591 

592 # Normalize gateway_id to empty string if None for consistent key matching 

593 normalized_gateway_id = gateway_id or "" 

594 

595 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type, normalized_gateway_id) 

596 

597 # Compute what the pool_key will be for this session 

598 # Use mcp_session_id as the identity basis for affinity 

599 identity_hash = hashlib.sha256(mcp_session_id.encode()).hexdigest() 

600 

601 # Hash user identity for privacy (unless it's "anonymous") 

602 if user_identity == "anonymous": 

603 user_hash = "anonymous" 

604 else: 

605 user_hash = hashlib.sha256(user_identity.encode()).hexdigest() 

606 

607 pool_key: PoolKey = (user_hash, url, identity_hash, transport_type, normalized_gateway_id) 

608 

609 # Store in local memory 

610 async with self._mcp_session_mapping_lock: 

611 self._mcp_session_mapping[mapping_key] = pool_key 

612 logger.debug(f"Session affinity pre-registered (local): {mcp_session_id[:8]}... → {url}, user={user_identity}") 

613 

614 # Store in Redis for multi-worker support AND register ownership atomically 

615 # Registering ownership HERE (during mapping) instead of in acquire() prevents 

616 # a race condition where two workers could both start creating sessions before 

617 # either registers ownership 

618 try: 

619 # First-Party 

620 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

621 

622 redis = await get_redis_client() 

623 if redis: 

624 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type, normalized_gateway_id) 

625 

626 # Store pool_key as JSON for easy deserialization 

627 pool_key_data = { 

628 "user_hash": user_hash, 

629 "url": url, 

630 "identity_hash": identity_hash, 

631 "transport_type": transport_type, 

632 "gateway_id": normalized_gateway_id, 

633 } 

634 await redis.setex(redis_key, settings.mcpgateway_session_affinity_ttl, orjson.dumps(pool_key_data)) # TTL from config 

635 

636 # CRITICAL: Register ownership atomically with mapping. 

637 # This claims ownership BEFORE any session creation attempt, preventing 

638 # the race condition where two workers both start creating sessions 

639 owner_key = self._pool_owner_key(mcp_session_id) 

640 # Atomic claim with TTL (avoids the SETNX/EXPIRE crash window). 

641 was_set = await redis.set(owner_key, WORKER_ID, nx=True, ex=settings.mcpgateway_session_affinity_ttl) 

642 if was_set: 

643 logger.debug(f"Session ownership claimed (SET NX): {mcp_session_id[:8]}... → worker {WORKER_ID}") 

644 else: 

645 # Another worker already claimed ownership 

646 existing_owner = await redis.get(owner_key) 

647 owner_id = existing_owner.decode() if isinstance(existing_owner, bytes) else existing_owner 

648 logger.debug(f"Session ownership already claimed by {owner_id}: {mcp_session_id[:8]}...") 

649 

650 logger.debug(f"Session affinity pre-registered (Redis): {mcp_session_id[:8]}... TTL={settings.mcpgateway_session_affinity_ttl}s") 

651 except Exception as e: 

652 # Redis failure is non-fatal - local mapping still works for same-worker requests 

653 logger.debug(f"Failed to store session mapping in Redis: {e}") 

654 

655 async def acquire( 

656 self, 

657 url: str, 

658 headers: Optional[Dict[str, str]] = None, 

659 transport_type: TransportType = TransportType.STREAMABLE_HTTP, 

660 httpx_client_factory: Optional[HttpxClientFactory] = None, 

661 timeout: Optional[float] = None, 

662 user_identity: Optional[str] = None, 

663 gateway_id: Optional[str] = None, 

664 ) -> PooledSession: 

665 """ 

666 Acquire a session for the given URL, identity, and transport type. 

667 

668 Sessions are isolated by identity (derived from auth headers) AND 

669 transport type. Returns an initialized, healthy session ready for tool calls. 

670 

671 Args: 

672 url: The MCP server URL. 

673 headers: Request headers (used for identity hashing and passed to server). 

674 transport_type: The transport type (SSE or STREAMABLE_HTTP). 

675 httpx_client_factory: Optional factory for creating httpx clients 

676 (for custom SSL/timeout configuration). 

677 timeout: Optional timeout in seconds for transport connection. 

678 gateway_id: Optional gateway ID for notification handler context. 

679 

680 Returns: 

681 PooledSession ready for use. 

682 

683 Raises: 

684 asyncio.TimeoutError: If acquire times out waiting for available session. 

685 RuntimeError: If pool is closed or circuit breaker is open. 

686 Exception: If session creation fails. 

687 """ 

688 if self._closed: 

689 raise RuntimeError("Session pool is closed") 

690 

691 if self._is_circuit_open(url): 

692 raise RuntimeError(f"Circuit breaker open for {url}") 

693 

694 # Use default timeout if not provided 

695 effective_timeout = timeout if timeout is not None else self._default_transport_timeout 

696 

697 user_id = user_identity or "anonymous" 

698 pool_key: Optional[PoolKey] = None 

699 

700 # Check pre-registered mapping first (set by respond() for session affinity) 

701 if settings.mcpgateway_session_affinity_enabled and headers: 

702 headers_lower = {k.lower(): v for k, v in headers.items()} 

703 mcp_session_id = headers_lower.get("x-mcp-session-id") 

704 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

705 normalized_gateway_id = gateway_id or "" 

706 mapping_key: SessionMappingKey = (mcp_session_id, url, transport_type.value, normalized_gateway_id) 

707 

708 # Check local memory first (fast path - same worker) 

709 async with self._mcp_session_mapping_lock: 

710 pool_key = self._mcp_session_mapping.get(mapping_key) 

711 if pool_key: 

712 self._session_affinity_local_hits += 1 

713 logger.debug(f"Session affinity hit (local): {mcp_session_id[:8]}...") 

714 

715 # If not in local memory, check Redis (multi-worker support) 

716 if pool_key is None: 

717 try: 

718 # First-Party 

719 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

720 

721 redis = await get_redis_client() 

722 if redis: 

723 redis_key = self._session_mapping_redis_key(mcp_session_id, url, transport_type.value, normalized_gateway_id) 

724 pool_key_data = await redis.get(redis_key) 

725 if pool_key_data: 

726 # Deserialize pool_key from JSON 

727 data = orjson.loads(pool_key_data) 

728 pool_key = ( 

729 data["user_hash"], 

730 data["url"], 

731 data["identity_hash"], 

732 data["transport_type"], 

733 data["gateway_id"], 

734 ) 

735 # Cache in local memory for future requests 

736 async with self._mcp_session_mapping_lock: 

737 self._mcp_session_mapping[mapping_key] = pool_key 

738 self._session_affinity_redis_hits += 1 

739 logger.debug(f"Session affinity hit (Redis): {mcp_session_id[:8]}...") 

740 except Exception as e: 

741 logger.debug(f"Failed to check Redis for session mapping: {e}") 

742 

743 # Fallback to normal pool key computation 

744 if pool_key is None: 

745 self._session_affinity_misses += 1 

746 pool_key = self._make_pool_key(url, headers, transport_type, user_id, gateway_id) 

747 

748 pool = await self._get_or_create_pool(pool_key) 

749 

750 # Update pool key last used time IMMEDIATELY after getting pool 

751 # This prevents race with eviction removing keys between awaits 

752 self._pool_last_used[pool_key] = time.time() 

753 

754 lock = await self._get_or_create_lock(pool_key) 

755 

756 # Guard semaphore access - eviction may have removed it between awaits 

757 # If so, re-create the pool structures 

758 if pool_key not in self._semaphores: 

759 pool = await self._get_or_create_pool(pool_key) 

760 self._pool_last_used[pool_key] = time.time() 

761 

762 semaphore = self._semaphores[pool_key] 

763 

764 # Throttled eviction - only run if enough time has passed (inline, not spawned) 

765 await self._maybe_evict_idle_pool_keys() 

766 

767 # Try to get from pool first (quick path, no lock needed for queue get) 

768 while True: 

769 try: 

770 pooled = pool.get_nowait() 

771 except asyncio.QueueEmpty: 

772 break 

773 

774 # Validate the session outside the lock 

775 if await self._validate_session(pooled): 

776 pooled.last_used = time.time() 

777 pooled.use_count += 1 

778 self._hits += 1 

779 async with lock: 

780 self._active[pool_key].add(pooled) 

781 logger.debug(f"Pool hit for {sanitize_url_for_logging(url)} (identity={pool_key[2][:8]}, transport={transport_type.value})") 

782 return pooled 

783 

784 # Session invalid, close it 

785 await self._close_session(pooled) 

786 self._evictions += 1 

787 semaphore.release() # Free up a slot 

788 

789 # No valid session in pool - try to create one or wait 

790 try: 

791 # Use semaphore with timeout to limit concurrent sessions 

792 acquired = await asyncio.wait_for(semaphore.acquire(), timeout=self._acquire_timeout) 

793 if not acquired: 

794 raise asyncio.TimeoutError("Failed to acquire session slot") 

795 except asyncio.TimeoutError: 

796 raise asyncio.TimeoutError(f"Timeout waiting for available session for {sanitize_url_for_logging(url)}") from None 

797 

798 # Create new session (semaphore acquired) 

799 try: 

800 # Verify we own this session before creating (prevents race condition) 

801 # If another worker already claimed ownership, we should not create a new session 

802 # Note: Ownership is registered atomically in register_session_mapping() using SETNX 

803 if settings.mcpgateway_session_affinity_enabled and headers: 

804 headers_lower = {k.lower(): v for k, v in headers.items()} 

805 mcp_session_id = headers_lower.get("x-mcp-session-id") 

806 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

807 owner = await self._get_pool_session_owner(mcp_session_id) 

808 if owner and owner != WORKER_ID: 

809 # Another worker claimed ownership - should have been forwarded 

810 # Release semaphore and raise to trigger forwarding 

811 semaphore.release() 

812 logger.warning(f"Session {mcp_session_id[:8]}... owned by worker {owner}, not us ({WORKER_ID})") 

813 raise RuntimeError(f"Session owned by another worker: {owner}") 

814 

815 pooled = await asyncio.wait_for( 

816 self._create_session(url, headers, transport_type, httpx_client_factory, effective_timeout, gateway_id), 

817 timeout=self._session_create_timeout, 

818 ) 

819 # Store identity components for key reconstruction 

820 pooled.identity_key = pool_key[2] 

821 pooled.user_identity = user_id 

822 

823 # Note: Ownership is now registered atomically in register_session_mapping() 

824 # before acquire() is called, so we don't need to register it here 

825 

826 self._misses += 1 

827 self._record_success(url) 

828 async with lock: 

829 self._active[pool_key].add(pooled) 

830 logger.debug(f"Pool miss for {sanitize_url_for_logging(url)} - created new session (transport={transport_type.value})") 

831 return pooled 

832 except BaseException as e: 

833 # Release semaphore on ANY failure (including CancelledError) 

834 semaphore.release() 

835 if not isinstance(e, asyncio.CancelledError): 

836 self._record_failure(url) 

837 logger.warning(f"Failed to create session for {sanitize_url_for_logging(url)}: {e}") 

838 raise 

839 

840 async def release(self, pooled: PooledSession) -> None: 

841 """ 

842 Return a session to the pool for reuse. 

843 

844 Args: 

845 pooled: The session to release. 

846 """ 

847 if pooled.is_closed: 

848 logger.warning("Attempted to release already-closed session") 

849 return 

850 

851 # Pool key includes transport type, user identity, and gateway_id 

852 # Re-compute user hash from stored raw identity (full hash for collision resistance) 

853 user_hash = "anonymous" 

854 if pooled.user_identity != "anonymous": 

855 user_hash = hashlib.sha256(pooled.user_identity.encode()).hexdigest() 

856 

857 pool_key = (user_hash, pooled.url, pooled.identity_key, pooled.transport_type.value, pooled.gateway_id) 

858 lock = await self._get_or_create_lock(pool_key) 

859 pool = await self._get_or_create_pool(pool_key) 

860 

861 async with lock: 

862 # Update last-used FIRST to prevent eviction race: 

863 # If eviction runs between removing from _active and putting back in pool, 

864 # it would see key as idle + inactive and evict it. By updating last-used 

865 # while still holding the lock and before removing from _active, we ensure 

866 # eviction sees recent activity. 

867 self._pool_last_used[pool_key] = time.time() 

868 self._active.get(pool_key, set()).discard(pooled) 

869 

870 # Check if session should be returned to pool 

871 if self._closed or pooled.age_seconds > self._session_ttl: 

872 await self._close_session(pooled) 

873 if pool_key in self._semaphores: 

874 self._semaphores[pool_key].release() 

875 if pooled.age_seconds > self._session_ttl: 

876 self._evictions += 1 

877 return 

878 

879 # Return to pool (pool may have been evicted in edge case, recreate if needed) 

880 if pool_key not in self._pools: 

881 pool = await self._get_or_create_pool(pool_key) 

882 self._pool_last_used[pool_key] = time.time() 

883 

884 try: 

885 pool.put_nowait(pooled) 

886 logger.debug(f"Session returned to pool for {sanitize_url_for_logging(pooled.url)}") 

887 except asyncio.QueueFull: 

888 # Pool full (shouldn't happen with semaphore), close session 

889 await self._close_session(pooled) 

890 if pool_key in self._semaphores: 

891 self._semaphores[pool_key].release() 

892 

893 async def _maybe_evict_idle_pool_keys(self) -> None: 

894 """ 

895 Reap stale sessions and evict idle pool keys. 

896 

897 This method is throttled - it only runs eviction if enough time has 

898 passed since the last run (default: 60 seconds). This prevents: 

899 - Unbounded task spawning on every acquire 

900 - Lock contention under high load 

901 

902 Two-phase cleanup: 

903 1. Close expired/stale sessions parked in idle pools (frees connections) 

904 2. Evict pool keys that are now empty and have no active sessions 

905 

906 This prevents unbounded connection and pool key growth when using 

907 rotating tokens (e.g., short-lived JWTs with unique identifiers). 

908 """ 

909 if self._closed: 

910 return 

911 

912 now = time.time() 

913 

914 # Throttle: only run eviction once per interval 

915 if now - self._last_eviction_run < self._eviction_run_interval: 

916 return 

917 

918 self._last_eviction_run = now 

919 

920 # Collect sessions to close and keys to evict (minimize time holding lock) 

921 sessions_to_close: list[PooledSession] = [] 

922 keys_to_evict: list[PoolKey] = [] 

923 

924 async with self._global_lock: 

925 for pool_key, last_used in list(self._pool_last_used.items()): 

926 # Skip recently-used pools 

927 if now - last_used < self._idle_pool_eviction: 

928 continue 

929 

930 pool = self._pools.get(pool_key) 

931 active = self._active.get(pool_key, set()) 

932 

933 # Skip if there are active sessions (in use) 

934 if active: 

935 continue 

936 

937 if pool: 

938 # Phase 1: Drain and collect expired/stale sessions from idle pools 

939 while not pool.empty(): 

940 try: 

941 session = pool.get_nowait() 

942 # Close if expired OR idle too long (defense in depth) 

943 if session.age_seconds > self._session_ttl or session.idle_seconds > self._idle_pool_eviction: 

944 sessions_to_close.append(session) 

945 # Release semaphore slot for this session 

946 if pool_key in self._semaphores: 

947 self._semaphores[pool_key].release() 

948 else: 

949 # Session still valid, put it back 

950 pool.put_nowait(session) 

951 break # Stop draining if we find a valid session 

952 except asyncio.QueueEmpty: 

953 break 

954 

955 # Phase 2: Evict pool key if now empty 

956 if pool.empty(): 

957 keys_to_evict.append(pool_key) 

958 

959 # Remove evicted keys from all tracking dicts 

960 for pool_key in keys_to_evict: 

961 self._pools.pop(pool_key, None) 

962 self._active.pop(pool_key, None) 

963 self._locks.pop(pool_key, None) 

964 self._semaphores.pop(pool_key, None) 

965 self._pool_last_used.pop(pool_key, None) 

966 self._pool_keys_evicted += 1 

967 logger.debug(f"Evicted idle pool key: {pool_key[0][:8]}|{pool_key[1]}|{pool_key[2][:8]}") 

968 

969 # Close sessions outside the lock (I/O operations) 

970 for session in sessions_to_close: 

971 await self._close_session(session) 

972 self._sessions_reaped += 1 

973 logger.debug(f"Reaped stale session for {sanitize_url_for_logging(session.url)} (age={session.age_seconds:.1f}s)") 

974 

975 async def _validate_session(self, pooled: PooledSession) -> bool: 

976 """ 

977 Validate a session is still usable. 

978 

979 Checks TTL and performs health check if session is stale. 

980 

981 Args: 

982 pooled: The session to validate. 

983 

984 Returns: 

985 True if session is valid, False otherwise. 

986 """ 

987 if pooled.is_closed: 

988 return False 

989 

990 # Check TTL 

991 if pooled.age_seconds > self._session_ttl: 

992 logger.debug(f"Session expired (age={pooled.age_seconds:.1f}s)") 

993 return False 

994 

995 # Health check if stale 

996 if pooled.idle_seconds > self._health_check_interval: 

997 return await self._run_health_check_chain(pooled) 

998 

999 return True 

1000 

1001 async def _run_health_check_chain(self, pooled: PooledSession) -> bool: 

1002 """ 

1003 Run health check methods in configured order until one succeeds. 

1004 

1005 The health check chain allows configuring which methods to try and in what order. 

1006 This supports both modern servers (with ping support) and legacy servers 

1007 (that may only support list_tools or no health check at all). 

1008 

1009 Args: 

1010 pooled: The session to health check. 

1011 

1012 Returns: 

1013 True if any health check method succeeds, False if all fail. 

1014 """ 

1015 for method in self._health_check_methods: 

1016 try: 

1017 if method == "ping": 

1018 await asyncio.wait_for(pooled.session.send_ping(), timeout=self._health_check_timeout) 

1019 logger.debug(f"Health check passed: ping (url={sanitize_url_for_logging(pooled.url)})") 

1020 return True 

1021 if method == "list_tools": 

1022 await asyncio.wait_for(pooled.session.list_tools(), timeout=self._health_check_timeout) 

1023 logger.debug(f"Health check passed: list_tools (url={sanitize_url_for_logging(pooled.url)})") 

1024 return True 

1025 if method == "list_prompts": 

1026 await asyncio.wait_for(pooled.session.list_prompts(), timeout=self._health_check_timeout) 

1027 logger.debug(f"Health check passed: list_prompts (url={sanitize_url_for_logging(pooled.url)})") 

1028 return True 

1029 if method == "list_resources": 

1030 await asyncio.wait_for(pooled.session.list_resources(), timeout=self._health_check_timeout) 

1031 logger.debug(f"Health check passed: list_resources (url={sanitize_url_for_logging(pooled.url)})") 

1032 return True 

1033 if method == "skip": 

1034 logger.debug(f"Health check skipped per configuration (url={sanitize_url_for_logging(pooled.url)})") 

1035 return True 

1036 logger.warning(f"Unknown health check method '{method}', skipping") 

1037 continue 

1038 

1039 except McpError as e: 

1040 # METHOD_NOT_FOUND (-32601) means the method isn't supported - try next 

1041 if e.error.code == METHOD_NOT_FOUND: 

1042 logger.debug(f"Health check method '{method}' not supported by server, trying next") 

1043 continue 

1044 # Other MCP errors are real failures 

1045 logger.debug(f"Health check '{method}' failed with MCP error: {e}") 

1046 self._health_check_failures += 1 

1047 return False 

1048 

1049 except asyncio.TimeoutError: 

1050 logger.debug(f"Health check '{method}' timed out after {self._health_check_timeout}s, trying next") 

1051 continue 

1052 

1053 except Exception as e: 

1054 logger.debug(f"Health check '{method}' failed: {e}") 

1055 self._health_check_failures += 1 

1056 return False 

1057 

1058 # All methods failed or were unsupported 

1059 logger.warning(f"All health check methods failed or unsupported (methods={self._health_check_methods})") 

1060 self._health_check_failures += 1 

1061 return False 

1062 

1063 async def _create_session( 

1064 self, 

1065 url: str, 

1066 headers: Optional[Dict[str, str]], 

1067 transport_type: TransportType, 

1068 httpx_client_factory: Optional[HttpxClientFactory], 

1069 timeout: Optional[float] = None, 

1070 gateway_id: Optional[str] = None, 

1071 ) -> PooledSession: 

1072 """ 

1073 Create a new initialized MCP session. 

1074 

1075 Args: 

1076 url: Server URL. 

1077 headers: Request headers. 

1078 transport_type: Transport type to use. 

1079 httpx_client_factory: Optional factory for httpx clients. 

1080 timeout: Optional timeout in seconds for transport connection. 

1081 gateway_id: Optional gateway ID for notification handler context. 

1082 

1083 Returns: 

1084 Initialized PooledSession. 

1085 

1086 Raises: 

1087 RuntimeError: If session creation or initialization fails. 

1088 asyncio.CancelledError: If cancelled during creation. 

1089 """ 

1090 # Merge headers with defaults 

1091 merged_headers = {"Accept": "application/json, text/event-stream"} 

1092 if headers: 

1093 merged_headers.update(headers) 

1094 

1095 # Strip gateway-internal session affinity headers before sending to upstream 

1096 # x-mcp-session-id is our internal representation, mcp-session-id is the MCP protocol header 

1097 # Neither should be forwarded to upstream servers 

1098 keys_to_remove = [k for k in merged_headers if k.lower() in ("x-mcp-session-id", "mcp-session-id")] 

1099 for k in keys_to_remove: 

1100 del merged_headers[k] 

1101 

1102 identity_key = self._compute_identity_hash(headers) 

1103 transport_ctx = None 

1104 session = None 

1105 success = False 

1106 

1107 try: 

1108 # Create transport context 

1109 if transport_type == TransportType.SSE: 

1110 if httpx_client_factory: 

1111 transport_ctx = sse_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout) 

1112 else: 

1113 transport_ctx = sse_client(url=url, headers=merged_headers, timeout=timeout) 

1114 # pylint: disable=unnecessary-dunder-call,no-member 

1115 streams = await transport_ctx.__aenter__() # noqa: PLC2801 - Must call directly for manual lifecycle management 

1116 read_stream, write_stream = streams[0], streams[1] 

1117 else: # STREAMABLE_HTTP 

1118 if httpx_client_factory: 

1119 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, httpx_client_factory=httpx_client_factory, timeout=timeout) 

1120 else: 

1121 transport_ctx = streamablehttp_client(url=url, headers=merged_headers, timeout=timeout) 

1122 # pylint: disable=unnecessary-dunder-call,no-member 

1123 read_stream, write_stream, _ = await transport_ctx.__aenter__() # noqa: PLC2801 - Must call directly for manual lifecycle management 

1124 

1125 # Create message handler if factory is configured 

1126 message_handler = None 

1127 if self._message_handler_factory: 

1128 try: 

1129 message_handler = self._message_handler_factory(url, gateway_id) 

1130 logger.debug(f"Created message handler for session {sanitize_url_for_logging(url)} (gateway={gateway_id})") 

1131 except Exception as e: 

1132 logger.warning(f"Failed to create message handler for {sanitize_url_for_logging(url)}: {e}") 

1133 

1134 # Create and initialize session 

1135 session = ClientSession(read_stream, write_stream, message_handler=message_handler) 

1136 # pylint: disable=unnecessary-dunder-call 

1137 await session.__aenter__() # noqa: PLC2801 - Must call directly for manual lifecycle management 

1138 await session.initialize() 

1139 

1140 logger.info(f"Created new MCP session for {sanitize_url_for_logging(url)} (transport={transport_type.value})") 

1141 success = True 

1142 

1143 return PooledSession( 

1144 session=session, 

1145 transport_context=transport_ctx, 

1146 url=url, 

1147 transport_type=transport_type, 

1148 headers=merged_headers, 

1149 identity_key=identity_key, 

1150 gateway_id=gateway_id or "", 

1151 ) 

1152 

1153 except asyncio.CancelledError: # pylint: disable=try-except-raise 

1154 # Re-raise CancelledError after cleanup (handled in finally) 

1155 raise 

1156 

1157 except Exception as e: 

1158 raise RuntimeError(f"Failed to create MCP session for {url}: {e}") from e 

1159 

1160 finally: 

1161 # Clean up on ANY failure (Exception, CancelledError, etc.) 

1162 # Only clean up if we didn't succeed 

1163 # Use anyio.move_on_after instead of asyncio.wait_for to properly propagate 

1164 # cancellation through anyio's cancel scope system (prevents orphaned spinning tasks) 

1165 if not success: 

1166 cleanup_timeout = _get_cleanup_timeout() 

1167 if session is not None: 

1168 with anyio.move_on_after(cleanup_timeout): 

1169 try: 

1170 await session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1171 except Exception: # nosec B110 - Best effort cleanup on connection failure 

1172 pass 

1173 if transport_ctx is not None: 

1174 with anyio.move_on_after(cleanup_timeout): 

1175 try: 

1176 await transport_ctx.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1177 except Exception: # nosec B110 - Best effort cleanup on connection failure 

1178 pass 

1179 

1180 async def _close_session(self, pooled: PooledSession) -> None: 

1181 """ 

1182 Close a session and its transport. 

1183 

1184 Uses timeouts to prevent indefinite blocking if session/transport tasks 

1185 don't respond to cancellation. This prevents CPU spin loops in anyio's 

1186 _deliver_cancellation which can occur when async iterators or blocking 

1187 operations don't properly handle CancelledError. 

1188 

1189 Args: 

1190 pooled: The session to close. 

1191 """ 

1192 if pooled.is_closed: 

1193 return 

1194 

1195 pooled.mark_closed() 

1196 

1197 # Use anyio's move_on_after instead of asyncio.wait_for to properly propagate 

1198 # cancellation through anyio's cancel scope system. asyncio.wait_for() creates 

1199 # orphaned anyio tasks that keep spinning in _deliver_cancellation. 

1200 cleanup_timeout = _get_cleanup_timeout() 

1201 

1202 # Close session with anyio timeout 

1203 with anyio.move_on_after(cleanup_timeout) as session_scope: 

1204 try: 

1205 await pooled.session.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1206 except Exception as e: 

1207 logger.debug(f"Error closing session: {e}") 

1208 if session_scope.cancelled_caught: 

1209 logger.warning(f"Session cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway") 

1210 

1211 # Close transport with anyio timeout 

1212 with anyio.move_on_after(cleanup_timeout) as transport_scope: 

1213 try: 

1214 await pooled.transport_context.__aexit__(None, None, None) # pylint: disable=unnecessary-dunder-call 

1215 except Exception as e: 

1216 logger.debug(f"Error closing transport: {e}") 

1217 if transport_scope.cancelled_caught: 

1218 logger.warning(f"Transport cleanup timed out for {sanitize_url_for_logging(pooled.url)} - proceeding anyway") 

1219 

1220 logger.debug(f"Closed session for {sanitize_url_for_logging(pooled.url)} (uses={pooled.use_count})") 

1221 

1222 # Clean up pool_owner key in Redis for session affinity 

1223 if settings.mcpgateway_session_affinity_enabled and pooled.headers: 

1224 headers_lower = {k.lower(): v for k, v in pooled.headers.items()} 

1225 mcp_session_id = headers_lower.get("x-mcp-session-id") 

1226 if mcp_session_id and self.is_valid_mcp_session_id(mcp_session_id): 

1227 await self._cleanup_pool_session_owner(mcp_session_id) 

1228 

1229 async def _cleanup_pool_session_owner(self, mcp_session_id: str) -> None: 

1230 """Clean up pool_owner key in Redis when session is closed. 

1231 

1232 Only deletes the key if this worker owns it (to prevent removing other workers' ownership). 

1233 

1234 Args: 

1235 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1236 """ 

1237 try: 

1238 # First-Party 

1239 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1240 

1241 redis = await get_redis_client() 

1242 if redis: 

1243 key = self._pool_owner_key(mcp_session_id) 

1244 # Only delete if we own it 

1245 owner = await redis.get(key) 

1246 if owner: 

1247 owner_id = owner.decode() if isinstance(owner, bytes) else owner 

1248 if owner_id == WORKER_ID: 

1249 await redis.delete(key) 

1250 logger.debug(f"Cleaned up pool session owner: {mcp_session_id[:8]}...") 

1251 except Exception as e: 

1252 # Cleanup failure is non-fatal 

1253 logger.debug(f"Failed to cleanup pool session owner in Redis: {e}") 

1254 

1255 async def close_all(self) -> None: 

1256 """ 

1257 Gracefully close all pooled and active sessions. 

1258 

1259 Should be called during application shutdown. 

1260 """ 

1261 self._closed = True 

1262 logger.info("Closing all pooled sessions...") 

1263 

1264 async with self._global_lock: 

1265 # Close all pooled sessions 

1266 for _pool_key, pool in list(self._pools.items()): 

1267 while not pool.empty(): 

1268 try: 

1269 pooled = pool.get_nowait() 

1270 await self._close_session(pooled) 

1271 except asyncio.QueueEmpty: 

1272 break 

1273 

1274 # Close all active sessions 

1275 for _pool_key, active_set in list(self._active.items()): 

1276 for pooled in list(active_set): 

1277 await self._close_session(pooled) 

1278 

1279 self._pools.clear() 

1280 self._active.clear() 

1281 self._locks.clear() 

1282 self._semaphores.clear() 

1283 

1284 # Stop RPC listener if running 

1285 if self._rpc_listener_task and not self._rpc_listener_task.done(): 

1286 self._rpc_listener_task.cancel() 

1287 try: 

1288 await self._rpc_listener_task 

1289 except asyncio.CancelledError: 

1290 pass 

1291 self._rpc_listener_task = None 

1292 

1293 logger.info("All sessions closed") 

1294 

1295 async def register_pool_session_owner(self, mcp_session_id: str) -> None: 

1296 """Register this worker as owner of a pool session in Redis. 

1297 

1298 This enables multi-worker session affinity by tracking which worker owns 

1299 which pool session. When a request with x-mcp-session-id arrives at a 

1300 different worker, it can forward the request to the owner worker. 

1301 

1302 Note: This method is now primarily used for refreshing TTL on existing ownership. 

1303 Initial ownership is claimed atomically in register_session_mapping() using SETNX. 

1304 

1305 Args: 

1306 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1307 """ 

1308 if not settings.mcpgateway_session_affinity_enabled: 

1309 return 

1310 

1311 if not self.is_valid_mcp_session_id(mcp_session_id): 

1312 logger.debug("Invalid mcp_session_id for owner registration, skipping") 

1313 return 

1314 

1315 try: 

1316 # First-Party 

1317 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1318 

1319 redis = await get_redis_client() 

1320 if redis: 

1321 key = self._pool_owner_key(mcp_session_id) 

1322 

1323 # Do not steal ownership: only claim if missing, or refresh TTL if we already own. 

1324 # Lua keeps this atomic. 

1325 script = """ 

1326 local cur = redis.call('GET', KEYS[1]) 

1327 if not cur then 

1328 redis.call('SET', KEYS[1], ARGV[1], 'EX', ARGV[2]) 

1329 return 1 

1330 end 

1331 if cur == ARGV[1] then 

1332 redis.call('EXPIRE', KEYS[1], ARGV[2]) 

1333 return 2 

1334 end 

1335 return 0 

1336 """ 

1337 ttl = int(settings.mcpgateway_session_affinity_ttl) 

1338 outcome = await redis.eval(script, 1, key, WORKER_ID, ttl) 

1339 logger.debug(f"Owner registration outcome={outcome} for session {mcp_session_id[:8]}...") 

1340 except Exception as e: 

1341 # Redis failure is non-fatal - single worker mode still works 

1342 logger.debug(f"Failed to register pool session owner in Redis: {e}") 

1343 

1344 async def _get_pool_session_owner(self, mcp_session_id: str) -> Optional[str]: 

1345 """Get the worker ID that owns a pool session. 

1346 

1347 Args: 

1348 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1349 

1350 Returns: 

1351 The worker ID that owns this session, or None if not found or Redis unavailable. 

1352 """ 

1353 if not settings.mcpgateway_session_affinity_enabled: 

1354 return None 

1355 

1356 if not self.is_valid_mcp_session_id(mcp_session_id): 

1357 return None 

1358 

1359 try: 

1360 # First-Party 

1361 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1362 

1363 redis = await get_redis_client() 

1364 if redis: 

1365 key = self._pool_owner_key(mcp_session_id) 

1366 owner = await redis.get(key) 

1367 if owner: 

1368 decoded = owner.decode() if isinstance(owner, bytes) else owner 

1369 return decoded 

1370 except Exception as e: 

1371 logger.debug(f"Failed to get pool session owner from Redis: {e}") 

1372 return None 

1373 

1374 async def forward_request_to_owner( 

1375 self, 

1376 mcp_session_id: str, 

1377 request_data: Dict[str, Any], 

1378 timeout: Optional[float] = None, 

1379 ) -> Optional[Dict[str, Any]]: 

1380 """Forward RPC request to the worker that owns the pool session. 

1381 

1382 This method checks Redis to find which worker owns the pool session for 

1383 the given mcp_session_id. If owned by another worker, it forwards the 

1384 request via Redis pub/sub and waits for the response. 

1385 

1386 Args: 

1387 mcp_session_id: The MCP session ID from x-mcp-session-id header. 

1388 request_data: The RPC request data to forward. 

1389 timeout: Optional timeout in seconds (default from config). 

1390 

1391 Returns: 

1392 The response from the owner worker, or None if we own the session 

1393 (caller should execute locally) or if Redis is unavailable. 

1394 

1395 Raises: 

1396 asyncio.TimeoutError: If the forwarded request times out. 

1397 """ 

1398 if not settings.mcpgateway_session_affinity_enabled: 

1399 return None 

1400 

1401 if not self.is_valid_mcp_session_id(mcp_session_id): 

1402 return None 

1403 

1404 effective_timeout = timeout if timeout is not None else settings.mcpgateway_pool_rpc_forward_timeout 

1405 

1406 try: 

1407 # First-Party 

1408 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1409 

1410 redis = await get_redis_client() 

1411 if not redis: 

1412 return None # Execute locally - no Redis 

1413 

1414 # Check who owns this session 

1415 owner = await redis.get(self._pool_owner_key(mcp_session_id)) 

1416 method = request_data.get("method", "unknown") 

1417 if not owner: 

1418 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | No owner → execute locally (new session)") 

1419 return None # No owner registered - execute locally (new session) 

1420 

1421 owner_id = owner.decode() if isinstance(owner, bytes) else owner 

1422 if owner_id == WORKER_ID: 

1423 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | We own it → execute locally") 

1424 return None # We own it - execute locally 

1425 

1426 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Owner: {owner_id} → forwarding") 

1427 

1428 # Forward to owner worker via pub/sub 

1429 response_id = str(uuid.uuid4()) 

1430 response_channel = f"mcpgw:pool_rpc_response:{response_id}" 

1431 

1432 # Subscribe to response channel 

1433 pubsub = redis.pubsub() 

1434 await pubsub.subscribe(response_channel) 

1435 

1436 try: 

1437 # Prepare request with response channel 

1438 forward_data = { 

1439 "type": "rpc_forward", 

1440 **request_data, 

1441 "response_channel": response_channel, 

1442 "mcp_session_id": mcp_session_id, 

1443 } 

1444 

1445 # Publish request to owner's channel 

1446 await redis.publish(f"mcpgw:pool_rpc:{owner_id}", orjson.dumps(forward_data)) 

1447 self._forwarded_requests += 1 

1448 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Method: {method} | Published to worker {owner_id}") 

1449 

1450 # Wait for response 

1451 async with asyncio.timeout(effective_timeout): 

1452 while True: 

1453 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) 

1454 if msg and msg["type"] == "message": 

1455 return orjson.loads(msg["data"]) 

1456 finally: 

1457 await pubsub.unsubscribe(response_channel) 

1458 

1459 except asyncio.TimeoutError: 

1460 self._forwarded_request_timeouts += 1 

1461 logger.warning(f"Timeout forwarding request to owner for session {mcp_session_id[:8]}...") 

1462 raise 

1463 except Exception as e: 

1464 self._forwarded_request_failures += 1 

1465 logger.debug(f"Error forwarding request to owner: {e}") 

1466 return None # Execute locally on error 

1467 

1468 async def start_rpc_listener(self) -> None: 

1469 """Start listening for forwarded RPC and HTTP requests on this worker's channels. 

1470 

1471 This method subscribes to Redis pub/sub channels specific to this worker 

1472 and processes incoming forwarded requests from other workers: 

1473 - mcpgw:pool_rpc:{WORKER_ID} - for SSE transport JSON-RPC forwards 

1474 - mcpgw:pool_http:{WORKER_ID} - for Streamable HTTP request forwards 

1475 """ 

1476 if not settings.mcpgateway_session_affinity_enabled: 

1477 return 

1478 

1479 try: 

1480 # First-Party 

1481 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1482 

1483 redis = await get_redis_client() 

1484 if not redis: 

1485 logger.debug("Redis not available, RPC listener not started") 

1486 return 

1487 

1488 rpc_channel = f"mcpgw:pool_rpc:{WORKER_ID}" 

1489 http_channel = f"mcpgw:pool_http:{WORKER_ID}" 

1490 pubsub = redis.pubsub() 

1491 await pubsub.subscribe(rpc_channel, http_channel) 

1492 logger.info(f"RPC/HTTP listener started for worker {WORKER_ID} on channels: {rpc_channel}, {http_channel}") 

1493 

1494 try: 

1495 while not self._closed: 

1496 try: 

1497 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) 

1498 if msg and msg["type"] == "message": 

1499 request = orjson.loads(msg["data"]) 

1500 forward_type = request.get("type") 

1501 response_channel = request.get("response_channel") 

1502 

1503 if response_channel: 

1504 if forward_type == "rpc_forward": 

1505 # Execute forwarded RPC request for SSE transport 

1506 response = await self._execute_forwarded_request(request) 

1507 await redis.publish(response_channel, orjson.dumps(response)) 

1508 logger.debug(f"Processed forwarded RPC request, response sent to {response_channel}") 

1509 elif forward_type == "http_forward": 

1510 # Execute forwarded HTTP request for Streamable HTTP transport 

1511 await self._execute_forwarded_http_request(request, redis) 

1512 else: 

1513 logger.warning(f"Unknown forward type: {forward_type}") 

1514 except Exception as e: 

1515 logger.warning(f"Error processing forwarded request: {e}") 

1516 finally: 

1517 await pubsub.unsubscribe(rpc_channel, http_channel) 

1518 logger.info(f"RPC/HTTP listener stopped for worker {WORKER_ID}") 

1519 

1520 except Exception as e: 

1521 logger.warning(f"RPC/HTTP listener failed: {e}") 

1522 

1523 async def _execute_forwarded_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 

1524 """Execute a forwarded RPC request locally via internal HTTP call. 

1525 

1526 This method handles RPC requests that were forwarded from another worker. 

1527 Instead of handling specific methods here, we make an internal HTTP call 

1528 to the local /rpc endpoint which reuses ALL existing method handling logic. 

1529 

1530 The x-forwarded-internally header prevents infinite forwarding loops. 

1531 

1532 Args: 

1533 request: The forwarded RPC request containing method, params, headers, req_id, etc. 

1534 

1535 Returns: 

1536 The JSON-RPC response from the local endpoint. 

1537 """ 

1538 try: 

1539 method = request.get("method") 

1540 params = request.get("params", {}) 

1541 headers = request.get("headers", {}) 

1542 req_id = request.get("req_id", 1) 

1543 mcp_session_id = request.get("mcp_session_id", "unknown") 

1544 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

1545 

1546 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Received forwarded request, executing locally") 

1547 

1548 # Make internal HTTP call to local /rpc endpoint 

1549 # This reuses ALL existing method handling logic without duplication 

1550 async with httpx.AsyncClient() as client: 

1551 # Build headers for internal request - forward original headers 

1552 # but add x-forwarded-internally to prevent infinite loops 

1553 internal_headers = dict(headers) 

1554 internal_headers["x-forwarded-internally"] = "true" 

1555 # Ensure content-type is set 

1556 internal_headers["content-type"] = "application/json" 

1557 

1558 response = await client.post( 

1559 f"http://127.0.0.1:{settings.port}/rpc", 

1560 json={"jsonrpc": "2.0", "method": method, "params": params, "id": req_id}, 

1561 headers=internal_headers, 

1562 timeout=settings.mcpgateway_pool_rpc_forward_timeout, 

1563 ) 

1564 

1565 # Parse response 

1566 response_data = response.json() 

1567 

1568 # Extract result or error from JSON-RPC response 

1569 if "error" in response_data: 

1570 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed with error") 

1571 return {"error": response_data["error"]} 

1572 logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded execution completed successfully") 

1573 return {"result": response_data.get("result", {})} 

1574 

1575 except httpx.TimeoutException: 

1576 logger.warning(f"Timeout executing forwarded request: {request.get('method')}") 

1577 return {"error": {"code": -32603, "message": "Internal request timeout"}} 

1578 except Exception as e: 

1579 logger.warning(f"Error executing forwarded request: {e}") 

1580 return {"error": {"code": -32603, "message": str(e)}} 

1581 

1582 async def _execute_forwarded_http_request(self, request: Dict[str, Any], redis: Any) -> None: 

1583 """Execute a forwarded HTTP request locally and return response via Redis. 

1584 

1585 This method handles full HTTP requests forwarded from other workers for 

1586 Streamable HTTP transport session affinity. It reconstructs the HTTP request, 

1587 makes an internal call to the appropriate endpoint, and publishes the response 

1588 back through Redis. 

1589 

1590 Args: 

1591 request: Serialized HTTP request data from Redis Pub/Sub containing: 

1592 - type: "http_forward" 

1593 - response_channel: Redis channel to publish response to 

1594 - mcp_session_id: Session identifier 

1595 - method: HTTP method (GET, POST, DELETE) 

1596 - path: Request path (e.g., /mcp) 

1597 - query_string: Query parameters 

1598 - headers: Request headers dict 

1599 - body: Hex-encoded request body 

1600 redis: Redis client for publishing response 

1601 """ 

1602 response_channel = None 

1603 try: 

1604 response_channel = request.get("response_channel") 

1605 method = request.get("method") 

1606 path = request.get("path") 

1607 query_string = request.get("query_string", "") 

1608 headers = request.get("headers", {}) 

1609 body_hex = request.get("body", "") 

1610 mcp_session_id = request.get("mcp_session_id") 

1611 

1612 # Decode hex body back to bytes 

1613 body = bytes.fromhex(body_hex) if body_hex else b"" 

1614 

1615 session_short = mcp_session_id[:8] if mcp_session_id and len(mcp_session_id) >= 8 else "unknown" 

1616 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Received forwarded HTTP request: {method} {path}") 

1617 

1618 # Add internal forwarding headers to prevent loops 

1619 internal_headers = dict(headers) 

1620 internal_headers["x-forwarded-internally"] = "true" 

1621 internal_headers["x-original-worker"] = request.get("original_worker", "unknown") 

1622 

1623 # Make internal HTTP request to local endpoint 

1624 url = f"http://127.0.0.1:{settings.port}{path}" 

1625 if query_string: 

1626 url = f"{url}?{query_string}" 

1627 

1628 async with httpx.AsyncClient() as client: 

1629 response = await client.request( 

1630 method=method, 

1631 url=url, 

1632 headers=internal_headers, 

1633 content=body, 

1634 timeout=settings.mcpgateway_pool_rpc_forward_timeout, 

1635 ) 

1636 

1637 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Executed locally: {response.status_code}") 

1638 

1639 # Serialize response for Redis transport 

1640 response_data = { 

1641 "status": response.status_code, 

1642 "headers": dict(response.headers), 

1643 "body": response.content.hex(), # Hex encode binary response 

1644 } 

1645 

1646 # Publish response back to requesting worker 

1647 if redis and response_channel: 

1648 await redis.publish(response_channel, orjson.dumps(response_data)) 

1649 logger.debug(f"[HTTP_AFFINITY] Published HTTP response to Redis channel: {response_channel}") 

1650 

1651 except Exception as e: 

1652 logger.error(f"Error executing forwarded HTTP request: {e}") 

1653 # Try to send error response if possible 

1654 if redis and response_channel: 

1655 error_response = { 

1656 "status": 500, 

1657 "headers": {"content-type": "application/json"}, 

1658 "body": orjson.dumps({"error": "Internal forwarding error"}).hex(), 

1659 } 

1660 try: 

1661 await redis.publish(response_channel, orjson.dumps(error_response)) 

1662 except Exception as publish_error: 

1663 logger.debug(f"Failed to publish error response via Redis: {publish_error}") 

1664 

1665 async def get_streamable_http_session_owner(self, mcp_session_id: str) -> Optional[str]: 

1666 """Get the worker ID that owns a Streamable HTTP session. 

1667 

1668 This is a public wrapper around _get_pool_session_owner for use by 

1669 streamablehttp_transport to check session ownership before handling requests. 

1670 

1671 Args: 

1672 mcp_session_id: The MCP session ID from mcp-session-id header. 

1673 

1674 Returns: 

1675 Worker ID if found, None otherwise. 

1676 """ 

1677 return await self._get_pool_session_owner(mcp_session_id) 

1678 

1679 async def forward_streamable_http_to_owner( 

1680 self, 

1681 owner_worker_id: str, 

1682 mcp_session_id: str, 

1683 method: str, 

1684 path: str, 

1685 headers: Dict[str, str], 

1686 body: bytes, 

1687 query_string: str = "", 

1688 ) -> Optional[Dict[str, Any]]: 

1689 """Forward a Streamable HTTP request to the worker that owns the session via Redis Pub/Sub. 

1690 

1691 This method forwards the entire HTTP request to another worker using Redis 

1692 Pub/Sub channels, similar to forward_request_to_owner() for SSE transport. 

1693 This ensures session affinity works correctly in single-host multi-worker 

1694 deployments where hostname-based routing fails. 

1695 

1696 Args: 

1697 owner_worker_id: The worker ID that owns the session. 

1698 mcp_session_id: The MCP session ID. 

1699 method: HTTP method (GET, POST, DELETE). 

1700 path: Request path (e.g., /mcp). 

1701 headers: Request headers. 

1702 body: Request body bytes. 

1703 query_string: Query string if any. 

1704 

1705 Returns: 

1706 Dict with 'status', 'headers', and 'body' from the owner worker's response, 

1707 or None if forwarding fails. 

1708 """ 

1709 if not settings.mcpgateway_session_affinity_enabled: 

1710 return None 

1711 

1712 if not self.is_valid_mcp_session_id(mcp_session_id): 

1713 return None 

1714 

1715 session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id 

1716 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {session_short}... | {method} {path} | Forwarding to worker {owner_worker_id}") 

1717 

1718 try: 

1719 # First-Party 

1720 from mcpgateway.utils.redis_client import get_redis_client # pylint: disable=import-outside-toplevel 

1721 

1722 redis = await get_redis_client() 

1723 if not redis: 

1724 logger.warning("Redis unavailable for HTTP forwarding, executing locally") 

1725 return None # Fall back to local execution 

1726 

1727 # Generate unique response channel for this request 

1728 response_uuid = uuid.uuid4().hex 

1729 response_channel = f"mcpgw:pool_http_response:{response_uuid}" 

1730 

1731 # Serialize HTTP request for Redis transport 

1732 forward_data = { 

1733 "type": "http_forward", 

1734 "response_channel": response_channel, 

1735 "mcp_session_id": mcp_session_id, 

1736 "method": method, 

1737 "path": path, 

1738 "query_string": query_string, 

1739 "headers": headers, 

1740 "body": body.hex() if body else "", # Hex encode binary body 

1741 "original_worker": WORKER_ID, 

1742 "timestamp": time.time(), 

1743 } 

1744 

1745 # Subscribe to response channel BEFORE publishing request (prevent race) 

1746 pubsub = redis.pubsub() 

1747 await pubsub.subscribe(response_channel) 

1748 

1749 try: 

1750 # Publish forwarded request to owner worker's HTTP channel 

1751 owner_channel = f"mcpgw:pool_http:{owner_worker_id}" 

1752 await redis.publish(owner_channel, orjson.dumps(forward_data)) 

1753 logger.debug(f"[HTTP_AFFINITY] Published HTTP request to Redis channel: {owner_channel}") 

1754 

1755 # Wait for response with timeout 

1756 timeout = settings.mcpgateway_pool_rpc_forward_timeout 

1757 async with asyncio.timeout(timeout): 

1758 while True: 

1759 msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) 

1760 if msg and msg["type"] == "message": 

1761 response_data = orjson.loads(msg["data"]) 

1762 logger.debug(f"[HTTP_AFFINITY] Received HTTP response via Redis: status={response_data.get('status')}") 

1763 

1764 # Decode hex body back to bytes 

1765 body_hex = response_data.get("body", "") 

1766 response_data["body"] = bytes.fromhex(body_hex) if body_hex else b"" 

1767 

1768 self._forwarded_requests += 1 

1769 return response_data 

1770 

1771 finally: 

1772 await pubsub.unsubscribe(response_channel) 

1773 

1774 except asyncio.TimeoutError: 

1775 self._forwarded_request_timeouts += 1 

1776 logger.warning(f"Timeout forwarding HTTP request to owner {owner_worker_id}") 

1777 return None 

1778 except Exception as e: 

1779 self._forwarded_request_failures += 1 

1780 logger.warning(f"Error forwarding HTTP request via Redis: {e}") 

1781 return None 

1782 

1783 def get_metrics(self) -> Dict[str, Any]: 

1784 """ 

1785 Return pool metrics for monitoring. 

1786 

1787 Returns: 

1788 Dict with hits, misses, evictions, hit_rate, and per-pool stats. 

1789 """ 

1790 total_requests = self._hits + self._misses 

1791 total_affinity_requests = self._session_affinity_local_hits + self._session_affinity_redis_hits + self._session_affinity_misses 

1792 return { 

1793 "hits": self._hits, 

1794 "misses": self._misses, 

1795 "evictions": self._evictions, 

1796 "health_check_failures": self._health_check_failures, 

1797 "circuit_breaker_trips": self._circuit_breaker_trips, 

1798 "pool_keys_evicted": self._pool_keys_evicted, 

1799 "sessions_reaped": self._sessions_reaped, 

1800 "anonymous_identity_count": self._anonymous_identity_count, 

1801 "hit_rate": self._hits / total_requests if total_requests > 0 else 0.0, 

1802 "pool_key_count": len(self._pools), 

1803 # Session affinity metrics 

1804 "session_affinity": { 

1805 "local_hits": self._session_affinity_local_hits, 

1806 "redis_hits": self._session_affinity_redis_hits, 

1807 "misses": self._session_affinity_misses, 

1808 "hit_rate": (self._session_affinity_local_hits + self._session_affinity_redis_hits) / total_affinity_requests if total_affinity_requests > 0 else 0.0, 

1809 "forwarded_requests": self._forwarded_requests, 

1810 "forwarded_failures": self._forwarded_request_failures, 

1811 "forwarded_timeouts": self._forwarded_request_timeouts, 

1812 }, 

1813 "pools": { 

1814 f"{url}|{identity[:8]}|{transport}|{user}|{gw_id[:8] if gw_id else 'none'}": { 

1815 "available": pool.qsize(), 

1816 "active": len(self._active.get((user, url, identity, transport, gw_id), set())), 

1817 "max": self._max_sessions, 

1818 } 

1819 for (user, url, identity, transport, gw_id), pool in self._pools.items() 

1820 }, 

1821 "circuit_breakers": { 

1822 url: { 

1823 "failures": self._failures.get(url, 0), 

1824 "open_until": self._circuit_open_until.get(url), 

1825 } 

1826 for url in set(self._failures.keys()) | set(self._circuit_open_until.keys()) 

1827 }, 

1828 } 

1829 

1830 @asynccontextmanager 

1831 async def session( 

1832 self, 

1833 url: str, 

1834 headers: Optional[Dict[str, str]] = None, 

1835 transport_type: TransportType = TransportType.STREAMABLE_HTTP, 

1836 httpx_client_factory: Optional[HttpxClientFactory] = None, 

1837 timeout: Optional[float] = None, 

1838 user_identity: Optional[str] = None, 

1839 gateway_id: Optional[str] = None, 

1840 ) -> "AsyncIterator[PooledSession]": 

1841 """ 

1842 Context manager for acquiring and releasing a session. 

1843 

1844 Usage: 

1845 async with pool.session(url, headers) as pooled: 

1846 result = await pooled.session.call_tool("my_tool", {}) 

1847 

1848 Args: 

1849 url: The MCP server URL. 

1850 headers: Request headers. 

1851 transport_type: Transport type to use. 

1852 httpx_client_factory: Optional factory for httpx clients. 

1853 timeout: Optional timeout in seconds for transport connection. 

1854 user_identity: Optional user identity for strict isolation. 

1855 gateway_id: Optional gateway ID for notification handler context. 

1856 

1857 Yields: 

1858 PooledSession ready for use. 

1859 """ 

1860 pooled = await self.acquire(url, headers, transport_type, httpx_client_factory, timeout, user_identity, gateway_id) 

1861 try: 

1862 yield pooled 

1863 finally: 

1864 await self.release(pooled) 

1865 

1866 

1867# Global pool instance - initialized by FastAPI lifespan 

1868_mcp_session_pool: Optional[MCPSessionPool] = None 

1869 

1870 

1871def get_mcp_session_pool() -> MCPSessionPool: 

1872 """Get the global MCP session pool instance. 

1873 

1874 Returns: 

1875 The global MCPSessionPool instance. 

1876 

1877 Raises: 

1878 RuntimeError: If pool has not been initialized. 

1879 """ 

1880 if _mcp_session_pool is None: 

1881 raise RuntimeError("MCP session pool not initialized. Call init_mcp_session_pool() first.") 

1882 return _mcp_session_pool 

1883 

1884 

1885def init_mcp_session_pool( 

1886 max_sessions_per_key: int = 10, 

1887 session_ttl_seconds: float = 300.0, 

1888 health_check_interval_seconds: float = 60.0, 

1889 acquire_timeout_seconds: float = 30.0, 

1890 session_create_timeout_seconds: float = 30.0, 

1891 circuit_breaker_threshold: int = 5, 

1892 circuit_breaker_reset_seconds: float = 60.0, 

1893 identity_headers: Optional[frozenset[str]] = None, 

1894 identity_extractor: Optional[IdentityExtractor] = None, 

1895 idle_pool_eviction_seconds: float = 600.0, 

1896 default_transport_timeout_seconds: float = 30.0, 

1897 health_check_methods: Optional[list[str]] = None, 

1898 health_check_timeout_seconds: float = 5.0, 

1899 message_handler_factory: Optional[MessageHandlerFactory] = None, 

1900 enable_notifications: bool = True, 

1901 notification_debounce_seconds: float = 5.0, 

1902) -> MCPSessionPool: 

1903 """Initialize the global MCP session pool. 

1904 

1905 Args: 

1906 See MCPSessionPool.__init__ for argument descriptions. 

1907 enable_notifications: Enable automatic notification service for list_changed events. 

1908 notification_debounce_seconds: Debounce interval for notification-triggered refreshes. 

1909 

1910 Returns: 

1911 The initialized MCPSessionPool instance. 

1912 """ 

1913 global _mcp_session_pool # pylint: disable=global-statement 

1914 

1915 # Auto-create notification service if enabled and no custom handler provided 

1916 effective_handler_factory = message_handler_factory 

1917 if enable_notifications and message_handler_factory is None: 

1918 # First-Party 

1919 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

1920 init_notification_service, 

1921 ) 

1922 

1923 # Initialize notification service (will be started during acquire with gateway context) 

1924 notification_svc = init_notification_service(debounce_seconds=notification_debounce_seconds) 

1925 

1926 # Create default handler factory that uses notification service 

1927 def default_handler_factory(url: str, gateway_id: Optional[str]): 

1928 """Create a message handler for MCP session notifications. 

1929 

1930 Args: 

1931 url: The MCP server URL for the session. 

1932 gateway_id: Optional gateway ID for attribution, falls back to URL if not provided. 

1933 

1934 Returns: 

1935 A message handler that forwards notifications to the notification service. 

1936 """ 

1937 return notification_svc.create_message_handler(gateway_id or url, url) 

1938 

1939 effective_handler_factory = default_handler_factory 

1940 logger.info("MCP notification service created (debounce=%ss)", notification_debounce_seconds) 

1941 

1942 _mcp_session_pool = MCPSessionPool( 

1943 max_sessions_per_key=max_sessions_per_key, 

1944 session_ttl_seconds=session_ttl_seconds, 

1945 health_check_interval_seconds=health_check_interval_seconds, 

1946 acquire_timeout_seconds=acquire_timeout_seconds, 

1947 session_create_timeout_seconds=session_create_timeout_seconds, 

1948 circuit_breaker_threshold=circuit_breaker_threshold, 

1949 circuit_breaker_reset_seconds=circuit_breaker_reset_seconds, 

1950 identity_headers=identity_headers, 

1951 identity_extractor=identity_extractor, 

1952 idle_pool_eviction_seconds=idle_pool_eviction_seconds, 

1953 default_transport_timeout_seconds=default_transport_timeout_seconds, 

1954 health_check_methods=health_check_methods, 

1955 health_check_timeout_seconds=health_check_timeout_seconds, 

1956 message_handler_factory=effective_handler_factory, 

1957 ) 

1958 logger.info("MCP session pool initialized") 

1959 return _mcp_session_pool 

1960 

1961 

1962async def close_mcp_session_pool() -> None: 

1963 """Close the global MCP session pool and notification service.""" 

1964 global _mcp_session_pool # pylint: disable=global-statement 

1965 if _mcp_session_pool is not None: 

1966 await _mcp_session_pool.close_all() 

1967 _mcp_session_pool = None 

1968 logger.info("MCP session pool closed") 

1969 

1970 # Close notification service if it was initialized 

1971 try: 

1972 # First-Party 

1973 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

1974 close_notification_service, 

1975 ) 

1976 

1977 await close_notification_service() 

1978 except (ImportError, RuntimeError): 

1979 pass # Notification service not initialized 

1980 

1981 

1982async def start_pool_notification_service(gateway_service: Any = None) -> None: 

1983 """Start the notification service background worker. 

1984 

1985 Call this after gateway_service is initialized to enable event-driven refresh. 

1986 

1987 Args: 

1988 gateway_service: Optional GatewayService instance for triggering refreshes. 

1989 """ 

1990 try: 

1991 # First-Party 

1992 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

1993 get_notification_service, 

1994 ) 

1995 

1996 notification_svc = get_notification_service() 

1997 await notification_svc.initialize(gateway_service) 

1998 logger.info("MCP notification service started") 

1999 except RuntimeError: 

2000 logger.debug("Notification service not configured, skipping start") 

2001 

2002 

2003def register_gateway_capabilities_for_notifications(gateway_id: str, capabilities: Dict[str, Any]) -> None: 

2004 """Register gateway capabilities for notification handling. 

2005 

2006 Call this after gateway initialization to enable list_changed notifications. 

2007 

2008 Args: 

2009 gateway_id: The gateway ID. 

2010 capabilities: Server capabilities from initialization response. 

2011 """ 

2012 try: 

2013 # First-Party 

2014 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2015 get_notification_service, 

2016 ) 

2017 

2018 notification_svc = get_notification_service() 

2019 notification_svc.register_gateway_capabilities(gateway_id, capabilities) 

2020 except RuntimeError: 

2021 pass # Notification service not initialized 

2022 

2023 

2024def unregister_gateway_from_notifications(gateway_id: str) -> None: 

2025 """Unregister a gateway from notification handling. 

2026 

2027 Call this when a gateway is deleted. 

2028 

2029 Args: 

2030 gateway_id: The gateway ID to unregister. 

2031 """ 

2032 try: 

2033 # First-Party 

2034 from mcpgateway.services.notification_service import ( # pylint: disable=import-outside-toplevel 

2035 get_notification_service, 

2036 ) 

2037 

2038 notification_svc = get_notification_service() 

2039 notification_svc.unregister_gateway(gateway_id) 

2040 except RuntimeError: 

2041 pass # Notification service not initialized