Coverage for mcpgateway / cache / session_registry.py: 100%

974 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/session_registry.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Session Registry with optional distributed state. 

8This module provides a registry for SSE sessions with support for distributed deployment 

9using Redis or SQLAlchemy as optional backends for shared state between workers. 

10 

11The SessionRegistry class manages server-sent event (SSE) sessions across multiple 

12worker processes, enabling horizontal scaling of MCP gateway deployments. It supports 

13three backend modes: 

14 

15- **memory**: In-memory storage for single-process deployments (default) 

16- **redis**: Redis-backed shared storage for multi-worker deployments 

17- **database**: SQLAlchemy-backed shared storage using any supported database 

18 

19In distributed mode (redis/database), session existence is tracked in the shared 

20backend while transport objects remain local to each worker process. This allows 

21workers to know about sessions on other workers and route messages appropriately. 

22 

23Examples: 

24 Basic usage with memory backend: 

25 

26 >>> from mcpgateway.cache.session_registry import SessionRegistry 

27 >>> class DummyTransport: 

28 ... async def disconnect(self): 

29 ... pass 

30 ... async def is_connected(self): 

31 ... return True 

32 >>> import asyncio 

33 >>> reg = SessionRegistry(backend='memory') 

34 >>> transport = DummyTransport() 

35 >>> asyncio.run(reg.add_session('sid123', transport)) 

36 >>> found = asyncio.run(reg.get_session('sid123')) 

37 >>> isinstance(found, DummyTransport) 

38 True 

39 >>> asyncio.run(reg.remove_session('sid123')) 

40 >>> asyncio.run(reg.get_session('sid123')) is None 

41 True 

42 

43 Broadcasting messages: 

44 

45 >>> reg = SessionRegistry(backend='memory') 

46 >>> asyncio.run(reg.broadcast('sid123', {'method': 'ping', 'id': 1})) 

47 >>> reg._session_message is not None 

48 True 

49""" 

50 

51# Standard 

52import asyncio 

53from asyncio import Task 

54from datetime import datetime, timedelta, timezone 

55import logging 

56import time 

57import traceback 

58from typing import Any, Dict, Optional 

59import uuid 

60 

61# Third-Party 

62from fastapi import HTTPException, status 

63import orjson 

64 

65# First-Party 

66from mcpgateway import __version__ 

67from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities 

68from mcpgateway.config import settings 

69from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord 

70from mcpgateway.services import PromptService, ResourceService, ToolService 

71from mcpgateway.services.logging_service import LoggingService 

72from mcpgateway.transports import SSETransport 

73from mcpgateway.utils.create_jwt_token import create_jwt_token 

74from mcpgateway.utils.internal_http import internal_loopback_base_url, internal_loopback_verify 

75from mcpgateway.utils.redis_client import get_redis_client 

76from mcpgateway.utils.retry_manager import ResilientHttpClient 

77from mcpgateway.validation.jsonrpc import JSONRPCError 

78 

79# Initialize logging service first 

80logging_service: LoggingService = LoggingService() 

81logger = logging_service.get_logger(__name__) 

82 

83tool_service: ToolService = ToolService() 

84resource_service: ResourceService = ResourceService() 

85prompt_service: PromptService = PromptService() 

86 

87try: 

88 # Third-Party 

89 from redis.asyncio import Redis 

90 

91 REDIS_AVAILABLE = True 

92except ImportError: 

93 REDIS_AVAILABLE = False 

94 

95try: 

96 # Third-Party 

97 from sqlalchemy import func 

98 

99 SQLALCHEMY_AVAILABLE = True 

100except ImportError: 

101 SQLALCHEMY_AVAILABLE = False 

102 

103 

104class SessionBackend: 

105 """Base class for session registry backend configuration. 

106 

107 This class handles the initialization and configuration of different backend 

108 types for session storage. It validates backend requirements and sets up 

109 necessary connections for Redis or database backends. 

110 

111 Attributes: 

112 _backend: The backend type ('memory', 'redis', 'database', or 'none') 

113 _session_ttl: Time-to-live for sessions in seconds 

114 _message_ttl: Time-to-live for messages in seconds 

115 _redis: Redis connection instance (redis backend only) 

116 _pubsub: Redis pubsub instance (redis backend only) 

117 _session_message: Temporary message storage (memory backend only) 

118 

119 Examples: 

120 >>> backend = SessionBackend(backend='memory') 

121 >>> backend._backend 

122 'memory' 

123 >>> backend._session_ttl 

124 3600 

125 

126 >>> try: 

127 ... backend = SessionBackend(backend='redis') 

128 ... except ValueError as e: 

129 ... str(e) 

130 'Redis backend requires redis_url' 

131 """ 

132 

133 def __init__( 

134 self, 

135 backend: str = "memory", 

136 redis_url: Optional[str] = None, 

137 database_url: Optional[str] = None, 

138 session_ttl: int = 3600, # 1 hour 

139 message_ttl: int = 600, # 10 min 

140 ): 

141 """Initialize session backend configuration. 

142 

143 Args: 

144 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'. 

145 - 'memory': In-memory storage, suitable for single-process deployments 

146 - 'redis': Redis-backed storage for multi-worker deployments 

147 - 'database': SQLAlchemy-backed storage for multi-worker deployments 

148 - 'none': No session tracking (dummy registry) 

149 redis_url: Redis connection URL. Required when backend='redis'. 

150 Format: 'redis://[:password]@host:port/db' 

151 database_url: Database connection URL. Required when backend='database'. 

152 Format depends on database type (e.g., 'postgresql://user:pass@host/db') 

153 session_ttl: Session time-to-live in seconds. Sessions are automatically 

154 cleaned up after this duration of inactivity. Default: 3600 (1 hour). 

155 message_ttl: Message time-to-live in seconds. Undelivered messages are 

156 removed after this duration. Default: 600 (10 minutes). 

157 

158 Raises: 

159 ValueError: If backend is invalid, required URL is missing, or required packages are not installed. 

160 

161 Examples: 

162 >>> # Memory backend (default) 

163 >>> backend = SessionBackend() 

164 >>> backend._backend 

165 'memory' 

166 

167 >>> # Redis backend requires URL 

168 >>> try: 

169 ... backend = SessionBackend(backend='redis') 

170 ... except ValueError as e: 

171 ... 'redis_url' in str(e) 

172 True 

173 

174 >>> # Invalid backend 

175 >>> try: 

176 ... backend = SessionBackend(backend='invalid') 

177 ... except ValueError as e: 

178 ... 'Invalid backend' in str(e) 

179 True 

180 """ 

181 

182 self._backend = backend.lower() 

183 self._session_ttl = session_ttl 

184 self._message_ttl = message_ttl 

185 

186 # Set up backend-specific components 

187 if self._backend == "memory": 

188 # Nothing special needed for memory backend 

189 self._session_message: dict[str, Any] | None = None 

190 

191 elif self._backend == "none": 

192 # No session tracking - this is just a dummy registry 

193 logger.info("Session registry initialized with 'none' backend - session tracking disabled") 

194 

195 elif self._backend == "redis": 

196 if not REDIS_AVAILABLE: 

197 raise ValueError("Redis backend requested but redis package not installed") 

198 if not redis_url: 

199 raise ValueError("Redis backend requires redis_url") 

200 

201 # Redis client is set in initialize() via the shared factory 

202 self._redis: Optional[Redis] = None 

203 self._pubsub = None 

204 

205 elif self._backend == "database": 

206 if not SQLALCHEMY_AVAILABLE: 

207 raise ValueError("Database backend requested but SQLAlchemy not installed") 

208 if not database_url: 

209 raise ValueError("Database backend requires database_url") 

210 else: 

211 raise ValueError(f"Invalid backend: {backend}") 

212 

213 

214class SessionRegistry(SessionBackend): 

215 """Registry for SSE sessions with optional distributed state. 

216 

217 This class manages server-sent event (SSE) sessions, providing methods to add, 

218 remove, and query sessions. It supports multiple backend types for different 

219 deployment scenarios: 

220 

221 - **Single-process deployments**: Use 'memory' backend (default) 

222 - **Multi-worker deployments**: Use 'redis' or 'database' backend 

223 - **Testing/development**: Use 'none' backend to disable session tracking 

224 

225 The registry maintains a local cache of transport objects while using the 

226 shared backend to track session existence across workers. This enables 

227 horizontal scaling while keeping transport objects process-local. 

228 

229 Attributes: 

230 _sessions: Local dictionary mapping session IDs to transport objects 

231 _lock: Asyncio lock for thread-safe access to _sessions 

232 _cleanup_task: Background task for cleaning up expired sessions 

233 

234 Examples: 

235 >>> import asyncio 

236 >>> from mcpgateway.cache.session_registry import SessionRegistry 

237 >>> 

238 >>> class MockTransport: 

239 ... async def disconnect(self): 

240 ... print("Disconnected") 

241 ... async def is_connected(self): 

242 ... return True 

243 ... async def send_message(self, msg): 

244 ... print(f"Sent: {msg}") 

245 >>> 

246 >>> # Create registry and add session 

247 >>> reg = SessionRegistry(backend='memory') 

248 >>> transport = MockTransport() 

249 >>> asyncio.run(reg.add_session('test123', transport)) 

250 >>> 

251 >>> # Retrieve session 

252 >>> found = asyncio.run(reg.get_session('test123')) 

253 >>> found is transport 

254 True 

255 >>> 

256 >>> # Remove session 

257 >>> asyncio.run(reg.remove_session('test123')) 

258 Disconnected 

259 >>> asyncio.run(reg.get_session('test123')) is None 

260 True 

261 """ 

262 

263 def __init__( 

264 self, 

265 backend: str = "memory", 

266 redis_url: Optional[str] = None, 

267 database_url: Optional[str] = None, 

268 session_ttl: int = 3600, # 1 hour 

269 message_ttl: int = 600, # 10 min 

270 ): 

271 """Initialize session registry with specified backend. 

272 

273 Args: 

274 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'. 

275 redis_url: Redis connection URL. Required when backend='redis'. 

276 database_url: Database connection URL. Required when backend='database'. 

277 session_ttl: Session time-to-live in seconds. Default: 3600. 

278 message_ttl: Message time-to-live in seconds. Default: 600. 

279 

280 Examples: 

281 >>> # Default memory backend 

282 >>> reg = SessionRegistry() 

283 >>> reg._backend 

284 'memory' 

285 >>> isinstance(reg._sessions, dict) 

286 True 

287 

288 >>> # Redis backend with custom TTL 

289 >>> try: 

290 ... reg = SessionRegistry( 

291 ... backend='redis', 

292 ... redis_url='redis://localhost:6379', 

293 ... session_ttl=7200 

294 ... ) 

295 ... except ValueError: 

296 ... pass # Redis may not be available 

297 """ 

298 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl) 

299 self._sessions: Dict[str, Any] = {} # Local transport cache 

300 self._session_owners: Dict[str, str] = {} # Session owner email by session_id 

301 self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id 

302 self._respond_tasks: Dict[str, asyncio.Task] = {} # Track respond tasks for cancellation 

303 self._stuck_tasks: Dict[str, asyncio.Task] = {} # Tasks that couldn't be cancelled (for monitoring) 

304 self._closing_sessions: set[str] = set() # Sessions being closed - respond loop should exit 

305 self._lock = asyncio.Lock() 

306 self._cleanup_task: Task | None = None 

307 self._stuck_task_reaper: Task | None = None # Reaper for stuck tasks 

308 

309 def register_respond_task(self, session_id: str, task: asyncio.Task) -> None: 

310 """Register a respond task for later cancellation. 

311 

312 Associates an asyncio Task with a session_id so it can be cancelled 

313 when the session is removed. This prevents orphaned tasks that cause 

314 CPU spin loops. 

315 

316 Args: 

317 session_id: Session identifier the task belongs to. 

318 task: The asyncio Task to track. 

319 """ 

320 self._respond_tasks[session_id] = task 

321 logger.debug(f"Registered respond task for session {session_id}") 

322 

323 async def _cancel_respond_task(self, session_id: str, timeout: float = 5.0) -> None: 

324 """Cancel and await a respond task with timeout. 

325 

326 Safely cancels the respond task associated with a session. Uses a timeout 

327 to prevent hanging if the task doesn't respond to cancellation. 

328 

329 If initial cancellation times out, escalates by force-disconnecting the 

330 transport to unblock the task, then retries cancellation (Finding 1 fix). 

331 

332 Args: 

333 session_id: Session identifier whose task should be cancelled. 

334 timeout: Maximum seconds to wait for task cancellation. Default 5.0. 

335 """ 

336 task = self._respond_tasks.get(session_id) 

337 if task is None: 

338 return 

339 

340 if task.done(): 

341 # Task already finished - safe to remove from tracking 

342 self._respond_tasks.pop(session_id, None) 

343 try: 

344 task.result() 

345 except asyncio.CancelledError: 

346 pass 

347 except Exception as e: 

348 logger.warning(f"Respond task for {session_id} failed with: {e}") 

349 return 

350 

351 task.cancel() 

352 logger.debug(f"Cancelling respond task for session {session_id}") 

353 

354 try: 

355 await asyncio.wait_for(task, timeout=timeout) 

356 # Cancellation successful - remove from tracking 

357 self._respond_tasks.pop(session_id, None) 

358 logger.debug(f"Respond task cancelled for session {session_id}") 

359 except asyncio.TimeoutError: 

360 # ESCALATION (Finding 1): Force-disconnect transport to unblock the task 

361 logger.warning(f"Respond task cancellation timed out for {session_id}, " f"escalating with transport disconnect") 

362 

363 # Force-disconnect the transport to unblock any pending I/O 

364 transport = self._sessions.get(session_id) 

365 if transport and hasattr(transport, "disconnect"): 

366 try: 

367 await transport.disconnect() 

368 logger.debug(f"Force-disconnected transport for {session_id}") 

369 except Exception as e: 

370 logger.warning(f"Failed to force-disconnect transport for {session_id}: {e}") 

371 

372 # Retry cancellation with shorter timeout 

373 if not task.done(): 

374 try: 

375 await asyncio.wait_for(task, timeout=2.0) 

376 self._respond_tasks.pop(session_id, None) 

377 logger.info(f"Respond task cancelled after escalation for {session_id}") 

378 except asyncio.TimeoutError: 

379 # Still stuck - move to stuck_tasks for monitoring (Finding 2 fix) 

380 self._respond_tasks.pop(session_id, None) 

381 self._stuck_tasks[session_id] = task 

382 logger.error(f"Respond task for {session_id} still stuck after escalation, " f"moved to stuck_tasks for monitoring (total stuck: {len(self._stuck_tasks)})") 

383 except asyncio.CancelledError: 

384 self._respond_tasks.pop(session_id, None) 

385 logger.info(f"Respond task cancelled after escalation for {session_id}") 

386 except Exception as e: 

387 self._respond_tasks.pop(session_id, None) 

388 logger.warning(f"Error during retry cancellation for {session_id}: {e}") 

389 else: 

390 self._respond_tasks.pop(session_id, None) 

391 logger.debug(f"Respond task completed during escalation for {session_id}") 

392 

393 except asyncio.CancelledError: 

394 # Cancellation successful - remove from tracking 

395 self._respond_tasks.pop(session_id, None) 

396 logger.debug(f"Respond task cancelled for session {session_id}") 

397 except Exception as e: 

398 # Remove from tracking on unexpected error - task state unknown 

399 self._respond_tasks.pop(session_id, None) 

400 logger.warning(f"Error during respond task cancellation for {session_id}: {e}") 

401 

402 async def _reap_stuck_tasks(self) -> None: 

403 """Periodically clean up stuck tasks that have completed. 

404 

405 This reaper runs every 30 seconds and: 

406 1. Removes completed tasks from _stuck_tasks 

407 2. Retries cancellation for tasks that are still running 

408 3. Logs warnings for tasks that remain stuck 

409 

410 This prevents memory leaks from tasks that eventually complete after 

411 being moved to _stuck_tasks during escalation. 

412 

413 Raises: 

414 asyncio.CancelledError: If the task is cancelled during shutdown. 

415 """ 

416 reap_interval = 30.0 # seconds 

417 retry_timeout = 2.0 # seconds for retry cancellation 

418 

419 while True: 

420 try: 

421 await asyncio.sleep(reap_interval) 

422 

423 if not self._stuck_tasks: 

424 continue 

425 

426 # Collect completed and still-stuck tasks 

427 completed = [] 

428 still_stuck = [] 

429 

430 for session_id, task in list(self._stuck_tasks.items()): 

431 if task.done(): 

432 completed.append(session_id) 

433 try: 

434 task.result() # Consume result to avoid warnings 

435 except (asyncio.CancelledError, Exception): 

436 pass 

437 else: 

438 still_stuck.append((session_id, task)) 

439 

440 # Remove completed tasks 

441 for session_id in completed: 

442 self._stuck_tasks.pop(session_id, None) 

443 

444 if completed: 

445 logger.info(f"Reaped {len(completed)} completed stuck tasks") 

446 

447 # Retry cancellation for still-stuck tasks 

448 for session_id, task in still_stuck: 

449 task.cancel() 

450 try: 

451 await asyncio.wait_for(task, timeout=retry_timeout) 

452 self._stuck_tasks.pop(session_id, None) 

453 logger.info(f"Stuck task {session_id} finally cancelled during reap") 

454 except asyncio.TimeoutError: 

455 logger.warning(f"Task {session_id} still stuck after reap retry") 

456 except asyncio.CancelledError: 

457 self._stuck_tasks.pop(session_id, None) 

458 logger.info(f"Stuck task {session_id} cancelled during reap") 

459 except Exception as e: 

460 logger.warning(f"Error during stuck task reap for {session_id}: {e}") 

461 

462 if self._stuck_tasks: 

463 logger.warning(f"Stuck tasks remaining: {len(self._stuck_tasks)}") 

464 

465 except asyncio.CancelledError: 

466 logger.debug("Stuck task reaper cancelled") 

467 raise 

468 except Exception as e: 

469 logger.error(f"Error in stuck task reaper: {e}") 

470 

471 async def initialize(self) -> None: 

472 """Initialize the registry with async setup. 

473 

474 This method performs asynchronous initialization tasks that cannot be done 

475 in __init__. It starts background cleanup tasks and sets up pubsub 

476 subscriptions for distributed backends. 

477 

478 Call this during application startup after creating the registry instance. 

479 

480 Examples: 

481 >>> import asyncio 

482 >>> reg = SessionRegistry(backend='memory') 

483 >>> asyncio.run(reg.initialize()) 

484 >>> reg._cleanup_task is not None 

485 True 

486 >>> 

487 >>> # Cleanup 

488 >>> asyncio.run(reg.shutdown()) 

489 """ 

490 logger.info(f"Initializing session registry with backend: {self._backend}") 

491 

492 if self._backend == "database": 

493 # Start database cleanup task 

494 self._cleanup_task = asyncio.create_task(self._db_cleanup_task()) 

495 logger.info("Database cleanup task started") 

496 

497 elif self._backend == "redis": 

498 # Get shared Redis client from factory 

499 self._redis = await get_redis_client() 

500 if self._redis: 

501 self._pubsub = self._redis.pubsub() 

502 await self._pubsub.subscribe("mcp_session_events") 

503 logger.info("Session registry connected to shared Redis client") 

504 

505 elif self._backend == "none": 

506 # Nothing to initialize for none backend 

507 pass 

508 

509 # Memory backend needs session cleanup 

510 elif self._backend == "memory": 

511 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task()) 

512 logger.info("Memory cleanup task started") 

513 

514 # Start stuck task reaper for all backends 

515 self._stuck_task_reaper = asyncio.create_task(self._reap_stuck_tasks()) 

516 logger.info("Stuck task reaper started") 

517 

518 async def shutdown(self) -> None: 

519 """Shutdown the registry and clean up resources. 

520 

521 This method cancels background tasks and closes connections to external 

522 services. Call this during application shutdown to ensure clean termination. 

523 

524 Examples: 

525 >>> import asyncio 

526 >>> reg = SessionRegistry() 

527 >>> asyncio.run(reg.initialize()) 

528 >>> task_was_created = reg._cleanup_task is not None 

529 >>> asyncio.run(reg.shutdown()) 

530 >>> # After shutdown, cleanup task should be handled (cancelled or done) 

531 >>> task_was_created and (reg._cleanup_task.cancelled() or reg._cleanup_task.done()) 

532 True 

533 """ 

534 logger.info("Shutting down session registry") 

535 

536 # Cancel cleanup task 

537 if self._cleanup_task: 

538 self._cleanup_task.cancel() 

539 try: 

540 await self._cleanup_task 

541 except asyncio.CancelledError: 

542 pass 

543 

544 # Cancel stuck task reaper 

545 if self._stuck_task_reaper: 

546 self._stuck_task_reaper.cancel() 

547 try: 

548 await self._stuck_task_reaper 

549 except asyncio.CancelledError: 

550 pass 

551 

552 # CRITICAL: Cancel ALL respond tasks to prevent CPU spin loops 

553 if self._respond_tasks: 

554 logger.info(f"Cancelling {len(self._respond_tasks)} respond tasks") 

555 tasks_to_cancel = list(self._respond_tasks.values()) 

556 self._respond_tasks.clear() 

557 

558 for task in tasks_to_cancel: 

559 if not task.done(): 

560 task.cancel() 

561 

562 if tasks_to_cancel: 

563 try: 

564 await asyncio.wait_for(asyncio.gather(*tasks_to_cancel, return_exceptions=True), timeout=10.0) 

565 logger.info("All respond tasks cancelled successfully") 

566 except asyncio.TimeoutError: 

567 logger.warning("Timeout waiting for respond tasks to cancel") 

568 

569 # Also cancel any stuck tasks (tasks that previously couldn't be cancelled) 

570 if self._stuck_tasks: 

571 logger.warning(f"Attempting final cancellation of {len(self._stuck_tasks)} stuck tasks") 

572 stuck_to_cancel = list(self._stuck_tasks.values()) 

573 self._stuck_tasks.clear() 

574 

575 for task in stuck_to_cancel: 

576 if not task.done(): 

577 task.cancel() 

578 

579 if stuck_to_cancel: 

580 try: 

581 await asyncio.wait_for(asyncio.gather(*stuck_to_cancel, return_exceptions=True), timeout=5.0) 

582 logger.info("Stuck tasks cancelled during shutdown") 

583 except asyncio.TimeoutError: 

584 logger.error("Some stuck tasks could not be cancelled during shutdown") 

585 

586 # Close Redis pubsub (but not the shared client) 

587 # Use timeout to prevent blocking if pubsub doesn't close cleanly 

588 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout 

589 if self._backend == "redis" and getattr(self, "_pubsub", None): 

590 try: 

591 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout) 

592 except asyncio.TimeoutError: 

593 logger.warning("Redis pubsub close timed out - proceeding anyway") 

594 except Exception as e: 

595 logger.error(f"Error closing Redis pubsub: {e}") 

596 # Don't close self._redis - it's the shared client managed by redis_client.py 

597 self._redis = None 

598 self._pubsub = None 

599 

600 async def add_session(self, session_id: str, transport: SSETransport) -> None: 

601 """Add a session to the registry. 

602 

603 Stores the session in both the local cache and the distributed backend 

604 (if configured). For distributed backends, this notifies other workers 

605 about the new session. 

606 

607 Args: 

608 session_id: Unique session identifier. Should be a UUID or similar 

609 unique string to avoid collisions. 

610 transport: SSE transport object for this session. Must implement 

611 the SSETransport interface. 

612 

613 Examples: 

614 >>> import asyncio 

615 >>> from mcpgateway.cache.session_registry import SessionRegistry 

616 >>> 

617 >>> class MockTransport: 

618 ... async def disconnect(self): 

619 ... print(f"Transport disconnected") 

620 ... async def is_connected(self): 

621 ... return True 

622 >>> 

623 >>> reg = SessionRegistry() 

624 >>> transport = MockTransport() 

625 >>> asyncio.run(reg.add_session('test-456', transport)) 

626 >>> 

627 >>> # Found in local cache 

628 >>> found = asyncio.run(reg.get_session('test-456')) 

629 >>> found is transport 

630 True 

631 >>> 

632 >>> # Remove session 

633 >>> asyncio.run(reg.remove_session('test-456')) 

634 Transport disconnected 

635 """ 

636 # Skip for none backend 

637 if self._backend == "none": 

638 return 

639 

640 async with self._lock: 

641 self._sessions[session_id] = transport 

642 

643 if self._backend == "redis": 

644 # Store session marker in Redis 

645 if not self._redis: 

646 logger.warning(f"Redis client not initialized, skipping distributed session tracking for {session_id}") 

647 return 

648 try: 

649 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1") 

650 # Publish event to notify other workers 

651 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()})) 

652 except Exception as e: 

653 logger.error(f"Redis error adding session {session_id}: {e}") 

654 

655 elif self._backend == "database": 

656 # Store session in database 

657 try: 

658 

659 def _db_add() -> None: 

660 """Store session record in the database. 

661 

662 Creates a new SessionRecord entry in the database for tracking 

663 distributed session state. Uses a fresh database connection from 

664 the connection pool. 

665 

666 This inner function is designed to be run in a thread executor 

667 to avoid blocking the async event loop during database I/O. 

668 

669 Raises: 

670 Exception: Any database error is re-raised after rollback. 

671 Common errors include duplicate session_id (unique constraint) 

672 or database connection issues. 

673 

674 Examples: 

675 >>> # This function is called internally by add_session() 

676 >>> # When executed, it creates a database record: 

677 >>> # SessionRecord(session_id='abc123', created_at=now()) 

678 """ 

679 db_session = next(get_db()) 

680 try: 

681 session_record = SessionRecord(session_id=session_id) 

682 db_session.add(session_record) 

683 db_session.commit() 

684 except Exception as ex: 

685 db_session.rollback() 

686 raise ex 

687 finally: 

688 db_session.close() 

689 

690 await asyncio.to_thread(_db_add) 

691 except Exception as e: 

692 logger.error(f"Database error adding session {session_id}: {e}") 

693 

694 logger.info(f"Added session: {session_id}") 

695 

696 def _session_owner_key(self, session_id: str) -> str: 

697 """Return Redis key used to store session ownership. 

698 

699 Args: 

700 session_id: Session identifier. 

701 

702 Returns: 

703 Redis key for the session owner mapping. 

704 """ 

705 return f"mcp:session_owner:{session_id}" 

706 

707 async def set_session_owner(self, session_id: str, owner_email: Optional[str]) -> None: 

708 """Set or clear owner for a session. 

709 

710 Args: 

711 session_id: Session identifier to update. 

712 owner_email: Owner email to set. Passing ``None`` clears ownership. 

713 

714 Returns: 

715 None. 

716 """ 

717 # Skip for none backend 

718 if self._backend == "none": 

719 return 

720 

721 if owner_email: 

722 self._session_owners[session_id] = owner_email 

723 else: 

724 self._session_owners.pop(session_id, None) 

725 

726 if self._backend == "redis": 

727 if not self._redis: 

728 logger.warning(f"Redis client not initialized, cannot set owner for session {session_id}") 

729 return 

730 try: 

731 owner_key = self._session_owner_key(session_id) 

732 if owner_email: 

733 await self._redis.setex(owner_key, self._session_ttl, owner_email) 

734 else: 

735 await self._redis.delete(owner_key) 

736 except Exception as e: 

737 logger.error(f"Redis error setting owner for session {session_id}: {e}") 

738 

739 elif self._backend == "database": 

740 try: 

741 

742 def _db_set_owner() -> None: 

743 """Persist owner metadata for a session in the database backend. 

744 

745 Raises: 

746 Exception: Propagates database write failures to the caller. 

747 """ 

748 db_session = next(get_db()) 

749 try: 

750 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

751 if not record: 

752 record = SessionRecord(session_id=session_id, data=None) 

753 db_session.add(record) 

754 db_session.flush() 

755 

756 record_data: Dict[str, Any] = {} 

757 if record.data: 

758 try: 

759 parsed = orjson.loads(record.data) 

760 if isinstance(parsed, dict): 

761 record_data = parsed 

762 except Exception: 

763 record_data = {} 

764 

765 if owner_email: 

766 record_data["owner_email"] = owner_email 

767 else: 

768 record_data.pop("owner_email", None) 

769 

770 record.data = orjson.dumps(record_data).decode() if record_data else None 

771 db_session.commit() 

772 except Exception as ex: 

773 db_session.rollback() 

774 raise ex 

775 finally: 

776 db_session.close() 

777 

778 await asyncio.to_thread(_db_set_owner) 

779 except Exception as e: 

780 logger.error(f"Database error setting owner for session {session_id}: {e}") 

781 

782 async def claim_session_owner(self, session_id: str, owner_email: str) -> Optional[str]: 

783 """Atomically claim ownership for a session and return the effective owner. 

784 

785 This method provides compare-and-set semantics: 

786 - If a session owner already exists, return the existing owner. 

787 - If no owner exists, claim ownership for ``owner_email``. 

788 

789 Args: 

790 session_id: Session identifier to claim. 

791 owner_email: Requesting owner email. 

792 

793 Returns: 

794 Effective owner email after the claim operation, or ``None`` if owner 

795 metadata could not be verified due to backend availability issues. 

796 """ 

797 if self._backend == "none": 

798 return owner_email 

799 

800 # Fast local cache path. 

801 cached_owner = self._session_owners.get(session_id) 

802 if cached_owner: 

803 return cached_owner 

804 

805 if self._backend == "memory": 

806 async with self._lock: 

807 existing_owner = self._session_owners.get(session_id) 

808 if existing_owner: 

809 return existing_owner 

810 self._session_owners[session_id] = owner_email 

811 return owner_email 

812 

813 if self._backend == "redis": 

814 if not self._redis: 

815 logger.warning("Redis client not initialized, session owner claim unavailable for %s", session_id) 

816 return None 

817 

818 owner_key = self._session_owner_key(session_id) 

819 try: 

820 claimed = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True) 

821 if claimed: 

822 self._session_owners[session_id] = owner_email 

823 return owner_email 

824 

825 owner_raw = await self._redis.get(owner_key) 

826 if owner_raw is not None: 

827 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw) 

828 if existing_owner: 

829 self._session_owners[session_id] = existing_owner 

830 return existing_owner 

831 

832 # Handle key expiry/race window by retrying once. 

833 claimed_retry = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True) 

834 if claimed_retry: 

835 self._session_owners[session_id] = owner_email 

836 return owner_email 

837 

838 owner_raw = await self._redis.get(owner_key) 

839 if owner_raw is None: 

840 return None 

841 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw) 

842 if existing_owner: 

843 self._session_owners[session_id] = existing_owner 

844 return existing_owner 

845 return None 

846 except Exception as e: 

847 logger.error("Redis error claiming owner for session %s: %s", session_id, e) 

848 return None 

849 

850 if self._backend == "database": 

851 try: 

852 

853 def _db_claim_owner() -> Optional[str]: 

854 """Claim owner in DB with optimistic compare-and-set retries. 

855 

856 Returns: 

857 Effective owner email or ``None`` when claim cannot be verified. 

858 """ 

859 db_session = next(get_db()) 

860 try: 

861 for _attempt in range(3): 

862 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

863 if not record: 

864 owner_payload = {"owner_email": owner_email} 

865 new_record = SessionRecord(session_id=session_id, data=orjson.dumps(owner_payload).decode()) 

866 db_session.add(new_record) 

867 try: 

868 db_session.commit() 

869 return owner_email 

870 except Exception: 

871 db_session.rollback() 

872 # Another writer may have inserted concurrently. Retry. 

873 continue 

874 

875 record_data: Dict[str, Any] = {} 

876 if record.data: 

877 try: 

878 parsed = orjson.loads(record.data) 

879 if isinstance(parsed, dict): 

880 record_data = parsed 

881 except Exception: 

882 record_data = {} 

883 

884 existing_owner = record_data.get("owner_email") 

885 if isinstance(existing_owner, str) and existing_owner: 

886 return existing_owner 

887 

888 current_data = record.data 

889 updated_data = dict(record_data) 

890 updated_data["owner_email"] = owner_email 

891 serialized = orjson.dumps(updated_data).decode() 

892 

893 update_query = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id) 

894 if current_data is None: 

895 update_query = update_query.filter(SessionRecord.data.is_(None)) 

896 else: 

897 update_query = update_query.filter(SessionRecord.data == current_data) 

898 

899 updated_rows = update_query.update({"data": serialized}, synchronize_session=False) 

900 if updated_rows == 1: 

901 db_session.commit() 

902 return owner_email 

903 

904 db_session.rollback() 

905 

906 return None 

907 finally: 

908 db_session.close() 

909 

910 claimed_owner = await asyncio.to_thread(_db_claim_owner) 

911 if claimed_owner: 

912 self._session_owners[session_id] = claimed_owner 

913 return claimed_owner 

914 except Exception as e: 

915 logger.error("Database error claiming owner for session %s: %s", session_id, e) 

916 return None 

917 

918 return None 

919 

920 async def get_session_owner(self, session_id: str) -> Optional[str]: 

921 """Get owner email for a session. 

922 

923 Args: 

924 session_id: Session identifier to resolve. 

925 

926 Returns: 

927 Owner email when present, otherwise ``None``. 

928 """ 

929 owner = self._session_owners.get(session_id) 

930 if owner: 

931 return owner 

932 

933 if self._backend == "redis": 

934 if not self._redis: 

935 return None 

936 try: 

937 owner_raw = await self._redis.get(self._session_owner_key(session_id)) 

938 if owner_raw is None: 

939 return None 

940 if isinstance(owner_raw, bytes): 

941 owner = owner_raw.decode() 

942 else: 

943 owner = str(owner_raw) 

944 if owner: 

945 self._session_owners[session_id] = owner 

946 return owner or None 

947 except Exception as e: 

948 logger.error(f"Redis error getting owner for session {session_id}: {e}") 

949 return None 

950 

951 if self._backend == "database": 

952 try: 

953 

954 def _db_get_owner() -> Optional[str]: 

955 db_session = next(get_db()) 

956 try: 

957 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

958 if not record or not record.data: 

959 return None 

960 try: 

961 data = orjson.loads(record.data) 

962 except Exception: 

963 return None 

964 if not isinstance(data, dict): 

965 return None 

966 owner_email = data.get("owner_email") 

967 return owner_email if isinstance(owner_email, str) and owner_email else None 

968 finally: 

969 db_session.close() 

970 

971 owner = await asyncio.to_thread(_db_get_owner) 

972 if owner: 

973 self._session_owners[session_id] = owner 

974 return owner 

975 except Exception as e: 

976 logger.error(f"Database error getting owner for session {session_id}: {e}") 

977 return None 

978 

979 return None 

980 

981 async def session_exists(self, session_id: str) -> Optional[bool]: 

982 """Return whether a session marker exists. 

983 

984 Args: 

985 session_id: Session identifier to resolve. 

986 

987 Returns: 

988 ``True`` when session exists, ``False`` when session does not exist, 

989 and ``None`` when existence cannot be verified due to backend errors. 

990 """ 

991 if self._backend == "none": 

992 return False 

993 

994 async with self._lock: 

995 if session_id in self._sessions: 

996 return True 

997 

998 if self._backend == "memory": 

999 return False 

1000 

1001 if self._backend == "redis": 

1002 if not self._redis: 

1003 return None 

1004 try: 

1005 return bool(await self._redis.exists(f"mcp:session:{session_id}")) 

1006 except Exception as e: 

1007 logger.error("Redis error checking existence for session %s: %s", session_id, e) 

1008 return None 

1009 

1010 if self._backend == "database": 

1011 try: 

1012 

1013 def _db_exists() -> bool: 

1014 """Check whether a session record exists in the database backend. 

1015 

1016 Returns: 

1017 ``True`` when a matching session record exists. 

1018 """ 

1019 db_session = next(get_db()) 

1020 try: 

1021 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1022 return record is not None 

1023 finally: 

1024 db_session.close() 

1025 

1026 return await asyncio.to_thread(_db_exists) 

1027 except Exception as e: 

1028 logger.error("Database error checking existence for session %s: %s", session_id, e) 

1029 return None 

1030 

1031 return False 

1032 

1033 async def get_session(self, session_id: str) -> Any: 

1034 """Get session transport by ID. 

1035 

1036 First checks the local cache for the transport object. If not found locally 

1037 but using a distributed backend, checks if the session exists on another 

1038 worker. 

1039 

1040 Args: 

1041 session_id: Session identifier to look up. 

1042 

1043 Returns: 

1044 SSETransport object if found locally, None if not found or exists 

1045 on another worker. 

1046 

1047 Examples: 

1048 >>> import asyncio 

1049 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1050 >>> 

1051 >>> class MockTransport: 

1052 ... pass 

1053 >>> 

1054 >>> reg = SessionRegistry() 

1055 >>> transport = MockTransport() 

1056 >>> asyncio.run(reg.add_session('test-456', transport)) 

1057 >>> 

1058 >>> # Found in local cache 

1059 >>> found = asyncio.run(reg.get_session('test-456')) 

1060 >>> found is transport 

1061 True 

1062 >>> 

1063 >>> # Not found 

1064 >>> asyncio.run(reg.get_session('nonexistent')) is None 

1065 True 

1066 """ 

1067 # Skip for none backend 

1068 if self._backend == "none": 

1069 return None 

1070 

1071 # First check local cache 

1072 async with self._lock: 

1073 transport = self._sessions.get(session_id) 

1074 if transport: 

1075 logger.info(f"Session {session_id} exists in local cache") 

1076 return transport 

1077 

1078 # If not in local cache, check if it exists in shared backend 

1079 if self._backend == "redis": 

1080 if not self._redis: 

1081 return None 

1082 try: 

1083 exists = await self._redis.exists(f"mcp:session:{session_id}") 

1084 session_exists = bool(exists) 

1085 if session_exists: 

1086 logger.info(f"Session {session_id} exists in Redis but not in local cache") 

1087 return None # We don't have the transport locally 

1088 except Exception as e: 

1089 logger.error(f"Redis error checking session {session_id}: {e}") 

1090 return None 

1091 

1092 elif self._backend == "database": 

1093 try: 

1094 

1095 def _db_check() -> bool: 

1096 """Check if a session exists in the database. 

1097 

1098 Queries the SessionRecord table to determine if a session with 

1099 the given session_id exists. This is used when the session is not 

1100 found in the local cache to check if it exists on another worker. 

1101 

1102 This inner function is designed to be run in a thread executor 

1103 to avoid blocking the async event loop during database queries. 

1104 

1105 Returns: 

1106 bool: True if the session exists in the database, False otherwise. 

1107 

1108 Examples: 

1109 >>> # This function is called internally by get_session() 

1110 >>> # Returns True if SessionRecord with session_id exists 

1111 >>> # Returns False if no matching record found 

1112 """ 

1113 db_session = next(get_db()) 

1114 try: 

1115 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1116 return record is not None 

1117 finally: 

1118 db_session.close() 

1119 

1120 exists = await asyncio.to_thread(_db_check) 

1121 if exists: 

1122 logger.info(f"Session {session_id} exists in database but not in local cache") 

1123 return None 

1124 except Exception as e: 

1125 logger.error(f"Database error checking session {session_id}: {e}") 

1126 return None 

1127 

1128 return None 

1129 

1130 async def remove_session(self, session_id: str) -> None: 

1131 """Remove a session from the registry. 

1132 

1133 Removes the session from both local cache and distributed backend. 

1134 If a transport is found locally, it will be disconnected before removal. 

1135 For distributed backends, notifies other workers about the removal. 

1136 

1137 Args: 

1138 session_id: Session identifier to remove. 

1139 

1140 Examples: 

1141 >>> import asyncio 

1142 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1143 >>> 

1144 >>> class MockTransport: 

1145 ... async def disconnect(self): 

1146 ... print(f"Transport disconnected") 

1147 ... async def is_connected(self): 

1148 ... return True 

1149 >>> 

1150 >>> reg = SessionRegistry() 

1151 >>> transport = MockTransport() 

1152 >>> asyncio.run(reg.add_session('remove-test', transport)) 

1153 >>> asyncio.run(reg.remove_session('remove-test')) 

1154 Transport disconnected 

1155 >>> 

1156 >>> # Session no longer exists 

1157 >>> asyncio.run(reg.get_session('remove-test')) is None 

1158 True 

1159 """ 

1160 # Skip for none backend 

1161 if self._backend == "none": 

1162 return 

1163 

1164 # Mark session as closing FIRST so respond loop can exit early 

1165 # This allows the loop to exit without waiting for cancellation to complete 

1166 self._closing_sessions.add(session_id) 

1167 

1168 try: 

1169 # CRITICAL: Cancel respond task before any cleanup 

1170 # This prevents orphaned tasks that cause CPU spin loops 

1171 await self._cancel_respond_task(session_id) 

1172 

1173 # Clean up local transport 

1174 transport = None 

1175 async with self._lock: 

1176 if session_id in self._sessions: 

1177 transport = self._sessions.pop(session_id) 

1178 self._session_owners.pop(session_id, None) 

1179 # Also clean up client capabilities 

1180 if session_id in self._client_capabilities: 

1181 self._client_capabilities.pop(session_id) 

1182 logger.debug(f"Removed capabilities for session {session_id}") 

1183 finally: 

1184 # Always remove from closing set 

1185 self._closing_sessions.discard(session_id) 

1186 

1187 # Disconnect transport if found 

1188 if transport: 

1189 try: 

1190 await transport.disconnect() 

1191 except Exception as e: 

1192 logger.error(f"Error disconnecting transport for session {session_id}: {e}") 

1193 

1194 # Remove from shared backend 

1195 if self._backend == "redis": 

1196 if not self._redis: 

1197 return 

1198 try: 

1199 await self._redis.delete(f"mcp:session:{session_id}") 

1200 await self._redis.delete(self._session_owner_key(session_id)) 

1201 # Notify other workers 

1202 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()})) 

1203 except Exception as e: 

1204 logger.error(f"Redis error removing session {session_id}: {e}") 

1205 

1206 elif self._backend == "database": 

1207 try: 

1208 

1209 def _db_remove() -> None: 

1210 """Delete session record from the database. 

1211 

1212 Removes the SessionRecord entry with the specified session_id 

1213 from the database. This is called when a session is being 

1214 terminated or has expired. 

1215 

1216 This inner function is designed to be run in a thread executor 

1217 to avoid blocking the async event loop during database operations. 

1218 

1219 Raises: 

1220 Exception: Any database error is re-raised after rollback. 

1221 This includes connection errors or constraint violations. 

1222 

1223 Examples: 

1224 >>> # This function is called internally by remove_session() 

1225 >>> # Deletes the SessionRecord where session_id matches 

1226 >>> # No error if session_id doesn't exist (idempotent) 

1227 """ 

1228 db_session = next(get_db()) 

1229 try: 

1230 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete() 

1231 db_session.commit() 

1232 except Exception as ex: 

1233 db_session.rollback() 

1234 raise ex 

1235 finally: 

1236 db_session.close() 

1237 

1238 await asyncio.to_thread(_db_remove) 

1239 except Exception as e: 

1240 logger.error(f"Database error removing session {session_id}: {e}") 

1241 

1242 logger.info(f"Removed session: {session_id}") 

1243 

1244 async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None: 

1245 """Broadcast a message to a session. 

1246 

1247 Sends a message to the specified session. The behavior depends on the backend: 

1248 

1249 - **memory**: Stores message temporarily for local delivery 

1250 - **redis**: Publishes message to Redis channel for the session 

1251 - **database**: Stores message in database for polling by worker with session 

1252 - **none**: No operation 

1253 

1254 This method is used for inter-process communication in distributed deployments. 

1255 

1256 Args: 

1257 session_id: Target session identifier. 

1258 message: Message to broadcast. Can be a dict, list, or any JSON-serializable object. 

1259 

1260 Examples: 

1261 >>> import asyncio 

1262 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1263 >>> 

1264 >>> reg = SessionRegistry(backend='memory') 

1265 >>> message = {'method': 'tools/list', 'id': 1} 

1266 >>> asyncio.run(reg.broadcast('session-789', message)) 

1267 >>> 

1268 >>> # Message stored for memory backend 

1269 >>> reg._session_message is not None 

1270 True 

1271 >>> reg._session_message['session_id'] 

1272 'session-789' 

1273 >>> orjson.loads(reg._session_message['message'])['message'] == message 

1274 True 

1275 """ 

1276 # Skip for none backend only 

1277 if self._backend == "none": 

1278 return 

1279 

1280 def _build_payload(msg: Any) -> str: 

1281 """Build a JSON payload for message broadcasting. 

1282 

1283 Args: 

1284 msg: Message to wrap in payload envelope. 

1285 

1286 Returns: 

1287 JSON-encoded string containing type, message, and timestamp. 

1288 """ 

1289 payload = {"type": "message", "message": msg, "timestamp": time.time()} 

1290 return orjson.dumps(payload).decode() 

1291 

1292 if self._backend == "memory": 

1293 payload_json = _build_payload(message) 

1294 self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": payload_json} 

1295 

1296 elif self._backend == "redis": 

1297 if not self._redis: 

1298 logger.warning(f"Redis client not initialized, cannot broadcast to {session_id}") 

1299 return 

1300 try: 

1301 broadcast_payload = { 

1302 "type": "message", 

1303 "message": message, # Keep as original type, not pre-encoded 

1304 "timestamp": time.time(), 

1305 } 

1306 # Single encode 

1307 payload_json = orjson.dumps(broadcast_payload) 

1308 await self._redis.publish(session_id, payload_json) # Single encode 

1309 except Exception as e: 

1310 logger.error(f"Redis error during broadcast: {e}") 

1311 elif self._backend == "database": 

1312 try: 

1313 msg_json = _build_payload(message) 

1314 

1315 def _db_add() -> None: 

1316 """Store message in the database for inter-process communication. 

1317 

1318 Creates a new SessionMessageRecord entry containing the session_id 

1319 and serialized message. This enables message passing between 

1320 different worker processes through the shared database. 

1321 

1322 This inner function is designed to be run in a thread executor 

1323 to avoid blocking the async event loop during database writes. 

1324 

1325 Raises: 

1326 Exception: Any database error is re-raised after rollback. 

1327 Common errors include database connection issues or 

1328 constraints violations. 

1329 

1330 Examples: 

1331 >>> # This function is called internally by broadcast() 

1332 >>> # Creates a record like: 

1333 >>> # SessionMessageRecord( 

1334 >>> # session_id='abc123', 

1335 >>> # message='{"method": "ping", "id": 1}', 

1336 >>> # created_at=now() 

1337 >>> # ) 

1338 """ 

1339 db_session = next(get_db()) 

1340 try: 

1341 message_record = SessionMessageRecord(session_id=session_id, message=msg_json) 

1342 db_session.add(message_record) 

1343 db_session.commit() 

1344 except Exception as ex: 

1345 db_session.rollback() 

1346 raise ex 

1347 finally: 

1348 db_session.close() 

1349 

1350 await asyncio.to_thread(_db_add) 

1351 except Exception as e: 

1352 logger.error(f"Database error during broadcast: {e}") 

1353 

1354 async def _register_session_mapping(self, session_id: str, message: Dict[str, Any], user_email: Optional[str] = None) -> None: 

1355 """Register session mapping for session affinity when tools are called. 

1356 

1357 This method is called on the worker that executes the request (the SSE session 

1358 owner) to pre-register the mapping between a downstream session ID and the 

1359 upstream MCP session pool key. This enables session affinity in multi-worker 

1360 deployments. 

1361 

1362 Only registers mappings for tools/call methods - list operations and other 

1363 methods don't need session affinity since they don't maintain state. 

1364 

1365 Args: 

1366 session_id: The downstream SSE session ID. 

1367 message: The MCP protocol message being broadcast. 

1368 user_email: Optional user email for session isolation. 

1369 """ 

1370 # Skip if session affinity is disabled 

1371 if not settings.mcpgateway_session_affinity_enabled: 

1372 return 

1373 

1374 # Only register for tools/call - other methods don't need session affinity 

1375 method = message.get("method") 

1376 if method != "tools/call": 

1377 return 

1378 

1379 # Extract tool name from params 

1380 params = message.get("params", {}) 

1381 tool_name = params.get("name") 

1382 if not tool_name: 

1383 return 

1384 

1385 try: 

1386 # Look up tool in cache to get gateway info 

1387 # First-Party 

1388 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

1389 

1390 tool_info = await tool_lookup_cache.get(tool_name) 

1391 if not tool_info: 

1392 logger.debug(f"Tool {tool_name} not found in cache, skipping session mapping registration") 

1393 return 

1394 

1395 # Extract gateway information 

1396 gateway = tool_info.get("gateway", {}) 

1397 gateway_url = gateway.get("url") 

1398 gateway_id = gateway.get("id") 

1399 transport = gateway.get("transport") 

1400 

1401 if not gateway_url or not gateway_id or not transport: 

1402 logger.debug(f"Incomplete gateway info for tool {tool_name}, skipping session mapping registration") 

1403 return 

1404 

1405 # Register the session mapping with the pool 

1406 # First-Party 

1407 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel 

1408 

1409 pool = get_mcp_session_pool() 

1410 await pool.register_session_mapping( 

1411 session_id, 

1412 gateway_url, 

1413 gateway_id, 

1414 transport, 

1415 user_email, 

1416 ) 

1417 

1418 logger.debug(f"Registered session mapping for session {session_id[:8]}... -> {gateway_url} (tool: {tool_name})") 

1419 

1420 except Exception as e: 

1421 # Don't fail the broadcast if session mapping registration fails 

1422 logger.warning(f"Failed to register session mapping for {session_id[:8]}...: {e}") 

1423 

1424 async def get_all_session_ids(self) -> list[str]: 

1425 """Return a snapshot list of all known local session IDs. 

1426 

1427 Returns: 

1428 list[str]: A snapshot list of currently known local session IDs. 

1429 """ 

1430 async with self._lock: 

1431 return list(self._sessions.keys()) 

1432 

1433 def get_session_sync(self, session_id: str) -> Any: 

1434 """Get session synchronously from local cache only. 

1435 

1436 This is a non-blocking method that only checks the local cache, 

1437 not the distributed backend. Use this when you need quick access 

1438 and know the session should be local. 

1439 

1440 Args: 

1441 session_id: Session identifier to look up. 

1442 

1443 Returns: 

1444 SSETransport object if found in local cache, None otherwise. 

1445 

1446 Examples: 

1447 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1448 >>> import asyncio 

1449 >>> 

1450 >>> class MockTransport: 

1451 ... pass 

1452 >>> 

1453 >>> reg = SessionRegistry() 

1454 >>> transport = MockTransport() 

1455 >>> asyncio.run(reg.add_session('sync-test', transport)) 

1456 >>> 

1457 >>> # Synchronous lookup 

1458 >>> found = reg.get_session_sync('sync-test') 

1459 >>> found is transport 

1460 True 

1461 >>> 

1462 >>> # Not found 

1463 >>> reg.get_session_sync('nonexistent') is None 

1464 True 

1465 """ 

1466 # Skip for none backend 

1467 if self._backend == "none": 

1468 return None 

1469 

1470 return self._sessions.get(session_id) 

1471 

1472 async def respond( 

1473 self, 

1474 server_id: Optional[str], 

1475 user: Dict[str, Any], 

1476 session_id: str, 

1477 ) -> None: 

1478 """Process and respond to broadcast messages for a session. 

1479 

1480 This method listens for messages directed to the specified session and 

1481 generates appropriate responses. The listening mechanism depends on the backend: 

1482 

1483 - **memory**: Checks the temporary message storage 

1484 - **redis**: Subscribes to Redis pubsub channel 

1485 - **database**: Polls database for new messages 

1486 

1487 When a message is received and the transport exists locally, it processes 

1488 the message and sends the response through the transport. 

1489 

1490 Args: 

1491 server_id: Optional server identifier for scoped operations. 

1492 user: User information including authentication token. 

1493 session_id: Session identifier to respond for. 

1494 

1495 Raises: 

1496 asyncio.CancelledError: When the respond task is cancelled (e.g., on session removal). 

1497 

1498 Examples: 

1499 >>> import asyncio 

1500 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1501 >>> 

1502 >>> # This method is typically called internally by the SSE handler 

1503 >>> reg = SessionRegistry() 

1504 >>> user = {'token': 'test-token'} 

1505 >>> # asyncio.run(reg.respond(None, user, 'session-id')) 

1506 """ 

1507 

1508 if self._backend == "none": 

1509 pass 

1510 

1511 elif self._backend == "memory": 

1512 transport = self.get_session_sync(session_id) 

1513 if transport and self._session_message: 

1514 message_json = self._session_message.get("message") 

1515 if message_json: 

1516 data = orjson.loads(message_json) 

1517 if isinstance(data, dict) and "message" in data: 

1518 message = data["message"] 

1519 else: 

1520 message = data 

1521 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user) 

1522 else: 

1523 logger.warning(f"Session message stored but message content is None for session {session_id}") 

1524 

1525 elif self._backend == "redis": 

1526 if not self._redis: 

1527 logger.warning(f"Redis client not initialized, cannot respond to {session_id}") 

1528 return 

1529 pubsub = self._redis.pubsub() 

1530 await pubsub.subscribe(session_id) 

1531 

1532 # Use timeout-based polling instead of infinite listen() to allow exit checks 

1533 # This is critical for allowing cancellation to work (Finding 2) 

1534 poll_timeout = 1.0 # Check every second if session still exists 

1535 

1536 try: 

1537 while True: 

1538 # Check if session still exists or is closing - exit early 

1539 if session_id not in self._sessions or session_id in self._closing_sessions: 

1540 logger.info(f"Session {session_id} removed or closing, exiting Redis respond loop") 

1541 break 

1542 

1543 # Use get_message with timeout instead of blocking listen() 

1544 try: 

1545 msg = await asyncio.wait_for( 

1546 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), timeout=poll_timeout + 0.5 # Slightly longer to account for Redis timeout 

1547 ) 

1548 except asyncio.TimeoutError: 

1549 # No message, loop back to check session existence 

1550 continue 

1551 

1552 if msg is None: 

1553 # CRITICAL: Sleep to prevent tight loop when get_message returns immediately 

1554 # This can happen in certain Redis states after disconnects 

1555 await asyncio.sleep(0.1) 

1556 continue 

1557 if msg["type"] != "message": 

1558 # Sleep on non-message types to prevent spin in edge cases 

1559 await asyncio.sleep(0.1) 

1560 continue 

1561 

1562 data = orjson.loads(msg["data"]) 

1563 message = data.get("message", {}) 

1564 transport = self.get_session_sync(session_id) 

1565 if transport: 

1566 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user) 

1567 except asyncio.CancelledError: 

1568 logger.info(f"PubSub listener for session {session_id} cancelled") 

1569 raise # Re-raise to properly complete cancellation 

1570 except Exception as e: 

1571 logger.error(f"PubSub listener error for session {session_id}: {e}") 

1572 finally: 

1573 # Pubsub cleanup first - use timeouts to prevent blocking 

1574 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout 

1575 try: 

1576 await asyncio.wait_for(pubsub.unsubscribe(session_id), timeout=cleanup_timeout) 

1577 except asyncio.TimeoutError: 

1578 logger.debug(f"Pubsub unsubscribe timed out for session {session_id}") 

1579 except Exception as e: 

1580 logger.debug(f"Error unsubscribing pubsub for session {session_id}: {e}") 

1581 try: 

1582 try: 

1583 await asyncio.wait_for(pubsub.aclose(), timeout=cleanup_timeout) 

1584 except AttributeError: 

1585 await asyncio.wait_for(pubsub.close(), timeout=cleanup_timeout) 

1586 except asyncio.TimeoutError: 

1587 logger.debug(f"Pubsub close timed out for session {session_id}") 

1588 except Exception as e: 

1589 logger.debug(f"Error closing pubsub for session {session_id}: {e}") 

1590 logger.info(f"Cleaned up pubsub for session {session_id}") 

1591 # Clean up task reference LAST (idempotent - may already be removed by _cancel_respond_task) 

1592 self._respond_tasks.pop(session_id, None) 

1593 

1594 elif self._backend == "database": 

1595 

1596 def _db_read_session_and_message( 

1597 session_id: str, 

1598 ) -> tuple[SessionRecord | None, SessionMessageRecord | None]: 

1599 """ 

1600 Check whether a session exists and retrieve its next pending message 

1601 in a single database query. 

1602 

1603 This function performs a LEFT OUTER JOIN between SessionRecord and 

1604 SessionMessageRecord to determine: 

1605 

1606 - Whether the session still exists 

1607 - Whether there is a pending message for the session (FIFO order) 

1608 

1609 It is used by the database-backed message polling loop to reduce 

1610 database load by collapsing multiple reads into a single query. 

1611 

1612 Messages are returned in FIFO order based on the message primary key. 

1613 

1614 This function is designed to be run in a thread executor to avoid 

1615 blocking the async event loop during database access. 

1616 

1617 Args: 

1618 session_id: The session identifier to look up. 

1619 

1620 Returns: 

1621 Tuple[SessionRecord | None, SessionMessageRecord | None]: 

1622 

1623 - (None, None) 

1624 The session does not exist. 

1625 

1626 - (SessionRecord, None) 

1627 The session exists but has no pending messages. 

1628 

1629 - (SessionRecord, SessionMessageRecord) 

1630 The session exists and has a pending message. 

1631 

1632 Raises: 

1633 Exception: Any database error is re-raised after rollback. 

1634 

1635 Examples: 

1636 >>> # This function is called internally by message_check_loop() 

1637 >>> # Session exists and has a pending message 

1638 >>> # Returns (SessionRecord, SessionMessageRecord) 

1639 

1640 >>> # Session exists but has no pending messages 

1641 >>> # Returns (SessionRecord, None) 

1642 

1643 >>> # Session has been removed 

1644 >>> # Returns (None, None) 

1645 """ 

1646 db_session = next(get_db()) 

1647 try: 

1648 result = ( 

1649 db_session.query(SessionRecord, SessionMessageRecord) 

1650 .outerjoin( 

1651 SessionMessageRecord, 

1652 SessionMessageRecord.session_id == SessionRecord.session_id, 

1653 ) 

1654 .filter(SessionRecord.session_id == session_id) 

1655 .order_by(SessionMessageRecord.id.asc()) 

1656 .first() 

1657 ) 

1658 if not result: 

1659 return None, None 

1660 session, message = result 

1661 return session, message 

1662 except Exception as ex: 

1663 db_session.rollback() 

1664 raise ex 

1665 finally: 

1666 db_session.close() 

1667 

1668 def _db_remove(session_id: str, message: str) -> None: 

1669 """Remove processed message from the database. 

1670 

1671 Deletes a specific message record after it has been successfully 

1672 processed and sent to the transport. This prevents duplicate 

1673 message delivery. 

1674 

1675 This inner function is designed to be run in a thread executor 

1676 to avoid blocking the async event loop during database deletes. 

1677 

1678 Args: 

1679 session_id: The session identifier the message belongs to. 

1680 message: The exact message content to remove (must match exactly). 

1681 

1682 Raises: 

1683 Exception: Any database error is re-raised after rollback. 

1684 

1685 Examples: 

1686 >>> # This function is called internally after message processing 

1687 >>> # Deletes the specific SessionMessageRecord entry 

1688 >>> # Log: "Removed message from mcp_messages table" 

1689 """ 

1690 db_session = next(get_db()) 

1691 try: 

1692 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete() 

1693 db_session.commit() 

1694 logger.info("Removed message from mcp_messages table") 

1695 except Exception as ex: 

1696 db_session.rollback() 

1697 raise ex 

1698 finally: 

1699 db_session.close() 

1700 

1701 async def message_check_loop(session_id: str) -> None: 

1702 """ 

1703 Background task that polls the database for messages belonging to a session 

1704 using adaptive polling with exponential backoff. 

1705 

1706 The loop continues until the session is removed from the database. 

1707 

1708 Behavior: 

1709 - Starts with a fast polling interval for low-latency message delivery. 

1710 - When no message is found, the polling interval increases exponentially 

1711 (up to a configured maximum) to reduce database load. 

1712 - When a message is received, the polling interval is immediately reset 

1713 to the fast interval. 

1714 - The loop exits as soon as the session no longer exists. 

1715 

1716 Polling rules: 

1717 - Message found → process message, reset polling interval. 

1718 - No message → increase polling interval (backoff). 

1719 - Session gone → stop polling immediately. 

1720 

1721 Args: 

1722 session_id (str): Unique identifier of the session to monitor. 

1723 

1724 Raises: 

1725 asyncio.CancelledError: When the polling loop is cancelled. 

1726 

1727 Examples 

1728 -------- 

1729 Adaptive backoff when no messages are present: 

1730 

1731 >>> poll_interval = 0.1 

1732 >>> backoff_factor = 1.5 

1733 >>> max_interval = 5.0 

1734 >>> poll_interval = min(poll_interval * backoff_factor, max_interval) 

1735 >>> poll_interval 

1736 0.15000000000000002 

1737 

1738 Backoff continues until the maximum interval is reached: 

1739 

1740 >>> poll_interval = 4.0 

1741 >>> poll_interval = min(poll_interval * 1.5, 5.0) 

1742 >>> poll_interval 

1743 5.0 

1744 

1745 Polling interval resets immediately when a message arrives: 

1746 

1747 >>> poll_interval = 2.0 

1748 >>> poll_interval = 0.1 

1749 >>> poll_interval 

1750 0.1 

1751 

1752 Session termination stops polling: 

1753 

1754 >>> session_exists = False 

1755 >>> if not session_exists: 

1756 ... "polling stopped" 

1757 'polling stopped' 

1758 """ 

1759 

1760 poll_interval = settings.poll_interval # start fast 

1761 max_interval = settings.max_interval # cap at configured maximum 

1762 backoff_factor = settings.backoff_factor 

1763 try: 

1764 while True: 

1765 # Check if session is closing before querying DB 

1766 if session_id in self._closing_sessions: 

1767 logger.debug("Session %s closing, stopping poll loop early", session_id) 

1768 break 

1769 

1770 session, record = await asyncio.to_thread(_db_read_session_and_message, session_id) 

1771 

1772 # session gone → stop polling 

1773 if not session: 

1774 logger.debug("Session %s no longer exists, stopping poll loop", session_id) 

1775 break 

1776 

1777 if record: 

1778 poll_interval = settings.poll_interval # reset on activity 

1779 

1780 data = orjson.loads(record.message) 

1781 if isinstance(data, dict) and "message" in data: 

1782 message = data["message"] 

1783 else: 

1784 message = data 

1785 

1786 transport = self.get_session_sync(session_id) 

1787 if transport: 

1788 logger.info("Ready to respond") 

1789 await self.generate_response( 

1790 message=message, 

1791 transport=transport, 

1792 server_id=server_id, 

1793 user=user, 

1794 ) 

1795 

1796 await asyncio.to_thread(_db_remove, session_id, record.message) 

1797 else: 

1798 # no message → backoff 

1799 # update polling interval with backoff factor 

1800 poll_interval = min(poll_interval * backoff_factor, max_interval) 

1801 

1802 await asyncio.sleep(poll_interval) 

1803 except asyncio.CancelledError: 

1804 logger.info(f"Message check loop cancelled for session {session_id}") 

1805 raise # Re-raise to properly complete cancellation 

1806 except Exception as e: 

1807 logger.error(f"Message check loop error for session {session_id}: {e}") 

1808 

1809 # CRITICAL: Await instead of fire-and-forget 

1810 # This ensures CancelledError propagates from outer respond() task to inner loop 

1811 # The outer task (registered from main.py) now runs until message_check_loop exits 

1812 try: 

1813 await message_check_loop(session_id) 

1814 except asyncio.CancelledError: 

1815 logger.info(f"Database respond cancelled for session {session_id}") 

1816 raise 

1817 finally: 

1818 # Clean up task reference on ANY exit (normal, cancelled, or error) 

1819 # Prevents stale done tasks from accumulating in _respond_tasks 

1820 self._respond_tasks.pop(session_id, None) 

1821 

1822 async def _refresh_redis_sessions(self) -> None: 

1823 """Refresh TTLs for Redis sessions and clean up disconnected sessions. 

1824 

1825 This internal method is used by the Redis backend to maintain session state. 

1826 It checks all local sessions, refreshes TTLs for connected sessions, and 

1827 removes disconnected ones. 

1828 """ 

1829 if not self._redis: 

1830 return 

1831 try: 

1832 # Check all local sessions 

1833 local_transports = {} 

1834 async with self._lock: 

1835 local_transports = self._sessions.copy() 

1836 

1837 for session_id, transport in local_transports.items(): 

1838 try: 

1839 if await transport.is_connected(): 

1840 # Refresh TTL in Redis 

1841 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl) 

1842 else: 

1843 # Remove disconnected session 

1844 await self.remove_session(session_id) 

1845 except Exception as e: 

1846 logger.error(f"Error refreshing session {session_id}: {e}") 

1847 

1848 except Exception as e: 

1849 logger.error(f"Error in Redis session refresh: {e}") 

1850 

1851 async def _db_cleanup_task(self) -> None: 

1852 """Background task to clean up expired database sessions. 

1853 

1854 Runs periodically (every 5 minutes) to remove expired sessions from the 

1855 database and refresh timestamps for active sessions. This prevents the 

1856 database from accumulating stale session records. 

1857 

1858 The task also verifies that local sessions still exist in the database 

1859 and removes them locally if they've been deleted elsewhere. 

1860 

1861 Raises: 

1862 asyncio.CancelledError: If the task is cancelled during shutdown. 

1863 """ 

1864 logger.info("Starting database cleanup task") 

1865 while True: 

1866 try: 

1867 # Clean up expired sessions every 5 minutes 

1868 def _db_cleanup() -> int: 

1869 """Remove expired sessions from the database. 

1870 

1871 Deletes all SessionRecord entries that haven't been accessed 

1872 within the session TTL period. Uses database-specific date 

1873 arithmetic to calculate expiry time. 

1874 

1875 This inner function is designed to be run in a thread executor 

1876 to avoid blocking the async event loop during bulk deletes. 

1877 

1878 Returns: 

1879 int: Number of expired session records deleted. 

1880 

1881 Raises: 

1882 Exception: Any database error is re-raised after rollback. 

1883 

1884 Examples: 

1885 >>> # This function is called periodically by _db_cleanup_task() 

1886 >>> # Deletes sessions older than session_ttl seconds 

1887 >>> # Returns count of deleted records for logging 

1888 >>> # Log: "Cleaned up 5 expired database sessions" 

1889 """ 

1890 db_session = next(get_db()) 

1891 try: 

1892 # Delete sessions that haven't been accessed for TTL seconds 

1893 # Use Python datetime for database-agnostic expiry calculation 

1894 expiry_time = datetime.now(timezone.utc) - timedelta(seconds=self._session_ttl) 

1895 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete() 

1896 db_session.commit() 

1897 return result 

1898 except Exception as ex: 

1899 db_session.rollback() 

1900 raise ex 

1901 finally: 

1902 db_session.close() 

1903 

1904 deleted = await asyncio.to_thread(_db_cleanup) 

1905 if deleted > 0: 

1906 logger.info(f"Cleaned up {deleted} expired database sessions") 

1907 

1908 # Check local sessions against database 

1909 await self._cleanup_database_sessions() 

1910 

1911 await asyncio.sleep(300) # Run every 5 minutes 

1912 

1913 except asyncio.CancelledError: 

1914 logger.info("Database cleanup task cancelled") 

1915 raise 

1916 except Exception as e: 

1917 logger.error(f"Error in database cleanup task: {e}") 

1918 await asyncio.sleep(600) # Sleep longer on error 

1919 

1920 def _refresh_session_db(self, session_id: str) -> bool: 

1921 """Update session's last accessed timestamp in the database. 

1922 

1923 Refreshes the last_accessed field for an active session to 

1924 prevent it from being cleaned up as expired. This is called 

1925 periodically for all local sessions with active transports. 

1926 

1927 Args: 

1928 session_id: The session identifier to refresh. 

1929 

1930 Returns: 

1931 bool: True if the session was found and updated, False if not found. 

1932 

1933 Raises: 

1934 Exception: Any database error is re-raised after rollback. 

1935 """ 

1936 db_session = next(get_db()) 

1937 try: 

1938 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1939 if session: 

1940 session.last_accessed = func.now() # pylint: disable=not-callable 

1941 db_session.commit() 

1942 return True 

1943 return False 

1944 except Exception as ex: 

1945 db_session.rollback() 

1946 raise ex 

1947 finally: 

1948 db_session.close() 

1949 

1950 async def _cleanup_database_sessions(self, max_concurrent: int = 20) -> None: 

1951 """Parallelize session cleanup with bounded concurrency. 

1952 

1953 Checks connection status first (fast), then refreshes connected sessions 

1954 in parallel using asyncio.gather() with a semaphore to limit concurrent 

1955 DB operations and prevent resource exhaustion. 

1956 

1957 Args: 

1958 max_concurrent: Maximum number of concurrent DB refresh operations. 

1959 Defaults to 20 to balance parallelism with resource usage. 

1960 """ 

1961 async with self._lock: 

1962 local_transports = self._sessions.copy() 

1963 

1964 # Check connections first (fast) 

1965 connected: list[str] = [] 

1966 for session_id, transport in local_transports.items(): 

1967 try: 

1968 if not await transport.is_connected(): 

1969 await self.remove_session(session_id) 

1970 else: 

1971 connected.append(session_id) 

1972 except Exception as e: 

1973 # Only log error, don't remove session on transient errors 

1974 logger.error(f"Error checking connection for session {session_id}: {e}") 

1975 

1976 # Parallel refresh of connected sessions with bounded concurrency 

1977 if connected: 

1978 semaphore = asyncio.Semaphore(max_concurrent) 

1979 

1980 async def bounded_refresh(session_id: str) -> bool: 

1981 """Refresh session with semaphore-bounded concurrency. 

1982 

1983 Args: 

1984 session_id: The session ID to refresh. 

1985 

1986 Returns: 

1987 True if refresh succeeded, False otherwise. 

1988 """ 

1989 async with semaphore: 

1990 return await asyncio.to_thread(self._refresh_session_db, session_id) 

1991 

1992 refresh_tasks = [bounded_refresh(session_id) for session_id in connected] 

1993 results = await asyncio.gather(*refresh_tasks, return_exceptions=True) 

1994 

1995 for session_id, result in zip(connected, results): 

1996 try: 

1997 if isinstance(result, Exception): 

1998 # Only log error, don't remove session on transient DB errors 

1999 logger.error(f"Error refreshing session {session_id}: {result}") 

2000 elif not result: 

2001 # Session no longer in database, remove locally 

2002 await self.remove_session(session_id) 

2003 except Exception as e: 

2004 logger.error(f"Error processing refresh result for session {session_id}: {e}") 

2005 

2006 async def _memory_cleanup_task(self) -> None: 

2007 """Background task to clean up disconnected sessions in memory backend. 

2008 

2009 Runs periodically (every minute) to check all local sessions and remove 

2010 those that are no longer connected. This prevents memory leaks from 

2011 accumulating disconnected transport objects. 

2012 

2013 Raises: 

2014 asyncio.CancelledError: If the task is cancelled during shutdown. 

2015 """ 

2016 logger.info("Starting memory cleanup task") 

2017 while True: 

2018 try: 

2019 # Check all local sessions 

2020 local_transports = {} 

2021 async with self._lock: 

2022 local_transports = self._sessions.copy() 

2023 

2024 for session_id, transport in local_transports.items(): 

2025 try: 

2026 if not await transport.is_connected(): 

2027 await self.remove_session(session_id) 

2028 except Exception as e: 

2029 logger.error(f"Error checking session {session_id}: {e}") 

2030 await self.remove_session(session_id) 

2031 

2032 await asyncio.sleep(60) # Run every minute 

2033 

2034 except asyncio.CancelledError: 

2035 logger.info("Memory cleanup task cancelled") 

2036 raise 

2037 except Exception as e: 

2038 logger.error(f"Error in memory cleanup task: {e}") 

2039 await asyncio.sleep(300) # Sleep longer on error 

2040 

2041 def _get_oauth_experimental_config(self, server_id: str) -> Optional[Dict[str, Dict[str, Any]]]: 

2042 """Query OAuth configuration for a server (synchronous, run in threadpool). 

2043 

2044 This method queries the database for OAuth configuration and returns 

2045 RFC 9728-safe fields for advertising in MCP capabilities. 

2046 

2047 Args: 

2048 server_id: The server ID to query OAuth configuration for. 

2049 

2050 Returns: 

2051 Dict with 'oauth' key containing safe OAuth config, or None if not configured. 

2052 """ 

2053 # First-Party 

2054 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel 

2055 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel 

2056 

2057 db = SessionLocal() 

2058 try: 

2059 server = db.get(DbServer, server_id) 

2060 if server and getattr(server, "oauth_enabled", False) and getattr(server, "oauth_config", None): 

2061 # Filter oauth_config to RFC 9728-safe fields only (never expose secrets) 

2062 oauth_config = server.oauth_config 

2063 safe_oauth: Dict[str, Any] = {} 

2064 

2065 # Extract authorization servers 

2066 if oauth_config.get("authorization_servers"): 

2067 safe_oauth["authorization_servers"] = oauth_config["authorization_servers"] 

2068 elif oauth_config.get("authorization_server"): 

2069 safe_oauth["authorization_servers"] = [oauth_config["authorization_server"]] 

2070 

2071 # Extract scopes 

2072 scopes = oauth_config.get("scopes_supported") or oauth_config.get("scopes") 

2073 if scopes: 

2074 safe_oauth["scopes_supported"] = scopes 

2075 

2076 # Add bearer methods 

2077 safe_oauth["bearer_methods_supported"] = oauth_config.get("bearer_methods_supported", ["header"]) 

2078 

2079 if safe_oauth.get("authorization_servers"): 

2080 logger.debug(f"Advertising OAuth capability for server {server_id}") 

2081 return {"oauth": safe_oauth} 

2082 return None 

2083 finally: 

2084 db.close() 

2085 

2086 # Handle initialize logic 

2087 async def handle_initialize_logic(self, body: Dict[str, Any], session_id: Optional[str] = None, server_id: Optional[str] = None) -> InitializeResult: 

2088 """Process MCP protocol initialization request. 

2089 

2090 Validates the protocol version and returns server capabilities and information. 

2091 This method implements the MCP (Model Context Protocol) initialization handshake. 

2092 

2093 Args: 

2094 body: Request body containing protocol_version and optional client_info. 

2095 Expected keys: 'protocol_version' or 'protocolVersion', 'capabilities'. 

2096 session_id: Optional session ID to associate client capabilities with. 

2097 server_id: Optional server ID to query OAuth configuration for RFC 9728 support. 

2098 

2099 Returns: 

2100 InitializeResult containing protocol version, server capabilities, and server info. 

2101 

2102 Raises: 

2103 HTTPException: If protocol_version is missing (400 Bad Request with MCP error code -32002). 

2104 

2105 Examples: 

2106 >>> import asyncio 

2107 >>> from mcpgateway.cache.session_registry import SessionRegistry 

2108 >>> 

2109 >>> reg = SessionRegistry() 

2110 >>> body = {'protocol_version': '2025-06-18'} 

2111 >>> result = asyncio.run(reg.handle_initialize_logic(body)) 

2112 >>> result.protocol_version 

2113 '2025-06-18' 

2114 >>> result.server_info.name 

2115 'ContextForge' 

2116 >>> 

2117 >>> # Missing protocol version 

2118 >>> try: 

2119 ... asyncio.run(reg.handle_initialize_logic({})) 

2120 ... except HTTPException as e: 

2121 ... e.status_code 

2122 400 

2123 """ 

2124 # First-Party 

2125 from mcpgateway.observability import create_span # pylint: disable=import-outside-toplevel 

2126 

2127 protocol_version = body.get("protocol_version") or body.get("protocolVersion") 

2128 client_capabilities = body.get("capabilities", {}) 

2129 span_attributes: Dict[str, Any] = { 

2130 "mcp.protocol_version": protocol_version, 

2131 "mcp.session_id": session_id, 

2132 "server.id": server_id, 

2133 } 

2134 

2135 with create_span("mcp.initialize", span_attributes): 

2136 # body.get("client_info") or body.get("clientInfo", {}) 

2137 

2138 if not protocol_version: 

2139 raise HTTPException( 

2140 status_code=status.HTTP_400_BAD_REQUEST, 

2141 detail="Missing protocol version", 

2142 headers={"MCP-Error-Code": "-32002"}, 

2143 ) 

2144 

2145 if protocol_version != settings.protocol_version: 

2146 logger.warning(f"Using non default protocol version: {protocol_version}") 

2147 

2148 # Store client capabilities if session_id provided 

2149 if session_id and client_capabilities: 

2150 await self.store_client_capabilities(session_id, client_capabilities) 

2151 logger.debug(f"Stored capabilities for session {session_id}: {client_capabilities}") 

2152 

2153 # Build experimental capabilities (including OAuth if configured) 

2154 experimental: Optional[Dict[str, Dict[str, Any]]] = None 

2155 

2156 # Query OAuth configuration if server_id is provided 

2157 if server_id: 

2158 try: 

2159 # Run synchronous DB query in threadpool to avoid blocking the event loop 

2160 experimental = await asyncio.to_thread(self._get_oauth_experimental_config, server_id) 

2161 except Exception as e: 

2162 logger.warning(f"Failed to query OAuth config for server {server_id}: {e}") 

2163 

2164 return InitializeResult( 

2165 protocolVersion=protocol_version, 

2166 capabilities=ServerCapabilities( 

2167 prompts={"listChanged": True}, 

2168 resources={"subscribe": True, "listChanged": True}, 

2169 tools={"listChanged": True}, 

2170 logging={}, 

2171 completions={}, # Advertise completions capability per MCP spec 

2172 experimental=experimental, # OAuth capability when configured 

2173 ), 

2174 serverInfo=Implementation(name=settings.app_name, version=__version__), 

2175 instructions=("ContextForge providing federated tools, resources and prompts. Use /admin interface for configuration."), 

2176 ) 

2177 

2178 async def store_client_capabilities(self, session_id: str, capabilities: Dict[str, Any]) -> None: 

2179 """Store client capabilities for a session. 

2180 

2181 Args: 

2182 session_id: The session ID 

2183 capabilities: Client capabilities dictionary from initialize request 

2184 """ 

2185 async with self._lock: 

2186 self._client_capabilities[session_id] = capabilities 

2187 logger.debug(f"Stored capabilities for session {session_id}") 

2188 

2189 async def get_client_capabilities(self, session_id: str) -> Optional[Dict[str, Any]]: 

2190 """Get client capabilities for a session. 

2191 

2192 Args: 

2193 session_id: The session ID 

2194 

2195 Returns: 

2196 Client capabilities dictionary, or None if not found 

2197 """ 

2198 async with self._lock: 

2199 return self._client_capabilities.get(session_id) 

2200 

2201 async def has_elicitation_capability(self, session_id: str) -> bool: 

2202 """Check if a session has elicitation capability. 

2203 

2204 Args: 

2205 session_id: The session ID 

2206 

2207 Returns: 

2208 True if session supports elicitation, False otherwise 

2209 """ 

2210 capabilities = await self.get_client_capabilities(session_id) 

2211 if not capabilities: 

2212 return False 

2213 # Check if elicitation capability exists in client capabilities 

2214 return bool(capabilities.get("elicitation")) 

2215 

2216 async def get_elicitation_capable_sessions(self) -> list[str]: 

2217 """Get list of session IDs that support elicitation. 

2218 

2219 Returns: 

2220 List of session IDs with elicitation capability 

2221 """ 

2222 async with self._lock: 

2223 capable_sessions = [] 

2224 for session_id, capabilities in self._client_capabilities.items(): 

2225 if capabilities.get("elicitation"): 

2226 # Verify session still exists 

2227 if session_id in self._sessions: 

2228 capable_sessions.append(session_id) 

2229 return capable_sessions 

2230 

2231 async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any]) -> None: 

2232 """Generate and send response for incoming MCP protocol message. 

2233 

2234 Processes MCP protocol messages and generates appropriate responses based on 

2235 the method. Supports various MCP methods including initialization, tool/resource/prompt 

2236 listing, tool invocation, and ping. 

2237 

2238 Uses loopback (127.0.0.1) for the internal RPC call to avoid issues when 

2239 the gateway is behind a reverse proxy or service mesh where the client-facing 

2240 URL is not reachable from the server itself. 

2241 

2242 Args: 

2243 message: Incoming MCP message as JSON. Must contain 'method' and 'id' fields. 

2244 transport: SSE transport to send responses through. 

2245 server_id: Optional server ID for scoped operations. 

2246 user: User information containing authentication token. 

2247 

2248 Examples: 

2249 >>> import asyncio 

2250 >>> from mcpgateway.cache.session_registry import SessionRegistry 

2251 >>> 

2252 >>> class MockTransport: 

2253 ... async def send_message(self, msg): 

2254 ... print(f"Response: {msg['method'] if 'method' in msg else msg.get('result', {})}") 

2255 >>> 

2256 >>> reg = SessionRegistry() 

2257 >>> transport = MockTransport() 

2258 >>> message = {"method": "ping", "id": 1} 

2259 >>> user = {"token": "test-token"} 

2260 >>> # asyncio.run(reg.generate_response(message, transport, None, user)) 

2261 >>> # Response: {} 

2262 """ 

2263 result = {} 

2264 

2265 if "method" in message and "id" in message: 

2266 method = message["method"] 

2267 params = message.get("params", {}) 

2268 params["server_id"] = server_id 

2269 req_id = message["id"] 

2270 

2271 rpc_input = { 

2272 "jsonrpc": "2.0", 

2273 "method": method, 

2274 "params": params, 

2275 "id": req_id, 

2276 } 

2277 # Get the token from the current authentication context 

2278 # The user object should contain auth_token, token_teams, and is_admin from the SSE endpoint 

2279 token = None 

2280 is_admin = user.get("is_admin", False) # Preserve admin status from SSE endpoint 

2281 

2282 try: 

2283 if hasattr(user, "get") and user.get("auth_token"): 

2284 token = user["auth_token"] 

2285 else: 

2286 # Fallback: create lightweight session token (teams resolved server-side by downstream /rpc) 

2287 logger.warning("No auth token available for SSE RPC call - creating fallback session token") 

2288 now = datetime.now(timezone.utc) 

2289 payload = { 

2290 "sub": user.get("email", "system"), 

2291 "iss": settings.jwt_issuer, 

2292 "aud": settings.jwt_audience, 

2293 "iat": int(now.timestamp()), 

2294 "jti": str(uuid.uuid4()), 

2295 "token_use": "session", # nosec B105 - token type marker, not a password 

2296 "user": { 

2297 "email": user.get("email", "system"), 

2298 "full_name": user.get("full_name", "System"), 

2299 "is_admin": is_admin, # Preserve admin status for cookie-authenticated admins 

2300 "auth_provider": "internal", 

2301 }, 

2302 } 

2303 # Generate token using centralized token creation 

2304 token = await create_jwt_token(payload) 

2305 

2306 # Pass downstream session id to /rpc for session affinity. 

2307 # This is gateway-internal only; the pool strips it before contacting upstream MCP servers. 

2308 if settings.mcpgateway_session_affinity_enabled: 

2309 await self._register_session_mapping(transport.session_id, message, user.get("email") if hasattr(user, "get") else None) 

2310 

2311 headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} 

2312 if settings.mcpgateway_session_affinity_enabled: 

2313 headers["x-mcp-session-id"] = transport.session_id 

2314 # Forward passthrough headers captured at SSE connection time (see #3640). 

2315 # This ensures X-Upstream-Authorization and other client passthrough headers 

2316 # reach the /rpc endpoint, which then forwards them to upstream MCP servers. 

2317 # Defense-in-depth: filter via filter_loopback_skip_headers() so passthrough 

2318 # can never override the gateway's internal JWT, content-type, or session/routing headers. 

2319 # First-Party 

2320 from mcpgateway.utils.passthrough_headers import filter_loopback_skip_headers # pylint: disable=import-outside-toplevel 

2321 

2322 passthrough = user.get("_passthrough_headers") or {} 

2323 if passthrough and isinstance(passthrough, dict): 

2324 headers.update(filter_loopback_skip_headers(passthrough)) 

2325 # Use loopback for internal RPC call (consistent with other self-call sites 

2326 # in mcp_session_pool.py and streamablehttp_transport.py). This avoids 

2327 # failures when the client-facing URL is not reachable from the server 

2328 # (e.g., behind a reverse proxy or service mesh). See #3049. 

2329 rpc_url = f"{internal_loopback_base_url()}/rpc" 

2330 

2331 logger.info(f"SSE RPC: Making call to {rpc_url} with method={method}, params={params}") 

2332 

2333 async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": internal_loopback_verify()}) as client: 

2334 logger.info(f"SSE RPC: Sending request to {rpc_url}") 

2335 rpc_response = await client.post( 

2336 url=rpc_url, 

2337 json=rpc_input, 

2338 headers=headers, 

2339 ) 

2340 logger.info(f"SSE RPC: Got response status {rpc_response.status_code}") 

2341 result = rpc_response.json() 

2342 logger.info(f"SSE RPC: Response content: {result}") 

2343 result = result.get("result", {}) 

2344 

2345 response = {"jsonrpc": "2.0", "result": result, "id": req_id} 

2346 except JSONRPCError as e: 

2347 logger.error(f"SSE RPC: JSON-RPC error: {e}") 

2348 result = e.to_dict() 

2349 response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id} 

2350 except Exception as e: 

2351 logger.error(f"SSE RPC: Exception during RPC call: {type(e).__name__}: {e}") 

2352 logger.error(f"SSE RPC: Traceback: {traceback.format_exc()}") 

2353 result = {"code": -32000, "message": "Internal error", "data": str(e)} 

2354 response = {"jsonrpc": "2.0", "error": result, "id": req_id} 

2355 

2356 logging.debug(f"Sending sse message:{response}") 

2357 await transport.send_message(response) 

2358 

2359 if message["method"] == "initialize": 

2360 await transport.send_message( 

2361 { 

2362 "jsonrpc": "2.0", 

2363 "method": "notifications/initialized", 

2364 "params": {}, 

2365 } 

2366 ) 

2367 notifications = [ 

2368 "tools/list_changed", 

2369 "resources/list_changed", 

2370 "prompts/list_changed", 

2371 ] 

2372 for notification in notifications: 

2373 await transport.send_message( 

2374 { 

2375 "jsonrpc": "2.0", 

2376 "method": f"notifications/{notification}", 

2377 "params": {}, 

2378 } 

2379 )