Coverage for mcpgateway / cache / session_registry.py: 100%

966 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/session_registry.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Session Registry with optional distributed state. 

8This module provides a registry for SSE sessions with support for distributed deployment 

9using Redis or SQLAlchemy as optional backends for shared state between workers. 

10 

11The SessionRegistry class manages server-sent event (SSE) sessions across multiple 

12worker processes, enabling horizontal scaling of MCP gateway deployments. It supports 

13three backend modes: 

14 

15- **memory**: In-memory storage for single-process deployments (default) 

16- **redis**: Redis-backed shared storage for multi-worker deployments 

17- **database**: SQLAlchemy-backed shared storage using any supported database 

18 

19In distributed mode (redis/database), session existence is tracked in the shared 

20backend while transport objects remain local to each worker process. This allows 

21workers to know about sessions on other workers and route messages appropriately. 

22 

23Examples: 

24 Basic usage with memory backend: 

25 

26 >>> from mcpgateway.cache.session_registry import SessionRegistry 

27 >>> class DummyTransport: 

28 ... async def disconnect(self): 

29 ... pass 

30 ... async def is_connected(self): 

31 ... return True 

32 >>> import asyncio 

33 >>> reg = SessionRegistry(backend='memory') 

34 >>> transport = DummyTransport() 

35 >>> asyncio.run(reg.add_session('sid123', transport)) 

36 >>> found = asyncio.run(reg.get_session('sid123')) 

37 >>> isinstance(found, DummyTransport) 

38 True 

39 >>> asyncio.run(reg.remove_session('sid123')) 

40 >>> asyncio.run(reg.get_session('sid123')) is None 

41 True 

42 

43 Broadcasting messages: 

44 

45 >>> reg = SessionRegistry(backend='memory') 

46 >>> asyncio.run(reg.broadcast('sid123', {'method': 'ping', 'id': 1})) 

47 >>> reg._session_message is not None 

48 True 

49""" 

50 

51# Standard 

52import asyncio 

53from asyncio import Task 

54from datetime import datetime, timedelta, timezone 

55import logging 

56import time 

57import traceback 

58from typing import Any, Dict, Optional 

59import uuid 

60 

61# Third-Party 

62from fastapi import HTTPException, status 

63import orjson 

64 

65# First-Party 

66from mcpgateway import __version__ 

67from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities 

68from mcpgateway.config import settings 

69from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord 

70from mcpgateway.services import PromptService, ResourceService, ToolService 

71from mcpgateway.services.logging_service import LoggingService 

72from mcpgateway.transports import SSETransport 

73from mcpgateway.utils.create_jwt_token import create_jwt_token 

74from mcpgateway.utils.redis_client import get_redis_client 

75from mcpgateway.utils.retry_manager import ResilientHttpClient 

76from mcpgateway.validation.jsonrpc import JSONRPCError 

77 

78# Initialize logging service first 

79logging_service: LoggingService = LoggingService() 

80logger = logging_service.get_logger(__name__) 

81 

82tool_service: ToolService = ToolService() 

83resource_service: ResourceService = ResourceService() 

84prompt_service: PromptService = PromptService() 

85 

86try: 

87 # Third-Party 

88 from redis.asyncio import Redis 

89 

90 REDIS_AVAILABLE = True 

91except ImportError: 

92 REDIS_AVAILABLE = False 

93 

94try: 

95 # Third-Party 

96 from sqlalchemy import func 

97 

98 SQLALCHEMY_AVAILABLE = True 

99except ImportError: 

100 SQLALCHEMY_AVAILABLE = False 

101 

102 

103class SessionBackend: 

104 """Base class for session registry backend configuration. 

105 

106 This class handles the initialization and configuration of different backend 

107 types for session storage. It validates backend requirements and sets up 

108 necessary connections for Redis or database backends. 

109 

110 Attributes: 

111 _backend: The backend type ('memory', 'redis', 'database', or 'none') 

112 _session_ttl: Time-to-live for sessions in seconds 

113 _message_ttl: Time-to-live for messages in seconds 

114 _redis: Redis connection instance (redis backend only) 

115 _pubsub: Redis pubsub instance (redis backend only) 

116 _session_message: Temporary message storage (memory backend only) 

117 

118 Examples: 

119 >>> backend = SessionBackend(backend='memory') 

120 >>> backend._backend 

121 'memory' 

122 >>> backend._session_ttl 

123 3600 

124 

125 >>> try: 

126 ... backend = SessionBackend(backend='redis') 

127 ... except ValueError as e: 

128 ... str(e) 

129 'Redis backend requires redis_url' 

130 """ 

131 

132 def __init__( 

133 self, 

134 backend: str = "memory", 

135 redis_url: Optional[str] = None, 

136 database_url: Optional[str] = None, 

137 session_ttl: int = 3600, # 1 hour 

138 message_ttl: int = 600, # 10 min 

139 ): 

140 """Initialize session backend configuration. 

141 

142 Args: 

143 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'. 

144 - 'memory': In-memory storage, suitable for single-process deployments 

145 - 'redis': Redis-backed storage for multi-worker deployments 

146 - 'database': SQLAlchemy-backed storage for multi-worker deployments 

147 - 'none': No session tracking (dummy registry) 

148 redis_url: Redis connection URL. Required when backend='redis'. 

149 Format: 'redis://[:password]@host:port/db' 

150 database_url: Database connection URL. Required when backend='database'. 

151 Format depends on database type (e.g., 'postgresql://user:pass@host/db') 

152 session_ttl: Session time-to-live in seconds. Sessions are automatically 

153 cleaned up after this duration of inactivity. Default: 3600 (1 hour). 

154 message_ttl: Message time-to-live in seconds. Undelivered messages are 

155 removed after this duration. Default: 600 (10 minutes). 

156 

157 Raises: 

158 ValueError: If backend is invalid, required URL is missing, or required packages are not installed. 

159 

160 Examples: 

161 >>> # Memory backend (default) 

162 >>> backend = SessionBackend() 

163 >>> backend._backend 

164 'memory' 

165 

166 >>> # Redis backend requires URL 

167 >>> try: 

168 ... backend = SessionBackend(backend='redis') 

169 ... except ValueError as e: 

170 ... 'redis_url' in str(e) 

171 True 

172 

173 >>> # Invalid backend 

174 >>> try: 

175 ... backend = SessionBackend(backend='invalid') 

176 ... except ValueError as e: 

177 ... 'Invalid backend' in str(e) 

178 True 

179 """ 

180 

181 self._backend = backend.lower() 

182 self._session_ttl = session_ttl 

183 self._message_ttl = message_ttl 

184 

185 # Set up backend-specific components 

186 if self._backend == "memory": 

187 # Nothing special needed for memory backend 

188 self._session_message: dict[str, Any] | None = None 

189 

190 elif self._backend == "none": 

191 # No session tracking - this is just a dummy registry 

192 logger.info("Session registry initialized with 'none' backend - session tracking disabled") 

193 

194 elif self._backend == "redis": 

195 if not REDIS_AVAILABLE: 

196 raise ValueError("Redis backend requested but redis package not installed") 

197 if not redis_url: 

198 raise ValueError("Redis backend requires redis_url") 

199 

200 # Redis client is set in initialize() via the shared factory 

201 self._redis: Optional[Redis] = None 

202 self._pubsub = None 

203 

204 elif self._backend == "database": 

205 if not SQLALCHEMY_AVAILABLE: 

206 raise ValueError("Database backend requested but SQLAlchemy not installed") 

207 if not database_url: 

208 raise ValueError("Database backend requires database_url") 

209 else: 

210 raise ValueError(f"Invalid backend: {backend}") 

211 

212 

213class SessionRegistry(SessionBackend): 

214 """Registry for SSE sessions with optional distributed state. 

215 

216 This class manages server-sent event (SSE) sessions, providing methods to add, 

217 remove, and query sessions. It supports multiple backend types for different 

218 deployment scenarios: 

219 

220 - **Single-process deployments**: Use 'memory' backend (default) 

221 - **Multi-worker deployments**: Use 'redis' or 'database' backend 

222 - **Testing/development**: Use 'none' backend to disable session tracking 

223 

224 The registry maintains a local cache of transport objects while using the 

225 shared backend to track session existence across workers. This enables 

226 horizontal scaling while keeping transport objects process-local. 

227 

228 Attributes: 

229 _sessions: Local dictionary mapping session IDs to transport objects 

230 _lock: Asyncio lock for thread-safe access to _sessions 

231 _cleanup_task: Background task for cleaning up expired sessions 

232 

233 Examples: 

234 >>> import asyncio 

235 >>> from mcpgateway.cache.session_registry import SessionRegistry 

236 >>> 

237 >>> class MockTransport: 

238 ... async def disconnect(self): 

239 ... print("Disconnected") 

240 ... async def is_connected(self): 

241 ... return True 

242 ... async def send_message(self, msg): 

243 ... print(f"Sent: {msg}") 

244 >>> 

245 >>> # Create registry and add session 

246 >>> reg = SessionRegistry(backend='memory') 

247 >>> transport = MockTransport() 

248 >>> asyncio.run(reg.add_session('test123', transport)) 

249 >>> 

250 >>> # Retrieve session 

251 >>> found = asyncio.run(reg.get_session('test123')) 

252 >>> found is transport 

253 True 

254 >>> 

255 >>> # Remove session 

256 >>> asyncio.run(reg.remove_session('test123')) 

257 Disconnected 

258 >>> asyncio.run(reg.get_session('test123')) is None 

259 True 

260 """ 

261 

262 def __init__( 

263 self, 

264 backend: str = "memory", 

265 redis_url: Optional[str] = None, 

266 database_url: Optional[str] = None, 

267 session_ttl: int = 3600, # 1 hour 

268 message_ttl: int = 600, # 10 min 

269 ): 

270 """Initialize session registry with specified backend. 

271 

272 Args: 

273 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'. 

274 redis_url: Redis connection URL. Required when backend='redis'. 

275 database_url: Database connection URL. Required when backend='database'. 

276 session_ttl: Session time-to-live in seconds. Default: 3600. 

277 message_ttl: Message time-to-live in seconds. Default: 600. 

278 

279 Examples: 

280 >>> # Default memory backend 

281 >>> reg = SessionRegistry() 

282 >>> reg._backend 

283 'memory' 

284 >>> isinstance(reg._sessions, dict) 

285 True 

286 

287 >>> # Redis backend with custom TTL 

288 >>> try: 

289 ... reg = SessionRegistry( 

290 ... backend='redis', 

291 ... redis_url='redis://localhost:6379', 

292 ... session_ttl=7200 

293 ... ) 

294 ... except ValueError: 

295 ... pass # Redis may not be available 

296 """ 

297 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl) 

298 self._sessions: Dict[str, Any] = {} # Local transport cache 

299 self._session_owners: Dict[str, str] = {} # Session owner email by session_id 

300 self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id 

301 self._respond_tasks: Dict[str, asyncio.Task] = {} # Track respond tasks for cancellation 

302 self._stuck_tasks: Dict[str, asyncio.Task] = {} # Tasks that couldn't be cancelled (for monitoring) 

303 self._closing_sessions: set[str] = set() # Sessions being closed - respond loop should exit 

304 self._lock = asyncio.Lock() 

305 self._cleanup_task: Task | None = None 

306 self._stuck_task_reaper: Task | None = None # Reaper for stuck tasks 

307 

308 def register_respond_task(self, session_id: str, task: asyncio.Task) -> None: 

309 """Register a respond task for later cancellation. 

310 

311 Associates an asyncio Task with a session_id so it can be cancelled 

312 when the session is removed. This prevents orphaned tasks that cause 

313 CPU spin loops. 

314 

315 Args: 

316 session_id: Session identifier the task belongs to. 

317 task: The asyncio Task to track. 

318 """ 

319 self._respond_tasks[session_id] = task 

320 logger.debug(f"Registered respond task for session {session_id}") 

321 

322 async def _cancel_respond_task(self, session_id: str, timeout: float = 5.0) -> None: 

323 """Cancel and await a respond task with timeout. 

324 

325 Safely cancels the respond task associated with a session. Uses a timeout 

326 to prevent hanging if the task doesn't respond to cancellation. 

327 

328 If initial cancellation times out, escalates by force-disconnecting the 

329 transport to unblock the task, then retries cancellation (Finding 1 fix). 

330 

331 Args: 

332 session_id: Session identifier whose task should be cancelled. 

333 timeout: Maximum seconds to wait for task cancellation. Default 5.0. 

334 """ 

335 task = self._respond_tasks.get(session_id) 

336 if task is None: 

337 return 

338 

339 if task.done(): 

340 # Task already finished - safe to remove from tracking 

341 self._respond_tasks.pop(session_id, None) 

342 try: 

343 task.result() 

344 except asyncio.CancelledError: 

345 pass 

346 except Exception as e: 

347 logger.warning(f"Respond task for {session_id} failed with: {e}") 

348 return 

349 

350 task.cancel() 

351 logger.debug(f"Cancelling respond task for session {session_id}") 

352 

353 try: 

354 await asyncio.wait_for(task, timeout=timeout) 

355 # Cancellation successful - remove from tracking 

356 self._respond_tasks.pop(session_id, None) 

357 logger.debug(f"Respond task cancelled for session {session_id}") 

358 except asyncio.TimeoutError: 

359 # ESCALATION (Finding 1): Force-disconnect transport to unblock the task 

360 logger.warning(f"Respond task cancellation timed out for {session_id}, " f"escalating with transport disconnect") 

361 

362 # Force-disconnect the transport to unblock any pending I/O 

363 transport = self._sessions.get(session_id) 

364 if transport and hasattr(transport, "disconnect"): 

365 try: 

366 await transport.disconnect() 

367 logger.debug(f"Force-disconnected transport for {session_id}") 

368 except Exception as e: 

369 logger.warning(f"Failed to force-disconnect transport for {session_id}: {e}") 

370 

371 # Retry cancellation with shorter timeout 

372 if not task.done(): 

373 try: 

374 await asyncio.wait_for(task, timeout=2.0) 

375 self._respond_tasks.pop(session_id, None) 

376 logger.info(f"Respond task cancelled after escalation for {session_id}") 

377 except asyncio.TimeoutError: 

378 # Still stuck - move to stuck_tasks for monitoring (Finding 2 fix) 

379 self._respond_tasks.pop(session_id, None) 

380 self._stuck_tasks[session_id] = task 

381 logger.error(f"Respond task for {session_id} still stuck after escalation, " f"moved to stuck_tasks for monitoring (total stuck: {len(self._stuck_tasks)})") 

382 except asyncio.CancelledError: 

383 self._respond_tasks.pop(session_id, None) 

384 logger.info(f"Respond task cancelled after escalation for {session_id}") 

385 except Exception as e: 

386 self._respond_tasks.pop(session_id, None) 

387 logger.warning(f"Error during retry cancellation for {session_id}: {e}") 

388 else: 

389 self._respond_tasks.pop(session_id, None) 

390 logger.debug(f"Respond task completed during escalation for {session_id}") 

391 

392 except asyncio.CancelledError: 

393 # Cancellation successful - remove from tracking 

394 self._respond_tasks.pop(session_id, None) 

395 logger.debug(f"Respond task cancelled for session {session_id}") 

396 except Exception as e: 

397 # Remove from tracking on unexpected error - task state unknown 

398 self._respond_tasks.pop(session_id, None) 

399 logger.warning(f"Error during respond task cancellation for {session_id}: {e}") 

400 

401 async def _reap_stuck_tasks(self) -> None: 

402 """Periodically clean up stuck tasks that have completed. 

403 

404 This reaper runs every 30 seconds and: 

405 1. Removes completed tasks from _stuck_tasks 

406 2. Retries cancellation for tasks that are still running 

407 3. Logs warnings for tasks that remain stuck 

408 

409 This prevents memory leaks from tasks that eventually complete after 

410 being moved to _stuck_tasks during escalation. 

411 

412 Raises: 

413 asyncio.CancelledError: If the task is cancelled during shutdown. 

414 """ 

415 reap_interval = 30.0 # seconds 

416 retry_timeout = 2.0 # seconds for retry cancellation 

417 

418 while True: 

419 try: 

420 await asyncio.sleep(reap_interval) 

421 

422 if not self._stuck_tasks: 

423 continue 

424 

425 # Collect completed and still-stuck tasks 

426 completed = [] 

427 still_stuck = [] 

428 

429 for session_id, task in list(self._stuck_tasks.items()): 

430 if task.done(): 

431 completed.append(session_id) 

432 try: 

433 task.result() # Consume result to avoid warnings 

434 except (asyncio.CancelledError, Exception): 

435 pass 

436 else: 

437 still_stuck.append((session_id, task)) 

438 

439 # Remove completed tasks 

440 for session_id in completed: 

441 self._stuck_tasks.pop(session_id, None) 

442 

443 if completed: 

444 logger.info(f"Reaped {len(completed)} completed stuck tasks") 

445 

446 # Retry cancellation for still-stuck tasks 

447 for session_id, task in still_stuck: 

448 task.cancel() 

449 try: 

450 await asyncio.wait_for(task, timeout=retry_timeout) 

451 self._stuck_tasks.pop(session_id, None) 

452 logger.info(f"Stuck task {session_id} finally cancelled during reap") 

453 except asyncio.TimeoutError: 

454 logger.warning(f"Task {session_id} still stuck after reap retry") 

455 except asyncio.CancelledError: 

456 self._stuck_tasks.pop(session_id, None) 

457 logger.info(f"Stuck task {session_id} cancelled during reap") 

458 except Exception as e: 

459 logger.warning(f"Error during stuck task reap for {session_id}: {e}") 

460 

461 if self._stuck_tasks: 

462 logger.warning(f"Stuck tasks remaining: {len(self._stuck_tasks)}") 

463 

464 except asyncio.CancelledError: 

465 logger.debug("Stuck task reaper cancelled") 

466 raise 

467 except Exception as e: 

468 logger.error(f"Error in stuck task reaper: {e}") 

469 

470 async def initialize(self) -> None: 

471 """Initialize the registry with async setup. 

472 

473 This method performs asynchronous initialization tasks that cannot be done 

474 in __init__. It starts background cleanup tasks and sets up pubsub 

475 subscriptions for distributed backends. 

476 

477 Call this during application startup after creating the registry instance. 

478 

479 Examples: 

480 >>> import asyncio 

481 >>> reg = SessionRegistry(backend='memory') 

482 >>> asyncio.run(reg.initialize()) 

483 >>> reg._cleanup_task is not None 

484 True 

485 >>> 

486 >>> # Cleanup 

487 >>> asyncio.run(reg.shutdown()) 

488 """ 

489 logger.info(f"Initializing session registry with backend: {self._backend}") 

490 

491 if self._backend == "database": 

492 # Start database cleanup task 

493 self._cleanup_task = asyncio.create_task(self._db_cleanup_task()) 

494 logger.info("Database cleanup task started") 

495 

496 elif self._backend == "redis": 

497 # Get shared Redis client from factory 

498 self._redis = await get_redis_client() 

499 if self._redis: 

500 self._pubsub = self._redis.pubsub() 

501 await self._pubsub.subscribe("mcp_session_events") 

502 logger.info("Session registry connected to shared Redis client") 

503 

504 elif self._backend == "none": 

505 # Nothing to initialize for none backend 

506 pass 

507 

508 # Memory backend needs session cleanup 

509 elif self._backend == "memory": 

510 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task()) 

511 logger.info("Memory cleanup task started") 

512 

513 # Start stuck task reaper for all backends 

514 self._stuck_task_reaper = asyncio.create_task(self._reap_stuck_tasks()) 

515 logger.info("Stuck task reaper started") 

516 

517 async def shutdown(self) -> None: 

518 """Shutdown the registry and clean up resources. 

519 

520 This method cancels background tasks and closes connections to external 

521 services. Call this during application shutdown to ensure clean termination. 

522 

523 Examples: 

524 >>> import asyncio 

525 >>> reg = SessionRegistry() 

526 >>> asyncio.run(reg.initialize()) 

527 >>> task_was_created = reg._cleanup_task is not None 

528 >>> asyncio.run(reg.shutdown()) 

529 >>> # After shutdown, cleanup task should be handled (cancelled or done) 

530 >>> task_was_created and (reg._cleanup_task.cancelled() or reg._cleanup_task.done()) 

531 True 

532 """ 

533 logger.info("Shutting down session registry") 

534 

535 # Cancel cleanup task 

536 if self._cleanup_task: 

537 self._cleanup_task.cancel() 

538 try: 

539 await self._cleanup_task 

540 except asyncio.CancelledError: 

541 pass 

542 

543 # Cancel stuck task reaper 

544 if self._stuck_task_reaper: 

545 self._stuck_task_reaper.cancel() 

546 try: 

547 await self._stuck_task_reaper 

548 except asyncio.CancelledError: 

549 pass 

550 

551 # CRITICAL: Cancel ALL respond tasks to prevent CPU spin loops 

552 if self._respond_tasks: 

553 logger.info(f"Cancelling {len(self._respond_tasks)} respond tasks") 

554 tasks_to_cancel = list(self._respond_tasks.values()) 

555 self._respond_tasks.clear() 

556 

557 for task in tasks_to_cancel: 

558 if not task.done(): 

559 task.cancel() 

560 

561 if tasks_to_cancel: 

562 try: 

563 await asyncio.wait_for(asyncio.gather(*tasks_to_cancel, return_exceptions=True), timeout=10.0) 

564 logger.info("All respond tasks cancelled successfully") 

565 except asyncio.TimeoutError: 

566 logger.warning("Timeout waiting for respond tasks to cancel") 

567 

568 # Also cancel any stuck tasks (tasks that previously couldn't be cancelled) 

569 if self._stuck_tasks: 

570 logger.warning(f"Attempting final cancellation of {len(self._stuck_tasks)} stuck tasks") 

571 stuck_to_cancel = list(self._stuck_tasks.values()) 

572 self._stuck_tasks.clear() 

573 

574 for task in stuck_to_cancel: 

575 if not task.done(): 

576 task.cancel() 

577 

578 if stuck_to_cancel: 

579 try: 

580 await asyncio.wait_for(asyncio.gather(*stuck_to_cancel, return_exceptions=True), timeout=5.0) 

581 logger.info("Stuck tasks cancelled during shutdown") 

582 except asyncio.TimeoutError: 

583 logger.error("Some stuck tasks could not be cancelled during shutdown") 

584 

585 # Close Redis pubsub (but not the shared client) 

586 # Use timeout to prevent blocking if pubsub doesn't close cleanly 

587 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout 

588 if self._backend == "redis" and getattr(self, "_pubsub", None): 

589 try: 

590 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout) 

591 except asyncio.TimeoutError: 

592 logger.warning("Redis pubsub close timed out - proceeding anyway") 

593 except Exception as e: 

594 logger.error(f"Error closing Redis pubsub: {e}") 

595 # Don't close self._redis - it's the shared client managed by redis_client.py 

596 self._redis = None 

597 self._pubsub = None 

598 

599 async def add_session(self, session_id: str, transport: SSETransport) -> None: 

600 """Add a session to the registry. 

601 

602 Stores the session in both the local cache and the distributed backend 

603 (if configured). For distributed backends, this notifies other workers 

604 about the new session. 

605 

606 Args: 

607 session_id: Unique session identifier. Should be a UUID or similar 

608 unique string to avoid collisions. 

609 transport: SSE transport object for this session. Must implement 

610 the SSETransport interface. 

611 

612 Examples: 

613 >>> import asyncio 

614 >>> from mcpgateway.cache.session_registry import SessionRegistry 

615 >>> 

616 >>> class MockTransport: 

617 ... async def disconnect(self): 

618 ... print(f"Transport disconnected") 

619 ... async def is_connected(self): 

620 ... return True 

621 >>> 

622 >>> reg = SessionRegistry() 

623 >>> transport = MockTransport() 

624 >>> asyncio.run(reg.add_session('test-456', transport)) 

625 >>> 

626 >>> # Found in local cache 

627 >>> found = asyncio.run(reg.get_session('test-456')) 

628 >>> found is transport 

629 True 

630 >>> 

631 >>> # Remove session 

632 >>> asyncio.run(reg.remove_session('test-456')) 

633 Transport disconnected 

634 """ 

635 # Skip for none backend 

636 if self._backend == "none": 

637 return 

638 

639 async with self._lock: 

640 self._sessions[session_id] = transport 

641 

642 if self._backend == "redis": 

643 # Store session marker in Redis 

644 if not self._redis: 

645 logger.warning(f"Redis client not initialized, skipping distributed session tracking for {session_id}") 

646 return 

647 try: 

648 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1") 

649 # Publish event to notify other workers 

650 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()})) 

651 except Exception as e: 

652 logger.error(f"Redis error adding session {session_id}: {e}") 

653 

654 elif self._backend == "database": 

655 # Store session in database 

656 try: 

657 

658 def _db_add() -> None: 

659 """Store session record in the database. 

660 

661 Creates a new SessionRecord entry in the database for tracking 

662 distributed session state. Uses a fresh database connection from 

663 the connection pool. 

664 

665 This inner function is designed to be run in a thread executor 

666 to avoid blocking the async event loop during database I/O. 

667 

668 Raises: 

669 Exception: Any database error is re-raised after rollback. 

670 Common errors include duplicate session_id (unique constraint) 

671 or database connection issues. 

672 

673 Examples: 

674 >>> # This function is called internally by add_session() 

675 >>> # When executed, it creates a database record: 

676 >>> # SessionRecord(session_id='abc123', created_at=now()) 

677 """ 

678 db_session = next(get_db()) 

679 try: 

680 session_record = SessionRecord(session_id=session_id) 

681 db_session.add(session_record) 

682 db_session.commit() 

683 except Exception as ex: 

684 db_session.rollback() 

685 raise ex 

686 finally: 

687 db_session.close() 

688 

689 await asyncio.to_thread(_db_add) 

690 except Exception as e: 

691 logger.error(f"Database error adding session {session_id}: {e}") 

692 

693 logger.info(f"Added session: {session_id}") 

694 

695 def _session_owner_key(self, session_id: str) -> str: 

696 """Return Redis key used to store session ownership. 

697 

698 Args: 

699 session_id: Session identifier. 

700 

701 Returns: 

702 Redis key for the session owner mapping. 

703 """ 

704 return f"mcp:session_owner:{session_id}" 

705 

706 async def set_session_owner(self, session_id: str, owner_email: Optional[str]) -> None: 

707 """Set or clear owner for a session. 

708 

709 Args: 

710 session_id: Session identifier to update. 

711 owner_email: Owner email to set. Passing ``None`` clears ownership. 

712 

713 Returns: 

714 None. 

715 """ 

716 # Skip for none backend 

717 if self._backend == "none": 

718 return 

719 

720 if owner_email: 

721 self._session_owners[session_id] = owner_email 

722 else: 

723 self._session_owners.pop(session_id, None) 

724 

725 if self._backend == "redis": 

726 if not self._redis: 

727 logger.warning(f"Redis client not initialized, cannot set owner for session {session_id}") 

728 return 

729 try: 

730 owner_key = self._session_owner_key(session_id) 

731 if owner_email: 

732 await self._redis.setex(owner_key, self._session_ttl, owner_email) 

733 else: 

734 await self._redis.delete(owner_key) 

735 except Exception as e: 

736 logger.error(f"Redis error setting owner for session {session_id}: {e}") 

737 

738 elif self._backend == "database": 

739 try: 

740 

741 def _db_set_owner() -> None: 

742 """Persist owner metadata for a session in the database backend. 

743 

744 Raises: 

745 Exception: Propagates database write failures to the caller. 

746 """ 

747 db_session = next(get_db()) 

748 try: 

749 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

750 if not record: 

751 record = SessionRecord(session_id=session_id, data=None) 

752 db_session.add(record) 

753 db_session.flush() 

754 

755 record_data: Dict[str, Any] = {} 

756 if record.data: 

757 try: 

758 parsed = orjson.loads(record.data) 

759 if isinstance(parsed, dict): 

760 record_data = parsed 

761 except Exception: 

762 record_data = {} 

763 

764 if owner_email: 

765 record_data["owner_email"] = owner_email 

766 else: 

767 record_data.pop("owner_email", None) 

768 

769 record.data = orjson.dumps(record_data).decode() if record_data else None 

770 db_session.commit() 

771 except Exception as ex: 

772 db_session.rollback() 

773 raise ex 

774 finally: 

775 db_session.close() 

776 

777 await asyncio.to_thread(_db_set_owner) 

778 except Exception as e: 

779 logger.error(f"Database error setting owner for session {session_id}: {e}") 

780 

781 async def claim_session_owner(self, session_id: str, owner_email: str) -> Optional[str]: 

782 """Atomically claim ownership for a session and return the effective owner. 

783 

784 This method provides compare-and-set semantics: 

785 - If a session owner already exists, return the existing owner. 

786 - If no owner exists, claim ownership for ``owner_email``. 

787 

788 Args: 

789 session_id: Session identifier to claim. 

790 owner_email: Requesting owner email. 

791 

792 Returns: 

793 Effective owner email after the claim operation, or ``None`` if owner 

794 metadata could not be verified due to backend availability issues. 

795 """ 

796 if self._backend == "none": 

797 return owner_email 

798 

799 # Fast local cache path. 

800 cached_owner = self._session_owners.get(session_id) 

801 if cached_owner: 

802 return cached_owner 

803 

804 if self._backend == "memory": 

805 async with self._lock: 

806 existing_owner = self._session_owners.get(session_id) 

807 if existing_owner: 

808 return existing_owner 

809 self._session_owners[session_id] = owner_email 

810 return owner_email 

811 

812 if self._backend == "redis": 

813 if not self._redis: 

814 logger.warning("Redis client not initialized, session owner claim unavailable for %s", session_id) 

815 return None 

816 

817 owner_key = self._session_owner_key(session_id) 

818 try: 

819 claimed = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True) 

820 if claimed: 

821 self._session_owners[session_id] = owner_email 

822 return owner_email 

823 

824 owner_raw = await self._redis.get(owner_key) 

825 if owner_raw is not None: 

826 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw) 

827 if existing_owner: 

828 self._session_owners[session_id] = existing_owner 

829 return existing_owner 

830 

831 # Handle key expiry/race window by retrying once. 

832 claimed_retry = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True) 

833 if claimed_retry: 

834 self._session_owners[session_id] = owner_email 

835 return owner_email 

836 

837 owner_raw = await self._redis.get(owner_key) 

838 if owner_raw is None: 

839 return None 

840 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw) 

841 if existing_owner: 

842 self._session_owners[session_id] = existing_owner 

843 return existing_owner 

844 return None 

845 except Exception as e: 

846 logger.error("Redis error claiming owner for session %s: %s", session_id, e) 

847 return None 

848 

849 if self._backend == "database": 

850 try: 

851 

852 def _db_claim_owner() -> Optional[str]: 

853 """Claim owner in DB with optimistic compare-and-set retries. 

854 

855 Returns: 

856 Effective owner email or ``None`` when claim cannot be verified. 

857 """ 

858 db_session = next(get_db()) 

859 try: 

860 for _attempt in range(3): 

861 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

862 if not record: 

863 owner_payload = {"owner_email": owner_email} 

864 new_record = SessionRecord(session_id=session_id, data=orjson.dumps(owner_payload).decode()) 

865 db_session.add(new_record) 

866 try: 

867 db_session.commit() 

868 return owner_email 

869 except Exception: 

870 db_session.rollback() 

871 # Another writer may have inserted concurrently. Retry. 

872 continue 

873 

874 record_data: Dict[str, Any] = {} 

875 if record.data: 

876 try: 

877 parsed = orjson.loads(record.data) 

878 if isinstance(parsed, dict): 

879 record_data = parsed 

880 except Exception: 

881 record_data = {} 

882 

883 existing_owner = record_data.get("owner_email") 

884 if isinstance(existing_owner, str) and existing_owner: 

885 return existing_owner 

886 

887 current_data = record.data 

888 updated_data = dict(record_data) 

889 updated_data["owner_email"] = owner_email 

890 serialized = orjson.dumps(updated_data).decode() 

891 

892 update_query = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id) 

893 if current_data is None: 

894 update_query = update_query.filter(SessionRecord.data.is_(None)) 

895 else: 

896 update_query = update_query.filter(SessionRecord.data == current_data) 

897 

898 updated_rows = update_query.update({"data": serialized}, synchronize_session=False) 

899 if updated_rows == 1: 

900 db_session.commit() 

901 return owner_email 

902 

903 db_session.rollback() 

904 

905 return None 

906 finally: 

907 db_session.close() 

908 

909 claimed_owner = await asyncio.to_thread(_db_claim_owner) 

910 if claimed_owner: 

911 self._session_owners[session_id] = claimed_owner 

912 return claimed_owner 

913 except Exception as e: 

914 logger.error("Database error claiming owner for session %s: %s", session_id, e) 

915 return None 

916 

917 return None 

918 

919 async def get_session_owner(self, session_id: str) -> Optional[str]: 

920 """Get owner email for a session. 

921 

922 Args: 

923 session_id: Session identifier to resolve. 

924 

925 Returns: 

926 Owner email when present, otherwise ``None``. 

927 """ 

928 owner = self._session_owners.get(session_id) 

929 if owner: 

930 return owner 

931 

932 if self._backend == "redis": 

933 if not self._redis: 

934 return None 

935 try: 

936 owner_raw = await self._redis.get(self._session_owner_key(session_id)) 

937 if owner_raw is None: 

938 return None 

939 if isinstance(owner_raw, bytes): 

940 owner = owner_raw.decode() 

941 else: 

942 owner = str(owner_raw) 

943 if owner: 

944 self._session_owners[session_id] = owner 

945 return owner or None 

946 except Exception as e: 

947 logger.error(f"Redis error getting owner for session {session_id}: {e}") 

948 return None 

949 

950 if self._backend == "database": 

951 try: 

952 

953 def _db_get_owner() -> Optional[str]: 

954 db_session = next(get_db()) 

955 try: 

956 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

957 if not record or not record.data: 

958 return None 

959 try: 

960 data = orjson.loads(record.data) 

961 except Exception: 

962 return None 

963 if not isinstance(data, dict): 

964 return None 

965 owner_email = data.get("owner_email") 

966 return owner_email if isinstance(owner_email, str) and owner_email else None 

967 finally: 

968 db_session.close() 

969 

970 owner = await asyncio.to_thread(_db_get_owner) 

971 if owner: 

972 self._session_owners[session_id] = owner 

973 return owner 

974 except Exception as e: 

975 logger.error(f"Database error getting owner for session {session_id}: {e}") 

976 return None 

977 

978 return None 

979 

980 async def session_exists(self, session_id: str) -> Optional[bool]: 

981 """Return whether a session marker exists. 

982 

983 Args: 

984 session_id: Session identifier to resolve. 

985 

986 Returns: 

987 ``True`` when session exists, ``False`` when session does not exist, 

988 and ``None`` when existence cannot be verified due to backend errors. 

989 """ 

990 if self._backend == "none": 

991 return False 

992 

993 async with self._lock: 

994 if session_id in self._sessions: 

995 return True 

996 

997 if self._backend == "memory": 

998 return False 

999 

1000 if self._backend == "redis": 

1001 if not self._redis: 

1002 return None 

1003 try: 

1004 return bool(await self._redis.exists(f"mcp:session:{session_id}")) 

1005 except Exception as e: 

1006 logger.error("Redis error checking existence for session %s: %s", session_id, e) 

1007 return None 

1008 

1009 if self._backend == "database": 

1010 try: 

1011 

1012 def _db_exists() -> bool: 

1013 """Check whether a session record exists in the database backend. 

1014 

1015 Returns: 

1016 ``True`` when a matching session record exists. 

1017 """ 

1018 db_session = next(get_db()) 

1019 try: 

1020 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1021 return record is not None 

1022 finally: 

1023 db_session.close() 

1024 

1025 return await asyncio.to_thread(_db_exists) 

1026 except Exception as e: 

1027 logger.error("Database error checking existence for session %s: %s", session_id, e) 

1028 return None 

1029 

1030 return False 

1031 

1032 async def get_session(self, session_id: str) -> Any: 

1033 """Get session transport by ID. 

1034 

1035 First checks the local cache for the transport object. If not found locally 

1036 but using a distributed backend, checks if the session exists on another 

1037 worker. 

1038 

1039 Args: 

1040 session_id: Session identifier to look up. 

1041 

1042 Returns: 

1043 SSETransport object if found locally, None if not found or exists 

1044 on another worker. 

1045 

1046 Examples: 

1047 >>> import asyncio 

1048 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1049 >>> 

1050 >>> class MockTransport: 

1051 ... pass 

1052 >>> 

1053 >>> reg = SessionRegistry() 

1054 >>> transport = MockTransport() 

1055 >>> asyncio.run(reg.add_session('test-456', transport)) 

1056 >>> 

1057 >>> # Found in local cache 

1058 >>> found = asyncio.run(reg.get_session('test-456')) 

1059 >>> found is transport 

1060 True 

1061 >>> 

1062 >>> # Not found 

1063 >>> asyncio.run(reg.get_session('nonexistent')) is None 

1064 True 

1065 """ 

1066 # Skip for none backend 

1067 if self._backend == "none": 

1068 return None 

1069 

1070 # First check local cache 

1071 async with self._lock: 

1072 transport = self._sessions.get(session_id) 

1073 if transport: 

1074 logger.info(f"Session {session_id} exists in local cache") 

1075 return transport 

1076 

1077 # If not in local cache, check if it exists in shared backend 

1078 if self._backend == "redis": 

1079 if not self._redis: 

1080 return None 

1081 try: 

1082 exists = await self._redis.exists(f"mcp:session:{session_id}") 

1083 session_exists = bool(exists) 

1084 if session_exists: 

1085 logger.info(f"Session {session_id} exists in Redis but not in local cache") 

1086 return None # We don't have the transport locally 

1087 except Exception as e: 

1088 logger.error(f"Redis error checking session {session_id}: {e}") 

1089 return None 

1090 

1091 elif self._backend == "database": 

1092 try: 

1093 

1094 def _db_check() -> bool: 

1095 """Check if a session exists in the database. 

1096 

1097 Queries the SessionRecord table to determine if a session with 

1098 the given session_id exists. This is used when the session is not 

1099 found in the local cache to check if it exists on another worker. 

1100 

1101 This inner function is designed to be run in a thread executor 

1102 to avoid blocking the async event loop during database queries. 

1103 

1104 Returns: 

1105 bool: True if the session exists in the database, False otherwise. 

1106 

1107 Examples: 

1108 >>> # This function is called internally by get_session() 

1109 >>> # Returns True if SessionRecord with session_id exists 

1110 >>> # Returns False if no matching record found 

1111 """ 

1112 db_session = next(get_db()) 

1113 try: 

1114 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1115 return record is not None 

1116 finally: 

1117 db_session.close() 

1118 

1119 exists = await asyncio.to_thread(_db_check) 

1120 if exists: 

1121 logger.info(f"Session {session_id} exists in database but not in local cache") 

1122 return None 

1123 except Exception as e: 

1124 logger.error(f"Database error checking session {session_id}: {e}") 

1125 return None 

1126 

1127 return None 

1128 

1129 async def remove_session(self, session_id: str) -> None: 

1130 """Remove a session from the registry. 

1131 

1132 Removes the session from both local cache and distributed backend. 

1133 If a transport is found locally, it will be disconnected before removal. 

1134 For distributed backends, notifies other workers about the removal. 

1135 

1136 Args: 

1137 session_id: Session identifier to remove. 

1138 

1139 Examples: 

1140 >>> import asyncio 

1141 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1142 >>> 

1143 >>> class MockTransport: 

1144 ... async def disconnect(self): 

1145 ... print(f"Transport disconnected") 

1146 ... async def is_connected(self): 

1147 ... return True 

1148 >>> 

1149 >>> reg = SessionRegistry() 

1150 >>> transport = MockTransport() 

1151 >>> asyncio.run(reg.add_session('remove-test', transport)) 

1152 >>> asyncio.run(reg.remove_session('remove-test')) 

1153 Transport disconnected 

1154 >>> 

1155 >>> # Session no longer exists 

1156 >>> asyncio.run(reg.get_session('remove-test')) is None 

1157 True 

1158 """ 

1159 # Skip for none backend 

1160 if self._backend == "none": 

1161 return 

1162 

1163 # Mark session as closing FIRST so respond loop can exit early 

1164 # This allows the loop to exit without waiting for cancellation to complete 

1165 self._closing_sessions.add(session_id) 

1166 

1167 try: 

1168 # CRITICAL: Cancel respond task before any cleanup 

1169 # This prevents orphaned tasks that cause CPU spin loops 

1170 await self._cancel_respond_task(session_id) 

1171 

1172 # Clean up local transport 

1173 transport = None 

1174 async with self._lock: 

1175 if session_id in self._sessions: 

1176 transport = self._sessions.pop(session_id) 

1177 self._session_owners.pop(session_id, None) 

1178 # Also clean up client capabilities 

1179 if session_id in self._client_capabilities: 

1180 self._client_capabilities.pop(session_id) 

1181 logger.debug(f"Removed capabilities for session {session_id}") 

1182 finally: 

1183 # Always remove from closing set 

1184 self._closing_sessions.discard(session_id) 

1185 

1186 # Disconnect transport if found 

1187 if transport: 

1188 try: 

1189 await transport.disconnect() 

1190 except Exception as e: 

1191 logger.error(f"Error disconnecting transport for session {session_id}: {e}") 

1192 

1193 # Remove from shared backend 

1194 if self._backend == "redis": 

1195 if not self._redis: 

1196 return 

1197 try: 

1198 await self._redis.delete(f"mcp:session:{session_id}") 

1199 await self._redis.delete(self._session_owner_key(session_id)) 

1200 # Notify other workers 

1201 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()})) 

1202 except Exception as e: 

1203 logger.error(f"Redis error removing session {session_id}: {e}") 

1204 

1205 elif self._backend == "database": 

1206 try: 

1207 

1208 def _db_remove() -> None: 

1209 """Delete session record from the database. 

1210 

1211 Removes the SessionRecord entry with the specified session_id 

1212 from the database. This is called when a session is being 

1213 terminated or has expired. 

1214 

1215 This inner function is designed to be run in a thread executor 

1216 to avoid blocking the async event loop during database operations. 

1217 

1218 Raises: 

1219 Exception: Any database error is re-raised after rollback. 

1220 This includes connection errors or constraint violations. 

1221 

1222 Examples: 

1223 >>> # This function is called internally by remove_session() 

1224 >>> # Deletes the SessionRecord where session_id matches 

1225 >>> # No error if session_id doesn't exist (idempotent) 

1226 """ 

1227 db_session = next(get_db()) 

1228 try: 

1229 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete() 

1230 db_session.commit() 

1231 except Exception as ex: 

1232 db_session.rollback() 

1233 raise ex 

1234 finally: 

1235 db_session.close() 

1236 

1237 await asyncio.to_thread(_db_remove) 

1238 except Exception as e: 

1239 logger.error(f"Database error removing session {session_id}: {e}") 

1240 

1241 logger.info(f"Removed session: {session_id}") 

1242 

1243 async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None: 

1244 """Broadcast a message to a session. 

1245 

1246 Sends a message to the specified session. The behavior depends on the backend: 

1247 

1248 - **memory**: Stores message temporarily for local delivery 

1249 - **redis**: Publishes message to Redis channel for the session 

1250 - **database**: Stores message in database for polling by worker with session 

1251 - **none**: No operation 

1252 

1253 This method is used for inter-process communication in distributed deployments. 

1254 

1255 Args: 

1256 session_id: Target session identifier. 

1257 message: Message to broadcast. Can be a dict, list, or any JSON-serializable object. 

1258 

1259 Examples: 

1260 >>> import asyncio 

1261 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1262 >>> 

1263 >>> reg = SessionRegistry(backend='memory') 

1264 >>> message = {'method': 'tools/list', 'id': 1} 

1265 >>> asyncio.run(reg.broadcast('session-789', message)) 

1266 >>> 

1267 >>> # Message stored for memory backend 

1268 >>> reg._session_message is not None 

1269 True 

1270 >>> reg._session_message['session_id'] 

1271 'session-789' 

1272 >>> orjson.loads(reg._session_message['message'])['message'] == message 

1273 True 

1274 """ 

1275 # Skip for none backend only 

1276 if self._backend == "none": 

1277 return 

1278 

1279 def _build_payload(msg: Any) -> str: 

1280 """Build a JSON payload for message broadcasting. 

1281 

1282 Args: 

1283 msg: Message to wrap in payload envelope. 

1284 

1285 Returns: 

1286 JSON-encoded string containing type, message, and timestamp. 

1287 """ 

1288 payload = {"type": "message", "message": msg, "timestamp": time.time()} 

1289 return orjson.dumps(payload).decode() 

1290 

1291 if self._backend == "memory": 

1292 payload_json = _build_payload(message) 

1293 self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": payload_json} 

1294 

1295 elif self._backend == "redis": 

1296 if not self._redis: 

1297 logger.warning(f"Redis client not initialized, cannot broadcast to {session_id}") 

1298 return 

1299 try: 

1300 broadcast_payload = { 

1301 "type": "message", 

1302 "message": message, # Keep as original type, not pre-encoded 

1303 "timestamp": time.time(), 

1304 } 

1305 # Single encode 

1306 payload_json = orjson.dumps(broadcast_payload) 

1307 await self._redis.publish(session_id, payload_json) # Single encode 

1308 except Exception as e: 

1309 logger.error(f"Redis error during broadcast: {e}") 

1310 elif self._backend == "database": 

1311 try: 

1312 msg_json = _build_payload(message) 

1313 

1314 def _db_add() -> None: 

1315 """Store message in the database for inter-process communication. 

1316 

1317 Creates a new SessionMessageRecord entry containing the session_id 

1318 and serialized message. This enables message passing between 

1319 different worker processes through the shared database. 

1320 

1321 This inner function is designed to be run in a thread executor 

1322 to avoid blocking the async event loop during database writes. 

1323 

1324 Raises: 

1325 Exception: Any database error is re-raised after rollback. 

1326 Common errors include database connection issues or 

1327 constraints violations. 

1328 

1329 Examples: 

1330 >>> # This function is called internally by broadcast() 

1331 >>> # Creates a record like: 

1332 >>> # SessionMessageRecord( 

1333 >>> # session_id='abc123', 

1334 >>> # message='{"method": "ping", "id": 1}', 

1335 >>> # created_at=now() 

1336 >>> # ) 

1337 """ 

1338 db_session = next(get_db()) 

1339 try: 

1340 message_record = SessionMessageRecord(session_id=session_id, message=msg_json) 

1341 db_session.add(message_record) 

1342 db_session.commit() 

1343 except Exception as ex: 

1344 db_session.rollback() 

1345 raise ex 

1346 finally: 

1347 db_session.close() 

1348 

1349 await asyncio.to_thread(_db_add) 

1350 except Exception as e: 

1351 logger.error(f"Database error during broadcast: {e}") 

1352 

1353 async def _register_session_mapping(self, session_id: str, message: Dict[str, Any], user_email: Optional[str] = None) -> None: 

1354 """Register session mapping for session affinity when tools are called. 

1355 

1356 This method is called on the worker that executes the request (the SSE session 

1357 owner) to pre-register the mapping between a downstream session ID and the 

1358 upstream MCP session pool key. This enables session affinity in multi-worker 

1359 deployments. 

1360 

1361 Only registers mappings for tools/call methods - list operations and other 

1362 methods don't need session affinity since they don't maintain state. 

1363 

1364 Args: 

1365 session_id: The downstream SSE session ID. 

1366 message: The MCP protocol message being broadcast. 

1367 user_email: Optional user email for session isolation. 

1368 """ 

1369 # Skip if session affinity is disabled 

1370 if not settings.mcpgateway_session_affinity_enabled: 

1371 return 

1372 

1373 # Only register for tools/call - other methods don't need session affinity 

1374 method = message.get("method") 

1375 if method != "tools/call": 

1376 return 

1377 

1378 # Extract tool name from params 

1379 params = message.get("params", {}) 

1380 tool_name = params.get("name") 

1381 if not tool_name: 

1382 return 

1383 

1384 try: 

1385 # Look up tool in cache to get gateway info 

1386 # First-Party 

1387 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

1388 

1389 tool_info = await tool_lookup_cache.get(tool_name) 

1390 if not tool_info: 

1391 logger.debug(f"Tool {tool_name} not found in cache, skipping session mapping registration") 

1392 return 

1393 

1394 # Extract gateway information 

1395 gateway = tool_info.get("gateway", {}) 

1396 gateway_url = gateway.get("url") 

1397 gateway_id = gateway.get("id") 

1398 transport = gateway.get("transport") 

1399 

1400 if not gateway_url or not gateway_id or not transport: 

1401 logger.debug(f"Incomplete gateway info for tool {tool_name}, skipping session mapping registration") 

1402 return 

1403 

1404 # Register the session mapping with the pool 

1405 # First-Party 

1406 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel 

1407 

1408 pool = get_mcp_session_pool() 

1409 await pool.register_session_mapping( 

1410 session_id, 

1411 gateway_url, 

1412 gateway_id, 

1413 transport, 

1414 user_email, 

1415 ) 

1416 

1417 logger.debug(f"Registered session mapping for session {session_id[:8]}... -> {gateway_url} (tool: {tool_name})") 

1418 

1419 except Exception as e: 

1420 # Don't fail the broadcast if session mapping registration fails 

1421 logger.warning(f"Failed to register session mapping for {session_id[:8]}...: {e}") 

1422 

1423 async def get_all_session_ids(self) -> list[str]: 

1424 """Return a snapshot list of all known local session IDs. 

1425 

1426 Returns: 

1427 list[str]: A snapshot list of currently known local session IDs. 

1428 """ 

1429 async with self._lock: 

1430 return list(self._sessions.keys()) 

1431 

1432 def get_session_sync(self, session_id: str) -> Any: 

1433 """Get session synchronously from local cache only. 

1434 

1435 This is a non-blocking method that only checks the local cache, 

1436 not the distributed backend. Use this when you need quick access 

1437 and know the session should be local. 

1438 

1439 Args: 

1440 session_id: Session identifier to look up. 

1441 

1442 Returns: 

1443 SSETransport object if found in local cache, None otherwise. 

1444 

1445 Examples: 

1446 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1447 >>> import asyncio 

1448 >>> 

1449 >>> class MockTransport: 

1450 ... pass 

1451 >>> 

1452 >>> reg = SessionRegistry() 

1453 >>> transport = MockTransport() 

1454 >>> asyncio.run(reg.add_session('sync-test', transport)) 

1455 >>> 

1456 >>> # Synchronous lookup 

1457 >>> found = reg.get_session_sync('sync-test') 

1458 >>> found is transport 

1459 True 

1460 >>> 

1461 >>> # Not found 

1462 >>> reg.get_session_sync('nonexistent') is None 

1463 True 

1464 """ 

1465 # Skip for none backend 

1466 if self._backend == "none": 

1467 return None 

1468 

1469 return self._sessions.get(session_id) 

1470 

1471 async def respond( 

1472 self, 

1473 server_id: Optional[str], 

1474 user: Dict[str, Any], 

1475 session_id: str, 

1476 ) -> None: 

1477 """Process and respond to broadcast messages for a session. 

1478 

1479 This method listens for messages directed to the specified session and 

1480 generates appropriate responses. The listening mechanism depends on the backend: 

1481 

1482 - **memory**: Checks the temporary message storage 

1483 - **redis**: Subscribes to Redis pubsub channel 

1484 - **database**: Polls database for new messages 

1485 

1486 When a message is received and the transport exists locally, it processes 

1487 the message and sends the response through the transport. 

1488 

1489 Args: 

1490 server_id: Optional server identifier for scoped operations. 

1491 user: User information including authentication token. 

1492 session_id: Session identifier to respond for. 

1493 

1494 Raises: 

1495 asyncio.CancelledError: When the respond task is cancelled (e.g., on session removal). 

1496 

1497 Examples: 

1498 >>> import asyncio 

1499 >>> from mcpgateway.cache.session_registry import SessionRegistry 

1500 >>> 

1501 >>> # This method is typically called internally by the SSE handler 

1502 >>> reg = SessionRegistry() 

1503 >>> user = {'token': 'test-token'} 

1504 >>> # asyncio.run(reg.respond(None, user, 'session-id')) 

1505 """ 

1506 

1507 if self._backend == "none": 

1508 pass 

1509 

1510 elif self._backend == "memory": 

1511 transport = self.get_session_sync(session_id) 

1512 if transport and self._session_message: 

1513 message_json = self._session_message.get("message") 

1514 if message_json: 

1515 data = orjson.loads(message_json) 

1516 if isinstance(data, dict) and "message" in data: 

1517 message = data["message"] 

1518 else: 

1519 message = data 

1520 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user) 

1521 else: 

1522 logger.warning(f"Session message stored but message content is None for session {session_id}") 

1523 

1524 elif self._backend == "redis": 

1525 if not self._redis: 

1526 logger.warning(f"Redis client not initialized, cannot respond to {session_id}") 

1527 return 

1528 pubsub = self._redis.pubsub() 

1529 await pubsub.subscribe(session_id) 

1530 

1531 # Use timeout-based polling instead of infinite listen() to allow exit checks 

1532 # This is critical for allowing cancellation to work (Finding 2) 

1533 poll_timeout = 1.0 # Check every second if session still exists 

1534 

1535 try: 

1536 while True: 

1537 # Check if session still exists or is closing - exit early 

1538 if session_id not in self._sessions or session_id in self._closing_sessions: 

1539 logger.info(f"Session {session_id} removed or closing, exiting Redis respond loop") 

1540 break 

1541 

1542 # Use get_message with timeout instead of blocking listen() 

1543 try: 

1544 msg = await asyncio.wait_for( 

1545 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), timeout=poll_timeout + 0.5 # Slightly longer to account for Redis timeout 

1546 ) 

1547 except asyncio.TimeoutError: 

1548 # No message, loop back to check session existence 

1549 continue 

1550 

1551 if msg is None: 

1552 # CRITICAL: Sleep to prevent tight loop when get_message returns immediately 

1553 # This can happen in certain Redis states after disconnects 

1554 await asyncio.sleep(0.1) 

1555 continue 

1556 if msg["type"] != "message": 

1557 # Sleep on non-message types to prevent spin in edge cases 

1558 await asyncio.sleep(0.1) 

1559 continue 

1560 

1561 data = orjson.loads(msg["data"]) 

1562 message = data.get("message", {}) 

1563 transport = self.get_session_sync(session_id) 

1564 if transport: 

1565 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user) 

1566 except asyncio.CancelledError: 

1567 logger.info(f"PubSub listener for session {session_id} cancelled") 

1568 raise # Re-raise to properly complete cancellation 

1569 except Exception as e: 

1570 logger.error(f"PubSub listener error for session {session_id}: {e}") 

1571 finally: 

1572 # Pubsub cleanup first - use timeouts to prevent blocking 

1573 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout 

1574 try: 

1575 await asyncio.wait_for(pubsub.unsubscribe(session_id), timeout=cleanup_timeout) 

1576 except asyncio.TimeoutError: 

1577 logger.debug(f"Pubsub unsubscribe timed out for session {session_id}") 

1578 except Exception as e: 

1579 logger.debug(f"Error unsubscribing pubsub for session {session_id}: {e}") 

1580 try: 

1581 try: 

1582 await asyncio.wait_for(pubsub.aclose(), timeout=cleanup_timeout) 

1583 except AttributeError: 

1584 await asyncio.wait_for(pubsub.close(), timeout=cleanup_timeout) 

1585 except asyncio.TimeoutError: 

1586 logger.debug(f"Pubsub close timed out for session {session_id}") 

1587 except Exception as e: 

1588 logger.debug(f"Error closing pubsub for session {session_id}: {e}") 

1589 logger.info(f"Cleaned up pubsub for session {session_id}") 

1590 # Clean up task reference LAST (idempotent - may already be removed by _cancel_respond_task) 

1591 self._respond_tasks.pop(session_id, None) 

1592 

1593 elif self._backend == "database": 

1594 

1595 def _db_read_session_and_message( 

1596 session_id: str, 

1597 ) -> tuple[SessionRecord | None, SessionMessageRecord | None]: 

1598 """ 

1599 Check whether a session exists and retrieve its next pending message 

1600 in a single database query. 

1601 

1602 This function performs a LEFT OUTER JOIN between SessionRecord and 

1603 SessionMessageRecord to determine: 

1604 

1605 - Whether the session still exists 

1606 - Whether there is a pending message for the session (FIFO order) 

1607 

1608 It is used by the database-backed message polling loop to reduce 

1609 database load by collapsing multiple reads into a single query. 

1610 

1611 Messages are returned in FIFO order based on the message primary key. 

1612 

1613 This function is designed to be run in a thread executor to avoid 

1614 blocking the async event loop during database access. 

1615 

1616 Args: 

1617 session_id: The session identifier to look up. 

1618 

1619 Returns: 

1620 Tuple[SessionRecord | None, SessionMessageRecord | None]: 

1621 

1622 - (None, None) 

1623 The session does not exist. 

1624 

1625 - (SessionRecord, None) 

1626 The session exists but has no pending messages. 

1627 

1628 - (SessionRecord, SessionMessageRecord) 

1629 The session exists and has a pending message. 

1630 

1631 Raises: 

1632 Exception: Any database error is re-raised after rollback. 

1633 

1634 Examples: 

1635 >>> # This function is called internally by message_check_loop() 

1636 >>> # Session exists and has a pending message 

1637 >>> # Returns (SessionRecord, SessionMessageRecord) 

1638 

1639 >>> # Session exists but has no pending messages 

1640 >>> # Returns (SessionRecord, None) 

1641 

1642 >>> # Session has been removed 

1643 >>> # Returns (None, None) 

1644 """ 

1645 db_session = next(get_db()) 

1646 try: 

1647 result = ( 

1648 db_session.query(SessionRecord, SessionMessageRecord) 

1649 .outerjoin( 

1650 SessionMessageRecord, 

1651 SessionMessageRecord.session_id == SessionRecord.session_id, 

1652 ) 

1653 .filter(SessionRecord.session_id == session_id) 

1654 .order_by(SessionMessageRecord.id.asc()) 

1655 .first() 

1656 ) 

1657 if not result: 

1658 return None, None 

1659 session, message = result 

1660 return session, message 

1661 except Exception as ex: 

1662 db_session.rollback() 

1663 raise ex 

1664 finally: 

1665 db_session.close() 

1666 

1667 def _db_remove(session_id: str, message: str) -> None: 

1668 """Remove processed message from the database. 

1669 

1670 Deletes a specific message record after it has been successfully 

1671 processed and sent to the transport. This prevents duplicate 

1672 message delivery. 

1673 

1674 This inner function is designed to be run in a thread executor 

1675 to avoid blocking the async event loop during database deletes. 

1676 

1677 Args: 

1678 session_id: The session identifier the message belongs to. 

1679 message: The exact message content to remove (must match exactly). 

1680 

1681 Raises: 

1682 Exception: Any database error is re-raised after rollback. 

1683 

1684 Examples: 

1685 >>> # This function is called internally after message processing 

1686 >>> # Deletes the specific SessionMessageRecord entry 

1687 >>> # Log: "Removed message from mcp_messages table" 

1688 """ 

1689 db_session = next(get_db()) 

1690 try: 

1691 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete() 

1692 db_session.commit() 

1693 logger.info("Removed message from mcp_messages table") 

1694 except Exception as ex: 

1695 db_session.rollback() 

1696 raise ex 

1697 finally: 

1698 db_session.close() 

1699 

1700 async def message_check_loop(session_id: str) -> None: 

1701 """ 

1702 Background task that polls the database for messages belonging to a session 

1703 using adaptive polling with exponential backoff. 

1704 

1705 The loop continues until the session is removed from the database. 

1706 

1707 Behavior: 

1708 - Starts with a fast polling interval for low-latency message delivery. 

1709 - When no message is found, the polling interval increases exponentially 

1710 (up to a configured maximum) to reduce database load. 

1711 - When a message is received, the polling interval is immediately reset 

1712 to the fast interval. 

1713 - The loop exits as soon as the session no longer exists. 

1714 

1715 Polling rules: 

1716 - Message found → process message, reset polling interval. 

1717 - No message → increase polling interval (backoff). 

1718 - Session gone → stop polling immediately. 

1719 

1720 Args: 

1721 session_id (str): Unique identifier of the session to monitor. 

1722 

1723 Raises: 

1724 asyncio.CancelledError: When the polling loop is cancelled. 

1725 

1726 Examples 

1727 -------- 

1728 Adaptive backoff when no messages are present: 

1729 

1730 >>> poll_interval = 0.1 

1731 >>> backoff_factor = 1.5 

1732 >>> max_interval = 5.0 

1733 >>> poll_interval = min(poll_interval * backoff_factor, max_interval) 

1734 >>> poll_interval 

1735 0.15000000000000002 

1736 

1737 Backoff continues until the maximum interval is reached: 

1738 

1739 >>> poll_interval = 4.0 

1740 >>> poll_interval = min(poll_interval * 1.5, 5.0) 

1741 >>> poll_interval 

1742 5.0 

1743 

1744 Polling interval resets immediately when a message arrives: 

1745 

1746 >>> poll_interval = 2.0 

1747 >>> poll_interval = 0.1 

1748 >>> poll_interval 

1749 0.1 

1750 

1751 Session termination stops polling: 

1752 

1753 >>> session_exists = False 

1754 >>> if not session_exists: 

1755 ... "polling stopped" 

1756 'polling stopped' 

1757 """ 

1758 

1759 poll_interval = settings.poll_interval # start fast 

1760 max_interval = settings.max_interval # cap at configured maximum 

1761 backoff_factor = settings.backoff_factor 

1762 try: 

1763 while True: 

1764 # Check if session is closing before querying DB 

1765 if session_id in self._closing_sessions: 

1766 logger.debug("Session %s closing, stopping poll loop early", session_id) 

1767 break 

1768 

1769 session, record = await asyncio.to_thread(_db_read_session_and_message, session_id) 

1770 

1771 # session gone → stop polling 

1772 if not session: 

1773 logger.debug("Session %s no longer exists, stopping poll loop", session_id) 

1774 break 

1775 

1776 if record: 

1777 poll_interval = settings.poll_interval # reset on activity 

1778 

1779 data = orjson.loads(record.message) 

1780 if isinstance(data, dict) and "message" in data: 

1781 message = data["message"] 

1782 else: 

1783 message = data 

1784 

1785 transport = self.get_session_sync(session_id) 

1786 if transport: 

1787 logger.info("Ready to respond") 

1788 await self.generate_response( 

1789 message=message, 

1790 transport=transport, 

1791 server_id=server_id, 

1792 user=user, 

1793 ) 

1794 

1795 await asyncio.to_thread(_db_remove, session_id, record.message) 

1796 else: 

1797 # no message → backoff 

1798 # update polling interval with backoff factor 

1799 poll_interval = min(poll_interval * backoff_factor, max_interval) 

1800 

1801 await asyncio.sleep(poll_interval) 

1802 except asyncio.CancelledError: 

1803 logger.info(f"Message check loop cancelled for session {session_id}") 

1804 raise # Re-raise to properly complete cancellation 

1805 except Exception as e: 

1806 logger.error(f"Message check loop error for session {session_id}: {e}") 

1807 

1808 # CRITICAL: Await instead of fire-and-forget 

1809 # This ensures CancelledError propagates from outer respond() task to inner loop 

1810 # The outer task (registered from main.py) now runs until message_check_loop exits 

1811 try: 

1812 await message_check_loop(session_id) 

1813 except asyncio.CancelledError: 

1814 logger.info(f"Database respond cancelled for session {session_id}") 

1815 raise 

1816 finally: 

1817 # Clean up task reference on ANY exit (normal, cancelled, or error) 

1818 # Prevents stale done tasks from accumulating in _respond_tasks 

1819 self._respond_tasks.pop(session_id, None) 

1820 

1821 async def _refresh_redis_sessions(self) -> None: 

1822 """Refresh TTLs for Redis sessions and clean up disconnected sessions. 

1823 

1824 This internal method is used by the Redis backend to maintain session state. 

1825 It checks all local sessions, refreshes TTLs for connected sessions, and 

1826 removes disconnected ones. 

1827 """ 

1828 if not self._redis: 

1829 return 

1830 try: 

1831 # Check all local sessions 

1832 local_transports = {} 

1833 async with self._lock: 

1834 local_transports = self._sessions.copy() 

1835 

1836 for session_id, transport in local_transports.items(): 

1837 try: 

1838 if await transport.is_connected(): 

1839 # Refresh TTL in Redis 

1840 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl) 

1841 else: 

1842 # Remove disconnected session 

1843 await self.remove_session(session_id) 

1844 except Exception as e: 

1845 logger.error(f"Error refreshing session {session_id}: {e}") 

1846 

1847 except Exception as e: 

1848 logger.error(f"Error in Redis session refresh: {e}") 

1849 

1850 async def _db_cleanup_task(self) -> None: 

1851 """Background task to clean up expired database sessions. 

1852 

1853 Runs periodically (every 5 minutes) to remove expired sessions from the 

1854 database and refresh timestamps for active sessions. This prevents the 

1855 database from accumulating stale session records. 

1856 

1857 The task also verifies that local sessions still exist in the database 

1858 and removes them locally if they've been deleted elsewhere. 

1859 

1860 Raises: 

1861 asyncio.CancelledError: If the task is cancelled during shutdown. 

1862 """ 

1863 logger.info("Starting database cleanup task") 

1864 while True: 

1865 try: 

1866 # Clean up expired sessions every 5 minutes 

1867 def _db_cleanup() -> int: 

1868 """Remove expired sessions from the database. 

1869 

1870 Deletes all SessionRecord entries that haven't been accessed 

1871 within the session TTL period. Uses database-specific date 

1872 arithmetic to calculate expiry time. 

1873 

1874 This inner function is designed to be run in a thread executor 

1875 to avoid blocking the async event loop during bulk deletes. 

1876 

1877 Returns: 

1878 int: Number of expired session records deleted. 

1879 

1880 Raises: 

1881 Exception: Any database error is re-raised after rollback. 

1882 

1883 Examples: 

1884 >>> # This function is called periodically by _db_cleanup_task() 

1885 >>> # Deletes sessions older than session_ttl seconds 

1886 >>> # Returns count of deleted records for logging 

1887 >>> # Log: "Cleaned up 5 expired database sessions" 

1888 """ 

1889 db_session = next(get_db()) 

1890 try: 

1891 # Delete sessions that haven't been accessed for TTL seconds 

1892 # Use Python datetime for database-agnostic expiry calculation 

1893 expiry_time = datetime.now(timezone.utc) - timedelta(seconds=self._session_ttl) 

1894 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete() 

1895 db_session.commit() 

1896 return result 

1897 except Exception as ex: 

1898 db_session.rollback() 

1899 raise ex 

1900 finally: 

1901 db_session.close() 

1902 

1903 deleted = await asyncio.to_thread(_db_cleanup) 

1904 if deleted > 0: 

1905 logger.info(f"Cleaned up {deleted} expired database sessions") 

1906 

1907 # Check local sessions against database 

1908 await self._cleanup_database_sessions() 

1909 

1910 await asyncio.sleep(300) # Run every 5 minutes 

1911 

1912 except asyncio.CancelledError: 

1913 logger.info("Database cleanup task cancelled") 

1914 raise 

1915 except Exception as e: 

1916 logger.error(f"Error in database cleanup task: {e}") 

1917 await asyncio.sleep(600) # Sleep longer on error 

1918 

1919 def _refresh_session_db(self, session_id: str) -> bool: 

1920 """Update session's last accessed timestamp in the database. 

1921 

1922 Refreshes the last_accessed field for an active session to 

1923 prevent it from being cleaned up as expired. This is called 

1924 periodically for all local sessions with active transports. 

1925 

1926 Args: 

1927 session_id: The session identifier to refresh. 

1928 

1929 Returns: 

1930 bool: True if the session was found and updated, False if not found. 

1931 

1932 Raises: 

1933 Exception: Any database error is re-raised after rollback. 

1934 """ 

1935 db_session = next(get_db()) 

1936 try: 

1937 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first() 

1938 if session: 

1939 session.last_accessed = func.now() # pylint: disable=not-callable 

1940 db_session.commit() 

1941 return True 

1942 return False 

1943 except Exception as ex: 

1944 db_session.rollback() 

1945 raise ex 

1946 finally: 

1947 db_session.close() 

1948 

1949 async def _cleanup_database_sessions(self, max_concurrent: int = 20) -> None: 

1950 """Parallelize session cleanup with bounded concurrency. 

1951 

1952 Checks connection status first (fast), then refreshes connected sessions 

1953 in parallel using asyncio.gather() with a semaphore to limit concurrent 

1954 DB operations and prevent resource exhaustion. 

1955 

1956 Args: 

1957 max_concurrent: Maximum number of concurrent DB refresh operations. 

1958 Defaults to 20 to balance parallelism with resource usage. 

1959 """ 

1960 async with self._lock: 

1961 local_transports = self._sessions.copy() 

1962 

1963 # Check connections first (fast) 

1964 connected: list[str] = [] 

1965 for session_id, transport in local_transports.items(): 

1966 try: 

1967 if not await transport.is_connected(): 

1968 await self.remove_session(session_id) 

1969 else: 

1970 connected.append(session_id) 

1971 except Exception as e: 

1972 # Only log error, don't remove session on transient errors 

1973 logger.error(f"Error checking connection for session {session_id}: {e}") 

1974 

1975 # Parallel refresh of connected sessions with bounded concurrency 

1976 if connected: 

1977 semaphore = asyncio.Semaphore(max_concurrent) 

1978 

1979 async def bounded_refresh(session_id: str) -> bool: 

1980 """Refresh session with semaphore-bounded concurrency. 

1981 

1982 Args: 

1983 session_id: The session ID to refresh. 

1984 

1985 Returns: 

1986 True if refresh succeeded, False otherwise. 

1987 """ 

1988 async with semaphore: 

1989 return await asyncio.to_thread(self._refresh_session_db, session_id) 

1990 

1991 refresh_tasks = [bounded_refresh(session_id) for session_id in connected] 

1992 results = await asyncio.gather(*refresh_tasks, return_exceptions=True) 

1993 

1994 for session_id, result in zip(connected, results): 

1995 try: 

1996 if isinstance(result, Exception): 

1997 # Only log error, don't remove session on transient DB errors 

1998 logger.error(f"Error refreshing session {session_id}: {result}") 

1999 elif not result: 

2000 # Session no longer in database, remove locally 

2001 await self.remove_session(session_id) 

2002 except Exception as e: 

2003 logger.error(f"Error processing refresh result for session {session_id}: {e}") 

2004 

2005 async def _memory_cleanup_task(self) -> None: 

2006 """Background task to clean up disconnected sessions in memory backend. 

2007 

2008 Runs periodically (every minute) to check all local sessions and remove 

2009 those that are no longer connected. This prevents memory leaks from 

2010 accumulating disconnected transport objects. 

2011 

2012 Raises: 

2013 asyncio.CancelledError: If the task is cancelled during shutdown. 

2014 """ 

2015 logger.info("Starting memory cleanup task") 

2016 while True: 

2017 try: 

2018 # Check all local sessions 

2019 local_transports = {} 

2020 async with self._lock: 

2021 local_transports = self._sessions.copy() 

2022 

2023 for session_id, transport in local_transports.items(): 

2024 try: 

2025 if not await transport.is_connected(): 

2026 await self.remove_session(session_id) 

2027 except Exception as e: 

2028 logger.error(f"Error checking session {session_id}: {e}") 

2029 await self.remove_session(session_id) 

2030 

2031 await asyncio.sleep(60) # Run every minute 

2032 

2033 except asyncio.CancelledError: 

2034 logger.info("Memory cleanup task cancelled") 

2035 raise 

2036 except Exception as e: 

2037 logger.error(f"Error in memory cleanup task: {e}") 

2038 await asyncio.sleep(300) # Sleep longer on error 

2039 

2040 def _get_oauth_experimental_config(self, server_id: str) -> Optional[Dict[str, Dict[str, Any]]]: 

2041 """Query OAuth configuration for a server (synchronous, run in threadpool). 

2042 

2043 This method queries the database for OAuth configuration and returns 

2044 RFC 9728-safe fields for advertising in MCP capabilities. 

2045 

2046 Args: 

2047 server_id: The server ID to query OAuth configuration for. 

2048 

2049 Returns: 

2050 Dict with 'oauth' key containing safe OAuth config, or None if not configured. 

2051 """ 

2052 # First-Party 

2053 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel 

2054 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel 

2055 

2056 db = SessionLocal() 

2057 try: 

2058 server = db.get(DbServer, server_id) 

2059 if server and getattr(server, "oauth_enabled", False) and getattr(server, "oauth_config", None): 

2060 # Filter oauth_config to RFC 9728-safe fields only (never expose secrets) 

2061 oauth_config = server.oauth_config 

2062 safe_oauth: Dict[str, Any] = {} 

2063 

2064 # Extract authorization servers 

2065 if oauth_config.get("authorization_servers"): 

2066 safe_oauth["authorization_servers"] = oauth_config["authorization_servers"] 

2067 elif oauth_config.get("authorization_server"): 

2068 safe_oauth["authorization_servers"] = [oauth_config["authorization_server"]] 

2069 

2070 # Extract scopes 

2071 scopes = oauth_config.get("scopes_supported") or oauth_config.get("scopes") 

2072 if scopes: 

2073 safe_oauth["scopes_supported"] = scopes 

2074 

2075 # Add bearer methods 

2076 safe_oauth["bearer_methods_supported"] = oauth_config.get("bearer_methods_supported", ["header"]) 

2077 

2078 if safe_oauth.get("authorization_servers"): 

2079 logger.debug(f"Advertising OAuth capability for server {server_id}") 

2080 return {"oauth": safe_oauth} 

2081 return None 

2082 finally: 

2083 db.close() 

2084 

2085 # Handle initialize logic 

2086 async def handle_initialize_logic(self, body: Dict[str, Any], session_id: Optional[str] = None, server_id: Optional[str] = None) -> InitializeResult: 

2087 """Process MCP protocol initialization request. 

2088 

2089 Validates the protocol version and returns server capabilities and information. 

2090 This method implements the MCP (Model Context Protocol) initialization handshake. 

2091 

2092 Args: 

2093 body: Request body containing protocol_version and optional client_info. 

2094 Expected keys: 'protocol_version' or 'protocolVersion', 'capabilities'. 

2095 session_id: Optional session ID to associate client capabilities with. 

2096 server_id: Optional server ID to query OAuth configuration for RFC 9728 support. 

2097 

2098 Returns: 

2099 InitializeResult containing protocol version, server capabilities, and server info. 

2100 

2101 Raises: 

2102 HTTPException: If protocol_version is missing (400 Bad Request with MCP error code -32002). 

2103 

2104 Examples: 

2105 >>> import asyncio 

2106 >>> from mcpgateway.cache.session_registry import SessionRegistry 

2107 >>> 

2108 >>> reg = SessionRegistry() 

2109 >>> body = {'protocol_version': '2025-06-18'} 

2110 >>> result = asyncio.run(reg.handle_initialize_logic(body)) 

2111 >>> result.protocol_version 

2112 '2025-06-18' 

2113 >>> result.server_info.name 

2114 'ContextForge' 

2115 >>> 

2116 >>> # Missing protocol version 

2117 >>> try: 

2118 ... asyncio.run(reg.handle_initialize_logic({})) 

2119 ... except HTTPException as e: 

2120 ... e.status_code 

2121 400 

2122 """ 

2123 protocol_version = body.get("protocol_version") or body.get("protocolVersion") 

2124 client_capabilities = body.get("capabilities", {}) 

2125 # body.get("client_info") or body.get("clientInfo", {}) 

2126 

2127 if not protocol_version: 

2128 raise HTTPException( 

2129 status_code=status.HTTP_400_BAD_REQUEST, 

2130 detail="Missing protocol version", 

2131 headers={"MCP-Error-Code": "-32002"}, 

2132 ) 

2133 

2134 if protocol_version != settings.protocol_version: 

2135 logger.warning(f"Using non default protocol version: {protocol_version}") 

2136 

2137 # Store client capabilities if session_id provided 

2138 if session_id and client_capabilities: 

2139 await self.store_client_capabilities(session_id, client_capabilities) 

2140 logger.debug(f"Stored capabilities for session {session_id}: {client_capabilities}") 

2141 

2142 # Build experimental capabilities (including OAuth if configured) 

2143 experimental: Optional[Dict[str, Dict[str, Any]]] = None 

2144 

2145 # Query OAuth configuration if server_id is provided 

2146 if server_id: 

2147 try: 

2148 # Run synchronous DB query in threadpool to avoid blocking the event loop 

2149 experimental = await asyncio.to_thread(self._get_oauth_experimental_config, server_id) 

2150 except Exception as e: 

2151 logger.warning(f"Failed to query OAuth config for server {server_id}: {e}") 

2152 

2153 return InitializeResult( 

2154 protocolVersion=protocol_version, 

2155 capabilities=ServerCapabilities( 

2156 prompts={"listChanged": True}, 

2157 resources={"subscribe": True, "listChanged": True}, 

2158 tools={"listChanged": True}, 

2159 logging={}, 

2160 completions={}, # Advertise completions capability per MCP spec 

2161 experimental=experimental, # OAuth capability when configured 

2162 ), 

2163 serverInfo=Implementation(name=settings.app_name, version=__version__), 

2164 instructions=("ContextForge providing federated tools, resources and prompts. Use /admin interface for configuration."), 

2165 ) 

2166 

2167 async def store_client_capabilities(self, session_id: str, capabilities: Dict[str, Any]) -> None: 

2168 """Store client capabilities for a session. 

2169 

2170 Args: 

2171 session_id: The session ID 

2172 capabilities: Client capabilities dictionary from initialize request 

2173 """ 

2174 async with self._lock: 

2175 self._client_capabilities[session_id] = capabilities 

2176 logger.debug(f"Stored capabilities for session {session_id}") 

2177 

2178 async def get_client_capabilities(self, session_id: str) -> Optional[Dict[str, Any]]: 

2179 """Get client capabilities for a session. 

2180 

2181 Args: 

2182 session_id: The session ID 

2183 

2184 Returns: 

2185 Client capabilities dictionary, or None if not found 

2186 """ 

2187 async with self._lock: 

2188 return self._client_capabilities.get(session_id) 

2189 

2190 async def has_elicitation_capability(self, session_id: str) -> bool: 

2191 """Check if a session has elicitation capability. 

2192 

2193 Args: 

2194 session_id: The session ID 

2195 

2196 Returns: 

2197 True if session supports elicitation, False otherwise 

2198 """ 

2199 capabilities = await self.get_client_capabilities(session_id) 

2200 if not capabilities: 

2201 return False 

2202 # Check if elicitation capability exists in client capabilities 

2203 return bool(capabilities.get("elicitation")) 

2204 

2205 async def get_elicitation_capable_sessions(self) -> list[str]: 

2206 """Get list of session IDs that support elicitation. 

2207 

2208 Returns: 

2209 List of session IDs with elicitation capability 

2210 """ 

2211 async with self._lock: 

2212 capable_sessions = [] 

2213 for session_id, capabilities in self._client_capabilities.items(): 

2214 if capabilities.get("elicitation"): 

2215 # Verify session still exists 

2216 if session_id in self._sessions: 

2217 capable_sessions.append(session_id) 

2218 return capable_sessions 

2219 

2220 async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any]) -> None: 

2221 """Generate and send response for incoming MCP protocol message. 

2222 

2223 Processes MCP protocol messages and generates appropriate responses based on 

2224 the method. Supports various MCP methods including initialization, tool/resource/prompt 

2225 listing, tool invocation, and ping. 

2226 

2227 Uses loopback (127.0.0.1) for the internal RPC call to avoid issues when 

2228 the gateway is behind a reverse proxy or service mesh where the client-facing 

2229 URL is not reachable from the server itself. 

2230 

2231 Args: 

2232 message: Incoming MCP message as JSON. Must contain 'method' and 'id' fields. 

2233 transport: SSE transport to send responses through. 

2234 server_id: Optional server ID for scoped operations. 

2235 user: User information containing authentication token. 

2236 

2237 Examples: 

2238 >>> import asyncio 

2239 >>> from mcpgateway.cache.session_registry import SessionRegistry 

2240 >>> 

2241 >>> class MockTransport: 

2242 ... async def send_message(self, msg): 

2243 ... print(f"Response: {msg['method'] if 'method' in msg else msg.get('result', {})}") 

2244 >>> 

2245 >>> reg = SessionRegistry() 

2246 >>> transport = MockTransport() 

2247 >>> message = {"method": "ping", "id": 1} 

2248 >>> user = {"token": "test-token"} 

2249 >>> # asyncio.run(reg.generate_response(message, transport, None, user)) 

2250 >>> # Response: {} 

2251 """ 

2252 result = {} 

2253 

2254 if "method" in message and "id" in message: 

2255 method = message["method"] 

2256 params = message.get("params", {}) 

2257 params["server_id"] = server_id 

2258 req_id = message["id"] 

2259 

2260 rpc_input = { 

2261 "jsonrpc": "2.0", 

2262 "method": method, 

2263 "params": params, 

2264 "id": req_id, 

2265 } 

2266 # Get the token from the current authentication context 

2267 # The user object should contain auth_token, token_teams, and is_admin from the SSE endpoint 

2268 token = None 

2269 is_admin = user.get("is_admin", False) # Preserve admin status from SSE endpoint 

2270 

2271 try: 

2272 if hasattr(user, "get") and user.get("auth_token"): 

2273 token = user["auth_token"] 

2274 else: 

2275 # Fallback: create lightweight session token (teams resolved server-side by downstream /rpc) 

2276 logger.warning("No auth token available for SSE RPC call - creating fallback session token") 

2277 now = datetime.now(timezone.utc) 

2278 payload = { 

2279 "sub": user.get("email", "system"), 

2280 "iss": settings.jwt_issuer, 

2281 "aud": settings.jwt_audience, 

2282 "iat": int(now.timestamp()), 

2283 "jti": str(uuid.uuid4()), 

2284 "token_use": "session", # nosec B105 - token type marker, not a password 

2285 "user": { 

2286 "email": user.get("email", "system"), 

2287 "full_name": user.get("full_name", "System"), 

2288 "is_admin": is_admin, # Preserve admin status for cookie-authenticated admins 

2289 "auth_provider": "internal", 

2290 }, 

2291 } 

2292 # Generate token using centralized token creation 

2293 token = await create_jwt_token(payload) 

2294 

2295 # Pass downstream session id to /rpc for session affinity. 

2296 # This is gateway-internal only; the pool strips it before contacting upstream MCP servers. 

2297 if settings.mcpgateway_session_affinity_enabled: 

2298 await self._register_session_mapping(transport.session_id, message, user.get("email") if hasattr(user, "get") else None) 

2299 

2300 headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} 

2301 if settings.mcpgateway_session_affinity_enabled: 

2302 headers["x-mcp-session-id"] = transport.session_id 

2303 # Use loopback for internal RPC call (consistent with other self-call sites 

2304 # in mcp_session_pool.py and streamablehttp_transport.py). This avoids 

2305 # failures when the client-facing URL is not reachable from the server 

2306 # (e.g., behind a reverse proxy or service mesh). See #3049. 

2307 rpc_url = f"http://127.0.0.1:{settings.port}/rpc" 

2308 

2309 logger.info(f"SSE RPC: Making call to {rpc_url} with method={method}, params={params}") 

2310 

2311 async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client: 

2312 logger.info(f"SSE RPC: Sending request to {rpc_url}") 

2313 rpc_response = await client.post( 

2314 url=rpc_url, 

2315 json=rpc_input, 

2316 headers=headers, 

2317 ) 

2318 logger.info(f"SSE RPC: Got response status {rpc_response.status_code}") 

2319 result = rpc_response.json() 

2320 logger.info(f"SSE RPC: Response content: {result}") 

2321 result = result.get("result", {}) 

2322 

2323 response = {"jsonrpc": "2.0", "result": result, "id": req_id} 

2324 except JSONRPCError as e: 

2325 logger.error(f"SSE RPC: JSON-RPC error: {e}") 

2326 result = e.to_dict() 

2327 response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id} 

2328 except Exception as e: 

2329 logger.error(f"SSE RPC: Exception during RPC call: {type(e).__name__}: {e}") 

2330 logger.error(f"SSE RPC: Traceback: {traceback.format_exc()}") 

2331 result = {"code": -32000, "message": "Internal error", "data": str(e)} 

2332 response = {"jsonrpc": "2.0", "error": result, "id": req_id} 

2333 

2334 logging.debug(f"Sending sse message:{response}") 

2335 await transport.send_message(response) 

2336 

2337 if message["method"] == "initialize": 

2338 await transport.send_message( 

2339 { 

2340 "jsonrpc": "2.0", 

2341 "method": "notifications/initialized", 

2342 "params": {}, 

2343 } 

2344 ) 

2345 notifications = [ 

2346 "tools/list_changed", 

2347 "resources/list_changed", 

2348 "prompts/list_changed", 

2349 ] 

2350 for notification in notifications: 

2351 await transport.send_message( 

2352 { 

2353 "jsonrpc": "2.0", 

2354 "method": f"notifications/{notification}", 

2355 "params": {}, 

2356 } 

2357 )