Coverage for mcpgateway / cache / session_registry.py: 100%

758 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/session_registry.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Session Registry with optional distributed state. 

8This module provides a registry for SSE sessions with support for distributed deployment 

9using Redis or SQLAlchemy as optional backends for shared state between workers. 

10 

11The SessionRegistry class manages server-sent event (SSE) sessions across multiple 

12worker processes, enabling horizontal scaling of MCP gateway deployments. It supports 

13three backend modes: 

14 

15- **memory**: In-memory storage for single-process deployments (default) 

16- **redis**: Redis-backed shared storage for multi-worker deployments 

17- **database**: SQLAlchemy-backed shared storage using any supported database 

18 

19In distributed mode (redis/database), session existence is tracked in the shared 

20backend while transport objects remain local to each worker process. This allows 

21workers to know about sessions on other workers and route messages appropriately. 

22 

23Examples: 

24 Basic usage with memory backend: 

25 

26 >>> from mcpgateway.cache.session_registry import SessionRegistry 

27 >>> class DummyTransport: 

28 ... async def disconnect(self): 

29 ... pass 

30 ... async def is_connected(self): 

31 ... return True 

32 >>> import asyncio 

33 >>> reg = SessionRegistry(backend='memory') 

34 >>> transport = DummyTransport() 

35 >>> asyncio.run(reg.add_session('sid123', transport)) 

36 >>> found = asyncio.run(reg.get_session('sid123')) 

37 >>> isinstance(found, DummyTransport) 

38 True 

39 >>> asyncio.run(reg.remove_session('sid123')) 

40 >>> asyncio.run(reg.get_session('sid123')) is None 

41 True 

42 

43 Broadcasting messages: 

44 

45 >>> reg = SessionRegistry(backend='memory') 

46 >>> asyncio.run(reg.broadcast('sid123', {'method': 'ping', 'id': 1})) 

47 >>> reg._session_message is not None 

48 True 

49""" 

50 

51# Standard 

52import asyncio 

53from asyncio import Task 

54from datetime import datetime, timedelta, timezone 

55import logging 

56import time 

57import traceback 

58from typing import Any, Dict, Optional 

59from urllib.parse import urlparse 

60import uuid 

61 

62# Third-Party 

63from fastapi import HTTPException, status 

64import orjson 

65 

66# First-Party 

67from mcpgateway import __version__ 

68from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities 

69from mcpgateway.config import settings 

70from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord 

71from mcpgateway.services import PromptService, ResourceService, ToolService 

72from mcpgateway.services.logging_service import LoggingService 

73from mcpgateway.transports import SSETransport 

74from mcpgateway.utils.create_jwt_token import create_jwt_token 

75from mcpgateway.utils.redis_client import get_redis_client 

76from mcpgateway.utils.retry_manager import ResilientHttpClient 

77from mcpgateway.validation.jsonrpc import JSONRPCError 

78 

79# Initialize logging service first 

80logging_service: LoggingService = LoggingService() 

81logger = logging_service.get_logger(__name__) 

82 

83tool_service: ToolService = ToolService() 

84resource_service: ResourceService = ResourceService() 

85prompt_service: PromptService = PromptService() 

86 

87try: 

88 # Third-Party 

89 from redis.asyncio import Redis 

90 

91 REDIS_AVAILABLE = True 

92except ImportError: 

93 REDIS_AVAILABLE = False 

94 

95try: 

96 # Third-Party 

97 from sqlalchemy import func 

98 

99 SQLALCHEMY_AVAILABLE = True 

100except ImportError: 

101 SQLALCHEMY_AVAILABLE = False 

102 

103 

104class SessionBackend: 

105 """Base class for session registry backend configuration. 

106 

107 This class handles the initialization and configuration of different backend 

108 types for session storage. It validates backend requirements and sets up 

109 necessary connections for Redis or database backends. 

110 

111 Attributes: 

112 _backend: The backend type ('memory', 'redis', 'database', or 'none') 

113 _session_ttl: Time-to-live for sessions in seconds 

114 _message_ttl: Time-to-live for messages in seconds 

115 _redis: Redis connection instance (redis backend only) 

116 _pubsub: Redis pubsub instance (redis backend only) 

117 _session_message: Temporary message storage (memory backend only) 

118 

119 Examples: 

120 >>> backend = SessionBackend(backend='memory') 

121 >>> backend._backend 

122 'memory' 

123 >>> backend._session_ttl 

124 3600 

125 

126 >>> try: 

127 ... backend = SessionBackend(backend='redis') 

128 ... except ValueError as e: 

129 ... str(e) 

130 'Redis backend requires redis_url' 

131 """ 

132 

133 def __init__( 

134 self, 

135 backend: str = "memory", 

136 redis_url: Optional[str] = None, 

137 database_url: Optional[str] = None, 

138 session_ttl: int = 3600, # 1 hour 

139 message_ttl: int = 600, # 10 min 

140 ): 

141 """Initialize session backend configuration. 

142 

143 Args: 

144 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'. 

145 - 'memory': In-memory storage, suitable for single-process deployments 

146 - 'redis': Redis-backed storage for multi-worker deployments 

147 - 'database': SQLAlchemy-backed storage for multi-worker deployments 

148 - 'none': No session tracking (dummy registry) 

149 redis_url: Redis connection URL. Required when backend='redis'. 

150 Format: 'redis://[:password]@host:port/db' 

151 database_url: Database connection URL. Required when backend='database'. 

152 Format depends on database type (e.g., 'postgresql://user:pass@host/db') 

153 session_ttl: Session time-to-live in seconds. Sessions are automatically 

154 cleaned up after this duration of inactivity. Default: 3600 (1 hour). 

155 message_ttl: Message time-to-live in seconds. Undelivered messages are 

156 removed after this duration. Default: 600 (10 minutes). 

157 

158 Raises: 

159 ValueError: If backend is invalid, required URL is missing, or required packages are not installed. 

160 

161 Examples: 

162 >>> # Memory backend (default) 

163 >>> backend = SessionBackend() 

164 >>> backend._backend 

165 'memory' 

166 

167 >>> # Redis backend requires URL 

168 >>> try: 

169 ... backend = SessionBackend(backend='redis') 

170 ... except ValueError as e: 

171 ... 'redis_url' in str(e) 

172 True 

173 

174 >>> # Invalid backend 

175 >>> try: 

176 ... backend = SessionBackend(backend='invalid') 

177 ... except ValueError as e: 

178 ... 'Invalid backend' in str(e) 

179 True 

180 """ 

181 

182 self._backend = backend.lower() 

183 self._session_ttl = session_ttl 

184 self._message_ttl = message_ttl 

185 

186 # Set up backend-specific components 

187 if self._backend == "memory": 

188 # Nothing special needed for memory backend 

189 self._session_message: dict[str, Any] | None = None 

190 

191 elif self._backend == "none": 

192 # No session tracking - this is just a dummy registry 

193 logger.info("Session registry initialized with 'none' backend - session tracking disabled") 

194 

195 elif self._backend == "redis": 

196 if not REDIS_AVAILABLE: 

197 raise ValueError("Redis backend requested but redis package not installed") 

198 if not redis_url: 

199 raise ValueError("Redis backend requires redis_url") 

200 

201 # Redis client is set in initialize() via the shared factory 

202 self._redis: Optional[Redis] = None 

203 self._pubsub = None 

204 

205 elif self._backend == "database": 

206 if not SQLALCHEMY_AVAILABLE: 

207 raise ValueError("Database backend requested but SQLAlchemy not installed") 

208 if not database_url: 

209 raise ValueError("Database backend requires database_url") 

210 else: 

211 raise ValueError(f"Invalid backend: {backend}") 

212 

213 

214class SessionRegistry(SessionBackend): 

215 """Registry for SSE sessions with optional distributed state. 

216 

217 This class manages server-sent event (SSE) sessions, providing methods to add, 

218 remove, and query sessions. It supports multiple backend types for different 

219 deployment scenarios: 

220 

221 - **Single-process deployments**: Use 'memory' backend (default) 

222 - **Multi-worker deployments**: Use 'redis' or 'database' backend 

223 - **Testing/development**: Use 'none' backend to disable session tracking 

224 

225 The registry maintains a local cache of transport objects while using the 

226 shared backend to track session existence across workers. This enables 

227 horizontal scaling while keeping transport objects process-local. 

228 

229 Attributes: 

230 _sessions: Local dictionary mapping session IDs to transport objects 

231 _lock: Asyncio lock for thread-safe access to _sessions 

232 _cleanup_task: Background task for cleaning up expired sessions 

233 

234 Examples: 

235 >>> import asyncio 

236 >>> from mcpgateway.cache.session_registry import SessionRegistry 

237 >>> 

238 >>> class MockTransport: 

239 ... async def disconnect(self): 

240 ... print("Disconnected") 

241 ... async def is_connected(self): 

242 ... return True 

243 ... async def send_message(self, msg): 

244 ... print(f"Sent: {msg}") 

245 >>> 

246 >>> # Create registry and add session 

247 >>> reg = SessionRegistry(backend='memory') 

248 >>> transport = MockTransport() 

249 >>> asyncio.run(reg.add_session('test123', transport)) 

250 >>> 

251 >>> # Retrieve session 

252 >>> found = asyncio.run(reg.get_session('test123')) 

253 >>> found is transport 

254 True 

255 >>> 

256 >>> # Remove session 

257 >>> asyncio.run(reg.remove_session('test123')) 

258 Disconnected 

259 >>> asyncio.run(reg.get_session('test123')) is None 

260 True 

261 """ 

262 

263 def __init__( 

264 self, 

265 backend: str = "memory", 

266 redis_url: Optional[str] = None, 

267 database_url: Optional[str] = None, 

268 session_ttl: int = 3600, # 1 hour 

269 message_ttl: int = 600, # 10 min 

270 ): 

271 """Initialize session registry with specified backend. 

272 

273 Args: 

274 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'. 

275 redis_url: Redis connection URL. Required when backend='redis'. 

276 database_url: Database connection URL. Required when backend='database'. 

277 session_ttl: Session time-to-live in seconds. Default: 3600. 

278 message_ttl: Message time-to-live in seconds. Default: 600. 

279 

280 Examples: 

281 >>> # Default memory backend 

282 >>> reg = SessionRegistry() 

283 >>> reg._backend 

284 'memory' 

285 >>> isinstance(reg._sessions, dict) 

286 True 

287 

288 >>> # Redis backend with custom TTL 

289 >>> try: 

290 ... reg = SessionRegistry( 

291 ... backend='redis', 

292 ... redis_url='redis://localhost:6379', 

293 ... session_ttl=7200 

294 ... ) 

295 ... except ValueError: 

296 ... pass # Redis may not be available 

297 """ 

298 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl) 

299 self._sessions: Dict[str, Any] = {} # Local transport cache 

300 self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id 

301 self._respond_tasks: Dict[str, asyncio.Task] = {} # Track respond tasks for cancellation 

302 self._stuck_tasks: Dict[str, asyncio.Task] = {} # Tasks that couldn't be cancelled (for monitoring) 

303 self._closing_sessions: set[str] = set() # Sessions being closed - respond loop should exit 

304 self._lock = asyncio.Lock() 

305 self._cleanup_task: Task | None = None 

306 self._stuck_task_reaper: Task | None = None # Reaper for stuck tasks 

307 

308 def register_respond_task(self, session_id: str, task: asyncio.Task) -> None: 

309 """Register a respond task for later cancellation. 

310 

311 Associates an asyncio Task with a session_id so it can be cancelled 

312 when the session is removed. This prevents orphaned tasks that cause 

313 CPU spin loops. 

314 

315 Args: 

316 session_id: Session identifier the task belongs to. 

317 task: The asyncio Task to track. 

318 """ 

319 self._respond_tasks[session_id] = task 

320 logger.debug(f"Registered respond task for session {session_id}") 

321 

322 async def _cancel_respond_task(self, session_id: str, timeout: float = 5.0) -> None: 

323 """Cancel and await a respond task with timeout. 

324 

325 Safely cancels the respond task associated with a session. Uses a timeout 

326 to prevent hanging if the task doesn't respond to cancellation. 

327 

328 If initial cancellation times out, escalates by force-disconnecting the 

329 transport to unblock the task, then retries cancellation (Finding 1 fix). 

330 

331 Args: 

332 session_id: Session identifier whose task should be cancelled. 

333 timeout: Maximum seconds to wait for task cancellation. Default 5.0. 

334 """ 

335 task = self._respond_tasks.get(session_id) 

336 if task is None: 

337 return 

338 

339 if task.done(): 

340 # Task already finished - safe to remove from tracking 

341 self._respond_tasks.pop(session_id, None) 

342 try: 

343 task.result() 

344 except asyncio.CancelledError: 

345 pass 

346 except Exception as e: 

347 logger.warning(f"Respond task for {session_id} failed with: {e}") 

348 return 

349 

350 task.cancel() 

351 logger.debug(f"Cancelling respond task for session {session_id}") 

352 

353 try: 

354 await asyncio.wait_for(task, timeout=timeout) 

355 # Cancellation successful - remove from tracking 

356 self._respond_tasks.pop(session_id, None) 

357 logger.debug(f"Respond task cancelled for session {session_id}") 

358 except asyncio.TimeoutError: 

359 # ESCALATION (Finding 1): Force-disconnect transport to unblock the task 

360 logger.warning(f"Respond task cancellation timed out for {session_id}, " f"escalating with transport disconnect") 

361 

362 # Force-disconnect the transport to unblock any pending I/O 

363 transport = self._sessions.get(session_id) 

364 if transport and hasattr(transport, "disconnect"): 

365 try: 

366 await transport.disconnect() 

367 logger.debug(f"Force-disconnected transport for {session_id}") 

368 except Exception as e: 

369 logger.warning(f"Failed to force-disconnect transport for {session_id}: {e}") 

370 

371 # Retry cancellation with shorter timeout 

372 if not task.done(): 

373 try: 

374 await asyncio.wait_for(task, timeout=2.0) 

375 self._respond_tasks.pop(session_id, None) 

376 logger.info(f"Respond task cancelled after escalation for {session_id}") 

377 except asyncio.TimeoutError: 

378 # Still stuck - move to stuck_tasks for monitoring (Finding 2 fix) 

379 self._respond_tasks.pop(session_id, None) 

380 self._stuck_tasks[session_id] = task 

381 logger.error(f"Respond task for {session_id} still stuck after escalation, " f"moved to stuck_tasks for monitoring (total stuck: {len(self._stuck_tasks)})") 

382 except asyncio.CancelledError: 

383 self._respond_tasks.pop(session_id, None) 

384 logger.info(f"Respond task cancelled after escalation for {session_id}") 

385 except Exception as e: 

386 self._respond_tasks.pop(session_id, None) 

387 logger.warning(f"Error during retry cancellation for {session_id}: {e}") 

388 else: 

389 self._respond_tasks.pop(session_id, None) 

390 logger.debug(f"Respond task completed during escalation for {session_id}") 

391 

392 except asyncio.CancelledError: 

393 # Cancellation successful - remove from tracking 

394 self._respond_tasks.pop(session_id, None) 

395 logger.debug(f"Respond task cancelled for session {session_id}") 

396 except Exception as e: 

397 # Remove from tracking on unexpected error - task state unknown 

398 self._respond_tasks.pop(session_id, None) 

399 logger.warning(f"Error during respond task cancellation for {session_id}: {e}") 

400 

401 async def _reap_stuck_tasks(self) -> None: 

402 """Periodically clean up stuck tasks that have completed. 

403 

404 This reaper runs every 30 seconds and: 

405 1. Removes completed tasks from _stuck_tasks 

406 2. Retries cancellation for tasks that are still running 

407 3. Logs warnings for tasks that remain stuck 

408 

409 This prevents memory leaks from tasks that eventually complete after 

410 being moved to _stuck_tasks during escalation. 

411 """ 

412 reap_interval = 30.0 # seconds 

413 retry_timeout = 2.0 # seconds for retry cancellation 

414 

415 while True: 

416 try: 

417 await asyncio.sleep(reap_interval) 

418 

419 if not self._stuck_tasks: 

420 continue 

421 

422 # Collect completed and still-stuck tasks 

423 completed = [] 

424 still_stuck = [] 

425 

426 for session_id, task in list(self._stuck_tasks.items()): 

427 if task.done(): 

428 completed.append(session_id) 

429 try: 

430 task.result() # Consume result to avoid warnings 

431 except (asyncio.CancelledError, Exception): 

432 pass 

433 else: 

434 still_stuck.append((session_id, task)) 

435 

436 # Remove completed tasks 

437 for session_id in completed: 

438 self._stuck_tasks.pop(session_id, None) 

439 

440 if completed: 

441 logger.info(f"Reaped {len(completed)} completed stuck tasks") 

442 

443 # Retry cancellation for still-stuck tasks 

444 for session_id, task in still_stuck: 

445 task.cancel() 

446 try: 

447 await asyncio.wait_for(task, timeout=retry_timeout) 

448 self._stuck_tasks.pop(session_id, None) 

449 logger.info(f"Stuck task {session_id} finally cancelled during reap") 

450 except asyncio.TimeoutError: 

451 logger.warning(f"Task {session_id} still stuck after reap retry") 

452 except asyncio.CancelledError: 

453 self._stuck_tasks.pop(session_id, None) 

454 logger.info(f"Stuck task {session_id} cancelled during reap") 

455 except Exception as e: 

456 logger.warning(f"Error during stuck task reap for {session_id}: {e}") 

457 

458 if self._stuck_tasks: 

459 logger.warning(f"Stuck tasks remaining: {len(self._stuck_tasks)}") 

460 

461 except asyncio.CancelledError: 

462 logger.debug("Stuck task reaper cancelled") 

463 break 

464 except Exception as e: 

465 logger.error(f"Error in stuck task reaper: {e}") 

466 

467 async def initialize(self) -> None: 

468 """Initialize the registry with async setup. 

469 

470 This method performs asynchronous initialization tasks that cannot be done 

471 in __init__. It starts background cleanup tasks and sets up pubsub 

472 subscriptions for distributed backends. 

473 

474 Call this during application startup after creating the registry instance. 

475 

476 Examples: 

477 >>> import asyncio 

478 >>> reg = SessionRegistry(backend='memory') 

479 >>> asyncio.run(reg.initialize()) 

480 >>> reg._cleanup_task is not None 

481 True 

482 >>> 

483 >>> # Cleanup 

484 >>> asyncio.run(reg.shutdown()) 

485 """ 

486 logger.info(f"Initializing session registry with backend: {self._backend}") 

487 

488 if self._backend == "database": 

489 # Start database cleanup task 

490 self._cleanup_task = asyncio.create_task(self._db_cleanup_task()) 

491 logger.info("Database cleanup task started") 

492 

493 elif self._backend == "redis": 

494 # Get shared Redis client from factory 

495 self._redis = await get_redis_client() 

496 if self._redis: 

497 self._pubsub = self._redis.pubsub() 

498 await self._pubsub.subscribe("mcp_session_events") 

499 logger.info("Session registry connected to shared Redis client") 

500 

501 elif self._backend == "none": 

502 # Nothing to initialize for none backend 

503 pass 

504 

505 # Memory backend needs session cleanup 

506 elif self._backend == "memory": 

507 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task()) 

508 logger.info("Memory cleanup task started") 

509 

510 # Start stuck task reaper for all backends 

511 self._stuck_task_reaper = asyncio.create_task(self._reap_stuck_tasks()) 

512 logger.info("Stuck task reaper started") 

513 

514 async def shutdown(self) -> None: 

515 """Shutdown the registry and clean up resources. 

516 

517 This method cancels background tasks and closes connections to external 

518 services. Call this during application shutdown to ensure clean termination. 

519 

520 Examples: 

521 >>> import asyncio 

522 >>> reg = SessionRegistry() 

523 >>> asyncio.run(reg.initialize()) 

524 >>> task_was_created = reg._cleanup_task is not None 

525 >>> asyncio.run(reg.shutdown()) 

526 >>> # After shutdown, cleanup task should be handled (cancelled or done) 

527 >>> task_was_created and (reg._cleanup_task.cancelled() or reg._cleanup_task.done()) 

528 True 

529 """ 

530 logger.info("Shutting down session registry") 

531 

532 # Cancel cleanup task 

533 if self._cleanup_task: 

534 self._cleanup_task.cancel() 

535 try: 

536 await self._cleanup_task 

537 except asyncio.CancelledError: 

538 pass 

539 

540 # Cancel stuck task reaper 

541 if self._stuck_task_reaper: 

542 self._stuck_task_reaper.cancel() 

543 try: 

544 await self._stuck_task_reaper 

545 except asyncio.CancelledError: 

546 pass 

547 

548 # CRITICAL: Cancel ALL respond tasks to prevent CPU spin loops 

549 if self._respond_tasks: 

550 logger.info(f"Cancelling {len(self._respond_tasks)} respond tasks") 

551 tasks_to_cancel = list(self._respond_tasks.values()) 

552 self._respond_tasks.clear() 

553 

554 for task in tasks_to_cancel: 

555 if not task.done(): 

556 task.cancel() 

557 

558 if tasks_to_cancel: 

559 try: 

560 await asyncio.wait_for(asyncio.gather(*tasks_to_cancel, return_exceptions=True), timeout=10.0) 

561 logger.info("All respond tasks cancelled successfully") 

562 except asyncio.TimeoutError: 

563 logger.warning("Timeout waiting for respond tasks to cancel") 

564 

565 # Also cancel any stuck tasks (tasks that previously couldn't be cancelled) 

566 if self._stuck_tasks: 

567 logger.warning(f"Attempting final cancellation of {len(self._stuck_tasks)} stuck tasks") 

568 stuck_to_cancel = list(self._stuck_tasks.values()) 

569 self._stuck_tasks.clear() 

570 

571 for task in stuck_to_cancel: 

572 if not task.done(): 

573 task.cancel() 

574 

575 if stuck_to_cancel: 

576 try: 

577 await asyncio.wait_for(asyncio.gather(*stuck_to_cancel, return_exceptions=True), timeout=5.0) 

578 logger.info("Stuck tasks cancelled during shutdown") 

579 except asyncio.TimeoutError: 

580 logger.error("Some stuck tasks could not be cancelled during shutdown") 

581 

582 # Close Redis pubsub (but not the shared client) 

583 # Use timeout to prevent blocking if pubsub doesn't close cleanly 

584 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout 

585 if self._backend == "redis" and getattr(self, "_pubsub", None): 

586 try: 

587 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout) 

588 except asyncio.TimeoutError: 

589 logger.warning("Redis pubsub close timed out - proceeding anyway") 

590 except Exception as e: 

591 logger.error(f"Error closing Redis pubsub: {e}") 

592 # Don't close self._redis - it's the shared client managed by redis_client.py 

593 self._redis = None 

594 self._pubsub = None 

595 

596 async def add_session(self, session_id: str, transport: SSETransport) -> None: 

597 """Add a session to the registry. 

598 

599 Stores the session in both the local cache and the distributed backend 

600 (if configured). For distributed backends, this notifies other workers 

601 about the new session. 

602 

603 Args: 

604 session_id: Unique session identifier. Should be a UUID or similar 

605 unique string to avoid collisions. 

606 transport: SSE transport object for this session. Must implement 

607 the SSETransport interface. 

608 

609 Examples: 

610 >>> import asyncio 

611 >>> from mcpgateway.cache.session_registry import SessionRegistry 

612 >>> 

613 >>> class MockTransport: 

614 ... async def disconnect(self): 

615 ... print(f"Transport disconnected") 

616 ... async def is_connected(self): 

617 ... return True 

618 >>> 

619 >>> reg = SessionRegistry() 

620 >>> transport = MockTransport() 

621 >>> asyncio.run(reg.add_session('test-456', transport)) 

622 >>> 

623 >>> # Found in local cache 

624 >>> found = asyncio.run(reg.get_session('test-456')) 

625 >>> found is transport 

626 True 

627 >>> 

628 >>> # Remove session 

629 >>> asyncio.run(reg.remove_session('test-456')) 

630 Transport disconnected 

631 """ 

632 # Skip for none backend 

633 if self._backend == "none": 

634 return 

635 

636 async with self._lock: 

637 self._sessions[session_id] = transport 

638 

639 if self._backend == "redis": 

640 # Store session marker in Redis 

641 if not self._redis: 

642 logger.warning(f"Redis client not initialized, skipping distributed session tracking for {session_id}") 

643 return 

644 try: 

645 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1") 

646 # Publish event to notify other workers 

647 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()})) 

648 except Exception as e: 

649 logger.error(f"Redis error adding session {session_id}: {e}") 

650 

651 elif self._backend == "database": 

652 # Store session in database 

653 try: 

654 

655 def _db_add() -> None: 

656 """Store session record in the database. 

657 

658 Creates a new SessionRecord entry in the database for tracking 

659 distributed session state. Uses a fresh database connection from 

660 the connection pool. 

661 

662 This inner function is designed to be run in a thread executor 

663 to avoid blocking the async event loop during database I/O. 

664 

665 Raises: 

666 Exception: Any database error is re-raised after rollback. 

667 Common errors include duplicate session_id (unique constraint) 

668 or database connection issues. 

669 

670 Examples: 

671 >>> # This function is called internally by add_session() 

672 >>> # When executed, it creates a database record: 

673 >>> # SessionRecord(session_id='abc123', created_at=now()) 

674 """ 

675 db_session = next(get_db()) 

676 try: 

677 session_record = SessionRecord(session_id=session_id) 

678 db_session.add(session_record) 

679 db_session.commit() 

680 except Exception as ex: 

681 db_session.rollback() 

682 raise ex 

683 finally: 

684 db_session.close() 

685 

686 await asyncio.to_thread(_db_add) 

687 except Exception as e: 

688 logger.error(f"Database error adding session {session_id}: {e}") 

689 

690 logger.info(f"Added session: {session_id}") 

691 

692 async def get_session(self, session_id: str) -> Any: 

693 """Get session transport by ID. 

694 

695 First checks the local cache for the transport object. If not found locally 

696 but using a distributed backend, checks if the session exists on another 

697 worker. 

698 

699 Args: 

700 session_id: Session identifier to look up. 

701 

702 Returns: 

703 SSETransport object if found locally, None if not found or exists 

704 on another worker. 

705 

706 Examples: 

707 >>> import asyncio 

708 >>> from mcpgateway.cache.session_registry import SessionRegistry 

709 >>> 

710 >>> class MockTransport: 

711 ... pass 

712 >>> 

713 >>> reg = SessionRegistry() 

714 >>> transport = MockTransport() 

715 >>> asyncio.run(reg.add_session('test-456', transport)) 

716 >>> 

717 >>> # Found in local cache 

718 >>> found = asyncio.run(reg.get_session('test-456')) 

719 >>> found is transport 

720 True 

721 >>> 

722 >>> # Not found 

723 >>> asyncio.run(reg.get_session('nonexistent')) is None 

724 True 

725 """ 

726 # Skip for none backend 

727 if self._backend == "none": 

728 return None 

729 

730 # First check local cache 

731 async with self._lock: 

732 transport = self._sessions.get(session_id) 

733 if transport: 

734 logger.info(f"Session {session_id} exists in local cache") 

735 return transport 

736 

737 # If not in local cache, check if it exists in shared backend 

738 if self._backend == "redis": 

739 if not self._redis: 

740 return None 

741 try: 

742 exists = await self._redis.exists(f"mcp:session:{session_id}") 

743 session_exists = bool(exists) 

744 if session_exists: 

745 logger.info(f"Session {session_id} exists in Redis but not in local cache") 

746 return None # We don't have the transport locally 

747 except Exception as e: 

748 logger.error(f"Redis error checking session {session_id}: {e}") 

749 return None 

750 

751 elif self._backend == "database": 

752 try: 

753 

754 def _db_check() -> bool: 

755 """Check if a session exists in the database. 

756 

757 Queries the SessionRecord table to determine if a session with 

758 the given session_id exists. This is used when the session is not 

759 found in the local cache to check if it exists on another worker. 

760 

761 This inner function is designed to be run in a thread executor 

762 to avoid blocking the async event loop during database queries. 

763 

764 Returns: 

765 bool: True if the session exists in the database, False otherwise. 

766 

767 Examples: 

768 >>> # This function is called internally by get_session() 

769 >>> # Returns True if SessionRecord with session_id exists 

770 >>> # Returns False if no matching record found 

771 """ 

772 db_session = next(get_db()) 

773 try: 

774 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

775 return record is not None 

776 finally: 

777 db_session.close() 

778 

779 exists = await asyncio.to_thread(_db_check) 

780 if exists: 

781 logger.info(f"Session {session_id} exists in database but not in local cache") 

782 return None 

783 except Exception as e: 

784 logger.error(f"Database error checking session {session_id}: {e}") 

785 return None 

786 

787 return None 

788 

789 async def remove_session(self, session_id: str) -> None: 

790 """Remove a session from the registry. 

791 

792 Removes the session from both local cache and distributed backend. 

793 If a transport is found locally, it will be disconnected before removal. 

794 For distributed backends, notifies other workers about the removal. 

795 

796 Args: 

797 session_id: Session identifier to remove. 

798 

799 Examples: 

800 >>> import asyncio 

801 >>> from mcpgateway.cache.session_registry import SessionRegistry 

802 >>> 

803 >>> class MockTransport: 

804 ... async def disconnect(self): 

805 ... print(f"Transport disconnected") 

806 ... async def is_connected(self): 

807 ... return True 

808 >>> 

809 >>> reg = SessionRegistry() 

810 >>> transport = MockTransport() 

811 >>> asyncio.run(reg.add_session('remove-test', transport)) 

812 >>> asyncio.run(reg.remove_session('remove-test')) 

813 Transport disconnected 

814 >>> 

815 >>> # Session no longer exists 

816 >>> asyncio.run(reg.get_session('remove-test')) is None 

817 True 

818 """ 

819 # Skip for none backend 

820 if self._backend == "none": 

821 return 

822 

823 # Mark session as closing FIRST so respond loop can exit early 

824 # This allows the loop to exit without waiting for cancellation to complete 

825 self._closing_sessions.add(session_id) 

826 

827 try: 

828 # CRITICAL: Cancel respond task before any cleanup 

829 # This prevents orphaned tasks that cause CPU spin loops 

830 await self._cancel_respond_task(session_id) 

831 

832 # Clean up local transport 

833 transport = None 

834 async with self._lock: 

835 if session_id in self._sessions: 

836 transport = self._sessions.pop(session_id) 

837 # Also clean up client capabilities 

838 if session_id in self._client_capabilities: 

839 self._client_capabilities.pop(session_id) 

840 logger.debug(f"Removed capabilities for session {session_id}") 

841 finally: 

842 # Always remove from closing set 

843 self._closing_sessions.discard(session_id) 

844 

845 # Disconnect transport if found 

846 if transport: 

847 try: 

848 await transport.disconnect() 

849 except Exception as e: 

850 logger.error(f"Error disconnecting transport for session {session_id}: {e}") 

851 

852 # Remove from shared backend 

853 if self._backend == "redis": 

854 if not self._redis: 

855 return 

856 try: 

857 await self._redis.delete(f"mcp:session:{session_id}") 

858 # Notify other workers 

859 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()})) 

860 except Exception as e: 

861 logger.error(f"Redis error removing session {session_id}: {e}") 

862 

863 elif self._backend == "database": 

864 try: 

865 

866 def _db_remove() -> None: 

867 """Delete session record from the database. 

868 

869 Removes the SessionRecord entry with the specified session_id 

870 from the database. This is called when a session is being 

871 terminated or has expired. 

872 

873 This inner function is designed to be run in a thread executor 

874 to avoid blocking the async event loop during database operations. 

875 

876 Raises: 

877 Exception: Any database error is re-raised after rollback. 

878 This includes connection errors or constraint violations. 

879 

880 Examples: 

881 >>> # This function is called internally by remove_session() 

882 >>> # Deletes the SessionRecord where session_id matches 

883 >>> # No error if session_id doesn't exist (idempotent) 

884 """ 

885 db_session = next(get_db()) 

886 try: 

887 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete() 

888 db_session.commit() 

889 except Exception as ex: 

890 db_session.rollback() 

891 raise ex 

892 finally: 

893 db_session.close() 

894 

895 await asyncio.to_thread(_db_remove) 

896 except Exception as e: 

897 logger.error(f"Database error removing session {session_id}: {e}") 

898 

899 logger.info(f"Removed session: {session_id}") 

900 

901 async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None: 

902 """Broadcast a message to a session. 

903 

904 Sends a message to the specified session. The behavior depends on the backend: 

905 

906 - **memory**: Stores message temporarily for local delivery 

907 - **redis**: Publishes message to Redis channel for the session 

908 - **database**: Stores message in database for polling by worker with session 

909 - **none**: No operation 

910 

911 This method is used for inter-process communication in distributed deployments. 

912 

913 Args: 

914 session_id: Target session identifier. 

915 message: Message to broadcast. Can be a dict, list, or any JSON-serializable object. 

916 

917 Examples: 

918 >>> import asyncio 

919 >>> from mcpgateway.cache.session_registry import SessionRegistry 

920 >>> 

921 >>> reg = SessionRegistry(backend='memory') 

922 >>> message = {'method': 'tools/list', 'id': 1} 

923 >>> asyncio.run(reg.broadcast('session-789', message)) 

924 >>> 

925 >>> # Message stored for memory backend 

926 >>> reg._session_message is not None 

927 True 

928 >>> reg._session_message['session_id'] 

929 'session-789' 

930 >>> orjson.loads(reg._session_message['message'])['message'] == message 

931 True 

932 """ 

933 # Skip for none backend only 

934 if self._backend == "none": 

935 return 

936 

937 def _build_payload(msg: Any) -> str: 

938 """Build a JSON payload for message broadcasting. 

939 

940 Args: 

941 msg: Message to wrap in payload envelope. 

942 

943 Returns: 

944 JSON-encoded string containing type, message, and timestamp. 

945 """ 

946 payload = {"type": "message", "message": msg, "timestamp": time.time()} 

947 return orjson.dumps(payload).decode() 

948 

949 if self._backend == "memory": 

950 payload_json = _build_payload(message) 

951 self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": payload_json} 

952 

953 elif self._backend == "redis": 

954 if not self._redis: 

955 logger.warning(f"Redis client not initialized, cannot broadcast to {session_id}") 

956 return 

957 try: 

958 broadcast_payload = { 

959 "type": "message", 

960 "message": message, # Keep as original type, not pre-encoded 

961 "timestamp": time.time(), 

962 } 

963 # Single encode 

964 payload_json = orjson.dumps(broadcast_payload) 

965 await self._redis.publish(session_id, payload_json) # Single encode 

966 except Exception as e: 

967 logger.error(f"Redis error during broadcast: {e}") 

968 elif self._backend == "database": 

969 try: 

970 msg_json = _build_payload(message) 

971 

972 def _db_add() -> None: 

973 """Store message in the database for inter-process communication. 

974 

975 Creates a new SessionMessageRecord entry containing the session_id 

976 and serialized message. This enables message passing between 

977 different worker processes through the shared database. 

978 

979 This inner function is designed to be run in a thread executor 

980 to avoid blocking the async event loop during database writes. 

981 

982 Raises: 

983 Exception: Any database error is re-raised after rollback. 

984 Common errors include database connection issues or 

985 constraints violations. 

986 

987 Examples: 

988 >>> # This function is called internally by broadcast() 

989 >>> # Creates a record like: 

990 >>> # SessionMessageRecord( 

991 >>> # session_id='abc123', 

992 >>> # message='{"method": "ping", "id": 1}', 

993 >>> # created_at=now() 

994 >>> # ) 

995 """ 

996 db_session = next(get_db()) 

997 try: 

998 message_record = SessionMessageRecord(session_id=session_id, message=msg_json) 

999 db_session.add(message_record) 

1000 db_session.commit() 

1001 except Exception as ex: 

1002 db_session.rollback() 

1003 raise ex 

1004 finally: 

1005 db_session.close() 

1006 

1007 await asyncio.to_thread(_db_add) 

1008 except Exception as e: 

1009 logger.error(f"Database error during broadcast: {e}") 

1010 

1011 async def _register_session_mapping(self, session_id: str, message: Dict[str, Any], user_email: Optional[str] = None) -> None: 

1012 """Register session mapping for session affinity when tools are called. 

1013 

1014 This method is called on the worker that executes the request (the SSE session 

1015 owner) to pre-register the mapping between a downstream session ID and the 

1016 upstream MCP session pool key. This enables session affinity in multi-worker 

1017 deployments. 

1018 

1019 Only registers mappings for tools/call methods - list operations and other 

1020 methods don't need session affinity since they don't maintain state. 

1021 

1022 Args: 

1023 session_id: The downstream SSE session ID. 

1024 message: The MCP protocol message being broadcast. 

1025 user_email: Optional user email for session isolation. 

1026 """ 

1027 # Skip if session affinity is disabled 

1028 if not settings.mcpgateway_session_affinity_enabled: 

1029 return 

1030 

1031 # Only register for tools/call - other methods don't need session affinity 

1032 method = message.get("method") 

1033 if method != "tools/call": 

1034 return 

1035 

1036 # Extract tool name from params 

1037 params = message.get("params", {}) 

1038 tool_name = params.get("name") 

1039 if not tool_name: 

1040 return 

1041 

1042 try: 

1043 # Look up tool in cache to get gateway info 

1044 # First-Party 

1045 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

1046 

1047 tool_info = await tool_lookup_cache.get(tool_name) 

1048 if not tool_info: 

1049 logger.debug(f"Tool {tool_name} not found in cache, skipping session mapping registration") 

1050 return 

1051 

1052 # Extract gateway information 

1053 gateway = tool_info.get("gateway", {}) 

1054 gateway_url = gateway.get("url") 

1055 gateway_id = gateway.get("id") 

1056 transport = gateway.get("transport") 

1057 

1058 if not gateway_url or not gateway_id or not transport: 

1059 logger.debug(f"Incomplete gateway info for tool {tool_name}, skipping session mapping registration") 

1060 return 

1061 

1062 # Register the session mapping with the pool 

1063 # First-Party 

1064 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel 

1065 

1066 pool = get_mcp_session_pool() 

1067 await pool.register_session_mapping( 

1068 session_id, 

1069 gateway_url, 

1070 gateway_id, 

1071 transport, 

1072 user_email, 

1073 ) 

1074 

1075 logger.debug(f"Registered session mapping for session {session_id[:8]}... -> {gateway_url} (tool: {tool_name})") 

1076 

1077 except Exception as e: 

1078 # Don't fail the broadcast if session mapping registration fails 

1079 logger.warning(f"Failed to register session mapping for {session_id[:8]}...: {e}") 

1080 

1081 async def get_all_session_ids(self) -> list[str]: 

1082 """Return a snapshot list of all known local session IDs. 

1083 

1084 Returns: 

1085 list[str]: A snapshot list of currently known local session IDs. 

1086 """ 

1087 async with self._lock: 

1088 return list(self._sessions.keys()) 

1089 

1090 def get_session_sync(self, session_id: str) -> Any: 

1091 """Get session synchronously from local cache only. 

1092 

1093 This is a non-blocking method that only checks the local cache, 

1094 not the distributed backend. Use this when you need quick access 

1095 and know the session should be local. 

1096 

1097 Args: 

1098 session_id: Session identifier to look up. 

1099 

1100 Returns: 

1101 SSETransport object if found in local cache, None otherwise. 

1102 

1103 Examples: 

1104 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1105 >>> import asyncio 

1106 >>> 

1107 >>> class MockTransport: 

1108 ... pass 

1109 >>> 

1110 >>> reg = SessionRegistry() 

1111 >>> transport = MockTransport() 

1112 >>> asyncio.run(reg.add_session('sync-test', transport)) 

1113 >>> 

1114 >>> # Synchronous lookup 

1115 >>> found = reg.get_session_sync('sync-test') 

1116 >>> found is transport 

1117 True 

1118 >>> 

1119 >>> # Not found 

1120 >>> reg.get_session_sync('nonexistent') is None 

1121 True 

1122 """ 

1123 # Skip for none backend 

1124 if self._backend == "none": 

1125 return None 

1126 

1127 return self._sessions.get(session_id) 

1128 

1129 async def respond( 

1130 self, 

1131 server_id: Optional[str], 

1132 user: Dict[str, Any], 

1133 session_id: str, 

1134 base_url: str, 

1135 ) -> None: 

1136 """Process and respond to broadcast messages for a session. 

1137 

1138 This method listens for messages directed to the specified session and 

1139 generates appropriate responses. The listening mechanism depends on the backend: 

1140 

1141 - **memory**: Checks the temporary message storage 

1142 - **redis**: Subscribes to Redis pubsub channel 

1143 - **database**: Polls database for new messages 

1144 

1145 When a message is received and the transport exists locally, it processes 

1146 the message and sends the response through the transport. 

1147 

1148 Args: 

1149 server_id: Optional server identifier for scoped operations. 

1150 user: User information including authentication token. 

1151 session_id: Session identifier to respond for. 

1152 base_url: Base URL for API calls (used for RPC endpoints). 

1153 

1154 Raises: 

1155 asyncio.CancelledError: When the respond task is cancelled (e.g., on session removal). 

1156 

1157 Examples: 

1158 >>> import asyncio 

1159 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1160 >>> 

1161 >>> # This method is typically called internally by the SSE handler 

1162 >>> reg = SessionRegistry() 

1163 >>> user = {'token': 'test-token'} 

1164 >>> # asyncio.run(reg.respond(None, user, 'session-id', 'http://localhost')) 

1165 """ 

1166 

1167 if self._backend == "none": 

1168 pass 

1169 

1170 elif self._backend == "memory": 

1171 transport = self.get_session_sync(session_id) 

1172 if transport and self._session_message: 

1173 message_json = self._session_message.get("message") 

1174 if message_json: 

1175 data = orjson.loads(message_json) 

1176 if isinstance(data, dict) and "message" in data: 

1177 message = data["message"] 

1178 else: 

1179 message = data 

1180 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) 

1181 else: 

1182 logger.warning(f"Session message stored but message content is None for session {session_id}") 

1183 

1184 elif self._backend == "redis": 

1185 if not self._redis: 

1186 logger.warning(f"Redis client not initialized, cannot respond to {session_id}") 

1187 return 

1188 pubsub = self._redis.pubsub() 

1189 await pubsub.subscribe(session_id) 

1190 

1191 # Use timeout-based polling instead of infinite listen() to allow exit checks 

1192 # This is critical for allowing cancellation to work (Finding 2) 

1193 poll_timeout = 1.0 # Check every second if session still exists 

1194 

1195 try: 

1196 while True: 

1197 # Check if session still exists or is closing - exit early 

1198 if session_id not in self._sessions or session_id in self._closing_sessions: 

1199 logger.info(f"Session {session_id} removed or closing, exiting Redis respond loop") 

1200 break 

1201 

1202 # Use get_message with timeout instead of blocking listen() 

1203 try: 

1204 msg = await asyncio.wait_for( 

1205 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), timeout=poll_timeout + 0.5 # Slightly longer to account for Redis timeout 

1206 ) 

1207 except asyncio.TimeoutError: 

1208 # No message, loop back to check session existence 

1209 continue 

1210 

1211 if msg is None: 

1212 # CRITICAL: Sleep to prevent tight loop when get_message returns immediately 

1213 # This can happen in certain Redis states after disconnects 

1214 await asyncio.sleep(0.1) 

1215 continue 

1216 if msg["type"] != "message": 

1217 # Sleep on non-message types to prevent spin in edge cases 

1218 await asyncio.sleep(0.1) 

1219 continue 

1220 

1221 data = orjson.loads(msg["data"]) 

1222 message = data.get("message", {}) 

1223 transport = self.get_session_sync(session_id) 

1224 if transport: 

1225 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) 

1226 except asyncio.CancelledError: 

1227 logger.info(f"PubSub listener for session {session_id} cancelled") 

1228 raise # Re-raise to properly complete cancellation 

1229 except Exception as e: 

1230 logger.error(f"PubSub listener error for session {session_id}: {e}") 

1231 finally: 

1232 # Pubsub cleanup first - use timeouts to prevent blocking 

1233 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout 

1234 try: 

1235 await asyncio.wait_for(pubsub.unsubscribe(session_id), timeout=cleanup_timeout) 

1236 except asyncio.TimeoutError: 

1237 logger.debug(f"Pubsub unsubscribe timed out for session {session_id}") 

1238 except Exception as e: 

1239 logger.debug(f"Error unsubscribing pubsub for session {session_id}: {e}") 

1240 try: 

1241 try: 

1242 await asyncio.wait_for(pubsub.aclose(), timeout=cleanup_timeout) 

1243 except AttributeError: 

1244 await asyncio.wait_for(pubsub.close(), timeout=cleanup_timeout) 

1245 except asyncio.TimeoutError: 

1246 logger.debug(f"Pubsub close timed out for session {session_id}") 

1247 except Exception as e: 

1248 logger.debug(f"Error closing pubsub for session {session_id}: {e}") 

1249 logger.info(f"Cleaned up pubsub for session {session_id}") 

1250 # Clean up task reference LAST (idempotent - may already be removed by _cancel_respond_task) 

1251 self._respond_tasks.pop(session_id, None) 

1252 

1253 elif self._backend == "database": 

1254 

1255 def _db_read_session_and_message( 

1256 session_id: str, 

1257 ) -> tuple[SessionRecord | None, SessionMessageRecord | None]: 

1258 """ 

1259 Check whether a session exists and retrieve its next pending message 

1260 in a single database query. 

1261 

1262 This function performs a LEFT OUTER JOIN between SessionRecord and 

1263 SessionMessageRecord to determine: 

1264 

1265 - Whether the session still exists 

1266 - Whether there is a pending message for the session (FIFO order) 

1267 

1268 It is used by the database-backed message polling loop to reduce 

1269 database load by collapsing multiple reads into a single query. 

1270 

1271 Messages are returned in FIFO order based on the message primary key. 

1272 

1273 This function is designed to be run in a thread executor to avoid 

1274 blocking the async event loop during database access. 

1275 

1276 Args: 

1277 session_id: The session identifier to look up. 

1278 

1279 Returns: 

1280 Tuple[SessionRecord | None, SessionMessageRecord | None]: 

1281 

1282 - (None, None) 

1283 The session does not exist. 

1284 

1285 - (SessionRecord, None) 

1286 The session exists but has no pending messages. 

1287 

1288 - (SessionRecord, SessionMessageRecord) 

1289 The session exists and has a pending message. 

1290 

1291 Raises: 

1292 Exception: Any database error is re-raised after rollback. 

1293 

1294 Examples: 

1295 >>> # This function is called internally by message_check_loop() 

1296 >>> # Session exists and has a pending message 

1297 >>> # Returns (SessionRecord, SessionMessageRecord) 

1298 

1299 >>> # Session exists but has no pending messages 

1300 >>> # Returns (SessionRecord, None) 

1301 

1302 >>> # Session has been removed 

1303 >>> # Returns (None, None) 

1304 """ 

1305 db_session = next(get_db()) 

1306 try: 

1307 result = ( 

1308 db_session.query(SessionRecord, SessionMessageRecord) 

1309 .outerjoin( 

1310 SessionMessageRecord, 

1311 SessionMessageRecord.session_id == SessionRecord.session_id, 

1312 ) 

1313 .filter(SessionRecord.session_id == session_id) 

1314 .order_by(SessionMessageRecord.id.asc()) 

1315 .first() 

1316 ) 

1317 if not result: 

1318 return None, None 

1319 session, message = result 

1320 return session, message 

1321 except Exception as ex: 

1322 db_session.rollback() 

1323 raise ex 

1324 finally: 

1325 db_session.close() 

1326 

1327 def _db_remove(session_id: str, message: str) -> None: 

1328 """Remove processed message from the database. 

1329 

1330 Deletes a specific message record after it has been successfully 

1331 processed and sent to the transport. This prevents duplicate 

1332 message delivery. 

1333 

1334 This inner function is designed to be run in a thread executor 

1335 to avoid blocking the async event loop during database deletes. 

1336 

1337 Args: 

1338 session_id: The session identifier the message belongs to. 

1339 message: The exact message content to remove (must match exactly). 

1340 

1341 Raises: 

1342 Exception: Any database error is re-raised after rollback. 

1343 

1344 Examples: 

1345 >>> # This function is called internally after message processing 

1346 >>> # Deletes the specific SessionMessageRecord entry 

1347 >>> # Log: "Removed message from mcp_messages table" 

1348 """ 

1349 db_session = next(get_db()) 

1350 try: 

1351 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete() 

1352 db_session.commit() 

1353 logger.info("Removed message from mcp_messages table") 

1354 except Exception as ex: 

1355 db_session.rollback() 

1356 raise ex 

1357 finally: 

1358 db_session.close() 

1359 

1360 async def message_check_loop(session_id: str) -> None: 

1361 """ 

1362 Background task that polls the database for messages belonging to a session 

1363 using adaptive polling with exponential backoff. 

1364 

1365 The loop continues until the session is removed from the database. 

1366 

1367 Behavior: 

1368 - Starts with a fast polling interval for low-latency message delivery. 

1369 - When no message is found, the polling interval increases exponentially 

1370 (up to a configured maximum) to reduce database load. 

1371 - When a message is received, the polling interval is immediately reset 

1372 to the fast interval. 

1373 - The loop exits as soon as the session no longer exists. 

1374 

1375 Polling rules: 

1376 - Message found → process message, reset polling interval. 

1377 - No message → increase polling interval (backoff). 

1378 - Session gone → stop polling immediately. 

1379 

1380 Args: 

1381 session_id (str): Unique identifier of the session to monitor. 

1382 

1383 Raises: 

1384 asyncio.CancelledError: When the polling loop is cancelled. 

1385 

1386 Examples 

1387 -------- 

1388 Adaptive backoff when no messages are present: 

1389 

1390 >>> poll_interval = 0.1 

1391 >>> backoff_factor = 1.5 

1392 >>> max_interval = 5.0 

1393 >>> poll_interval = min(poll_interval * backoff_factor, max_interval) 

1394 >>> poll_interval 

1395 0.15000000000000002 

1396 

1397 Backoff continues until the maximum interval is reached: 

1398 

1399 >>> poll_interval = 4.0 

1400 >>> poll_interval = min(poll_interval * 1.5, 5.0) 

1401 >>> poll_interval 

1402 5.0 

1403 

1404 Polling interval resets immediately when a message arrives: 

1405 

1406 >>> poll_interval = 2.0 

1407 >>> poll_interval = 0.1 

1408 >>> poll_interval 

1409 0.1 

1410 

1411 Session termination stops polling: 

1412 

1413 >>> session_exists = False 

1414 >>> if not session_exists: 

1415 ... "polling stopped" 

1416 'polling stopped' 

1417 """ 

1418 

1419 poll_interval = settings.poll_interval # start fast 

1420 max_interval = settings.max_interval # cap at configured maximum 

1421 backoff_factor = settings.backoff_factor 

1422 try: 

1423 while True: 

1424 # Check if session is closing before querying DB 

1425 if session_id in self._closing_sessions: 

1426 logger.debug("Session %s closing, stopping poll loop early", session_id) 

1427 break 

1428 

1429 session, record = await asyncio.to_thread(_db_read_session_and_message, session_id) 

1430 

1431 # session gone → stop polling 

1432 if not session: 

1433 logger.debug("Session %s no longer exists, stopping poll loop", session_id) 

1434 break 

1435 

1436 if record: 

1437 poll_interval = settings.poll_interval # reset on activity 

1438 

1439 data = orjson.loads(record.message) 

1440 if isinstance(data, dict) and "message" in data: 

1441 message = data["message"] 

1442 else: 

1443 message = data 

1444 

1445 transport = self.get_session_sync(session_id) 

1446 if transport: 

1447 logger.info("Ready to respond") 

1448 await self.generate_response( 

1449 message=message, 

1450 transport=transport, 

1451 server_id=server_id, 

1452 user=user, 

1453 base_url=base_url, 

1454 ) 

1455 

1456 await asyncio.to_thread(_db_remove, session_id, record.message) 

1457 else: 

1458 # no message → backoff 

1459 # update polling interval with backoff factor 

1460 poll_interval = min(poll_interval * backoff_factor, max_interval) 

1461 

1462 await asyncio.sleep(poll_interval) 

1463 except asyncio.CancelledError: 

1464 logger.info(f"Message check loop cancelled for session {session_id}") 

1465 raise # Re-raise to properly complete cancellation 

1466 except Exception as e: 

1467 logger.error(f"Message check loop error for session {session_id}: {e}") 

1468 

1469 # CRITICAL: Await instead of fire-and-forget 

1470 # This ensures CancelledError propagates from outer respond() task to inner loop 

1471 # The outer task (registered from main.py) now runs until message_check_loop exits 

1472 try: 

1473 await message_check_loop(session_id) 

1474 except asyncio.CancelledError: 

1475 logger.info(f"Database respond cancelled for session {session_id}") 

1476 raise 

1477 finally: 

1478 # Clean up task reference on ANY exit (normal, cancelled, or error) 

1479 # Prevents stale done tasks from accumulating in _respond_tasks 

1480 self._respond_tasks.pop(session_id, None) 

1481 

1482 async def _refresh_redis_sessions(self) -> None: 

1483 """Refresh TTLs for Redis sessions and clean up disconnected sessions. 

1484 

1485 This internal method is used by the Redis backend to maintain session state. 

1486 It checks all local sessions, refreshes TTLs for connected sessions, and 

1487 removes disconnected ones. 

1488 """ 

1489 if not self._redis: 

1490 return 

1491 try: 

1492 # Check all local sessions 

1493 local_transports = {} 

1494 async with self._lock: 

1495 local_transports = self._sessions.copy() 

1496 

1497 for session_id, transport in local_transports.items(): 

1498 try: 

1499 if await transport.is_connected(): 

1500 # Refresh TTL in Redis 

1501 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl) 

1502 else: 

1503 # Remove disconnected session 

1504 await self.remove_session(session_id) 

1505 except Exception as e: 

1506 logger.error(f"Error refreshing session {session_id}: {e}") 

1507 

1508 except Exception as e: 

1509 logger.error(f"Error in Redis session refresh: {e}") 

1510 

1511 async def _db_cleanup_task(self) -> None: 

1512 """Background task to clean up expired database sessions. 

1513 

1514 Runs periodically (every 5 minutes) to remove expired sessions from the 

1515 database and refresh timestamps for active sessions. This prevents the 

1516 database from accumulating stale session records. 

1517 

1518 The task also verifies that local sessions still exist in the database 

1519 and removes them locally if they've been deleted elsewhere. 

1520 """ 

1521 logger.info("Starting database cleanup task") 

1522 while True: 

1523 try: 

1524 # Clean up expired sessions every 5 minutes 

1525 def _db_cleanup() -> int: 

1526 """Remove expired sessions from the database. 

1527 

1528 Deletes all SessionRecord entries that haven't been accessed 

1529 within the session TTL period. Uses database-specific date 

1530 arithmetic to calculate expiry time. 

1531 

1532 This inner function is designed to be run in a thread executor 

1533 to avoid blocking the async event loop during bulk deletes. 

1534 

1535 Returns: 

1536 int: Number of expired session records deleted. 

1537 

1538 Raises: 

1539 Exception: Any database error is re-raised after rollback. 

1540 

1541 Examples: 

1542 >>> # This function is called periodically by _db_cleanup_task() 

1543 >>> # Deletes sessions older than session_ttl seconds 

1544 >>> # Returns count of deleted records for logging 

1545 >>> # Log: "Cleaned up 5 expired database sessions" 

1546 """ 

1547 db_session = next(get_db()) 

1548 try: 

1549 # Delete sessions that haven't been accessed for TTL seconds 

1550 # Use Python datetime for database-agnostic expiry calculation 

1551 expiry_time = datetime.now(timezone.utc) - timedelta(seconds=self._session_ttl) 

1552 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete() 

1553 db_session.commit() 

1554 return result 

1555 except Exception as ex: 

1556 db_session.rollback() 

1557 raise ex 

1558 finally: 

1559 db_session.close() 

1560 

1561 deleted = await asyncio.to_thread(_db_cleanup) 

1562 if deleted > 0: 

1563 logger.info(f"Cleaned up {deleted} expired database sessions") 

1564 

1565 # Check local sessions against database 

1566 await self._cleanup_database_sessions() 

1567 

1568 await asyncio.sleep(300) # Run every 5 minutes 

1569 

1570 except asyncio.CancelledError: 

1571 logger.info("Database cleanup task cancelled") 

1572 break 

1573 except Exception as e: 

1574 logger.error(f"Error in database cleanup task: {e}") 

1575 await asyncio.sleep(600) # Sleep longer on error 

1576 

1577 def _refresh_session_db(self, session_id: str) -> bool: 

1578 """Update session's last accessed timestamp in the database. 

1579 

1580 Refreshes the last_accessed field for an active session to 

1581 prevent it from being cleaned up as expired. This is called 

1582 periodically for all local sessions with active transports. 

1583 

1584 Args: 

1585 session_id: The session identifier to refresh. 

1586 

1587 Returns: 

1588 bool: True if the session was found and updated, False if not found. 

1589 

1590 Raises: 

1591 Exception: Any database error is re-raised after rollback. 

1592 """ 

1593 db_session = next(get_db()) 

1594 try: 

1595 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1596 if session: 

1597 session.last_accessed = func.now() # pylint: disable=not-callable 

1598 db_session.commit() 

1599 return True 

1600 return False 

1601 except Exception as ex: 

1602 db_session.rollback() 

1603 raise ex 

1604 finally: 

1605 db_session.close() 

1606 

1607 async def _cleanup_database_sessions(self, max_concurrent: int = 20) -> None: 

1608 """Parallelize session cleanup with bounded concurrency. 

1609 

1610 Checks connection status first (fast), then refreshes connected sessions 

1611 in parallel using asyncio.gather() with a semaphore to limit concurrent 

1612 DB operations and prevent resource exhaustion. 

1613 

1614 Args: 

1615 max_concurrent: Maximum number of concurrent DB refresh operations. 

1616 Defaults to 20 to balance parallelism with resource usage. 

1617 """ 

1618 async with self._lock: 

1619 local_transports = self._sessions.copy() 

1620 

1621 # Check connections first (fast) 

1622 connected: list[str] = [] 

1623 for session_id, transport in local_transports.items(): 

1624 try: 

1625 if not await transport.is_connected(): 

1626 await self.remove_session(session_id) 

1627 else: 

1628 connected.append(session_id) 

1629 except Exception as e: 

1630 # Only log error, don't remove session on transient errors 

1631 logger.error(f"Error checking connection for session {session_id}: {e}") 

1632 

1633 # Parallel refresh of connected sessions with bounded concurrency 

1634 if connected: 

1635 semaphore = asyncio.Semaphore(max_concurrent) 

1636 

1637 async def bounded_refresh(session_id: str) -> bool: 

1638 """Refresh session with semaphore-bounded concurrency. 

1639 

1640 Args: 

1641 session_id: The session ID to refresh. 

1642 

1643 Returns: 

1644 True if refresh succeeded, False otherwise. 

1645 """ 

1646 async with semaphore: 

1647 return await asyncio.to_thread(self._refresh_session_db, session_id) 

1648 

1649 refresh_tasks = [bounded_refresh(session_id) for session_id in connected] 

1650 results = await asyncio.gather(*refresh_tasks, return_exceptions=True) 

1651 

1652 for session_id, result in zip(connected, results): 

1653 try: 

1654 if isinstance(result, Exception): 

1655 # Only log error, don't remove session on transient DB errors 

1656 logger.error(f"Error refreshing session {session_id}: {result}") 

1657 elif not result: 

1658 # Session no longer in database, remove locally 

1659 await self.remove_session(session_id) 

1660 except Exception as e: 

1661 logger.error(f"Error processing refresh result for session {session_id}: {e}") 

1662 

1663 async def _memory_cleanup_task(self) -> None: 

1664 """Background task to clean up disconnected sessions in memory backend. 

1665 

1666 Runs periodically (every minute) to check all local sessions and remove 

1667 those that are no longer connected. This prevents memory leaks from 

1668 accumulating disconnected transport objects. 

1669 """ 

1670 logger.info("Starting memory cleanup task") 

1671 while True: 

1672 try: 

1673 # Check all local sessions 

1674 local_transports = {} 

1675 async with self._lock: 

1676 local_transports = self._sessions.copy() 

1677 

1678 for session_id, transport in local_transports.items(): 

1679 try: 

1680 if not await transport.is_connected(): 

1681 await self.remove_session(session_id) 

1682 except Exception as e: 

1683 logger.error(f"Error checking session {session_id}: {e}") 

1684 await self.remove_session(session_id) 

1685 

1686 await asyncio.sleep(60) # Run every minute 

1687 

1688 except asyncio.CancelledError: 

1689 logger.info("Memory cleanup task cancelled") 

1690 break 

1691 except Exception as e: 

1692 logger.error(f"Error in memory cleanup task: {e}") 

1693 await asyncio.sleep(300) # Sleep longer on error 

1694 

1695 def _get_oauth_experimental_config(self, server_id: str) -> Optional[Dict[str, Dict[str, Any]]]: 

1696 """Query OAuth configuration for a server (synchronous, run in threadpool). 

1697 

1698 This method queries the database for OAuth configuration and returns 

1699 RFC 9728-safe fields for advertising in MCP capabilities. 

1700 

1701 Args: 

1702 server_id: The server ID to query OAuth configuration for. 

1703 

1704 Returns: 

1705 Dict with 'oauth' key containing safe OAuth config, or None if not configured. 

1706 """ 

1707 # First-Party 

1708 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel 

1709 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel 

1710 

1711 db = SessionLocal() 

1712 try: 

1713 server = db.get(DbServer, server_id) 

1714 if server and getattr(server, "oauth_enabled", False) and getattr(server, "oauth_config", None): 

1715 # Filter oauth_config to RFC 9728-safe fields only (never expose secrets) 

1716 oauth_config = server.oauth_config 

1717 safe_oauth: Dict[str, Any] = {} 

1718 

1719 # Extract authorization servers 

1720 if oauth_config.get("authorization_servers"): 

1721 safe_oauth["authorization_servers"] = oauth_config["authorization_servers"] 

1722 elif oauth_config.get("authorization_server"): 

1723 safe_oauth["authorization_servers"] = [oauth_config["authorization_server"]] 

1724 

1725 # Extract scopes 

1726 scopes = oauth_config.get("scopes_supported") or oauth_config.get("scopes") 

1727 if scopes: 

1728 safe_oauth["scopes_supported"] = scopes 

1729 

1730 # Add bearer methods 

1731 safe_oauth["bearer_methods_supported"] = oauth_config.get("bearer_methods_supported", ["header"]) 

1732 

1733 if safe_oauth.get("authorization_servers"): 

1734 logger.debug(f"Advertising OAuth capability for server {server_id}") 

1735 return {"oauth": safe_oauth} 

1736 return None 

1737 finally: 

1738 db.close() 

1739 

1740 # Handle initialize logic 

1741 async def handle_initialize_logic(self, body: Dict[str, Any], session_id: Optional[str] = None, server_id: Optional[str] = None) -> InitializeResult: 

1742 """Process MCP protocol initialization request. 

1743 

1744 Validates the protocol version and returns server capabilities and information. 

1745 This method implements the MCP (Model Context Protocol) initialization handshake. 

1746 

1747 Args: 

1748 body: Request body containing protocol_version and optional client_info. 

1749 Expected keys: 'protocol_version' or 'protocolVersion', 'capabilities'. 

1750 session_id: Optional session ID to associate client capabilities with. 

1751 server_id: Optional server ID to query OAuth configuration for RFC 9728 support. 

1752 

1753 Returns: 

1754 InitializeResult containing protocol version, server capabilities, and server info. 

1755 

1756 Raises: 

1757 HTTPException: If protocol_version is missing (400 Bad Request with MCP error code -32002). 

1758 

1759 Examples: 

1760 >>> import asyncio 

1761 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1762 >>> 

1763 >>> reg = SessionRegistry() 

1764 >>> body = {'protocol_version': '2025-06-18'} 

1765 >>> result = asyncio.run(reg.handle_initialize_logic(body)) 

1766 >>> result.protocol_version 

1767 '2025-06-18' 

1768 >>> result.server_info.name 

1769 'MCP_Gateway' 

1770 >>> 

1771 >>> # Missing protocol version 

1772 >>> try: 

1773 ... asyncio.run(reg.handle_initialize_logic({})) 

1774 ... except HTTPException as e: 

1775 ... e.status_code 

1776 400 

1777 """ 

1778 protocol_version = body.get("protocol_version") or body.get("protocolVersion") 

1779 client_capabilities = body.get("capabilities", {}) 

1780 # body.get("client_info") or body.get("clientInfo", {}) 

1781 

1782 if not protocol_version: 

1783 raise HTTPException( 

1784 status_code=status.HTTP_400_BAD_REQUEST, 

1785 detail="Missing protocol version", 

1786 headers={"MCP-Error-Code": "-32002"}, 

1787 ) 

1788 

1789 if protocol_version != settings.protocol_version: 

1790 logger.warning(f"Using non default protocol version: {protocol_version}") 

1791 

1792 # Store client capabilities if session_id provided 

1793 if session_id and client_capabilities: 

1794 await self.store_client_capabilities(session_id, client_capabilities) 

1795 logger.debug(f"Stored capabilities for session {session_id}: {client_capabilities}") 

1796 

1797 # Build experimental capabilities (including OAuth if configured) 

1798 experimental: Optional[Dict[str, Dict[str, Any]]] = None 

1799 

1800 # Query OAuth configuration if server_id is provided 

1801 if server_id: 

1802 try: 

1803 # Run synchronous DB query in threadpool to avoid blocking the event loop 

1804 experimental = await asyncio.to_thread(self._get_oauth_experimental_config, server_id) 

1805 except Exception as e: 

1806 logger.warning(f"Failed to query OAuth config for server {server_id}: {e}") 

1807 

1808 return InitializeResult( 

1809 protocolVersion=protocol_version, 

1810 capabilities=ServerCapabilities( 

1811 prompts={"listChanged": True}, 

1812 resources={"subscribe": True, "listChanged": True}, 

1813 tools={"listChanged": True}, 

1814 logging={}, 

1815 completions={}, # Advertise completions capability per MCP spec 

1816 experimental=experimental, # OAuth capability when configured 

1817 ), 

1818 serverInfo=Implementation(name=settings.app_name, version=__version__), 

1819 instructions=("MCP Gateway providing federated tools, resources and prompts. Use /admin interface for configuration."), 

1820 ) 

1821 

1822 async def store_client_capabilities(self, session_id: str, capabilities: Dict[str, Any]) -> None: 

1823 """Store client capabilities for a session. 

1824 

1825 Args: 

1826 session_id: The session ID 

1827 capabilities: Client capabilities dictionary from initialize request 

1828 """ 

1829 async with self._lock: 

1830 self._client_capabilities[session_id] = capabilities 

1831 logger.debug(f"Stored capabilities for session {session_id}") 

1832 

1833 async def get_client_capabilities(self, session_id: str) -> Optional[Dict[str, Any]]: 

1834 """Get client capabilities for a session. 

1835 

1836 Args: 

1837 session_id: The session ID 

1838 

1839 Returns: 

1840 Client capabilities dictionary, or None if not found 

1841 """ 

1842 async with self._lock: 

1843 return self._client_capabilities.get(session_id) 

1844 

1845 async def has_elicitation_capability(self, session_id: str) -> bool: 

1846 """Check if a session has elicitation capability. 

1847 

1848 Args: 

1849 session_id: The session ID 

1850 

1851 Returns: 

1852 True if session supports elicitation, False otherwise 

1853 """ 

1854 capabilities = await self.get_client_capabilities(session_id) 

1855 if not capabilities: 

1856 return False 

1857 # Check if elicitation capability exists in client capabilities 

1858 return bool(capabilities.get("elicitation")) 

1859 

1860 async def get_elicitation_capable_sessions(self) -> list[str]: 

1861 """Get list of session IDs that support elicitation. 

1862 

1863 Returns: 

1864 List of session IDs with elicitation capability 

1865 """ 

1866 async with self._lock: 

1867 capable_sessions = [] 

1868 for session_id, capabilities in self._client_capabilities.items(): 

1869 if capabilities.get("elicitation"): 

1870 # Verify session still exists 

1871 if session_id in self._sessions: 

1872 capable_sessions.append(session_id) 

1873 return capable_sessions 

1874 

1875 async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any], base_url: str) -> None: 

1876 """Generate and send response for incoming MCP protocol message. 

1877 

1878 Processes MCP protocol messages and generates appropriate responses based on 

1879 the method. Supports various MCP methods including initialization, tool/resource/prompt 

1880 listing, tool invocation, and ping. 

1881 

1882 Args: 

1883 message: Incoming MCP message as JSON. Must contain 'method' and 'id' fields. 

1884 transport: SSE transport to send responses through. 

1885 server_id: Optional server ID for scoped operations. 

1886 user: User information containing authentication token. 

1887 base_url: Base URL for constructing RPC endpoints. 

1888 

1889 Examples: 

1890 >>> import asyncio 

1891 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1892 >>> 

1893 >>> class MockTransport: 

1894 ... async def send_message(self, msg): 

1895 ... print(f"Response: {msg['method'] if 'method' in msg else msg.get('result', {})}") 

1896 >>> 

1897 >>> reg = SessionRegistry() 

1898 >>> transport = MockTransport() 

1899 >>> message = {"method": "ping", "id": 1} 

1900 >>> user = {"token": "test-token"} 

1901 >>> # asyncio.run(reg.generate_response(message, transport, None, user, "http://localhost")) 

1902 >>> # Response: {} 

1903 """ 

1904 result = {} 

1905 

1906 if "method" in message and "id" in message: 

1907 method = message["method"] 

1908 params = message.get("params", {}) 

1909 params["server_id"] = server_id 

1910 req_id = message["id"] 

1911 

1912 rpc_input = { 

1913 "jsonrpc": "2.0", 

1914 "method": method, 

1915 "params": params, 

1916 "id": req_id, 

1917 } 

1918 # Get the token from the current authentication context 

1919 # The user object should contain auth_token, token_teams, and is_admin from the SSE endpoint 

1920 token = None 

1921 is_admin = user.get("is_admin", False) # Preserve admin status from SSE endpoint 

1922 

1923 try: 

1924 if hasattr(user, "get") and user.get("auth_token"): 

1925 token = user["auth_token"] 

1926 else: 

1927 # Fallback: create lightweight session token (teams resolved server-side by downstream /rpc) 

1928 logger.warning("No auth token available for SSE RPC call - creating fallback session token") 

1929 now = datetime.now(timezone.utc) 

1930 payload = { 

1931 "sub": user.get("email", "system"), 

1932 "iss": settings.jwt_issuer, 

1933 "aud": settings.jwt_audience, 

1934 "iat": int(now.timestamp()), 

1935 "jti": str(uuid.uuid4()), 

1936 "token_use": "session", # nosec B105 - token type marker, not a password 

1937 "user": { 

1938 "email": user.get("email", "system"), 

1939 "full_name": user.get("full_name", "System"), 

1940 "is_admin": is_admin, # Preserve admin status for cookie-authenticated admins 

1941 "auth_provider": "internal", 

1942 }, 

1943 } 

1944 # Generate token using centralized token creation 

1945 token = await create_jwt_token(payload) 

1946 

1947 # Pass downstream session id to /rpc for session affinity. 

1948 # This is gateway-internal only; the pool strips it before contacting upstream MCP servers. 

1949 if settings.mcpgateway_session_affinity_enabled: 

1950 await self._register_session_mapping(transport.session_id, message, user.get("email") if hasattr(user, "get") else None) 

1951 

1952 headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} 

1953 if settings.mcpgateway_session_affinity_enabled: 

1954 headers["x-mcp-session-id"] = transport.session_id 

1955 # Extract root URL from base_url (remove /servers/{id} path) 

1956 parsed_url = urlparse(base_url) 

1957 # Preserve the path up to the root path (before /servers/{id}) 

1958 path_parts = parsed_url.path.split("/") 

1959 if "/servers/" in parsed_url.path: 

1960 # Find the index of 'servers' and take everything before it 

1961 try: 

1962 servers_index = path_parts.index("servers") 

1963 root_path = "/" + "/".join(path_parts[1:servers_index]).strip("/") 

1964 if root_path == "/": 

1965 root_path = "" 

1966 except ValueError: 

1967 root_path = "" 

1968 else: 

1969 root_path = parsed_url.path.rstrip("/") 

1970 

1971 root_url = f"{parsed_url.scheme}://{parsed_url.netloc}{root_path}" 

1972 rpc_url = root_url + "/rpc" 

1973 

1974 logger.info(f"SSE RPC: Making call to {rpc_url} with method={method}, params={params}") 

1975 

1976 async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client: 

1977 logger.info(f"SSE RPC: Sending request to {rpc_url}") 

1978 rpc_response = await client.post( 

1979 url=rpc_url, 

1980 json=rpc_input, 

1981 headers=headers, 

1982 ) 

1983 logger.info(f"SSE RPC: Got response status {rpc_response.status_code}") 

1984 result = rpc_response.json() 

1985 logger.info(f"SSE RPC: Response content: {result}") 

1986 result = result.get("result", {}) 

1987 

1988 response = {"jsonrpc": "2.0", "result": result, "id": req_id} 

1989 except JSONRPCError as e: 

1990 logger.error(f"SSE RPC: JSON-RPC error: {e}") 

1991 result = e.to_dict() 

1992 response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id} 

1993 except Exception as e: 

1994 logger.error(f"SSE RPC: Exception during RPC call: {type(e).__name__}: {e}") 

1995 logger.error(f"SSE RPC: Traceback: {traceback.format_exc()}") 

1996 result = {"code": -32000, "message": "Internal error", "data": str(e)} 

1997 response = {"jsonrpc": "2.0", "error": result, "id": req_id} 

1998 

1999 logging.debug(f"Sending sse message:{response}") 

2000 await transport.send_message(response) 

2001 

2002 if message["method"] == "initialize": 

2003 await transport.send_message( 

2004 { 

2005 "jsonrpc": "2.0", 

2006 "method": "notifications/initialized", 

2007 "params": {}, 

2008 } 

2009 ) 

2010 notifications = [ 

2011 "tools/list_changed", 

2012 "resources/list_changed", 

2013 "prompts/list_changed", 

2014 ] 

2015 for notification in notifications: 

2016 await transport.send_message( 

2017 { 

2018 "jsonrpc": "2.0", 

2019 "method": f"notifications/{notification}", 

2020 "params": {}, 

2021 } 

2022 )