Coverage for mcpgateway / cache / session_registry.py: 100%
758 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/session_registry.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Session Registry with optional distributed state.
8This module provides a registry for SSE sessions with support for distributed deployment
9using Redis or SQLAlchemy as optional backends for shared state between workers.
11The SessionRegistry class manages server-sent event (SSE) sessions across multiple
12worker processes, enabling horizontal scaling of MCP gateway deployments. It supports
13three backend modes:
15- **memory**: In-memory storage for single-process deployments (default)
16- **redis**: Redis-backed shared storage for multi-worker deployments
17- **database**: SQLAlchemy-backed shared storage using any supported database
19In distributed mode (redis/database), session existence is tracked in the shared
20backend while transport objects remain local to each worker process. This allows
21workers to know about sessions on other workers and route messages appropriately.
23Examples:
24 Basic usage with memory backend:
26 >>> from mcpgateway.cache.session_registry import SessionRegistry
27 >>> class DummyTransport:
28 ... async def disconnect(self):
29 ... pass
30 ... async def is_connected(self):
31 ... return True
32 >>> import asyncio
33 >>> reg = SessionRegistry(backend='memory')
34 >>> transport = DummyTransport()
35 >>> asyncio.run(reg.add_session('sid123', transport))
36 >>> found = asyncio.run(reg.get_session('sid123'))
37 >>> isinstance(found, DummyTransport)
38 True
39 >>> asyncio.run(reg.remove_session('sid123'))
40 >>> asyncio.run(reg.get_session('sid123')) is None
41 True
43 Broadcasting messages:
45 >>> reg = SessionRegistry(backend='memory')
46 >>> asyncio.run(reg.broadcast('sid123', {'method': 'ping', 'id': 1}))
47 >>> reg._session_message is not None
48 True
49"""
51# Standard
52import asyncio
53from asyncio import Task
54from datetime import datetime, timedelta, timezone
55import logging
56import time
57import traceback
58from typing import Any, Dict, Optional
59from urllib.parse import urlparse
60import uuid
62# Third-Party
63from fastapi import HTTPException, status
64import orjson
66# First-Party
67from mcpgateway import __version__
68from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities
69from mcpgateway.config import settings
70from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord
71from mcpgateway.services import PromptService, ResourceService, ToolService
72from mcpgateway.services.logging_service import LoggingService
73from mcpgateway.transports import SSETransport
74from mcpgateway.utils.create_jwt_token import create_jwt_token
75from mcpgateway.utils.redis_client import get_redis_client
76from mcpgateway.utils.retry_manager import ResilientHttpClient
77from mcpgateway.validation.jsonrpc import JSONRPCError
79# Initialize logging service first
80logging_service: LoggingService = LoggingService()
81logger = logging_service.get_logger(__name__)
83tool_service: ToolService = ToolService()
84resource_service: ResourceService = ResourceService()
85prompt_service: PromptService = PromptService()
87try:
88 # Third-Party
89 from redis.asyncio import Redis
91 REDIS_AVAILABLE = True
92except ImportError:
93 REDIS_AVAILABLE = False
95try:
96 # Third-Party
97 from sqlalchemy import func
99 SQLALCHEMY_AVAILABLE = True
100except ImportError:
101 SQLALCHEMY_AVAILABLE = False
104class SessionBackend:
105 """Base class for session registry backend configuration.
107 This class handles the initialization and configuration of different backend
108 types for session storage. It validates backend requirements and sets up
109 necessary connections for Redis or database backends.
111 Attributes:
112 _backend: The backend type ('memory', 'redis', 'database', or 'none')
113 _session_ttl: Time-to-live for sessions in seconds
114 _message_ttl: Time-to-live for messages in seconds
115 _redis: Redis connection instance (redis backend only)
116 _pubsub: Redis pubsub instance (redis backend only)
117 _session_message: Temporary message storage (memory backend only)
119 Examples:
120 >>> backend = SessionBackend(backend='memory')
121 >>> backend._backend
122 'memory'
123 >>> backend._session_ttl
124 3600
126 >>> try:
127 ... backend = SessionBackend(backend='redis')
128 ... except ValueError as e:
129 ... str(e)
130 'Redis backend requires redis_url'
131 """
133 def __init__(
134 self,
135 backend: str = "memory",
136 redis_url: Optional[str] = None,
137 database_url: Optional[str] = None,
138 session_ttl: int = 3600, # 1 hour
139 message_ttl: int = 600, # 10 min
140 ):
141 """Initialize session backend configuration.
143 Args:
144 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'.
145 - 'memory': In-memory storage, suitable for single-process deployments
146 - 'redis': Redis-backed storage for multi-worker deployments
147 - 'database': SQLAlchemy-backed storage for multi-worker deployments
148 - 'none': No session tracking (dummy registry)
149 redis_url: Redis connection URL. Required when backend='redis'.
150 Format: 'redis://[:password]@host:port/db'
151 database_url: Database connection URL. Required when backend='database'.
152 Format depends on database type (e.g., 'postgresql://user:pass@host/db')
153 session_ttl: Session time-to-live in seconds. Sessions are automatically
154 cleaned up after this duration of inactivity. Default: 3600 (1 hour).
155 message_ttl: Message time-to-live in seconds. Undelivered messages are
156 removed after this duration. Default: 600 (10 minutes).
158 Raises:
159 ValueError: If backend is invalid, required URL is missing, or required packages are not installed.
161 Examples:
162 >>> # Memory backend (default)
163 >>> backend = SessionBackend()
164 >>> backend._backend
165 'memory'
167 >>> # Redis backend requires URL
168 >>> try:
169 ... backend = SessionBackend(backend='redis')
170 ... except ValueError as e:
171 ... 'redis_url' in str(e)
172 True
174 >>> # Invalid backend
175 >>> try:
176 ... backend = SessionBackend(backend='invalid')
177 ... except ValueError as e:
178 ... 'Invalid backend' in str(e)
179 True
180 """
182 self._backend = backend.lower()
183 self._session_ttl = session_ttl
184 self._message_ttl = message_ttl
186 # Set up backend-specific components
187 if self._backend == "memory":
188 # Nothing special needed for memory backend
189 self._session_message: dict[str, Any] | None = None
191 elif self._backend == "none":
192 # No session tracking - this is just a dummy registry
193 logger.info("Session registry initialized with 'none' backend - session tracking disabled")
195 elif self._backend == "redis":
196 if not REDIS_AVAILABLE:
197 raise ValueError("Redis backend requested but redis package not installed")
198 if not redis_url:
199 raise ValueError("Redis backend requires redis_url")
201 # Redis client is set in initialize() via the shared factory
202 self._redis: Optional[Redis] = None
203 self._pubsub = None
205 elif self._backend == "database":
206 if not SQLALCHEMY_AVAILABLE:
207 raise ValueError("Database backend requested but SQLAlchemy not installed")
208 if not database_url:
209 raise ValueError("Database backend requires database_url")
210 else:
211 raise ValueError(f"Invalid backend: {backend}")
214class SessionRegistry(SessionBackend):
215 """Registry for SSE sessions with optional distributed state.
217 This class manages server-sent event (SSE) sessions, providing methods to add,
218 remove, and query sessions. It supports multiple backend types for different
219 deployment scenarios:
221 - **Single-process deployments**: Use 'memory' backend (default)
222 - **Multi-worker deployments**: Use 'redis' or 'database' backend
223 - **Testing/development**: Use 'none' backend to disable session tracking
225 The registry maintains a local cache of transport objects while using the
226 shared backend to track session existence across workers. This enables
227 horizontal scaling while keeping transport objects process-local.
229 Attributes:
230 _sessions: Local dictionary mapping session IDs to transport objects
231 _lock: Asyncio lock for thread-safe access to _sessions
232 _cleanup_task: Background task for cleaning up expired sessions
234 Examples:
235 >>> import asyncio
236 >>> from mcpgateway.cache.session_registry import SessionRegistry
237 >>>
238 >>> class MockTransport:
239 ... async def disconnect(self):
240 ... print("Disconnected")
241 ... async def is_connected(self):
242 ... return True
243 ... async def send_message(self, msg):
244 ... print(f"Sent: {msg}")
245 >>>
246 >>> # Create registry and add session
247 >>> reg = SessionRegistry(backend='memory')
248 >>> transport = MockTransport()
249 >>> asyncio.run(reg.add_session('test123', transport))
250 >>>
251 >>> # Retrieve session
252 >>> found = asyncio.run(reg.get_session('test123'))
253 >>> found is transport
254 True
255 >>>
256 >>> # Remove session
257 >>> asyncio.run(reg.remove_session('test123'))
258 Disconnected
259 >>> asyncio.run(reg.get_session('test123')) is None
260 True
261 """
263 def __init__(
264 self,
265 backend: str = "memory",
266 redis_url: Optional[str] = None,
267 database_url: Optional[str] = None,
268 session_ttl: int = 3600, # 1 hour
269 message_ttl: int = 600, # 10 min
270 ):
271 """Initialize session registry with specified backend.
273 Args:
274 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'.
275 redis_url: Redis connection URL. Required when backend='redis'.
276 database_url: Database connection URL. Required when backend='database'.
277 session_ttl: Session time-to-live in seconds. Default: 3600.
278 message_ttl: Message time-to-live in seconds. Default: 600.
280 Examples:
281 >>> # Default memory backend
282 >>> reg = SessionRegistry()
283 >>> reg._backend
284 'memory'
285 >>> isinstance(reg._sessions, dict)
286 True
288 >>> # Redis backend with custom TTL
289 >>> try:
290 ... reg = SessionRegistry(
291 ... backend='redis',
292 ... redis_url='redis://localhost:6379',
293 ... session_ttl=7200
294 ... )
295 ... except ValueError:
296 ... pass # Redis may not be available
297 """
298 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl)
299 self._sessions: Dict[str, Any] = {} # Local transport cache
300 self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id
301 self._respond_tasks: Dict[str, asyncio.Task] = {} # Track respond tasks for cancellation
302 self._stuck_tasks: Dict[str, asyncio.Task] = {} # Tasks that couldn't be cancelled (for monitoring)
303 self._closing_sessions: set[str] = set() # Sessions being closed - respond loop should exit
304 self._lock = asyncio.Lock()
305 self._cleanup_task: Task | None = None
306 self._stuck_task_reaper: Task | None = None # Reaper for stuck tasks
308 def register_respond_task(self, session_id: str, task: asyncio.Task) -> None:
309 """Register a respond task for later cancellation.
311 Associates an asyncio Task with a session_id so it can be cancelled
312 when the session is removed. This prevents orphaned tasks that cause
313 CPU spin loops.
315 Args:
316 session_id: Session identifier the task belongs to.
317 task: The asyncio Task to track.
318 """
319 self._respond_tasks[session_id] = task
320 logger.debug(f"Registered respond task for session {session_id}")
322 async def _cancel_respond_task(self, session_id: str, timeout: float = 5.0) -> None:
323 """Cancel and await a respond task with timeout.
325 Safely cancels the respond task associated with a session. Uses a timeout
326 to prevent hanging if the task doesn't respond to cancellation.
328 If initial cancellation times out, escalates by force-disconnecting the
329 transport to unblock the task, then retries cancellation (Finding 1 fix).
331 Args:
332 session_id: Session identifier whose task should be cancelled.
333 timeout: Maximum seconds to wait for task cancellation. Default 5.0.
334 """
335 task = self._respond_tasks.get(session_id)
336 if task is None:
337 return
339 if task.done():
340 # Task already finished - safe to remove from tracking
341 self._respond_tasks.pop(session_id, None)
342 try:
343 task.result()
344 except asyncio.CancelledError:
345 pass
346 except Exception as e:
347 logger.warning(f"Respond task for {session_id} failed with: {e}")
348 return
350 task.cancel()
351 logger.debug(f"Cancelling respond task for session {session_id}")
353 try:
354 await asyncio.wait_for(task, timeout=timeout)
355 # Cancellation successful - remove from tracking
356 self._respond_tasks.pop(session_id, None)
357 logger.debug(f"Respond task cancelled for session {session_id}")
358 except asyncio.TimeoutError:
359 # ESCALATION (Finding 1): Force-disconnect transport to unblock the task
360 logger.warning(f"Respond task cancellation timed out for {session_id}, " f"escalating with transport disconnect")
362 # Force-disconnect the transport to unblock any pending I/O
363 transport = self._sessions.get(session_id)
364 if transport and hasattr(transport, "disconnect"):
365 try:
366 await transport.disconnect()
367 logger.debug(f"Force-disconnected transport for {session_id}")
368 except Exception as e:
369 logger.warning(f"Failed to force-disconnect transport for {session_id}: {e}")
371 # Retry cancellation with shorter timeout
372 if not task.done():
373 try:
374 await asyncio.wait_for(task, timeout=2.0)
375 self._respond_tasks.pop(session_id, None)
376 logger.info(f"Respond task cancelled after escalation for {session_id}")
377 except asyncio.TimeoutError:
378 # Still stuck - move to stuck_tasks for monitoring (Finding 2 fix)
379 self._respond_tasks.pop(session_id, None)
380 self._stuck_tasks[session_id] = task
381 logger.error(f"Respond task for {session_id} still stuck after escalation, " f"moved to stuck_tasks for monitoring (total stuck: {len(self._stuck_tasks)})")
382 except asyncio.CancelledError:
383 self._respond_tasks.pop(session_id, None)
384 logger.info(f"Respond task cancelled after escalation for {session_id}")
385 except Exception as e:
386 self._respond_tasks.pop(session_id, None)
387 logger.warning(f"Error during retry cancellation for {session_id}: {e}")
388 else:
389 self._respond_tasks.pop(session_id, None)
390 logger.debug(f"Respond task completed during escalation for {session_id}")
392 except asyncio.CancelledError:
393 # Cancellation successful - remove from tracking
394 self._respond_tasks.pop(session_id, None)
395 logger.debug(f"Respond task cancelled for session {session_id}")
396 except Exception as e:
397 # Remove from tracking on unexpected error - task state unknown
398 self._respond_tasks.pop(session_id, None)
399 logger.warning(f"Error during respond task cancellation for {session_id}: {e}")
401 async def _reap_stuck_tasks(self) -> None:
402 """Periodically clean up stuck tasks that have completed.
404 This reaper runs every 30 seconds and:
405 1. Removes completed tasks from _stuck_tasks
406 2. Retries cancellation for tasks that are still running
407 3. Logs warnings for tasks that remain stuck
409 This prevents memory leaks from tasks that eventually complete after
410 being moved to _stuck_tasks during escalation.
411 """
412 reap_interval = 30.0 # seconds
413 retry_timeout = 2.0 # seconds for retry cancellation
415 while True:
416 try:
417 await asyncio.sleep(reap_interval)
419 if not self._stuck_tasks:
420 continue
422 # Collect completed and still-stuck tasks
423 completed = []
424 still_stuck = []
426 for session_id, task in list(self._stuck_tasks.items()):
427 if task.done():
428 completed.append(session_id)
429 try:
430 task.result() # Consume result to avoid warnings
431 except (asyncio.CancelledError, Exception):
432 pass
433 else:
434 still_stuck.append((session_id, task))
436 # Remove completed tasks
437 for session_id in completed:
438 self._stuck_tasks.pop(session_id, None)
440 if completed:
441 logger.info(f"Reaped {len(completed)} completed stuck tasks")
443 # Retry cancellation for still-stuck tasks
444 for session_id, task in still_stuck:
445 task.cancel()
446 try:
447 await asyncio.wait_for(task, timeout=retry_timeout)
448 self._stuck_tasks.pop(session_id, None)
449 logger.info(f"Stuck task {session_id} finally cancelled during reap")
450 except asyncio.TimeoutError:
451 logger.warning(f"Task {session_id} still stuck after reap retry")
452 except asyncio.CancelledError:
453 self._stuck_tasks.pop(session_id, None)
454 logger.info(f"Stuck task {session_id} cancelled during reap")
455 except Exception as e:
456 logger.warning(f"Error during stuck task reap for {session_id}: {e}")
458 if self._stuck_tasks:
459 logger.warning(f"Stuck tasks remaining: {len(self._stuck_tasks)}")
461 except asyncio.CancelledError:
462 logger.debug("Stuck task reaper cancelled")
463 break
464 except Exception as e:
465 logger.error(f"Error in stuck task reaper: {e}")
467 async def initialize(self) -> None:
468 """Initialize the registry with async setup.
470 This method performs asynchronous initialization tasks that cannot be done
471 in __init__. It starts background cleanup tasks and sets up pubsub
472 subscriptions for distributed backends.
474 Call this during application startup after creating the registry instance.
476 Examples:
477 >>> import asyncio
478 >>> reg = SessionRegistry(backend='memory')
479 >>> asyncio.run(reg.initialize())
480 >>> reg._cleanup_task is not None
481 True
482 >>>
483 >>> # Cleanup
484 >>> asyncio.run(reg.shutdown())
485 """
486 logger.info(f"Initializing session registry with backend: {self._backend}")
488 if self._backend == "database":
489 # Start database cleanup task
490 self._cleanup_task = asyncio.create_task(self._db_cleanup_task())
491 logger.info("Database cleanup task started")
493 elif self._backend == "redis":
494 # Get shared Redis client from factory
495 self._redis = await get_redis_client()
496 if self._redis:
497 self._pubsub = self._redis.pubsub()
498 await self._pubsub.subscribe("mcp_session_events")
499 logger.info("Session registry connected to shared Redis client")
501 elif self._backend == "none":
502 # Nothing to initialize for none backend
503 pass
505 # Memory backend needs session cleanup
506 elif self._backend == "memory":
507 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task())
508 logger.info("Memory cleanup task started")
510 # Start stuck task reaper for all backends
511 self._stuck_task_reaper = asyncio.create_task(self._reap_stuck_tasks())
512 logger.info("Stuck task reaper started")
514 async def shutdown(self) -> None:
515 """Shutdown the registry and clean up resources.
517 This method cancels background tasks and closes connections to external
518 services. Call this during application shutdown to ensure clean termination.
520 Examples:
521 >>> import asyncio
522 >>> reg = SessionRegistry()
523 >>> asyncio.run(reg.initialize())
524 >>> task_was_created = reg._cleanup_task is not None
525 >>> asyncio.run(reg.shutdown())
526 >>> # After shutdown, cleanup task should be handled (cancelled or done)
527 >>> task_was_created and (reg._cleanup_task.cancelled() or reg._cleanup_task.done())
528 True
529 """
530 logger.info("Shutting down session registry")
532 # Cancel cleanup task
533 if self._cleanup_task:
534 self._cleanup_task.cancel()
535 try:
536 await self._cleanup_task
537 except asyncio.CancelledError:
538 pass
540 # Cancel stuck task reaper
541 if self._stuck_task_reaper:
542 self._stuck_task_reaper.cancel()
543 try:
544 await self._stuck_task_reaper
545 except asyncio.CancelledError:
546 pass
548 # CRITICAL: Cancel ALL respond tasks to prevent CPU spin loops
549 if self._respond_tasks:
550 logger.info(f"Cancelling {len(self._respond_tasks)} respond tasks")
551 tasks_to_cancel = list(self._respond_tasks.values())
552 self._respond_tasks.clear()
554 for task in tasks_to_cancel:
555 if not task.done():
556 task.cancel()
558 if tasks_to_cancel:
559 try:
560 await asyncio.wait_for(asyncio.gather(*tasks_to_cancel, return_exceptions=True), timeout=10.0)
561 logger.info("All respond tasks cancelled successfully")
562 except asyncio.TimeoutError:
563 logger.warning("Timeout waiting for respond tasks to cancel")
565 # Also cancel any stuck tasks (tasks that previously couldn't be cancelled)
566 if self._stuck_tasks:
567 logger.warning(f"Attempting final cancellation of {len(self._stuck_tasks)} stuck tasks")
568 stuck_to_cancel = list(self._stuck_tasks.values())
569 self._stuck_tasks.clear()
571 for task in stuck_to_cancel:
572 if not task.done():
573 task.cancel()
575 if stuck_to_cancel:
576 try:
577 await asyncio.wait_for(asyncio.gather(*stuck_to_cancel, return_exceptions=True), timeout=5.0)
578 logger.info("Stuck tasks cancelled during shutdown")
579 except asyncio.TimeoutError:
580 logger.error("Some stuck tasks could not be cancelled during shutdown")
582 # Close Redis pubsub (but not the shared client)
583 # Use timeout to prevent blocking if pubsub doesn't close cleanly
584 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout
585 if self._backend == "redis" and getattr(self, "_pubsub", None):
586 try:
587 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
588 except asyncio.TimeoutError:
589 logger.warning("Redis pubsub close timed out - proceeding anyway")
590 except Exception as e:
591 logger.error(f"Error closing Redis pubsub: {e}")
592 # Don't close self._redis - it's the shared client managed by redis_client.py
593 self._redis = None
594 self._pubsub = None
596 async def add_session(self, session_id: str, transport: SSETransport) -> None:
597 """Add a session to the registry.
599 Stores the session in both the local cache and the distributed backend
600 (if configured). For distributed backends, this notifies other workers
601 about the new session.
603 Args:
604 session_id: Unique session identifier. Should be a UUID or similar
605 unique string to avoid collisions.
606 transport: SSE transport object for this session. Must implement
607 the SSETransport interface.
609 Examples:
610 >>> import asyncio
611 >>> from mcpgateway.cache.session_registry import SessionRegistry
612 >>>
613 >>> class MockTransport:
614 ... async def disconnect(self):
615 ... print(f"Transport disconnected")
616 ... async def is_connected(self):
617 ... return True
618 >>>
619 >>> reg = SessionRegistry()
620 >>> transport = MockTransport()
621 >>> asyncio.run(reg.add_session('test-456', transport))
622 >>>
623 >>> # Found in local cache
624 >>> found = asyncio.run(reg.get_session('test-456'))
625 >>> found is transport
626 True
627 >>>
628 >>> # Remove session
629 >>> asyncio.run(reg.remove_session('test-456'))
630 Transport disconnected
631 """
632 # Skip for none backend
633 if self._backend == "none":
634 return
636 async with self._lock:
637 self._sessions[session_id] = transport
639 if self._backend == "redis":
640 # Store session marker in Redis
641 if not self._redis:
642 logger.warning(f"Redis client not initialized, skipping distributed session tracking for {session_id}")
643 return
644 try:
645 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1")
646 # Publish event to notify other workers
647 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()}))
648 except Exception as e:
649 logger.error(f"Redis error adding session {session_id}: {e}")
651 elif self._backend == "database":
652 # Store session in database
653 try:
655 def _db_add() -> None:
656 """Store session record in the database.
658 Creates a new SessionRecord entry in the database for tracking
659 distributed session state. Uses a fresh database connection from
660 the connection pool.
662 This inner function is designed to be run in a thread executor
663 to avoid blocking the async event loop during database I/O.
665 Raises:
666 Exception: Any database error is re-raised after rollback.
667 Common errors include duplicate session_id (unique constraint)
668 or database connection issues.
670 Examples:
671 >>> # This function is called internally by add_session()
672 >>> # When executed, it creates a database record:
673 >>> # SessionRecord(session_id='abc123', created_at=now())
674 """
675 db_session = next(get_db())
676 try:
677 session_record = SessionRecord(session_id=session_id)
678 db_session.add(session_record)
679 db_session.commit()
680 except Exception as ex:
681 db_session.rollback()
682 raise ex
683 finally:
684 db_session.close()
686 await asyncio.to_thread(_db_add)
687 except Exception as e:
688 logger.error(f"Database error adding session {session_id}: {e}")
690 logger.info(f"Added session: {session_id}")
692 async def get_session(self, session_id: str) -> Any:
693 """Get session transport by ID.
695 First checks the local cache for the transport object. If not found locally
696 but using a distributed backend, checks if the session exists on another
697 worker.
699 Args:
700 session_id: Session identifier to look up.
702 Returns:
703 SSETransport object if found locally, None if not found or exists
704 on another worker.
706 Examples:
707 >>> import asyncio
708 >>> from mcpgateway.cache.session_registry import SessionRegistry
709 >>>
710 >>> class MockTransport:
711 ... pass
712 >>>
713 >>> reg = SessionRegistry()
714 >>> transport = MockTransport()
715 >>> asyncio.run(reg.add_session('test-456', transport))
716 >>>
717 >>> # Found in local cache
718 >>> found = asyncio.run(reg.get_session('test-456'))
719 >>> found is transport
720 True
721 >>>
722 >>> # Not found
723 >>> asyncio.run(reg.get_session('nonexistent')) is None
724 True
725 """
726 # Skip for none backend
727 if self._backend == "none":
728 return None
730 # First check local cache
731 async with self._lock:
732 transport = self._sessions.get(session_id)
733 if transport:
734 logger.info(f"Session {session_id} exists in local cache")
735 return transport
737 # If not in local cache, check if it exists in shared backend
738 if self._backend == "redis":
739 if not self._redis:
740 return None
741 try:
742 exists = await self._redis.exists(f"mcp:session:{session_id}")
743 session_exists = bool(exists)
744 if session_exists:
745 logger.info(f"Session {session_id} exists in Redis but not in local cache")
746 return None # We don't have the transport locally
747 except Exception as e:
748 logger.error(f"Redis error checking session {session_id}: {e}")
749 return None
751 elif self._backend == "database":
752 try:
754 def _db_check() -> bool:
755 """Check if a session exists in the database.
757 Queries the SessionRecord table to determine if a session with
758 the given session_id exists. This is used when the session is not
759 found in the local cache to check if it exists on another worker.
761 This inner function is designed to be run in a thread executor
762 to avoid blocking the async event loop during database queries.
764 Returns:
765 bool: True if the session exists in the database, False otherwise.
767 Examples:
768 >>> # This function is called internally by get_session()
769 >>> # Returns True if SessionRecord with session_id exists
770 >>> # Returns False if no matching record found
771 """
772 db_session = next(get_db())
773 try:
774 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
775 return record is not None
776 finally:
777 db_session.close()
779 exists = await asyncio.to_thread(_db_check)
780 if exists:
781 logger.info(f"Session {session_id} exists in database but not in local cache")
782 return None
783 except Exception as e:
784 logger.error(f"Database error checking session {session_id}: {e}")
785 return None
787 return None
789 async def remove_session(self, session_id: str) -> None:
790 """Remove a session from the registry.
792 Removes the session from both local cache and distributed backend.
793 If a transport is found locally, it will be disconnected before removal.
794 For distributed backends, notifies other workers about the removal.
796 Args:
797 session_id: Session identifier to remove.
799 Examples:
800 >>> import asyncio
801 >>> from mcpgateway.cache.session_registry import SessionRegistry
802 >>>
803 >>> class MockTransport:
804 ... async def disconnect(self):
805 ... print(f"Transport disconnected")
806 ... async def is_connected(self):
807 ... return True
808 >>>
809 >>> reg = SessionRegistry()
810 >>> transport = MockTransport()
811 >>> asyncio.run(reg.add_session('remove-test', transport))
812 >>> asyncio.run(reg.remove_session('remove-test'))
813 Transport disconnected
814 >>>
815 >>> # Session no longer exists
816 >>> asyncio.run(reg.get_session('remove-test')) is None
817 True
818 """
819 # Skip for none backend
820 if self._backend == "none":
821 return
823 # Mark session as closing FIRST so respond loop can exit early
824 # This allows the loop to exit without waiting for cancellation to complete
825 self._closing_sessions.add(session_id)
827 try:
828 # CRITICAL: Cancel respond task before any cleanup
829 # This prevents orphaned tasks that cause CPU spin loops
830 await self._cancel_respond_task(session_id)
832 # Clean up local transport
833 transport = None
834 async with self._lock:
835 if session_id in self._sessions:
836 transport = self._sessions.pop(session_id)
837 # Also clean up client capabilities
838 if session_id in self._client_capabilities:
839 self._client_capabilities.pop(session_id)
840 logger.debug(f"Removed capabilities for session {session_id}")
841 finally:
842 # Always remove from closing set
843 self._closing_sessions.discard(session_id)
845 # Disconnect transport if found
846 if transport:
847 try:
848 await transport.disconnect()
849 except Exception as e:
850 logger.error(f"Error disconnecting transport for session {session_id}: {e}")
852 # Remove from shared backend
853 if self._backend == "redis":
854 if not self._redis:
855 return
856 try:
857 await self._redis.delete(f"mcp:session:{session_id}")
858 # Notify other workers
859 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()}))
860 except Exception as e:
861 logger.error(f"Redis error removing session {session_id}: {e}")
863 elif self._backend == "database":
864 try:
866 def _db_remove() -> None:
867 """Delete session record from the database.
869 Removes the SessionRecord entry with the specified session_id
870 from the database. This is called when a session is being
871 terminated or has expired.
873 This inner function is designed to be run in a thread executor
874 to avoid blocking the async event loop during database operations.
876 Raises:
877 Exception: Any database error is re-raised after rollback.
878 This includes connection errors or constraint violations.
880 Examples:
881 >>> # This function is called internally by remove_session()
882 >>> # Deletes the SessionRecord where session_id matches
883 >>> # No error if session_id doesn't exist (idempotent)
884 """
885 db_session = next(get_db())
886 try:
887 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete()
888 db_session.commit()
889 except Exception as ex:
890 db_session.rollback()
891 raise ex
892 finally:
893 db_session.close()
895 await asyncio.to_thread(_db_remove)
896 except Exception as e:
897 logger.error(f"Database error removing session {session_id}: {e}")
899 logger.info(f"Removed session: {session_id}")
901 async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None:
902 """Broadcast a message to a session.
904 Sends a message to the specified session. The behavior depends on the backend:
906 - **memory**: Stores message temporarily for local delivery
907 - **redis**: Publishes message to Redis channel for the session
908 - **database**: Stores message in database for polling by worker with session
909 - **none**: No operation
911 This method is used for inter-process communication in distributed deployments.
913 Args:
914 session_id: Target session identifier.
915 message: Message to broadcast. Can be a dict, list, or any JSON-serializable object.
917 Examples:
918 >>> import asyncio
919 >>> from mcpgateway.cache.session_registry import SessionRegistry
920 >>>
921 >>> reg = SessionRegistry(backend='memory')
922 >>> message = {'method': 'tools/list', 'id': 1}
923 >>> asyncio.run(reg.broadcast('session-789', message))
924 >>>
925 >>> # Message stored for memory backend
926 >>> reg._session_message is not None
927 True
928 >>> reg._session_message['session_id']
929 'session-789'
930 >>> orjson.loads(reg._session_message['message'])['message'] == message
931 True
932 """
933 # Skip for none backend only
934 if self._backend == "none":
935 return
937 def _build_payload(msg: Any) -> str:
938 """Build a JSON payload for message broadcasting.
940 Args:
941 msg: Message to wrap in payload envelope.
943 Returns:
944 JSON-encoded string containing type, message, and timestamp.
945 """
946 payload = {"type": "message", "message": msg, "timestamp": time.time()}
947 return orjson.dumps(payload).decode()
949 if self._backend == "memory":
950 payload_json = _build_payload(message)
951 self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": payload_json}
953 elif self._backend == "redis":
954 if not self._redis:
955 logger.warning(f"Redis client not initialized, cannot broadcast to {session_id}")
956 return
957 try:
958 broadcast_payload = {
959 "type": "message",
960 "message": message, # Keep as original type, not pre-encoded
961 "timestamp": time.time(),
962 }
963 # Single encode
964 payload_json = orjson.dumps(broadcast_payload)
965 await self._redis.publish(session_id, payload_json) # Single encode
966 except Exception as e:
967 logger.error(f"Redis error during broadcast: {e}")
968 elif self._backend == "database":
969 try:
970 msg_json = _build_payload(message)
972 def _db_add() -> None:
973 """Store message in the database for inter-process communication.
975 Creates a new SessionMessageRecord entry containing the session_id
976 and serialized message. This enables message passing between
977 different worker processes through the shared database.
979 This inner function is designed to be run in a thread executor
980 to avoid blocking the async event loop during database writes.
982 Raises:
983 Exception: Any database error is re-raised after rollback.
984 Common errors include database connection issues or
985 constraints violations.
987 Examples:
988 >>> # This function is called internally by broadcast()
989 >>> # Creates a record like:
990 >>> # SessionMessageRecord(
991 >>> # session_id='abc123',
992 >>> # message='{"method": "ping", "id": 1}',
993 >>> # created_at=now()
994 >>> # )
995 """
996 db_session = next(get_db())
997 try:
998 message_record = SessionMessageRecord(session_id=session_id, message=msg_json)
999 db_session.add(message_record)
1000 db_session.commit()
1001 except Exception as ex:
1002 db_session.rollback()
1003 raise ex
1004 finally:
1005 db_session.close()
1007 await asyncio.to_thread(_db_add)
1008 except Exception as e:
1009 logger.error(f"Database error during broadcast: {e}")
1011 async def _register_session_mapping(self, session_id: str, message: Dict[str, Any], user_email: Optional[str] = None) -> None:
1012 """Register session mapping for session affinity when tools are called.
1014 This method is called on the worker that executes the request (the SSE session
1015 owner) to pre-register the mapping between a downstream session ID and the
1016 upstream MCP session pool key. This enables session affinity in multi-worker
1017 deployments.
1019 Only registers mappings for tools/call methods - list operations and other
1020 methods don't need session affinity since they don't maintain state.
1022 Args:
1023 session_id: The downstream SSE session ID.
1024 message: The MCP protocol message being broadcast.
1025 user_email: Optional user email for session isolation.
1026 """
1027 # Skip if session affinity is disabled
1028 if not settings.mcpgateway_session_affinity_enabled:
1029 return
1031 # Only register for tools/call - other methods don't need session affinity
1032 method = message.get("method")
1033 if method != "tools/call":
1034 return
1036 # Extract tool name from params
1037 params = message.get("params", {})
1038 tool_name = params.get("name")
1039 if not tool_name:
1040 return
1042 try:
1043 # Look up tool in cache to get gateway info
1044 # First-Party
1045 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
1047 tool_info = await tool_lookup_cache.get(tool_name)
1048 if not tool_info:
1049 logger.debug(f"Tool {tool_name} not found in cache, skipping session mapping registration")
1050 return
1052 # Extract gateway information
1053 gateway = tool_info.get("gateway", {})
1054 gateway_url = gateway.get("url")
1055 gateway_id = gateway.get("id")
1056 transport = gateway.get("transport")
1058 if not gateway_url or not gateway_id or not transport:
1059 logger.debug(f"Incomplete gateway info for tool {tool_name}, skipping session mapping registration")
1060 return
1062 # Register the session mapping with the pool
1063 # First-Party
1064 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
1066 pool = get_mcp_session_pool()
1067 await pool.register_session_mapping(
1068 session_id,
1069 gateway_url,
1070 gateway_id,
1071 transport,
1072 user_email,
1073 )
1075 logger.debug(f"Registered session mapping for session {session_id[:8]}... -> {gateway_url} (tool: {tool_name})")
1077 except Exception as e:
1078 # Don't fail the broadcast if session mapping registration fails
1079 logger.warning(f"Failed to register session mapping for {session_id[:8]}...: {e}")
1081 async def get_all_session_ids(self) -> list[str]:
1082 """Return a snapshot list of all known local session IDs.
1084 Returns:
1085 list[str]: A snapshot list of currently known local session IDs.
1086 """
1087 async with self._lock:
1088 return list(self._sessions.keys())
1090 def get_session_sync(self, session_id: str) -> Any:
1091 """Get session synchronously from local cache only.
1093 This is a non-blocking method that only checks the local cache,
1094 not the distributed backend. Use this when you need quick access
1095 and know the session should be local.
1097 Args:
1098 session_id: Session identifier to look up.
1100 Returns:
1101 SSETransport object if found in local cache, None otherwise.
1103 Examples:
1104 >>> from mcpgateway.cache.session_registry import SessionRegistry
1105 >>> import asyncio
1106 >>>
1107 >>> class MockTransport:
1108 ... pass
1109 >>>
1110 >>> reg = SessionRegistry()
1111 >>> transport = MockTransport()
1112 >>> asyncio.run(reg.add_session('sync-test', transport))
1113 >>>
1114 >>> # Synchronous lookup
1115 >>> found = reg.get_session_sync('sync-test')
1116 >>> found is transport
1117 True
1118 >>>
1119 >>> # Not found
1120 >>> reg.get_session_sync('nonexistent') is None
1121 True
1122 """
1123 # Skip for none backend
1124 if self._backend == "none":
1125 return None
1127 return self._sessions.get(session_id)
1129 async def respond(
1130 self,
1131 server_id: Optional[str],
1132 user: Dict[str, Any],
1133 session_id: str,
1134 base_url: str,
1135 ) -> None:
1136 """Process and respond to broadcast messages for a session.
1138 This method listens for messages directed to the specified session and
1139 generates appropriate responses. The listening mechanism depends on the backend:
1141 - **memory**: Checks the temporary message storage
1142 - **redis**: Subscribes to Redis pubsub channel
1143 - **database**: Polls database for new messages
1145 When a message is received and the transport exists locally, it processes
1146 the message and sends the response through the transport.
1148 Args:
1149 server_id: Optional server identifier for scoped operations.
1150 user: User information including authentication token.
1151 session_id: Session identifier to respond for.
1152 base_url: Base URL for API calls (used for RPC endpoints).
1154 Raises:
1155 asyncio.CancelledError: When the respond task is cancelled (e.g., on session removal).
1157 Examples:
1158 >>> import asyncio
1159 >>> from mcpgateway.cache.session_registry import SessionRegistry
1160 >>>
1161 >>> # This method is typically called internally by the SSE handler
1162 >>> reg = SessionRegistry()
1163 >>> user = {'token': 'test-token'}
1164 >>> # asyncio.run(reg.respond(None, user, 'session-id', 'http://localhost'))
1165 """
1167 if self._backend == "none":
1168 pass
1170 elif self._backend == "memory":
1171 transport = self.get_session_sync(session_id)
1172 if transport and self._session_message:
1173 message_json = self._session_message.get("message")
1174 if message_json:
1175 data = orjson.loads(message_json)
1176 if isinstance(data, dict) and "message" in data:
1177 message = data["message"]
1178 else:
1179 message = data
1180 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
1181 else:
1182 logger.warning(f"Session message stored but message content is None for session {session_id}")
1184 elif self._backend == "redis":
1185 if not self._redis:
1186 logger.warning(f"Redis client not initialized, cannot respond to {session_id}")
1187 return
1188 pubsub = self._redis.pubsub()
1189 await pubsub.subscribe(session_id)
1191 # Use timeout-based polling instead of infinite listen() to allow exit checks
1192 # This is critical for allowing cancellation to work (Finding 2)
1193 poll_timeout = 1.0 # Check every second if session still exists
1195 try:
1196 while True:
1197 # Check if session still exists or is closing - exit early
1198 if session_id not in self._sessions or session_id in self._closing_sessions:
1199 logger.info(f"Session {session_id} removed or closing, exiting Redis respond loop")
1200 break
1202 # Use get_message with timeout instead of blocking listen()
1203 try:
1204 msg = await asyncio.wait_for(
1205 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), timeout=poll_timeout + 0.5 # Slightly longer to account for Redis timeout
1206 )
1207 except asyncio.TimeoutError:
1208 # No message, loop back to check session existence
1209 continue
1211 if msg is None:
1212 # CRITICAL: Sleep to prevent tight loop when get_message returns immediately
1213 # This can happen in certain Redis states after disconnects
1214 await asyncio.sleep(0.1)
1215 continue
1216 if msg["type"] != "message":
1217 # Sleep on non-message types to prevent spin in edge cases
1218 await asyncio.sleep(0.1)
1219 continue
1221 data = orjson.loads(msg["data"])
1222 message = data.get("message", {})
1223 transport = self.get_session_sync(session_id)
1224 if transport:
1225 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
1226 except asyncio.CancelledError:
1227 logger.info(f"PubSub listener for session {session_id} cancelled")
1228 raise # Re-raise to properly complete cancellation
1229 except Exception as e:
1230 logger.error(f"PubSub listener error for session {session_id}: {e}")
1231 finally:
1232 # Pubsub cleanup first - use timeouts to prevent blocking
1233 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout
1234 try:
1235 await asyncio.wait_for(pubsub.unsubscribe(session_id), timeout=cleanup_timeout)
1236 except asyncio.TimeoutError:
1237 logger.debug(f"Pubsub unsubscribe timed out for session {session_id}")
1238 except Exception as e:
1239 logger.debug(f"Error unsubscribing pubsub for session {session_id}: {e}")
1240 try:
1241 try:
1242 await asyncio.wait_for(pubsub.aclose(), timeout=cleanup_timeout)
1243 except AttributeError:
1244 await asyncio.wait_for(pubsub.close(), timeout=cleanup_timeout)
1245 except asyncio.TimeoutError:
1246 logger.debug(f"Pubsub close timed out for session {session_id}")
1247 except Exception as e:
1248 logger.debug(f"Error closing pubsub for session {session_id}: {e}")
1249 logger.info(f"Cleaned up pubsub for session {session_id}")
1250 # Clean up task reference LAST (idempotent - may already be removed by _cancel_respond_task)
1251 self._respond_tasks.pop(session_id, None)
1253 elif self._backend == "database":
1255 def _db_read_session_and_message(
1256 session_id: str,
1257 ) -> tuple[SessionRecord | None, SessionMessageRecord | None]:
1258 """
1259 Check whether a session exists and retrieve its next pending message
1260 in a single database query.
1262 This function performs a LEFT OUTER JOIN between SessionRecord and
1263 SessionMessageRecord to determine:
1265 - Whether the session still exists
1266 - Whether there is a pending message for the session (FIFO order)
1268 It is used by the database-backed message polling loop to reduce
1269 database load by collapsing multiple reads into a single query.
1271 Messages are returned in FIFO order based on the message primary key.
1273 This function is designed to be run in a thread executor to avoid
1274 blocking the async event loop during database access.
1276 Args:
1277 session_id: The session identifier to look up.
1279 Returns:
1280 Tuple[SessionRecord | None, SessionMessageRecord | None]:
1282 - (None, None)
1283 The session does not exist.
1285 - (SessionRecord, None)
1286 The session exists but has no pending messages.
1288 - (SessionRecord, SessionMessageRecord)
1289 The session exists and has a pending message.
1291 Raises:
1292 Exception: Any database error is re-raised after rollback.
1294 Examples:
1295 >>> # This function is called internally by message_check_loop()
1296 >>> # Session exists and has a pending message
1297 >>> # Returns (SessionRecord, SessionMessageRecord)
1299 >>> # Session exists but has no pending messages
1300 >>> # Returns (SessionRecord, None)
1302 >>> # Session has been removed
1303 >>> # Returns (None, None)
1304 """
1305 db_session = next(get_db())
1306 try:
1307 result = (
1308 db_session.query(SessionRecord, SessionMessageRecord)
1309 .outerjoin(
1310 SessionMessageRecord,
1311 SessionMessageRecord.session_id == SessionRecord.session_id,
1312 )
1313 .filter(SessionRecord.session_id == session_id)
1314 .order_by(SessionMessageRecord.id.asc())
1315 .first()
1316 )
1317 if not result:
1318 return None, None
1319 session, message = result
1320 return session, message
1321 except Exception as ex:
1322 db_session.rollback()
1323 raise ex
1324 finally:
1325 db_session.close()
1327 def _db_remove(session_id: str, message: str) -> None:
1328 """Remove processed message from the database.
1330 Deletes a specific message record after it has been successfully
1331 processed and sent to the transport. This prevents duplicate
1332 message delivery.
1334 This inner function is designed to be run in a thread executor
1335 to avoid blocking the async event loop during database deletes.
1337 Args:
1338 session_id: The session identifier the message belongs to.
1339 message: The exact message content to remove (must match exactly).
1341 Raises:
1342 Exception: Any database error is re-raised after rollback.
1344 Examples:
1345 >>> # This function is called internally after message processing
1346 >>> # Deletes the specific SessionMessageRecord entry
1347 >>> # Log: "Removed message from mcp_messages table"
1348 """
1349 db_session = next(get_db())
1350 try:
1351 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete()
1352 db_session.commit()
1353 logger.info("Removed message from mcp_messages table")
1354 except Exception as ex:
1355 db_session.rollback()
1356 raise ex
1357 finally:
1358 db_session.close()
1360 async def message_check_loop(session_id: str) -> None:
1361 """
1362 Background task that polls the database for messages belonging to a session
1363 using adaptive polling with exponential backoff.
1365 The loop continues until the session is removed from the database.
1367 Behavior:
1368 - Starts with a fast polling interval for low-latency message delivery.
1369 - When no message is found, the polling interval increases exponentially
1370 (up to a configured maximum) to reduce database load.
1371 - When a message is received, the polling interval is immediately reset
1372 to the fast interval.
1373 - The loop exits as soon as the session no longer exists.
1375 Polling rules:
1376 - Message found → process message, reset polling interval.
1377 - No message → increase polling interval (backoff).
1378 - Session gone → stop polling immediately.
1380 Args:
1381 session_id (str): Unique identifier of the session to monitor.
1383 Raises:
1384 asyncio.CancelledError: When the polling loop is cancelled.
1386 Examples
1387 --------
1388 Adaptive backoff when no messages are present:
1390 >>> poll_interval = 0.1
1391 >>> backoff_factor = 1.5
1392 >>> max_interval = 5.0
1393 >>> poll_interval = min(poll_interval * backoff_factor, max_interval)
1394 >>> poll_interval
1395 0.15000000000000002
1397 Backoff continues until the maximum interval is reached:
1399 >>> poll_interval = 4.0
1400 >>> poll_interval = min(poll_interval * 1.5, 5.0)
1401 >>> poll_interval
1402 5.0
1404 Polling interval resets immediately when a message arrives:
1406 >>> poll_interval = 2.0
1407 >>> poll_interval = 0.1
1408 >>> poll_interval
1409 0.1
1411 Session termination stops polling:
1413 >>> session_exists = False
1414 >>> if not session_exists:
1415 ... "polling stopped"
1416 'polling stopped'
1417 """
1419 poll_interval = settings.poll_interval # start fast
1420 max_interval = settings.max_interval # cap at configured maximum
1421 backoff_factor = settings.backoff_factor
1422 try:
1423 while True:
1424 # Check if session is closing before querying DB
1425 if session_id in self._closing_sessions:
1426 logger.debug("Session %s closing, stopping poll loop early", session_id)
1427 break
1429 session, record = await asyncio.to_thread(_db_read_session_and_message, session_id)
1431 # session gone → stop polling
1432 if not session:
1433 logger.debug("Session %s no longer exists, stopping poll loop", session_id)
1434 break
1436 if record:
1437 poll_interval = settings.poll_interval # reset on activity
1439 data = orjson.loads(record.message)
1440 if isinstance(data, dict) and "message" in data:
1441 message = data["message"]
1442 else:
1443 message = data
1445 transport = self.get_session_sync(session_id)
1446 if transport:
1447 logger.info("Ready to respond")
1448 await self.generate_response(
1449 message=message,
1450 transport=transport,
1451 server_id=server_id,
1452 user=user,
1453 base_url=base_url,
1454 )
1456 await asyncio.to_thread(_db_remove, session_id, record.message)
1457 else:
1458 # no message → backoff
1459 # update polling interval with backoff factor
1460 poll_interval = min(poll_interval * backoff_factor, max_interval)
1462 await asyncio.sleep(poll_interval)
1463 except asyncio.CancelledError:
1464 logger.info(f"Message check loop cancelled for session {session_id}")
1465 raise # Re-raise to properly complete cancellation
1466 except Exception as e:
1467 logger.error(f"Message check loop error for session {session_id}: {e}")
1469 # CRITICAL: Await instead of fire-and-forget
1470 # This ensures CancelledError propagates from outer respond() task to inner loop
1471 # The outer task (registered from main.py) now runs until message_check_loop exits
1472 try:
1473 await message_check_loop(session_id)
1474 except asyncio.CancelledError:
1475 logger.info(f"Database respond cancelled for session {session_id}")
1476 raise
1477 finally:
1478 # Clean up task reference on ANY exit (normal, cancelled, or error)
1479 # Prevents stale done tasks from accumulating in _respond_tasks
1480 self._respond_tasks.pop(session_id, None)
1482 async def _refresh_redis_sessions(self) -> None:
1483 """Refresh TTLs for Redis sessions and clean up disconnected sessions.
1485 This internal method is used by the Redis backend to maintain session state.
1486 It checks all local sessions, refreshes TTLs for connected sessions, and
1487 removes disconnected ones.
1488 """
1489 if not self._redis:
1490 return
1491 try:
1492 # Check all local sessions
1493 local_transports = {}
1494 async with self._lock:
1495 local_transports = self._sessions.copy()
1497 for session_id, transport in local_transports.items():
1498 try:
1499 if await transport.is_connected():
1500 # Refresh TTL in Redis
1501 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl)
1502 else:
1503 # Remove disconnected session
1504 await self.remove_session(session_id)
1505 except Exception as e:
1506 logger.error(f"Error refreshing session {session_id}: {e}")
1508 except Exception as e:
1509 logger.error(f"Error in Redis session refresh: {e}")
1511 async def _db_cleanup_task(self) -> None:
1512 """Background task to clean up expired database sessions.
1514 Runs periodically (every 5 minutes) to remove expired sessions from the
1515 database and refresh timestamps for active sessions. This prevents the
1516 database from accumulating stale session records.
1518 The task also verifies that local sessions still exist in the database
1519 and removes them locally if they've been deleted elsewhere.
1520 """
1521 logger.info("Starting database cleanup task")
1522 while True:
1523 try:
1524 # Clean up expired sessions every 5 minutes
1525 def _db_cleanup() -> int:
1526 """Remove expired sessions from the database.
1528 Deletes all SessionRecord entries that haven't been accessed
1529 within the session TTL period. Uses database-specific date
1530 arithmetic to calculate expiry time.
1532 This inner function is designed to be run in a thread executor
1533 to avoid blocking the async event loop during bulk deletes.
1535 Returns:
1536 int: Number of expired session records deleted.
1538 Raises:
1539 Exception: Any database error is re-raised after rollback.
1541 Examples:
1542 >>> # This function is called periodically by _db_cleanup_task()
1543 >>> # Deletes sessions older than session_ttl seconds
1544 >>> # Returns count of deleted records for logging
1545 >>> # Log: "Cleaned up 5 expired database sessions"
1546 """
1547 db_session = next(get_db())
1548 try:
1549 # Delete sessions that haven't been accessed for TTL seconds
1550 # Use Python datetime for database-agnostic expiry calculation
1551 expiry_time = datetime.now(timezone.utc) - timedelta(seconds=self._session_ttl)
1552 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete()
1553 db_session.commit()
1554 return result
1555 except Exception as ex:
1556 db_session.rollback()
1557 raise ex
1558 finally:
1559 db_session.close()
1561 deleted = await asyncio.to_thread(_db_cleanup)
1562 if deleted > 0:
1563 logger.info(f"Cleaned up {deleted} expired database sessions")
1565 # Check local sessions against database
1566 await self._cleanup_database_sessions()
1568 await asyncio.sleep(300) # Run every 5 minutes
1570 except asyncio.CancelledError:
1571 logger.info("Database cleanup task cancelled")
1572 break
1573 except Exception as e:
1574 logger.error(f"Error in database cleanup task: {e}")
1575 await asyncio.sleep(600) # Sleep longer on error
1577 def _refresh_session_db(self, session_id: str) -> bool:
1578 """Update session's last accessed timestamp in the database.
1580 Refreshes the last_accessed field for an active session to
1581 prevent it from being cleaned up as expired. This is called
1582 periodically for all local sessions with active transports.
1584 Args:
1585 session_id: The session identifier to refresh.
1587 Returns:
1588 bool: True if the session was found and updated, False if not found.
1590 Raises:
1591 Exception: Any database error is re-raised after rollback.
1592 """
1593 db_session = next(get_db())
1594 try:
1595 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1596 if session:
1597 session.last_accessed = func.now() # pylint: disable=not-callable
1598 db_session.commit()
1599 return True
1600 return False
1601 except Exception as ex:
1602 db_session.rollback()
1603 raise ex
1604 finally:
1605 db_session.close()
1607 async def _cleanup_database_sessions(self, max_concurrent: int = 20) -> None:
1608 """Parallelize session cleanup with bounded concurrency.
1610 Checks connection status first (fast), then refreshes connected sessions
1611 in parallel using asyncio.gather() with a semaphore to limit concurrent
1612 DB operations and prevent resource exhaustion.
1614 Args:
1615 max_concurrent: Maximum number of concurrent DB refresh operations.
1616 Defaults to 20 to balance parallelism with resource usage.
1617 """
1618 async with self._lock:
1619 local_transports = self._sessions.copy()
1621 # Check connections first (fast)
1622 connected: list[str] = []
1623 for session_id, transport in local_transports.items():
1624 try:
1625 if not await transport.is_connected():
1626 await self.remove_session(session_id)
1627 else:
1628 connected.append(session_id)
1629 except Exception as e:
1630 # Only log error, don't remove session on transient errors
1631 logger.error(f"Error checking connection for session {session_id}: {e}")
1633 # Parallel refresh of connected sessions with bounded concurrency
1634 if connected:
1635 semaphore = asyncio.Semaphore(max_concurrent)
1637 async def bounded_refresh(session_id: str) -> bool:
1638 """Refresh session with semaphore-bounded concurrency.
1640 Args:
1641 session_id: The session ID to refresh.
1643 Returns:
1644 True if refresh succeeded, False otherwise.
1645 """
1646 async with semaphore:
1647 return await asyncio.to_thread(self._refresh_session_db, session_id)
1649 refresh_tasks = [bounded_refresh(session_id) for session_id in connected]
1650 results = await asyncio.gather(*refresh_tasks, return_exceptions=True)
1652 for session_id, result in zip(connected, results):
1653 try:
1654 if isinstance(result, Exception):
1655 # Only log error, don't remove session on transient DB errors
1656 logger.error(f"Error refreshing session {session_id}: {result}")
1657 elif not result:
1658 # Session no longer in database, remove locally
1659 await self.remove_session(session_id)
1660 except Exception as e:
1661 logger.error(f"Error processing refresh result for session {session_id}: {e}")
1663 async def _memory_cleanup_task(self) -> None:
1664 """Background task to clean up disconnected sessions in memory backend.
1666 Runs periodically (every minute) to check all local sessions and remove
1667 those that are no longer connected. This prevents memory leaks from
1668 accumulating disconnected transport objects.
1669 """
1670 logger.info("Starting memory cleanup task")
1671 while True:
1672 try:
1673 # Check all local sessions
1674 local_transports = {}
1675 async with self._lock:
1676 local_transports = self._sessions.copy()
1678 for session_id, transport in local_transports.items():
1679 try:
1680 if not await transport.is_connected():
1681 await self.remove_session(session_id)
1682 except Exception as e:
1683 logger.error(f"Error checking session {session_id}: {e}")
1684 await self.remove_session(session_id)
1686 await asyncio.sleep(60) # Run every minute
1688 except asyncio.CancelledError:
1689 logger.info("Memory cleanup task cancelled")
1690 break
1691 except Exception as e:
1692 logger.error(f"Error in memory cleanup task: {e}")
1693 await asyncio.sleep(300) # Sleep longer on error
1695 def _get_oauth_experimental_config(self, server_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
1696 """Query OAuth configuration for a server (synchronous, run in threadpool).
1698 This method queries the database for OAuth configuration and returns
1699 RFC 9728-safe fields for advertising in MCP capabilities.
1701 Args:
1702 server_id: The server ID to query OAuth configuration for.
1704 Returns:
1705 Dict with 'oauth' key containing safe OAuth config, or None if not configured.
1706 """
1707 # First-Party
1708 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
1709 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel
1711 db = SessionLocal()
1712 try:
1713 server = db.get(DbServer, server_id)
1714 if server and getattr(server, "oauth_enabled", False) and getattr(server, "oauth_config", None):
1715 # Filter oauth_config to RFC 9728-safe fields only (never expose secrets)
1716 oauth_config = server.oauth_config
1717 safe_oauth: Dict[str, Any] = {}
1719 # Extract authorization servers
1720 if oauth_config.get("authorization_servers"):
1721 safe_oauth["authorization_servers"] = oauth_config["authorization_servers"]
1722 elif oauth_config.get("authorization_server"):
1723 safe_oauth["authorization_servers"] = [oauth_config["authorization_server"]]
1725 # Extract scopes
1726 scopes = oauth_config.get("scopes_supported") or oauth_config.get("scopes")
1727 if scopes:
1728 safe_oauth["scopes_supported"] = scopes
1730 # Add bearer methods
1731 safe_oauth["bearer_methods_supported"] = oauth_config.get("bearer_methods_supported", ["header"])
1733 if safe_oauth.get("authorization_servers"):
1734 logger.debug(f"Advertising OAuth capability for server {server_id}")
1735 return {"oauth": safe_oauth}
1736 return None
1737 finally:
1738 db.close()
1740 # Handle initialize logic
1741 async def handle_initialize_logic(self, body: Dict[str, Any], session_id: Optional[str] = None, server_id: Optional[str] = None) -> InitializeResult:
1742 """Process MCP protocol initialization request.
1744 Validates the protocol version and returns server capabilities and information.
1745 This method implements the MCP (Model Context Protocol) initialization handshake.
1747 Args:
1748 body: Request body containing protocol_version and optional client_info.
1749 Expected keys: 'protocol_version' or 'protocolVersion', 'capabilities'.
1750 session_id: Optional session ID to associate client capabilities with.
1751 server_id: Optional server ID to query OAuth configuration for RFC 9728 support.
1753 Returns:
1754 InitializeResult containing protocol version, server capabilities, and server info.
1756 Raises:
1757 HTTPException: If protocol_version is missing (400 Bad Request with MCP error code -32002).
1759 Examples:
1760 >>> import asyncio
1761 >>> from mcpgateway.cache.session_registry import SessionRegistry
1762 >>>
1763 >>> reg = SessionRegistry()
1764 >>> body = {'protocol_version': '2025-06-18'}
1765 >>> result = asyncio.run(reg.handle_initialize_logic(body))
1766 >>> result.protocol_version
1767 '2025-06-18'
1768 >>> result.server_info.name
1769 'MCP_Gateway'
1770 >>>
1771 >>> # Missing protocol version
1772 >>> try:
1773 ... asyncio.run(reg.handle_initialize_logic({}))
1774 ... except HTTPException as e:
1775 ... e.status_code
1776 400
1777 """
1778 protocol_version = body.get("protocol_version") or body.get("protocolVersion")
1779 client_capabilities = body.get("capabilities", {})
1780 # body.get("client_info") or body.get("clientInfo", {})
1782 if not protocol_version:
1783 raise HTTPException(
1784 status_code=status.HTTP_400_BAD_REQUEST,
1785 detail="Missing protocol version",
1786 headers={"MCP-Error-Code": "-32002"},
1787 )
1789 if protocol_version != settings.protocol_version:
1790 logger.warning(f"Using non default protocol version: {protocol_version}")
1792 # Store client capabilities if session_id provided
1793 if session_id and client_capabilities:
1794 await self.store_client_capabilities(session_id, client_capabilities)
1795 logger.debug(f"Stored capabilities for session {session_id}: {client_capabilities}")
1797 # Build experimental capabilities (including OAuth if configured)
1798 experimental: Optional[Dict[str, Dict[str, Any]]] = None
1800 # Query OAuth configuration if server_id is provided
1801 if server_id:
1802 try:
1803 # Run synchronous DB query in threadpool to avoid blocking the event loop
1804 experimental = await asyncio.to_thread(self._get_oauth_experimental_config, server_id)
1805 except Exception as e:
1806 logger.warning(f"Failed to query OAuth config for server {server_id}: {e}")
1808 return InitializeResult(
1809 protocolVersion=protocol_version,
1810 capabilities=ServerCapabilities(
1811 prompts={"listChanged": True},
1812 resources={"subscribe": True, "listChanged": True},
1813 tools={"listChanged": True},
1814 logging={},
1815 completions={}, # Advertise completions capability per MCP spec
1816 experimental=experimental, # OAuth capability when configured
1817 ),
1818 serverInfo=Implementation(name=settings.app_name, version=__version__),
1819 instructions=("MCP Gateway providing federated tools, resources and prompts. Use /admin interface for configuration."),
1820 )
1822 async def store_client_capabilities(self, session_id: str, capabilities: Dict[str, Any]) -> None:
1823 """Store client capabilities for a session.
1825 Args:
1826 session_id: The session ID
1827 capabilities: Client capabilities dictionary from initialize request
1828 """
1829 async with self._lock:
1830 self._client_capabilities[session_id] = capabilities
1831 logger.debug(f"Stored capabilities for session {session_id}")
1833 async def get_client_capabilities(self, session_id: str) -> Optional[Dict[str, Any]]:
1834 """Get client capabilities for a session.
1836 Args:
1837 session_id: The session ID
1839 Returns:
1840 Client capabilities dictionary, or None if not found
1841 """
1842 async with self._lock:
1843 return self._client_capabilities.get(session_id)
1845 async def has_elicitation_capability(self, session_id: str) -> bool:
1846 """Check if a session has elicitation capability.
1848 Args:
1849 session_id: The session ID
1851 Returns:
1852 True if session supports elicitation, False otherwise
1853 """
1854 capabilities = await self.get_client_capabilities(session_id)
1855 if not capabilities:
1856 return False
1857 # Check if elicitation capability exists in client capabilities
1858 return bool(capabilities.get("elicitation"))
1860 async def get_elicitation_capable_sessions(self) -> list[str]:
1861 """Get list of session IDs that support elicitation.
1863 Returns:
1864 List of session IDs with elicitation capability
1865 """
1866 async with self._lock:
1867 capable_sessions = []
1868 for session_id, capabilities in self._client_capabilities.items():
1869 if capabilities.get("elicitation"):
1870 # Verify session still exists
1871 if session_id in self._sessions:
1872 capable_sessions.append(session_id)
1873 return capable_sessions
1875 async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any], base_url: str) -> None:
1876 """Generate and send response for incoming MCP protocol message.
1878 Processes MCP protocol messages and generates appropriate responses based on
1879 the method. Supports various MCP methods including initialization, tool/resource/prompt
1880 listing, tool invocation, and ping.
1882 Args:
1883 message: Incoming MCP message as JSON. Must contain 'method' and 'id' fields.
1884 transport: SSE transport to send responses through.
1885 server_id: Optional server ID for scoped operations.
1886 user: User information containing authentication token.
1887 base_url: Base URL for constructing RPC endpoints.
1889 Examples:
1890 >>> import asyncio
1891 >>> from mcpgateway.cache.session_registry import SessionRegistry
1892 >>>
1893 >>> class MockTransport:
1894 ... async def send_message(self, msg):
1895 ... print(f"Response: {msg['method'] if 'method' in msg else msg.get('result', {})}")
1896 >>>
1897 >>> reg = SessionRegistry()
1898 >>> transport = MockTransport()
1899 >>> message = {"method": "ping", "id": 1}
1900 >>> user = {"token": "test-token"}
1901 >>> # asyncio.run(reg.generate_response(message, transport, None, user, "http://localhost"))
1902 >>> # Response: {}
1903 """
1904 result = {}
1906 if "method" in message and "id" in message:
1907 method = message["method"]
1908 params = message.get("params", {})
1909 params["server_id"] = server_id
1910 req_id = message["id"]
1912 rpc_input = {
1913 "jsonrpc": "2.0",
1914 "method": method,
1915 "params": params,
1916 "id": req_id,
1917 }
1918 # Get the token from the current authentication context
1919 # The user object should contain auth_token, token_teams, and is_admin from the SSE endpoint
1920 token = None
1921 is_admin = user.get("is_admin", False) # Preserve admin status from SSE endpoint
1923 try:
1924 if hasattr(user, "get") and user.get("auth_token"):
1925 token = user["auth_token"]
1926 else:
1927 # Fallback: create lightweight session token (teams resolved server-side by downstream /rpc)
1928 logger.warning("No auth token available for SSE RPC call - creating fallback session token")
1929 now = datetime.now(timezone.utc)
1930 payload = {
1931 "sub": user.get("email", "system"),
1932 "iss": settings.jwt_issuer,
1933 "aud": settings.jwt_audience,
1934 "iat": int(now.timestamp()),
1935 "jti": str(uuid.uuid4()),
1936 "token_use": "session", # nosec B105 - token type marker, not a password
1937 "user": {
1938 "email": user.get("email", "system"),
1939 "full_name": user.get("full_name", "System"),
1940 "is_admin": is_admin, # Preserve admin status for cookie-authenticated admins
1941 "auth_provider": "internal",
1942 },
1943 }
1944 # Generate token using centralized token creation
1945 token = await create_jwt_token(payload)
1947 # Pass downstream session id to /rpc for session affinity.
1948 # This is gateway-internal only; the pool strips it before contacting upstream MCP servers.
1949 if settings.mcpgateway_session_affinity_enabled:
1950 await self._register_session_mapping(transport.session_id, message, user.get("email") if hasattr(user, "get") else None)
1952 headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
1953 if settings.mcpgateway_session_affinity_enabled:
1954 headers["x-mcp-session-id"] = transport.session_id
1955 # Extract root URL from base_url (remove /servers/{id} path)
1956 parsed_url = urlparse(base_url)
1957 # Preserve the path up to the root path (before /servers/{id})
1958 path_parts = parsed_url.path.split("/")
1959 if "/servers/" in parsed_url.path:
1960 # Find the index of 'servers' and take everything before it
1961 try:
1962 servers_index = path_parts.index("servers")
1963 root_path = "/" + "/".join(path_parts[1:servers_index]).strip("/")
1964 if root_path == "/":
1965 root_path = ""
1966 except ValueError:
1967 root_path = ""
1968 else:
1969 root_path = parsed_url.path.rstrip("/")
1971 root_url = f"{parsed_url.scheme}://{parsed_url.netloc}{root_path}"
1972 rpc_url = root_url + "/rpc"
1974 logger.info(f"SSE RPC: Making call to {rpc_url} with method={method}, params={params}")
1976 async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client:
1977 logger.info(f"SSE RPC: Sending request to {rpc_url}")
1978 rpc_response = await client.post(
1979 url=rpc_url,
1980 json=rpc_input,
1981 headers=headers,
1982 )
1983 logger.info(f"SSE RPC: Got response status {rpc_response.status_code}")
1984 result = rpc_response.json()
1985 logger.info(f"SSE RPC: Response content: {result}")
1986 result = result.get("result", {})
1988 response = {"jsonrpc": "2.0", "result": result, "id": req_id}
1989 except JSONRPCError as e:
1990 logger.error(f"SSE RPC: JSON-RPC error: {e}")
1991 result = e.to_dict()
1992 response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id}
1993 except Exception as e:
1994 logger.error(f"SSE RPC: Exception during RPC call: {type(e).__name__}: {e}")
1995 logger.error(f"SSE RPC: Traceback: {traceback.format_exc()}")
1996 result = {"code": -32000, "message": "Internal error", "data": str(e)}
1997 response = {"jsonrpc": "2.0", "error": result, "id": req_id}
1999 logging.debug(f"Sending sse message:{response}")
2000 await transport.send_message(response)
2002 if message["method"] == "initialize":
2003 await transport.send_message(
2004 {
2005 "jsonrpc": "2.0",
2006 "method": "notifications/initialized",
2007 "params": {},
2008 }
2009 )
2010 notifications = [
2011 "tools/list_changed",
2012 "resources/list_changed",
2013 "prompts/list_changed",
2014 ]
2015 for notification in notifications:
2016 await transport.send_message(
2017 {
2018 "jsonrpc": "2.0",
2019 "method": f"notifications/{notification}",
2020 "params": {},
2021 }
2022 )