Coverage for mcpgateway / cache / session_registry.py: 100%
966 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/session_registry.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Session Registry with optional distributed state.
8This module provides a registry for SSE sessions with support for distributed deployment
9using Redis or SQLAlchemy as optional backends for shared state between workers.
11The SessionRegistry class manages server-sent event (SSE) sessions across multiple
12worker processes, enabling horizontal scaling of MCP gateway deployments. It supports
13three backend modes:
15- **memory**: In-memory storage for single-process deployments (default)
16- **redis**: Redis-backed shared storage for multi-worker deployments
17- **database**: SQLAlchemy-backed shared storage using any supported database
19In distributed mode (redis/database), session existence is tracked in the shared
20backend while transport objects remain local to each worker process. This allows
21workers to know about sessions on other workers and route messages appropriately.
23Examples:
24 Basic usage with memory backend:
26 >>> from mcpgateway.cache.session_registry import SessionRegistry
27 >>> class DummyTransport:
28 ... async def disconnect(self):
29 ... pass
30 ... async def is_connected(self):
31 ... return True
32 >>> import asyncio
33 >>> reg = SessionRegistry(backend='memory')
34 >>> transport = DummyTransport()
35 >>> asyncio.run(reg.add_session('sid123', transport))
36 >>> found = asyncio.run(reg.get_session('sid123'))
37 >>> isinstance(found, DummyTransport)
38 True
39 >>> asyncio.run(reg.remove_session('sid123'))
40 >>> asyncio.run(reg.get_session('sid123')) is None
41 True
43 Broadcasting messages:
45 >>> reg = SessionRegistry(backend='memory')
46 >>> asyncio.run(reg.broadcast('sid123', {'method': 'ping', 'id': 1}))
47 >>> reg._session_message is not None
48 True
49"""
51# Standard
52import asyncio
53from asyncio import Task
54from datetime import datetime, timedelta, timezone
55import logging
56import time
57import traceback
58from typing import Any, Dict, Optional
59import uuid
61# Third-Party
62from fastapi import HTTPException, status
63import orjson
65# First-Party
66from mcpgateway import __version__
67from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities
68from mcpgateway.config import settings
69from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord
70from mcpgateway.services import PromptService, ResourceService, ToolService
71from mcpgateway.services.logging_service import LoggingService
72from mcpgateway.transports import SSETransport
73from mcpgateway.utils.create_jwt_token import create_jwt_token
74from mcpgateway.utils.redis_client import get_redis_client
75from mcpgateway.utils.retry_manager import ResilientHttpClient
76from mcpgateway.validation.jsonrpc import JSONRPCError
78# Initialize logging service first
79logging_service: LoggingService = LoggingService()
80logger = logging_service.get_logger(__name__)
82tool_service: ToolService = ToolService()
83resource_service: ResourceService = ResourceService()
84prompt_service: PromptService = PromptService()
86try:
87 # Third-Party
88 from redis.asyncio import Redis
90 REDIS_AVAILABLE = True
91except ImportError:
92 REDIS_AVAILABLE = False
94try:
95 # Third-Party
96 from sqlalchemy import func
98 SQLALCHEMY_AVAILABLE = True
99except ImportError:
100 SQLALCHEMY_AVAILABLE = False
103class SessionBackend:
104 """Base class for session registry backend configuration.
106 This class handles the initialization and configuration of different backend
107 types for session storage. It validates backend requirements and sets up
108 necessary connections for Redis or database backends.
110 Attributes:
111 _backend: The backend type ('memory', 'redis', 'database', or 'none')
112 _session_ttl: Time-to-live for sessions in seconds
113 _message_ttl: Time-to-live for messages in seconds
114 _redis: Redis connection instance (redis backend only)
115 _pubsub: Redis pubsub instance (redis backend only)
116 _session_message: Temporary message storage (memory backend only)
118 Examples:
119 >>> backend = SessionBackend(backend='memory')
120 >>> backend._backend
121 'memory'
122 >>> backend._session_ttl
123 3600
125 >>> try:
126 ... backend = SessionBackend(backend='redis')
127 ... except ValueError as e:
128 ... str(e)
129 'Redis backend requires redis_url'
130 """
132 def __init__(
133 self,
134 backend: str = "memory",
135 redis_url: Optional[str] = None,
136 database_url: Optional[str] = None,
137 session_ttl: int = 3600, # 1 hour
138 message_ttl: int = 600, # 10 min
139 ):
140 """Initialize session backend configuration.
142 Args:
143 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'.
144 - 'memory': In-memory storage, suitable for single-process deployments
145 - 'redis': Redis-backed storage for multi-worker deployments
146 - 'database': SQLAlchemy-backed storage for multi-worker deployments
147 - 'none': No session tracking (dummy registry)
148 redis_url: Redis connection URL. Required when backend='redis'.
149 Format: 'redis://[:password]@host:port/db'
150 database_url: Database connection URL. Required when backend='database'.
151 Format depends on database type (e.g., 'postgresql://user:pass@host/db')
152 session_ttl: Session time-to-live in seconds. Sessions are automatically
153 cleaned up after this duration of inactivity. Default: 3600 (1 hour).
154 message_ttl: Message time-to-live in seconds. Undelivered messages are
155 removed after this duration. Default: 600 (10 minutes).
157 Raises:
158 ValueError: If backend is invalid, required URL is missing, or required packages are not installed.
160 Examples:
161 >>> # Memory backend (default)
162 >>> backend = SessionBackend()
163 >>> backend._backend
164 'memory'
166 >>> # Redis backend requires URL
167 >>> try:
168 ... backend = SessionBackend(backend='redis')
169 ... except ValueError as e:
170 ... 'redis_url' in str(e)
171 True
173 >>> # Invalid backend
174 >>> try:
175 ... backend = SessionBackend(backend='invalid')
176 ... except ValueError as e:
177 ... 'Invalid backend' in str(e)
178 True
179 """
181 self._backend = backend.lower()
182 self._session_ttl = session_ttl
183 self._message_ttl = message_ttl
185 # Set up backend-specific components
186 if self._backend == "memory":
187 # Nothing special needed for memory backend
188 self._session_message: dict[str, Any] | None = None
190 elif self._backend == "none":
191 # No session tracking - this is just a dummy registry
192 logger.info("Session registry initialized with 'none' backend - session tracking disabled")
194 elif self._backend == "redis":
195 if not REDIS_AVAILABLE:
196 raise ValueError("Redis backend requested but redis package not installed")
197 if not redis_url:
198 raise ValueError("Redis backend requires redis_url")
200 # Redis client is set in initialize() via the shared factory
201 self._redis: Optional[Redis] = None
202 self._pubsub = None
204 elif self._backend == "database":
205 if not SQLALCHEMY_AVAILABLE:
206 raise ValueError("Database backend requested but SQLAlchemy not installed")
207 if not database_url:
208 raise ValueError("Database backend requires database_url")
209 else:
210 raise ValueError(f"Invalid backend: {backend}")
213class SessionRegistry(SessionBackend):
214 """Registry for SSE sessions with optional distributed state.
216 This class manages server-sent event (SSE) sessions, providing methods to add,
217 remove, and query sessions. It supports multiple backend types for different
218 deployment scenarios:
220 - **Single-process deployments**: Use 'memory' backend (default)
221 - **Multi-worker deployments**: Use 'redis' or 'database' backend
222 - **Testing/development**: Use 'none' backend to disable session tracking
224 The registry maintains a local cache of transport objects while using the
225 shared backend to track session existence across workers. This enables
226 horizontal scaling while keeping transport objects process-local.
228 Attributes:
229 _sessions: Local dictionary mapping session IDs to transport objects
230 _lock: Asyncio lock for thread-safe access to _sessions
231 _cleanup_task: Background task for cleaning up expired sessions
233 Examples:
234 >>> import asyncio
235 >>> from mcpgateway.cache.session_registry import SessionRegistry
236 >>>
237 >>> class MockTransport:
238 ... async def disconnect(self):
239 ... print("Disconnected")
240 ... async def is_connected(self):
241 ... return True
242 ... async def send_message(self, msg):
243 ... print(f"Sent: {msg}")
244 >>>
245 >>> # Create registry and add session
246 >>> reg = SessionRegistry(backend='memory')
247 >>> transport = MockTransport()
248 >>> asyncio.run(reg.add_session('test123', transport))
249 >>>
250 >>> # Retrieve session
251 >>> found = asyncio.run(reg.get_session('test123'))
252 >>> found is transport
253 True
254 >>>
255 >>> # Remove session
256 >>> asyncio.run(reg.remove_session('test123'))
257 Disconnected
258 >>> asyncio.run(reg.get_session('test123')) is None
259 True
260 """
262 def __init__(
263 self,
264 backend: str = "memory",
265 redis_url: Optional[str] = None,
266 database_url: Optional[str] = None,
267 session_ttl: int = 3600, # 1 hour
268 message_ttl: int = 600, # 10 min
269 ):
270 """Initialize session registry with specified backend.
272 Args:
273 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'.
274 redis_url: Redis connection URL. Required when backend='redis'.
275 database_url: Database connection URL. Required when backend='database'.
276 session_ttl: Session time-to-live in seconds. Default: 3600.
277 message_ttl: Message time-to-live in seconds. Default: 600.
279 Examples:
280 >>> # Default memory backend
281 >>> reg = SessionRegistry()
282 >>> reg._backend
283 'memory'
284 >>> isinstance(reg._sessions, dict)
285 True
287 >>> # Redis backend with custom TTL
288 >>> try:
289 ... reg = SessionRegistry(
290 ... backend='redis',
291 ... redis_url='redis://localhost:6379',
292 ... session_ttl=7200
293 ... )
294 ... except ValueError:
295 ... pass # Redis may not be available
296 """
297 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl)
298 self._sessions: Dict[str, Any] = {} # Local transport cache
299 self._session_owners: Dict[str, str] = {} # Session owner email by session_id
300 self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id
301 self._respond_tasks: Dict[str, asyncio.Task] = {} # Track respond tasks for cancellation
302 self._stuck_tasks: Dict[str, asyncio.Task] = {} # Tasks that couldn't be cancelled (for monitoring)
303 self._closing_sessions: set[str] = set() # Sessions being closed - respond loop should exit
304 self._lock = asyncio.Lock()
305 self._cleanup_task: Task | None = None
306 self._stuck_task_reaper: Task | None = None # Reaper for stuck tasks
308 def register_respond_task(self, session_id: str, task: asyncio.Task) -> None:
309 """Register a respond task for later cancellation.
311 Associates an asyncio Task with a session_id so it can be cancelled
312 when the session is removed. This prevents orphaned tasks that cause
313 CPU spin loops.
315 Args:
316 session_id: Session identifier the task belongs to.
317 task: The asyncio Task to track.
318 """
319 self._respond_tasks[session_id] = task
320 logger.debug(f"Registered respond task for session {session_id}")
322 async def _cancel_respond_task(self, session_id: str, timeout: float = 5.0) -> None:
323 """Cancel and await a respond task with timeout.
325 Safely cancels the respond task associated with a session. Uses a timeout
326 to prevent hanging if the task doesn't respond to cancellation.
328 If initial cancellation times out, escalates by force-disconnecting the
329 transport to unblock the task, then retries cancellation (Finding 1 fix).
331 Args:
332 session_id: Session identifier whose task should be cancelled.
333 timeout: Maximum seconds to wait for task cancellation. Default 5.0.
334 """
335 task = self._respond_tasks.get(session_id)
336 if task is None:
337 return
339 if task.done():
340 # Task already finished - safe to remove from tracking
341 self._respond_tasks.pop(session_id, None)
342 try:
343 task.result()
344 except asyncio.CancelledError:
345 pass
346 except Exception as e:
347 logger.warning(f"Respond task for {session_id} failed with: {e}")
348 return
350 task.cancel()
351 logger.debug(f"Cancelling respond task for session {session_id}")
353 try:
354 await asyncio.wait_for(task, timeout=timeout)
355 # Cancellation successful - remove from tracking
356 self._respond_tasks.pop(session_id, None)
357 logger.debug(f"Respond task cancelled for session {session_id}")
358 except asyncio.TimeoutError:
359 # ESCALATION (Finding 1): Force-disconnect transport to unblock the task
360 logger.warning(f"Respond task cancellation timed out for {session_id}, " f"escalating with transport disconnect")
362 # Force-disconnect the transport to unblock any pending I/O
363 transport = self._sessions.get(session_id)
364 if transport and hasattr(transport, "disconnect"):
365 try:
366 await transport.disconnect()
367 logger.debug(f"Force-disconnected transport for {session_id}")
368 except Exception as e:
369 logger.warning(f"Failed to force-disconnect transport for {session_id}: {e}")
371 # Retry cancellation with shorter timeout
372 if not task.done():
373 try:
374 await asyncio.wait_for(task, timeout=2.0)
375 self._respond_tasks.pop(session_id, None)
376 logger.info(f"Respond task cancelled after escalation for {session_id}")
377 except asyncio.TimeoutError:
378 # Still stuck - move to stuck_tasks for monitoring (Finding 2 fix)
379 self._respond_tasks.pop(session_id, None)
380 self._stuck_tasks[session_id] = task
381 logger.error(f"Respond task for {session_id} still stuck after escalation, " f"moved to stuck_tasks for monitoring (total stuck: {len(self._stuck_tasks)})")
382 except asyncio.CancelledError:
383 self._respond_tasks.pop(session_id, None)
384 logger.info(f"Respond task cancelled after escalation for {session_id}")
385 except Exception as e:
386 self._respond_tasks.pop(session_id, None)
387 logger.warning(f"Error during retry cancellation for {session_id}: {e}")
388 else:
389 self._respond_tasks.pop(session_id, None)
390 logger.debug(f"Respond task completed during escalation for {session_id}")
392 except asyncio.CancelledError:
393 # Cancellation successful - remove from tracking
394 self._respond_tasks.pop(session_id, None)
395 logger.debug(f"Respond task cancelled for session {session_id}")
396 except Exception as e:
397 # Remove from tracking on unexpected error - task state unknown
398 self._respond_tasks.pop(session_id, None)
399 logger.warning(f"Error during respond task cancellation for {session_id}: {e}")
401 async def _reap_stuck_tasks(self) -> None:
402 """Periodically clean up stuck tasks that have completed.
404 This reaper runs every 30 seconds and:
405 1. Removes completed tasks from _stuck_tasks
406 2. Retries cancellation for tasks that are still running
407 3. Logs warnings for tasks that remain stuck
409 This prevents memory leaks from tasks that eventually complete after
410 being moved to _stuck_tasks during escalation.
412 Raises:
413 asyncio.CancelledError: If the task is cancelled during shutdown.
414 """
415 reap_interval = 30.0 # seconds
416 retry_timeout = 2.0 # seconds for retry cancellation
418 while True:
419 try:
420 await asyncio.sleep(reap_interval)
422 if not self._stuck_tasks:
423 continue
425 # Collect completed and still-stuck tasks
426 completed = []
427 still_stuck = []
429 for session_id, task in list(self._stuck_tasks.items()):
430 if task.done():
431 completed.append(session_id)
432 try:
433 task.result() # Consume result to avoid warnings
434 except (asyncio.CancelledError, Exception):
435 pass
436 else:
437 still_stuck.append((session_id, task))
439 # Remove completed tasks
440 for session_id in completed:
441 self._stuck_tasks.pop(session_id, None)
443 if completed:
444 logger.info(f"Reaped {len(completed)} completed stuck tasks")
446 # Retry cancellation for still-stuck tasks
447 for session_id, task in still_stuck:
448 task.cancel()
449 try:
450 await asyncio.wait_for(task, timeout=retry_timeout)
451 self._stuck_tasks.pop(session_id, None)
452 logger.info(f"Stuck task {session_id} finally cancelled during reap")
453 except asyncio.TimeoutError:
454 logger.warning(f"Task {session_id} still stuck after reap retry")
455 except asyncio.CancelledError:
456 self._stuck_tasks.pop(session_id, None)
457 logger.info(f"Stuck task {session_id} cancelled during reap")
458 except Exception as e:
459 logger.warning(f"Error during stuck task reap for {session_id}: {e}")
461 if self._stuck_tasks:
462 logger.warning(f"Stuck tasks remaining: {len(self._stuck_tasks)}")
464 except asyncio.CancelledError:
465 logger.debug("Stuck task reaper cancelled")
466 raise
467 except Exception as e:
468 logger.error(f"Error in stuck task reaper: {e}")
470 async def initialize(self) -> None:
471 """Initialize the registry with async setup.
473 This method performs asynchronous initialization tasks that cannot be done
474 in __init__. It starts background cleanup tasks and sets up pubsub
475 subscriptions for distributed backends.
477 Call this during application startup after creating the registry instance.
479 Examples:
480 >>> import asyncio
481 >>> reg = SessionRegistry(backend='memory')
482 >>> asyncio.run(reg.initialize())
483 >>> reg._cleanup_task is not None
484 True
485 >>>
486 >>> # Cleanup
487 >>> asyncio.run(reg.shutdown())
488 """
489 logger.info(f"Initializing session registry with backend: {self._backend}")
491 if self._backend == "database":
492 # Start database cleanup task
493 self._cleanup_task = asyncio.create_task(self._db_cleanup_task())
494 logger.info("Database cleanup task started")
496 elif self._backend == "redis":
497 # Get shared Redis client from factory
498 self._redis = await get_redis_client()
499 if self._redis:
500 self._pubsub = self._redis.pubsub()
501 await self._pubsub.subscribe("mcp_session_events")
502 logger.info("Session registry connected to shared Redis client")
504 elif self._backend == "none":
505 # Nothing to initialize for none backend
506 pass
508 # Memory backend needs session cleanup
509 elif self._backend == "memory":
510 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task())
511 logger.info("Memory cleanup task started")
513 # Start stuck task reaper for all backends
514 self._stuck_task_reaper = asyncio.create_task(self._reap_stuck_tasks())
515 logger.info("Stuck task reaper started")
517 async def shutdown(self) -> None:
518 """Shutdown the registry and clean up resources.
520 This method cancels background tasks and closes connections to external
521 services. Call this during application shutdown to ensure clean termination.
523 Examples:
524 >>> import asyncio
525 >>> reg = SessionRegistry()
526 >>> asyncio.run(reg.initialize())
527 >>> task_was_created = reg._cleanup_task is not None
528 >>> asyncio.run(reg.shutdown())
529 >>> # After shutdown, cleanup task should be handled (cancelled or done)
530 >>> task_was_created and (reg._cleanup_task.cancelled() or reg._cleanup_task.done())
531 True
532 """
533 logger.info("Shutting down session registry")
535 # Cancel cleanup task
536 if self._cleanup_task:
537 self._cleanup_task.cancel()
538 try:
539 await self._cleanup_task
540 except asyncio.CancelledError:
541 pass
543 # Cancel stuck task reaper
544 if self._stuck_task_reaper:
545 self._stuck_task_reaper.cancel()
546 try:
547 await self._stuck_task_reaper
548 except asyncio.CancelledError:
549 pass
551 # CRITICAL: Cancel ALL respond tasks to prevent CPU spin loops
552 if self._respond_tasks:
553 logger.info(f"Cancelling {len(self._respond_tasks)} respond tasks")
554 tasks_to_cancel = list(self._respond_tasks.values())
555 self._respond_tasks.clear()
557 for task in tasks_to_cancel:
558 if not task.done():
559 task.cancel()
561 if tasks_to_cancel:
562 try:
563 await asyncio.wait_for(asyncio.gather(*tasks_to_cancel, return_exceptions=True), timeout=10.0)
564 logger.info("All respond tasks cancelled successfully")
565 except asyncio.TimeoutError:
566 logger.warning("Timeout waiting for respond tasks to cancel")
568 # Also cancel any stuck tasks (tasks that previously couldn't be cancelled)
569 if self._stuck_tasks:
570 logger.warning(f"Attempting final cancellation of {len(self._stuck_tasks)} stuck tasks")
571 stuck_to_cancel = list(self._stuck_tasks.values())
572 self._stuck_tasks.clear()
574 for task in stuck_to_cancel:
575 if not task.done():
576 task.cancel()
578 if stuck_to_cancel:
579 try:
580 await asyncio.wait_for(asyncio.gather(*stuck_to_cancel, return_exceptions=True), timeout=5.0)
581 logger.info("Stuck tasks cancelled during shutdown")
582 except asyncio.TimeoutError:
583 logger.error("Some stuck tasks could not be cancelled during shutdown")
585 # Close Redis pubsub (but not the shared client)
586 # Use timeout to prevent blocking if pubsub doesn't close cleanly
587 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout
588 if self._backend == "redis" and getattr(self, "_pubsub", None):
589 try:
590 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
591 except asyncio.TimeoutError:
592 logger.warning("Redis pubsub close timed out - proceeding anyway")
593 except Exception as e:
594 logger.error(f"Error closing Redis pubsub: {e}")
595 # Don't close self._redis - it's the shared client managed by redis_client.py
596 self._redis = None
597 self._pubsub = None
599 async def add_session(self, session_id: str, transport: SSETransport) -> None:
600 """Add a session to the registry.
602 Stores the session in both the local cache and the distributed backend
603 (if configured). For distributed backends, this notifies other workers
604 about the new session.
606 Args:
607 session_id: Unique session identifier. Should be a UUID or similar
608 unique string to avoid collisions.
609 transport: SSE transport object for this session. Must implement
610 the SSETransport interface.
612 Examples:
613 >>> import asyncio
614 >>> from mcpgateway.cache.session_registry import SessionRegistry
615 >>>
616 >>> class MockTransport:
617 ... async def disconnect(self):
618 ... print(f"Transport disconnected")
619 ... async def is_connected(self):
620 ... return True
621 >>>
622 >>> reg = SessionRegistry()
623 >>> transport = MockTransport()
624 >>> asyncio.run(reg.add_session('test-456', transport))
625 >>>
626 >>> # Found in local cache
627 >>> found = asyncio.run(reg.get_session('test-456'))
628 >>> found is transport
629 True
630 >>>
631 >>> # Remove session
632 >>> asyncio.run(reg.remove_session('test-456'))
633 Transport disconnected
634 """
635 # Skip for none backend
636 if self._backend == "none":
637 return
639 async with self._lock:
640 self._sessions[session_id] = transport
642 if self._backend == "redis":
643 # Store session marker in Redis
644 if not self._redis:
645 logger.warning(f"Redis client not initialized, skipping distributed session tracking for {session_id}")
646 return
647 try:
648 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1")
649 # Publish event to notify other workers
650 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()}))
651 except Exception as e:
652 logger.error(f"Redis error adding session {session_id}: {e}")
654 elif self._backend == "database":
655 # Store session in database
656 try:
658 def _db_add() -> None:
659 """Store session record in the database.
661 Creates a new SessionRecord entry in the database for tracking
662 distributed session state. Uses a fresh database connection from
663 the connection pool.
665 This inner function is designed to be run in a thread executor
666 to avoid blocking the async event loop during database I/O.
668 Raises:
669 Exception: Any database error is re-raised after rollback.
670 Common errors include duplicate session_id (unique constraint)
671 or database connection issues.
673 Examples:
674 >>> # This function is called internally by add_session()
675 >>> # When executed, it creates a database record:
676 >>> # SessionRecord(session_id='abc123', created_at=now())
677 """
678 db_session = next(get_db())
679 try:
680 session_record = SessionRecord(session_id=session_id)
681 db_session.add(session_record)
682 db_session.commit()
683 except Exception as ex:
684 db_session.rollback()
685 raise ex
686 finally:
687 db_session.close()
689 await asyncio.to_thread(_db_add)
690 except Exception as e:
691 logger.error(f"Database error adding session {session_id}: {e}")
693 logger.info(f"Added session: {session_id}")
695 def _session_owner_key(self, session_id: str) -> str:
696 """Return Redis key used to store session ownership.
698 Args:
699 session_id: Session identifier.
701 Returns:
702 Redis key for the session owner mapping.
703 """
704 return f"mcp:session_owner:{session_id}"
706 async def set_session_owner(self, session_id: str, owner_email: Optional[str]) -> None:
707 """Set or clear owner for a session.
709 Args:
710 session_id: Session identifier to update.
711 owner_email: Owner email to set. Passing ``None`` clears ownership.
713 Returns:
714 None.
715 """
716 # Skip for none backend
717 if self._backend == "none":
718 return
720 if owner_email:
721 self._session_owners[session_id] = owner_email
722 else:
723 self._session_owners.pop(session_id, None)
725 if self._backend == "redis":
726 if not self._redis:
727 logger.warning(f"Redis client not initialized, cannot set owner for session {session_id}")
728 return
729 try:
730 owner_key = self._session_owner_key(session_id)
731 if owner_email:
732 await self._redis.setex(owner_key, self._session_ttl, owner_email)
733 else:
734 await self._redis.delete(owner_key)
735 except Exception as e:
736 logger.error(f"Redis error setting owner for session {session_id}: {e}")
738 elif self._backend == "database":
739 try:
741 def _db_set_owner() -> None:
742 """Persist owner metadata for a session in the database backend.
744 Raises:
745 Exception: Propagates database write failures to the caller.
746 """
747 db_session = next(get_db())
748 try:
749 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
750 if not record:
751 record = SessionRecord(session_id=session_id, data=None)
752 db_session.add(record)
753 db_session.flush()
755 record_data: Dict[str, Any] = {}
756 if record.data:
757 try:
758 parsed = orjson.loads(record.data)
759 if isinstance(parsed, dict):
760 record_data = parsed
761 except Exception:
762 record_data = {}
764 if owner_email:
765 record_data["owner_email"] = owner_email
766 else:
767 record_data.pop("owner_email", None)
769 record.data = orjson.dumps(record_data).decode() if record_data else None
770 db_session.commit()
771 except Exception as ex:
772 db_session.rollback()
773 raise ex
774 finally:
775 db_session.close()
777 await asyncio.to_thread(_db_set_owner)
778 except Exception as e:
779 logger.error(f"Database error setting owner for session {session_id}: {e}")
781 async def claim_session_owner(self, session_id: str, owner_email: str) -> Optional[str]:
782 """Atomically claim ownership for a session and return the effective owner.
784 This method provides compare-and-set semantics:
785 - If a session owner already exists, return the existing owner.
786 - If no owner exists, claim ownership for ``owner_email``.
788 Args:
789 session_id: Session identifier to claim.
790 owner_email: Requesting owner email.
792 Returns:
793 Effective owner email after the claim operation, or ``None`` if owner
794 metadata could not be verified due to backend availability issues.
795 """
796 if self._backend == "none":
797 return owner_email
799 # Fast local cache path.
800 cached_owner = self._session_owners.get(session_id)
801 if cached_owner:
802 return cached_owner
804 if self._backend == "memory":
805 async with self._lock:
806 existing_owner = self._session_owners.get(session_id)
807 if existing_owner:
808 return existing_owner
809 self._session_owners[session_id] = owner_email
810 return owner_email
812 if self._backend == "redis":
813 if not self._redis:
814 logger.warning("Redis client not initialized, session owner claim unavailable for %s", session_id)
815 return None
817 owner_key = self._session_owner_key(session_id)
818 try:
819 claimed = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True)
820 if claimed:
821 self._session_owners[session_id] = owner_email
822 return owner_email
824 owner_raw = await self._redis.get(owner_key)
825 if owner_raw is not None:
826 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw)
827 if existing_owner:
828 self._session_owners[session_id] = existing_owner
829 return existing_owner
831 # Handle key expiry/race window by retrying once.
832 claimed_retry = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True)
833 if claimed_retry:
834 self._session_owners[session_id] = owner_email
835 return owner_email
837 owner_raw = await self._redis.get(owner_key)
838 if owner_raw is None:
839 return None
840 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw)
841 if existing_owner:
842 self._session_owners[session_id] = existing_owner
843 return existing_owner
844 return None
845 except Exception as e:
846 logger.error("Redis error claiming owner for session %s: %s", session_id, e)
847 return None
849 if self._backend == "database":
850 try:
852 def _db_claim_owner() -> Optional[str]:
853 """Claim owner in DB with optimistic compare-and-set retries.
855 Returns:
856 Effective owner email or ``None`` when claim cannot be verified.
857 """
858 db_session = next(get_db())
859 try:
860 for _attempt in range(3):
861 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
862 if not record:
863 owner_payload = {"owner_email": owner_email}
864 new_record = SessionRecord(session_id=session_id, data=orjson.dumps(owner_payload).decode())
865 db_session.add(new_record)
866 try:
867 db_session.commit()
868 return owner_email
869 except Exception:
870 db_session.rollback()
871 # Another writer may have inserted concurrently. Retry.
872 continue
874 record_data: Dict[str, Any] = {}
875 if record.data:
876 try:
877 parsed = orjson.loads(record.data)
878 if isinstance(parsed, dict):
879 record_data = parsed
880 except Exception:
881 record_data = {}
883 existing_owner = record_data.get("owner_email")
884 if isinstance(existing_owner, str) and existing_owner:
885 return existing_owner
887 current_data = record.data
888 updated_data = dict(record_data)
889 updated_data["owner_email"] = owner_email
890 serialized = orjson.dumps(updated_data).decode()
892 update_query = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id)
893 if current_data is None:
894 update_query = update_query.filter(SessionRecord.data.is_(None))
895 else:
896 update_query = update_query.filter(SessionRecord.data == current_data)
898 updated_rows = update_query.update({"data": serialized}, synchronize_session=False)
899 if updated_rows == 1:
900 db_session.commit()
901 return owner_email
903 db_session.rollback()
905 return None
906 finally:
907 db_session.close()
909 claimed_owner = await asyncio.to_thread(_db_claim_owner)
910 if claimed_owner:
911 self._session_owners[session_id] = claimed_owner
912 return claimed_owner
913 except Exception as e:
914 logger.error("Database error claiming owner for session %s: %s", session_id, e)
915 return None
917 return None
919 async def get_session_owner(self, session_id: str) -> Optional[str]:
920 """Get owner email for a session.
922 Args:
923 session_id: Session identifier to resolve.
925 Returns:
926 Owner email when present, otherwise ``None``.
927 """
928 owner = self._session_owners.get(session_id)
929 if owner:
930 return owner
932 if self._backend == "redis":
933 if not self._redis:
934 return None
935 try:
936 owner_raw = await self._redis.get(self._session_owner_key(session_id))
937 if owner_raw is None:
938 return None
939 if isinstance(owner_raw, bytes):
940 owner = owner_raw.decode()
941 else:
942 owner = str(owner_raw)
943 if owner:
944 self._session_owners[session_id] = owner
945 return owner or None
946 except Exception as e:
947 logger.error(f"Redis error getting owner for session {session_id}: {e}")
948 return None
950 if self._backend == "database":
951 try:
953 def _db_get_owner() -> Optional[str]:
954 db_session = next(get_db())
955 try:
956 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
957 if not record or not record.data:
958 return None
959 try:
960 data = orjson.loads(record.data)
961 except Exception:
962 return None
963 if not isinstance(data, dict):
964 return None
965 owner_email = data.get("owner_email")
966 return owner_email if isinstance(owner_email, str) and owner_email else None
967 finally:
968 db_session.close()
970 owner = await asyncio.to_thread(_db_get_owner)
971 if owner:
972 self._session_owners[session_id] = owner
973 return owner
974 except Exception as e:
975 logger.error(f"Database error getting owner for session {session_id}: {e}")
976 return None
978 return None
980 async def session_exists(self, session_id: str) -> Optional[bool]:
981 """Return whether a session marker exists.
983 Args:
984 session_id: Session identifier to resolve.
986 Returns:
987 ``True`` when session exists, ``False`` when session does not exist,
988 and ``None`` when existence cannot be verified due to backend errors.
989 """
990 if self._backend == "none":
991 return False
993 async with self._lock:
994 if session_id in self._sessions:
995 return True
997 if self._backend == "memory":
998 return False
1000 if self._backend == "redis":
1001 if not self._redis:
1002 return None
1003 try:
1004 return bool(await self._redis.exists(f"mcp:session:{session_id}"))
1005 except Exception as e:
1006 logger.error("Redis error checking existence for session %s: %s", session_id, e)
1007 return None
1009 if self._backend == "database":
1010 try:
1012 def _db_exists() -> bool:
1013 """Check whether a session record exists in the database backend.
1015 Returns:
1016 ``True`` when a matching session record exists.
1017 """
1018 db_session = next(get_db())
1019 try:
1020 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1021 return record is not None
1022 finally:
1023 db_session.close()
1025 return await asyncio.to_thread(_db_exists)
1026 except Exception as e:
1027 logger.error("Database error checking existence for session %s: %s", session_id, e)
1028 return None
1030 return False
1032 async def get_session(self, session_id: str) -> Any:
1033 """Get session transport by ID.
1035 First checks the local cache for the transport object. If not found locally
1036 but using a distributed backend, checks if the session exists on another
1037 worker.
1039 Args:
1040 session_id: Session identifier to look up.
1042 Returns:
1043 SSETransport object if found locally, None if not found or exists
1044 on another worker.
1046 Examples:
1047 >>> import asyncio
1048 >>> from mcpgateway.cache.session_registry import SessionRegistry
1049 >>>
1050 >>> class MockTransport:
1051 ... pass
1052 >>>
1053 >>> reg = SessionRegistry()
1054 >>> transport = MockTransport()
1055 >>> asyncio.run(reg.add_session('test-456', transport))
1056 >>>
1057 >>> # Found in local cache
1058 >>> found = asyncio.run(reg.get_session('test-456'))
1059 >>> found is transport
1060 True
1061 >>>
1062 >>> # Not found
1063 >>> asyncio.run(reg.get_session('nonexistent')) is None
1064 True
1065 """
1066 # Skip for none backend
1067 if self._backend == "none":
1068 return None
1070 # First check local cache
1071 async with self._lock:
1072 transport = self._sessions.get(session_id)
1073 if transport:
1074 logger.info(f"Session {session_id} exists in local cache")
1075 return transport
1077 # If not in local cache, check if it exists in shared backend
1078 if self._backend == "redis":
1079 if not self._redis:
1080 return None
1081 try:
1082 exists = await self._redis.exists(f"mcp:session:{session_id}")
1083 session_exists = bool(exists)
1084 if session_exists:
1085 logger.info(f"Session {session_id} exists in Redis but not in local cache")
1086 return None # We don't have the transport locally
1087 except Exception as e:
1088 logger.error(f"Redis error checking session {session_id}: {e}")
1089 return None
1091 elif self._backend == "database":
1092 try:
1094 def _db_check() -> bool:
1095 """Check if a session exists in the database.
1097 Queries the SessionRecord table to determine if a session with
1098 the given session_id exists. This is used when the session is not
1099 found in the local cache to check if it exists on another worker.
1101 This inner function is designed to be run in a thread executor
1102 to avoid blocking the async event loop during database queries.
1104 Returns:
1105 bool: True if the session exists in the database, False otherwise.
1107 Examples:
1108 >>> # This function is called internally by get_session()
1109 >>> # Returns True if SessionRecord with session_id exists
1110 >>> # Returns False if no matching record found
1111 """
1112 db_session = next(get_db())
1113 try:
1114 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1115 return record is not None
1116 finally:
1117 db_session.close()
1119 exists = await asyncio.to_thread(_db_check)
1120 if exists:
1121 logger.info(f"Session {session_id} exists in database but not in local cache")
1122 return None
1123 except Exception as e:
1124 logger.error(f"Database error checking session {session_id}: {e}")
1125 return None
1127 return None
1129 async def remove_session(self, session_id: str) -> None:
1130 """Remove a session from the registry.
1132 Removes the session from both local cache and distributed backend.
1133 If a transport is found locally, it will be disconnected before removal.
1134 For distributed backends, notifies other workers about the removal.
1136 Args:
1137 session_id: Session identifier to remove.
1139 Examples:
1140 >>> import asyncio
1141 >>> from mcpgateway.cache.session_registry import SessionRegistry
1142 >>>
1143 >>> class MockTransport:
1144 ... async def disconnect(self):
1145 ... print(f"Transport disconnected")
1146 ... async def is_connected(self):
1147 ... return True
1148 >>>
1149 >>> reg = SessionRegistry()
1150 >>> transport = MockTransport()
1151 >>> asyncio.run(reg.add_session('remove-test', transport))
1152 >>> asyncio.run(reg.remove_session('remove-test'))
1153 Transport disconnected
1154 >>>
1155 >>> # Session no longer exists
1156 >>> asyncio.run(reg.get_session('remove-test')) is None
1157 True
1158 """
1159 # Skip for none backend
1160 if self._backend == "none":
1161 return
1163 # Mark session as closing FIRST so respond loop can exit early
1164 # This allows the loop to exit without waiting for cancellation to complete
1165 self._closing_sessions.add(session_id)
1167 try:
1168 # CRITICAL: Cancel respond task before any cleanup
1169 # This prevents orphaned tasks that cause CPU spin loops
1170 await self._cancel_respond_task(session_id)
1172 # Clean up local transport
1173 transport = None
1174 async with self._lock:
1175 if session_id in self._sessions:
1176 transport = self._sessions.pop(session_id)
1177 self._session_owners.pop(session_id, None)
1178 # Also clean up client capabilities
1179 if session_id in self._client_capabilities:
1180 self._client_capabilities.pop(session_id)
1181 logger.debug(f"Removed capabilities for session {session_id}")
1182 finally:
1183 # Always remove from closing set
1184 self._closing_sessions.discard(session_id)
1186 # Disconnect transport if found
1187 if transport:
1188 try:
1189 await transport.disconnect()
1190 except Exception as e:
1191 logger.error(f"Error disconnecting transport for session {session_id}: {e}")
1193 # Remove from shared backend
1194 if self._backend == "redis":
1195 if not self._redis:
1196 return
1197 try:
1198 await self._redis.delete(f"mcp:session:{session_id}")
1199 await self._redis.delete(self._session_owner_key(session_id))
1200 # Notify other workers
1201 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()}))
1202 except Exception as e:
1203 logger.error(f"Redis error removing session {session_id}: {e}")
1205 elif self._backend == "database":
1206 try:
1208 def _db_remove() -> None:
1209 """Delete session record from the database.
1211 Removes the SessionRecord entry with the specified session_id
1212 from the database. This is called when a session is being
1213 terminated or has expired.
1215 This inner function is designed to be run in a thread executor
1216 to avoid blocking the async event loop during database operations.
1218 Raises:
1219 Exception: Any database error is re-raised after rollback.
1220 This includes connection errors or constraint violations.
1222 Examples:
1223 >>> # This function is called internally by remove_session()
1224 >>> # Deletes the SessionRecord where session_id matches
1225 >>> # No error if session_id doesn't exist (idempotent)
1226 """
1227 db_session = next(get_db())
1228 try:
1229 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete()
1230 db_session.commit()
1231 except Exception as ex:
1232 db_session.rollback()
1233 raise ex
1234 finally:
1235 db_session.close()
1237 await asyncio.to_thread(_db_remove)
1238 except Exception as e:
1239 logger.error(f"Database error removing session {session_id}: {e}")
1241 logger.info(f"Removed session: {session_id}")
1243 async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None:
1244 """Broadcast a message to a session.
1246 Sends a message to the specified session. The behavior depends on the backend:
1248 - **memory**: Stores message temporarily for local delivery
1249 - **redis**: Publishes message to Redis channel for the session
1250 - **database**: Stores message in database for polling by worker with session
1251 - **none**: No operation
1253 This method is used for inter-process communication in distributed deployments.
1255 Args:
1256 session_id: Target session identifier.
1257 message: Message to broadcast. Can be a dict, list, or any JSON-serializable object.
1259 Examples:
1260 >>> import asyncio
1261 >>> from mcpgateway.cache.session_registry import SessionRegistry
1262 >>>
1263 >>> reg = SessionRegistry(backend='memory')
1264 >>> message = {'method': 'tools/list', 'id': 1}
1265 >>> asyncio.run(reg.broadcast('session-789', message))
1266 >>>
1267 >>> # Message stored for memory backend
1268 >>> reg._session_message is not None
1269 True
1270 >>> reg._session_message['session_id']
1271 'session-789'
1272 >>> orjson.loads(reg._session_message['message'])['message'] == message
1273 True
1274 """
1275 # Skip for none backend only
1276 if self._backend == "none":
1277 return
1279 def _build_payload(msg: Any) -> str:
1280 """Build a JSON payload for message broadcasting.
1282 Args:
1283 msg: Message to wrap in payload envelope.
1285 Returns:
1286 JSON-encoded string containing type, message, and timestamp.
1287 """
1288 payload = {"type": "message", "message": msg, "timestamp": time.time()}
1289 return orjson.dumps(payload).decode()
1291 if self._backend == "memory":
1292 payload_json = _build_payload(message)
1293 self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": payload_json}
1295 elif self._backend == "redis":
1296 if not self._redis:
1297 logger.warning(f"Redis client not initialized, cannot broadcast to {session_id}")
1298 return
1299 try:
1300 broadcast_payload = {
1301 "type": "message",
1302 "message": message, # Keep as original type, not pre-encoded
1303 "timestamp": time.time(),
1304 }
1305 # Single encode
1306 payload_json = orjson.dumps(broadcast_payload)
1307 await self._redis.publish(session_id, payload_json) # Single encode
1308 except Exception as e:
1309 logger.error(f"Redis error during broadcast: {e}")
1310 elif self._backend == "database":
1311 try:
1312 msg_json = _build_payload(message)
1314 def _db_add() -> None:
1315 """Store message in the database for inter-process communication.
1317 Creates a new SessionMessageRecord entry containing the session_id
1318 and serialized message. This enables message passing between
1319 different worker processes through the shared database.
1321 This inner function is designed to be run in a thread executor
1322 to avoid blocking the async event loop during database writes.
1324 Raises:
1325 Exception: Any database error is re-raised after rollback.
1326 Common errors include database connection issues or
1327 constraints violations.
1329 Examples:
1330 >>> # This function is called internally by broadcast()
1331 >>> # Creates a record like:
1332 >>> # SessionMessageRecord(
1333 >>> # session_id='abc123',
1334 >>> # message='{"method": "ping", "id": 1}',
1335 >>> # created_at=now()
1336 >>> # )
1337 """
1338 db_session = next(get_db())
1339 try:
1340 message_record = SessionMessageRecord(session_id=session_id, message=msg_json)
1341 db_session.add(message_record)
1342 db_session.commit()
1343 except Exception as ex:
1344 db_session.rollback()
1345 raise ex
1346 finally:
1347 db_session.close()
1349 await asyncio.to_thread(_db_add)
1350 except Exception as e:
1351 logger.error(f"Database error during broadcast: {e}")
1353 async def _register_session_mapping(self, session_id: str, message: Dict[str, Any], user_email: Optional[str] = None) -> None:
1354 """Register session mapping for session affinity when tools are called.
1356 This method is called on the worker that executes the request (the SSE session
1357 owner) to pre-register the mapping between a downstream session ID and the
1358 upstream MCP session pool key. This enables session affinity in multi-worker
1359 deployments.
1361 Only registers mappings for tools/call methods - list operations and other
1362 methods don't need session affinity since they don't maintain state.
1364 Args:
1365 session_id: The downstream SSE session ID.
1366 message: The MCP protocol message being broadcast.
1367 user_email: Optional user email for session isolation.
1368 """
1369 # Skip if session affinity is disabled
1370 if not settings.mcpgateway_session_affinity_enabled:
1371 return
1373 # Only register for tools/call - other methods don't need session affinity
1374 method = message.get("method")
1375 if method != "tools/call":
1376 return
1378 # Extract tool name from params
1379 params = message.get("params", {})
1380 tool_name = params.get("name")
1381 if not tool_name:
1382 return
1384 try:
1385 # Look up tool in cache to get gateway info
1386 # First-Party
1387 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
1389 tool_info = await tool_lookup_cache.get(tool_name)
1390 if not tool_info:
1391 logger.debug(f"Tool {tool_name} not found in cache, skipping session mapping registration")
1392 return
1394 # Extract gateway information
1395 gateway = tool_info.get("gateway", {})
1396 gateway_url = gateway.get("url")
1397 gateway_id = gateway.get("id")
1398 transport = gateway.get("transport")
1400 if not gateway_url or not gateway_id or not transport:
1401 logger.debug(f"Incomplete gateway info for tool {tool_name}, skipping session mapping registration")
1402 return
1404 # Register the session mapping with the pool
1405 # First-Party
1406 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
1408 pool = get_mcp_session_pool()
1409 await pool.register_session_mapping(
1410 session_id,
1411 gateway_url,
1412 gateway_id,
1413 transport,
1414 user_email,
1415 )
1417 logger.debug(f"Registered session mapping for session {session_id[:8]}... -> {gateway_url} (tool: {tool_name})")
1419 except Exception as e:
1420 # Don't fail the broadcast if session mapping registration fails
1421 logger.warning(f"Failed to register session mapping for {session_id[:8]}...: {e}")
1423 async def get_all_session_ids(self) -> list[str]:
1424 """Return a snapshot list of all known local session IDs.
1426 Returns:
1427 list[str]: A snapshot list of currently known local session IDs.
1428 """
1429 async with self._lock:
1430 return list(self._sessions.keys())
1432 def get_session_sync(self, session_id: str) -> Any:
1433 """Get session synchronously from local cache only.
1435 This is a non-blocking method that only checks the local cache,
1436 not the distributed backend. Use this when you need quick access
1437 and know the session should be local.
1439 Args:
1440 session_id: Session identifier to look up.
1442 Returns:
1443 SSETransport object if found in local cache, None otherwise.
1445 Examples:
1446 >>> from mcpgateway.cache.session_registry import SessionRegistry
1447 >>> import asyncio
1448 >>>
1449 >>> class MockTransport:
1450 ... pass
1451 >>>
1452 >>> reg = SessionRegistry()
1453 >>> transport = MockTransport()
1454 >>> asyncio.run(reg.add_session('sync-test', transport))
1455 >>>
1456 >>> # Synchronous lookup
1457 >>> found = reg.get_session_sync('sync-test')
1458 >>> found is transport
1459 True
1460 >>>
1461 >>> # Not found
1462 >>> reg.get_session_sync('nonexistent') is None
1463 True
1464 """
1465 # Skip for none backend
1466 if self._backend == "none":
1467 return None
1469 return self._sessions.get(session_id)
1471 async def respond(
1472 self,
1473 server_id: Optional[str],
1474 user: Dict[str, Any],
1475 session_id: str,
1476 ) -> None:
1477 """Process and respond to broadcast messages for a session.
1479 This method listens for messages directed to the specified session and
1480 generates appropriate responses. The listening mechanism depends on the backend:
1482 - **memory**: Checks the temporary message storage
1483 - **redis**: Subscribes to Redis pubsub channel
1484 - **database**: Polls database for new messages
1486 When a message is received and the transport exists locally, it processes
1487 the message and sends the response through the transport.
1489 Args:
1490 server_id: Optional server identifier for scoped operations.
1491 user: User information including authentication token.
1492 session_id: Session identifier to respond for.
1494 Raises:
1495 asyncio.CancelledError: When the respond task is cancelled (e.g., on session removal).
1497 Examples:
1498 >>> import asyncio
1499 >>> from mcpgateway.cache.session_registry import SessionRegistry
1500 >>>
1501 >>> # This method is typically called internally by the SSE handler
1502 >>> reg = SessionRegistry()
1503 >>> user = {'token': 'test-token'}
1504 >>> # asyncio.run(reg.respond(None, user, 'session-id'))
1505 """
1507 if self._backend == "none":
1508 pass
1510 elif self._backend == "memory":
1511 transport = self.get_session_sync(session_id)
1512 if transport and self._session_message:
1513 message_json = self._session_message.get("message")
1514 if message_json:
1515 data = orjson.loads(message_json)
1516 if isinstance(data, dict) and "message" in data:
1517 message = data["message"]
1518 else:
1519 message = data
1520 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
1521 else:
1522 logger.warning(f"Session message stored but message content is None for session {session_id}")
1524 elif self._backend == "redis":
1525 if not self._redis:
1526 logger.warning(f"Redis client not initialized, cannot respond to {session_id}")
1527 return
1528 pubsub = self._redis.pubsub()
1529 await pubsub.subscribe(session_id)
1531 # Use timeout-based polling instead of infinite listen() to allow exit checks
1532 # This is critical for allowing cancellation to work (Finding 2)
1533 poll_timeout = 1.0 # Check every second if session still exists
1535 try:
1536 while True:
1537 # Check if session still exists or is closing - exit early
1538 if session_id not in self._sessions or session_id in self._closing_sessions:
1539 logger.info(f"Session {session_id} removed or closing, exiting Redis respond loop")
1540 break
1542 # Use get_message with timeout instead of blocking listen()
1543 try:
1544 msg = await asyncio.wait_for(
1545 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), timeout=poll_timeout + 0.5 # Slightly longer to account for Redis timeout
1546 )
1547 except asyncio.TimeoutError:
1548 # No message, loop back to check session existence
1549 continue
1551 if msg is None:
1552 # CRITICAL: Sleep to prevent tight loop when get_message returns immediately
1553 # This can happen in certain Redis states after disconnects
1554 await asyncio.sleep(0.1)
1555 continue
1556 if msg["type"] != "message":
1557 # Sleep on non-message types to prevent spin in edge cases
1558 await asyncio.sleep(0.1)
1559 continue
1561 data = orjson.loads(msg["data"])
1562 message = data.get("message", {})
1563 transport = self.get_session_sync(session_id)
1564 if transport:
1565 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
1566 except asyncio.CancelledError:
1567 logger.info(f"PubSub listener for session {session_id} cancelled")
1568 raise # Re-raise to properly complete cancellation
1569 except Exception as e:
1570 logger.error(f"PubSub listener error for session {session_id}: {e}")
1571 finally:
1572 # Pubsub cleanup first - use timeouts to prevent blocking
1573 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout
1574 try:
1575 await asyncio.wait_for(pubsub.unsubscribe(session_id), timeout=cleanup_timeout)
1576 except asyncio.TimeoutError:
1577 logger.debug(f"Pubsub unsubscribe timed out for session {session_id}")
1578 except Exception as e:
1579 logger.debug(f"Error unsubscribing pubsub for session {session_id}: {e}")
1580 try:
1581 try:
1582 await asyncio.wait_for(pubsub.aclose(), timeout=cleanup_timeout)
1583 except AttributeError:
1584 await asyncio.wait_for(pubsub.close(), timeout=cleanup_timeout)
1585 except asyncio.TimeoutError:
1586 logger.debug(f"Pubsub close timed out for session {session_id}")
1587 except Exception as e:
1588 logger.debug(f"Error closing pubsub for session {session_id}: {e}")
1589 logger.info(f"Cleaned up pubsub for session {session_id}")
1590 # Clean up task reference LAST (idempotent - may already be removed by _cancel_respond_task)
1591 self._respond_tasks.pop(session_id, None)
1593 elif self._backend == "database":
1595 def _db_read_session_and_message(
1596 session_id: str,
1597 ) -> tuple[SessionRecord | None, SessionMessageRecord | None]:
1598 """
1599 Check whether a session exists and retrieve its next pending message
1600 in a single database query.
1602 This function performs a LEFT OUTER JOIN between SessionRecord and
1603 SessionMessageRecord to determine:
1605 - Whether the session still exists
1606 - Whether there is a pending message for the session (FIFO order)
1608 It is used by the database-backed message polling loop to reduce
1609 database load by collapsing multiple reads into a single query.
1611 Messages are returned in FIFO order based on the message primary key.
1613 This function is designed to be run in a thread executor to avoid
1614 blocking the async event loop during database access.
1616 Args:
1617 session_id: The session identifier to look up.
1619 Returns:
1620 Tuple[SessionRecord | None, SessionMessageRecord | None]:
1622 - (None, None)
1623 The session does not exist.
1625 - (SessionRecord, None)
1626 The session exists but has no pending messages.
1628 - (SessionRecord, SessionMessageRecord)
1629 The session exists and has a pending message.
1631 Raises:
1632 Exception: Any database error is re-raised after rollback.
1634 Examples:
1635 >>> # This function is called internally by message_check_loop()
1636 >>> # Session exists and has a pending message
1637 >>> # Returns (SessionRecord, SessionMessageRecord)
1639 >>> # Session exists but has no pending messages
1640 >>> # Returns (SessionRecord, None)
1642 >>> # Session has been removed
1643 >>> # Returns (None, None)
1644 """
1645 db_session = next(get_db())
1646 try:
1647 result = (
1648 db_session.query(SessionRecord, SessionMessageRecord)
1649 .outerjoin(
1650 SessionMessageRecord,
1651 SessionMessageRecord.session_id == SessionRecord.session_id,
1652 )
1653 .filter(SessionRecord.session_id == session_id)
1654 .order_by(SessionMessageRecord.id.asc())
1655 .first()
1656 )
1657 if not result:
1658 return None, None
1659 session, message = result
1660 return session, message
1661 except Exception as ex:
1662 db_session.rollback()
1663 raise ex
1664 finally:
1665 db_session.close()
1667 def _db_remove(session_id: str, message: str) -> None:
1668 """Remove processed message from the database.
1670 Deletes a specific message record after it has been successfully
1671 processed and sent to the transport. This prevents duplicate
1672 message delivery.
1674 This inner function is designed to be run in a thread executor
1675 to avoid blocking the async event loop during database deletes.
1677 Args:
1678 session_id: The session identifier the message belongs to.
1679 message: The exact message content to remove (must match exactly).
1681 Raises:
1682 Exception: Any database error is re-raised after rollback.
1684 Examples:
1685 >>> # This function is called internally after message processing
1686 >>> # Deletes the specific SessionMessageRecord entry
1687 >>> # Log: "Removed message from mcp_messages table"
1688 """
1689 db_session = next(get_db())
1690 try:
1691 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete()
1692 db_session.commit()
1693 logger.info("Removed message from mcp_messages table")
1694 except Exception as ex:
1695 db_session.rollback()
1696 raise ex
1697 finally:
1698 db_session.close()
1700 async def message_check_loop(session_id: str) -> None:
1701 """
1702 Background task that polls the database for messages belonging to a session
1703 using adaptive polling with exponential backoff.
1705 The loop continues until the session is removed from the database.
1707 Behavior:
1708 - Starts with a fast polling interval for low-latency message delivery.
1709 - When no message is found, the polling interval increases exponentially
1710 (up to a configured maximum) to reduce database load.
1711 - When a message is received, the polling interval is immediately reset
1712 to the fast interval.
1713 - The loop exits as soon as the session no longer exists.
1715 Polling rules:
1716 - Message found → process message, reset polling interval.
1717 - No message → increase polling interval (backoff).
1718 - Session gone → stop polling immediately.
1720 Args:
1721 session_id (str): Unique identifier of the session to monitor.
1723 Raises:
1724 asyncio.CancelledError: When the polling loop is cancelled.
1726 Examples
1727 --------
1728 Adaptive backoff when no messages are present:
1730 >>> poll_interval = 0.1
1731 >>> backoff_factor = 1.5
1732 >>> max_interval = 5.0
1733 >>> poll_interval = min(poll_interval * backoff_factor, max_interval)
1734 >>> poll_interval
1735 0.15000000000000002
1737 Backoff continues until the maximum interval is reached:
1739 >>> poll_interval = 4.0
1740 >>> poll_interval = min(poll_interval * 1.5, 5.0)
1741 >>> poll_interval
1742 5.0
1744 Polling interval resets immediately when a message arrives:
1746 >>> poll_interval = 2.0
1747 >>> poll_interval = 0.1
1748 >>> poll_interval
1749 0.1
1751 Session termination stops polling:
1753 >>> session_exists = False
1754 >>> if not session_exists:
1755 ... "polling stopped"
1756 'polling stopped'
1757 """
1759 poll_interval = settings.poll_interval # start fast
1760 max_interval = settings.max_interval # cap at configured maximum
1761 backoff_factor = settings.backoff_factor
1762 try:
1763 while True:
1764 # Check if session is closing before querying DB
1765 if session_id in self._closing_sessions:
1766 logger.debug("Session %s closing, stopping poll loop early", session_id)
1767 break
1769 session, record = await asyncio.to_thread(_db_read_session_and_message, session_id)
1771 # session gone → stop polling
1772 if not session:
1773 logger.debug("Session %s no longer exists, stopping poll loop", session_id)
1774 break
1776 if record:
1777 poll_interval = settings.poll_interval # reset on activity
1779 data = orjson.loads(record.message)
1780 if isinstance(data, dict) and "message" in data:
1781 message = data["message"]
1782 else:
1783 message = data
1785 transport = self.get_session_sync(session_id)
1786 if transport:
1787 logger.info("Ready to respond")
1788 await self.generate_response(
1789 message=message,
1790 transport=transport,
1791 server_id=server_id,
1792 user=user,
1793 )
1795 await asyncio.to_thread(_db_remove, session_id, record.message)
1796 else:
1797 # no message → backoff
1798 # update polling interval with backoff factor
1799 poll_interval = min(poll_interval * backoff_factor, max_interval)
1801 await asyncio.sleep(poll_interval)
1802 except asyncio.CancelledError:
1803 logger.info(f"Message check loop cancelled for session {session_id}")
1804 raise # Re-raise to properly complete cancellation
1805 except Exception as e:
1806 logger.error(f"Message check loop error for session {session_id}: {e}")
1808 # CRITICAL: Await instead of fire-and-forget
1809 # This ensures CancelledError propagates from outer respond() task to inner loop
1810 # The outer task (registered from main.py) now runs until message_check_loop exits
1811 try:
1812 await message_check_loop(session_id)
1813 except asyncio.CancelledError:
1814 logger.info(f"Database respond cancelled for session {session_id}")
1815 raise
1816 finally:
1817 # Clean up task reference on ANY exit (normal, cancelled, or error)
1818 # Prevents stale done tasks from accumulating in _respond_tasks
1819 self._respond_tasks.pop(session_id, None)
1821 async def _refresh_redis_sessions(self) -> None:
1822 """Refresh TTLs for Redis sessions and clean up disconnected sessions.
1824 This internal method is used by the Redis backend to maintain session state.
1825 It checks all local sessions, refreshes TTLs for connected sessions, and
1826 removes disconnected ones.
1827 """
1828 if not self._redis:
1829 return
1830 try:
1831 # Check all local sessions
1832 local_transports = {}
1833 async with self._lock:
1834 local_transports = self._sessions.copy()
1836 for session_id, transport in local_transports.items():
1837 try:
1838 if await transport.is_connected():
1839 # Refresh TTL in Redis
1840 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl)
1841 else:
1842 # Remove disconnected session
1843 await self.remove_session(session_id)
1844 except Exception as e:
1845 logger.error(f"Error refreshing session {session_id}: {e}")
1847 except Exception as e:
1848 logger.error(f"Error in Redis session refresh: {e}")
1850 async def _db_cleanup_task(self) -> None:
1851 """Background task to clean up expired database sessions.
1853 Runs periodically (every 5 minutes) to remove expired sessions from the
1854 database and refresh timestamps for active sessions. This prevents the
1855 database from accumulating stale session records.
1857 The task also verifies that local sessions still exist in the database
1858 and removes them locally if they've been deleted elsewhere.
1860 Raises:
1861 asyncio.CancelledError: If the task is cancelled during shutdown.
1862 """
1863 logger.info("Starting database cleanup task")
1864 while True:
1865 try:
1866 # Clean up expired sessions every 5 minutes
1867 def _db_cleanup() -> int:
1868 """Remove expired sessions from the database.
1870 Deletes all SessionRecord entries that haven't been accessed
1871 within the session TTL period. Uses database-specific date
1872 arithmetic to calculate expiry time.
1874 This inner function is designed to be run in a thread executor
1875 to avoid blocking the async event loop during bulk deletes.
1877 Returns:
1878 int: Number of expired session records deleted.
1880 Raises:
1881 Exception: Any database error is re-raised after rollback.
1883 Examples:
1884 >>> # This function is called periodically by _db_cleanup_task()
1885 >>> # Deletes sessions older than session_ttl seconds
1886 >>> # Returns count of deleted records for logging
1887 >>> # Log: "Cleaned up 5 expired database sessions"
1888 """
1889 db_session = next(get_db())
1890 try:
1891 # Delete sessions that haven't been accessed for TTL seconds
1892 # Use Python datetime for database-agnostic expiry calculation
1893 expiry_time = datetime.now(timezone.utc) - timedelta(seconds=self._session_ttl)
1894 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete()
1895 db_session.commit()
1896 return result
1897 except Exception as ex:
1898 db_session.rollback()
1899 raise ex
1900 finally:
1901 db_session.close()
1903 deleted = await asyncio.to_thread(_db_cleanup)
1904 if deleted > 0:
1905 logger.info(f"Cleaned up {deleted} expired database sessions")
1907 # Check local sessions against database
1908 await self._cleanup_database_sessions()
1910 await asyncio.sleep(300) # Run every 5 minutes
1912 except asyncio.CancelledError:
1913 logger.info("Database cleanup task cancelled")
1914 raise
1915 except Exception as e:
1916 logger.error(f"Error in database cleanup task: {e}")
1917 await asyncio.sleep(600) # Sleep longer on error
1919 def _refresh_session_db(self, session_id: str) -> bool:
1920 """Update session's last accessed timestamp in the database.
1922 Refreshes the last_accessed field for an active session to
1923 prevent it from being cleaned up as expired. This is called
1924 periodically for all local sessions with active transports.
1926 Args:
1927 session_id: The session identifier to refresh.
1929 Returns:
1930 bool: True if the session was found and updated, False if not found.
1932 Raises:
1933 Exception: Any database error is re-raised after rollback.
1934 """
1935 db_session = next(get_db())
1936 try:
1937 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1938 if session:
1939 session.last_accessed = func.now() # pylint: disable=not-callable
1940 db_session.commit()
1941 return True
1942 return False
1943 except Exception as ex:
1944 db_session.rollback()
1945 raise ex
1946 finally:
1947 db_session.close()
1949 async def _cleanup_database_sessions(self, max_concurrent: int = 20) -> None:
1950 """Parallelize session cleanup with bounded concurrency.
1952 Checks connection status first (fast), then refreshes connected sessions
1953 in parallel using asyncio.gather() with a semaphore to limit concurrent
1954 DB operations and prevent resource exhaustion.
1956 Args:
1957 max_concurrent: Maximum number of concurrent DB refresh operations.
1958 Defaults to 20 to balance parallelism with resource usage.
1959 """
1960 async with self._lock:
1961 local_transports = self._sessions.copy()
1963 # Check connections first (fast)
1964 connected: list[str] = []
1965 for session_id, transport in local_transports.items():
1966 try:
1967 if not await transport.is_connected():
1968 await self.remove_session(session_id)
1969 else:
1970 connected.append(session_id)
1971 except Exception as e:
1972 # Only log error, don't remove session on transient errors
1973 logger.error(f"Error checking connection for session {session_id}: {e}")
1975 # Parallel refresh of connected sessions with bounded concurrency
1976 if connected:
1977 semaphore = asyncio.Semaphore(max_concurrent)
1979 async def bounded_refresh(session_id: str) -> bool:
1980 """Refresh session with semaphore-bounded concurrency.
1982 Args:
1983 session_id: The session ID to refresh.
1985 Returns:
1986 True if refresh succeeded, False otherwise.
1987 """
1988 async with semaphore:
1989 return await asyncio.to_thread(self._refresh_session_db, session_id)
1991 refresh_tasks = [bounded_refresh(session_id) for session_id in connected]
1992 results = await asyncio.gather(*refresh_tasks, return_exceptions=True)
1994 for session_id, result in zip(connected, results):
1995 try:
1996 if isinstance(result, Exception):
1997 # Only log error, don't remove session on transient DB errors
1998 logger.error(f"Error refreshing session {session_id}: {result}")
1999 elif not result:
2000 # Session no longer in database, remove locally
2001 await self.remove_session(session_id)
2002 except Exception as e:
2003 logger.error(f"Error processing refresh result for session {session_id}: {e}")
2005 async def _memory_cleanup_task(self) -> None:
2006 """Background task to clean up disconnected sessions in memory backend.
2008 Runs periodically (every minute) to check all local sessions and remove
2009 those that are no longer connected. This prevents memory leaks from
2010 accumulating disconnected transport objects.
2012 Raises:
2013 asyncio.CancelledError: If the task is cancelled during shutdown.
2014 """
2015 logger.info("Starting memory cleanup task")
2016 while True:
2017 try:
2018 # Check all local sessions
2019 local_transports = {}
2020 async with self._lock:
2021 local_transports = self._sessions.copy()
2023 for session_id, transport in local_transports.items():
2024 try:
2025 if not await transport.is_connected():
2026 await self.remove_session(session_id)
2027 except Exception as e:
2028 logger.error(f"Error checking session {session_id}: {e}")
2029 await self.remove_session(session_id)
2031 await asyncio.sleep(60) # Run every minute
2033 except asyncio.CancelledError:
2034 logger.info("Memory cleanup task cancelled")
2035 raise
2036 except Exception as e:
2037 logger.error(f"Error in memory cleanup task: {e}")
2038 await asyncio.sleep(300) # Sleep longer on error
2040 def _get_oauth_experimental_config(self, server_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
2041 """Query OAuth configuration for a server (synchronous, run in threadpool).
2043 This method queries the database for OAuth configuration and returns
2044 RFC 9728-safe fields for advertising in MCP capabilities.
2046 Args:
2047 server_id: The server ID to query OAuth configuration for.
2049 Returns:
2050 Dict with 'oauth' key containing safe OAuth config, or None if not configured.
2051 """
2052 # First-Party
2053 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
2054 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel
2056 db = SessionLocal()
2057 try:
2058 server = db.get(DbServer, server_id)
2059 if server and getattr(server, "oauth_enabled", False) and getattr(server, "oauth_config", None):
2060 # Filter oauth_config to RFC 9728-safe fields only (never expose secrets)
2061 oauth_config = server.oauth_config
2062 safe_oauth: Dict[str, Any] = {}
2064 # Extract authorization servers
2065 if oauth_config.get("authorization_servers"):
2066 safe_oauth["authorization_servers"] = oauth_config["authorization_servers"]
2067 elif oauth_config.get("authorization_server"):
2068 safe_oauth["authorization_servers"] = [oauth_config["authorization_server"]]
2070 # Extract scopes
2071 scopes = oauth_config.get("scopes_supported") or oauth_config.get("scopes")
2072 if scopes:
2073 safe_oauth["scopes_supported"] = scopes
2075 # Add bearer methods
2076 safe_oauth["bearer_methods_supported"] = oauth_config.get("bearer_methods_supported", ["header"])
2078 if safe_oauth.get("authorization_servers"):
2079 logger.debug(f"Advertising OAuth capability for server {server_id}")
2080 return {"oauth": safe_oauth}
2081 return None
2082 finally:
2083 db.close()
2085 # Handle initialize logic
2086 async def handle_initialize_logic(self, body: Dict[str, Any], session_id: Optional[str] = None, server_id: Optional[str] = None) -> InitializeResult:
2087 """Process MCP protocol initialization request.
2089 Validates the protocol version and returns server capabilities and information.
2090 This method implements the MCP (Model Context Protocol) initialization handshake.
2092 Args:
2093 body: Request body containing protocol_version and optional client_info.
2094 Expected keys: 'protocol_version' or 'protocolVersion', 'capabilities'.
2095 session_id: Optional session ID to associate client capabilities with.
2096 server_id: Optional server ID to query OAuth configuration for RFC 9728 support.
2098 Returns:
2099 InitializeResult containing protocol version, server capabilities, and server info.
2101 Raises:
2102 HTTPException: If protocol_version is missing (400 Bad Request with MCP error code -32002).
2104 Examples:
2105 >>> import asyncio
2106 >>> from mcpgateway.cache.session_registry import SessionRegistry
2107 >>>
2108 >>> reg = SessionRegistry()
2109 >>> body = {'protocol_version': '2025-06-18'}
2110 >>> result = asyncio.run(reg.handle_initialize_logic(body))
2111 >>> result.protocol_version
2112 '2025-06-18'
2113 >>> result.server_info.name
2114 'ContextForge'
2115 >>>
2116 >>> # Missing protocol version
2117 >>> try:
2118 ... asyncio.run(reg.handle_initialize_logic({}))
2119 ... except HTTPException as e:
2120 ... e.status_code
2121 400
2122 """
2123 protocol_version = body.get("protocol_version") or body.get("protocolVersion")
2124 client_capabilities = body.get("capabilities", {})
2125 # body.get("client_info") or body.get("clientInfo", {})
2127 if not protocol_version:
2128 raise HTTPException(
2129 status_code=status.HTTP_400_BAD_REQUEST,
2130 detail="Missing protocol version",
2131 headers={"MCP-Error-Code": "-32002"},
2132 )
2134 if protocol_version != settings.protocol_version:
2135 logger.warning(f"Using non default protocol version: {protocol_version}")
2137 # Store client capabilities if session_id provided
2138 if session_id and client_capabilities:
2139 await self.store_client_capabilities(session_id, client_capabilities)
2140 logger.debug(f"Stored capabilities for session {session_id}: {client_capabilities}")
2142 # Build experimental capabilities (including OAuth if configured)
2143 experimental: Optional[Dict[str, Dict[str, Any]]] = None
2145 # Query OAuth configuration if server_id is provided
2146 if server_id:
2147 try:
2148 # Run synchronous DB query in threadpool to avoid blocking the event loop
2149 experimental = await asyncio.to_thread(self._get_oauth_experimental_config, server_id)
2150 except Exception as e:
2151 logger.warning(f"Failed to query OAuth config for server {server_id}: {e}")
2153 return InitializeResult(
2154 protocolVersion=protocol_version,
2155 capabilities=ServerCapabilities(
2156 prompts={"listChanged": True},
2157 resources={"subscribe": True, "listChanged": True},
2158 tools={"listChanged": True},
2159 logging={},
2160 completions={}, # Advertise completions capability per MCP spec
2161 experimental=experimental, # OAuth capability when configured
2162 ),
2163 serverInfo=Implementation(name=settings.app_name, version=__version__),
2164 instructions=("ContextForge providing federated tools, resources and prompts. Use /admin interface for configuration."),
2165 )
2167 async def store_client_capabilities(self, session_id: str, capabilities: Dict[str, Any]) -> None:
2168 """Store client capabilities for a session.
2170 Args:
2171 session_id: The session ID
2172 capabilities: Client capabilities dictionary from initialize request
2173 """
2174 async with self._lock:
2175 self._client_capabilities[session_id] = capabilities
2176 logger.debug(f"Stored capabilities for session {session_id}")
2178 async def get_client_capabilities(self, session_id: str) -> Optional[Dict[str, Any]]:
2179 """Get client capabilities for a session.
2181 Args:
2182 session_id: The session ID
2184 Returns:
2185 Client capabilities dictionary, or None if not found
2186 """
2187 async with self._lock:
2188 return self._client_capabilities.get(session_id)
2190 async def has_elicitation_capability(self, session_id: str) -> bool:
2191 """Check if a session has elicitation capability.
2193 Args:
2194 session_id: The session ID
2196 Returns:
2197 True if session supports elicitation, False otherwise
2198 """
2199 capabilities = await self.get_client_capabilities(session_id)
2200 if not capabilities:
2201 return False
2202 # Check if elicitation capability exists in client capabilities
2203 return bool(capabilities.get("elicitation"))
2205 async def get_elicitation_capable_sessions(self) -> list[str]:
2206 """Get list of session IDs that support elicitation.
2208 Returns:
2209 List of session IDs with elicitation capability
2210 """
2211 async with self._lock:
2212 capable_sessions = []
2213 for session_id, capabilities in self._client_capabilities.items():
2214 if capabilities.get("elicitation"):
2215 # Verify session still exists
2216 if session_id in self._sessions:
2217 capable_sessions.append(session_id)
2218 return capable_sessions
2220 async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any]) -> None:
2221 """Generate and send response for incoming MCP protocol message.
2223 Processes MCP protocol messages and generates appropriate responses based on
2224 the method. Supports various MCP methods including initialization, tool/resource/prompt
2225 listing, tool invocation, and ping.
2227 Uses loopback (127.0.0.1) for the internal RPC call to avoid issues when
2228 the gateway is behind a reverse proxy or service mesh where the client-facing
2229 URL is not reachable from the server itself.
2231 Args:
2232 message: Incoming MCP message as JSON. Must contain 'method' and 'id' fields.
2233 transport: SSE transport to send responses through.
2234 server_id: Optional server ID for scoped operations.
2235 user: User information containing authentication token.
2237 Examples:
2238 >>> import asyncio
2239 >>> from mcpgateway.cache.session_registry import SessionRegistry
2240 >>>
2241 >>> class MockTransport:
2242 ... async def send_message(self, msg):
2243 ... print(f"Response: {msg['method'] if 'method' in msg else msg.get('result', {})}")
2244 >>>
2245 >>> reg = SessionRegistry()
2246 >>> transport = MockTransport()
2247 >>> message = {"method": "ping", "id": 1}
2248 >>> user = {"token": "test-token"}
2249 >>> # asyncio.run(reg.generate_response(message, transport, None, user))
2250 >>> # Response: {}
2251 """
2252 result = {}
2254 if "method" in message and "id" in message:
2255 method = message["method"]
2256 params = message.get("params", {})
2257 params["server_id"] = server_id
2258 req_id = message["id"]
2260 rpc_input = {
2261 "jsonrpc": "2.0",
2262 "method": method,
2263 "params": params,
2264 "id": req_id,
2265 }
2266 # Get the token from the current authentication context
2267 # The user object should contain auth_token, token_teams, and is_admin from the SSE endpoint
2268 token = None
2269 is_admin = user.get("is_admin", False) # Preserve admin status from SSE endpoint
2271 try:
2272 if hasattr(user, "get") and user.get("auth_token"):
2273 token = user["auth_token"]
2274 else:
2275 # Fallback: create lightweight session token (teams resolved server-side by downstream /rpc)
2276 logger.warning("No auth token available for SSE RPC call - creating fallback session token")
2277 now = datetime.now(timezone.utc)
2278 payload = {
2279 "sub": user.get("email", "system"),
2280 "iss": settings.jwt_issuer,
2281 "aud": settings.jwt_audience,
2282 "iat": int(now.timestamp()),
2283 "jti": str(uuid.uuid4()),
2284 "token_use": "session", # nosec B105 - token type marker, not a password
2285 "user": {
2286 "email": user.get("email", "system"),
2287 "full_name": user.get("full_name", "System"),
2288 "is_admin": is_admin, # Preserve admin status for cookie-authenticated admins
2289 "auth_provider": "internal",
2290 },
2291 }
2292 # Generate token using centralized token creation
2293 token = await create_jwt_token(payload)
2295 # Pass downstream session id to /rpc for session affinity.
2296 # This is gateway-internal only; the pool strips it before contacting upstream MCP servers.
2297 if settings.mcpgateway_session_affinity_enabled:
2298 await self._register_session_mapping(transport.session_id, message, user.get("email") if hasattr(user, "get") else None)
2300 headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
2301 if settings.mcpgateway_session_affinity_enabled:
2302 headers["x-mcp-session-id"] = transport.session_id
2303 # Use loopback for internal RPC call (consistent with other self-call sites
2304 # in mcp_session_pool.py and streamablehttp_transport.py). This avoids
2305 # failures when the client-facing URL is not reachable from the server
2306 # (e.g., behind a reverse proxy or service mesh). See #3049.
2307 rpc_url = f"http://127.0.0.1:{settings.port}/rpc"
2309 logger.info(f"SSE RPC: Making call to {rpc_url} with method={method}, params={params}")
2311 async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client:
2312 logger.info(f"SSE RPC: Sending request to {rpc_url}")
2313 rpc_response = await client.post(
2314 url=rpc_url,
2315 json=rpc_input,
2316 headers=headers,
2317 )
2318 logger.info(f"SSE RPC: Got response status {rpc_response.status_code}")
2319 result = rpc_response.json()
2320 logger.info(f"SSE RPC: Response content: {result}")
2321 result = result.get("result", {})
2323 response = {"jsonrpc": "2.0", "result": result, "id": req_id}
2324 except JSONRPCError as e:
2325 logger.error(f"SSE RPC: JSON-RPC error: {e}")
2326 result = e.to_dict()
2327 response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id}
2328 except Exception as e:
2329 logger.error(f"SSE RPC: Exception during RPC call: {type(e).__name__}: {e}")
2330 logger.error(f"SSE RPC: Traceback: {traceback.format_exc()}")
2331 result = {"code": -32000, "message": "Internal error", "data": str(e)}
2332 response = {"jsonrpc": "2.0", "error": result, "id": req_id}
2334 logging.debug(f"Sending sse message:{response}")
2335 await transport.send_message(response)
2337 if message["method"] == "initialize":
2338 await transport.send_message(
2339 {
2340 "jsonrpc": "2.0",
2341 "method": "notifications/initialized",
2342 "params": {},
2343 }
2344 )
2345 notifications = [
2346 "tools/list_changed",
2347 "resources/list_changed",
2348 "prompts/list_changed",
2349 ]
2350 for notification in notifications:
2351 await transport.send_message(
2352 {
2353 "jsonrpc": "2.0",
2354 "method": f"notifications/{notification}",
2355 "params": {},
2356 }
2357 )