Coverage for mcpgateway / cache / session_registry.py: 100%
974 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/session_registry.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Session Registry with optional distributed state.
8This module provides a registry for SSE sessions with support for distributed deployment
9using Redis or SQLAlchemy as optional backends for shared state between workers.
11The SessionRegistry class manages server-sent event (SSE) sessions across multiple
12worker processes, enabling horizontal scaling of MCP gateway deployments. It supports
13three backend modes:
15- **memory**: In-memory storage for single-process deployments (default)
16- **redis**: Redis-backed shared storage for multi-worker deployments
17- **database**: SQLAlchemy-backed shared storage using any supported database
19In distributed mode (redis/database), session existence is tracked in the shared
20backend while transport objects remain local to each worker process. This allows
21workers to know about sessions on other workers and route messages appropriately.
23Examples:
24 Basic usage with memory backend:
26 >>> from mcpgateway.cache.session_registry import SessionRegistry
27 >>> class DummyTransport:
28 ... async def disconnect(self):
29 ... pass
30 ... async def is_connected(self):
31 ... return True
32 >>> import asyncio
33 >>> reg = SessionRegistry(backend='memory')
34 >>> transport = DummyTransport()
35 >>> asyncio.run(reg.add_session('sid123', transport))
36 >>> found = asyncio.run(reg.get_session('sid123'))
37 >>> isinstance(found, DummyTransport)
38 True
39 >>> asyncio.run(reg.remove_session('sid123'))
40 >>> asyncio.run(reg.get_session('sid123')) is None
41 True
43 Broadcasting messages:
45 >>> reg = SessionRegistry(backend='memory')
46 >>> asyncio.run(reg.broadcast('sid123', {'method': 'ping', 'id': 1}))
47 >>> reg._session_message is not None
48 True
49"""
51# Standard
52import asyncio
53from asyncio import Task
54from datetime import datetime, timedelta, timezone
55import logging
56import time
57import traceback
58from typing import Any, Dict, Optional
59import uuid
61# Third-Party
62from fastapi import HTTPException, status
63import orjson
65# First-Party
66from mcpgateway import __version__
67from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities
68from mcpgateway.config import settings
69from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord
70from mcpgateway.services import PromptService, ResourceService, ToolService
71from mcpgateway.services.logging_service import LoggingService
72from mcpgateway.transports import SSETransport
73from mcpgateway.utils.create_jwt_token import create_jwt_token
74from mcpgateway.utils.internal_http import internal_loopback_base_url, internal_loopback_verify
75from mcpgateway.utils.redis_client import get_redis_client
76from mcpgateway.utils.retry_manager import ResilientHttpClient
77from mcpgateway.validation.jsonrpc import JSONRPCError
79# Initialize logging service first
80logging_service: LoggingService = LoggingService()
81logger = logging_service.get_logger(__name__)
83tool_service: ToolService = ToolService()
84resource_service: ResourceService = ResourceService()
85prompt_service: PromptService = PromptService()
87try:
88 # Third-Party
89 from redis.asyncio import Redis
91 REDIS_AVAILABLE = True
92except ImportError:
93 REDIS_AVAILABLE = False
95try:
96 # Third-Party
97 from sqlalchemy import func
99 SQLALCHEMY_AVAILABLE = True
100except ImportError:
101 SQLALCHEMY_AVAILABLE = False
104class SessionBackend:
105 """Base class for session registry backend configuration.
107 This class handles the initialization and configuration of different backend
108 types for session storage. It validates backend requirements and sets up
109 necessary connections for Redis or database backends.
111 Attributes:
112 _backend: The backend type ('memory', 'redis', 'database', or 'none')
113 _session_ttl: Time-to-live for sessions in seconds
114 _message_ttl: Time-to-live for messages in seconds
115 _redis: Redis connection instance (redis backend only)
116 _pubsub: Redis pubsub instance (redis backend only)
117 _session_message: Temporary message storage (memory backend only)
119 Examples:
120 >>> backend = SessionBackend(backend='memory')
121 >>> backend._backend
122 'memory'
123 >>> backend._session_ttl
124 3600
126 >>> try:
127 ... backend = SessionBackend(backend='redis')
128 ... except ValueError as e:
129 ... str(e)
130 'Redis backend requires redis_url'
131 """
133 def __init__(
134 self,
135 backend: str = "memory",
136 redis_url: Optional[str] = None,
137 database_url: Optional[str] = None,
138 session_ttl: int = 3600, # 1 hour
139 message_ttl: int = 600, # 10 min
140 ):
141 """Initialize session backend configuration.
143 Args:
144 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'.
145 - 'memory': In-memory storage, suitable for single-process deployments
146 - 'redis': Redis-backed storage for multi-worker deployments
147 - 'database': SQLAlchemy-backed storage for multi-worker deployments
148 - 'none': No session tracking (dummy registry)
149 redis_url: Redis connection URL. Required when backend='redis'.
150 Format: 'redis://[:password]@host:port/db'
151 database_url: Database connection URL. Required when backend='database'.
152 Format depends on database type (e.g., 'postgresql://user:pass@host/db')
153 session_ttl: Session time-to-live in seconds. Sessions are automatically
154 cleaned up after this duration of inactivity. Default: 3600 (1 hour).
155 message_ttl: Message time-to-live in seconds. Undelivered messages are
156 removed after this duration. Default: 600 (10 minutes).
158 Raises:
159 ValueError: If backend is invalid, required URL is missing, or required packages are not installed.
161 Examples:
162 >>> # Memory backend (default)
163 >>> backend = SessionBackend()
164 >>> backend._backend
165 'memory'
167 >>> # Redis backend requires URL
168 >>> try:
169 ... backend = SessionBackend(backend='redis')
170 ... except ValueError as e:
171 ... 'redis_url' in str(e)
172 True
174 >>> # Invalid backend
175 >>> try:
176 ... backend = SessionBackend(backend='invalid')
177 ... except ValueError as e:
178 ... 'Invalid backend' in str(e)
179 True
180 """
182 self._backend = backend.lower()
183 self._session_ttl = session_ttl
184 self._message_ttl = message_ttl
186 # Set up backend-specific components
187 if self._backend == "memory":
188 # Nothing special needed for memory backend
189 self._session_message: dict[str, Any] | None = None
191 elif self._backend == "none":
192 # No session tracking - this is just a dummy registry
193 logger.info("Session registry initialized with 'none' backend - session tracking disabled")
195 elif self._backend == "redis":
196 if not REDIS_AVAILABLE:
197 raise ValueError("Redis backend requested but redis package not installed")
198 if not redis_url:
199 raise ValueError("Redis backend requires redis_url")
201 # Redis client is set in initialize() via the shared factory
202 self._redis: Optional[Redis] = None
203 self._pubsub = None
205 elif self._backend == "database":
206 if not SQLALCHEMY_AVAILABLE:
207 raise ValueError("Database backend requested but SQLAlchemy not installed")
208 if not database_url:
209 raise ValueError("Database backend requires database_url")
210 else:
211 raise ValueError(f"Invalid backend: {backend}")
214class SessionRegistry(SessionBackend):
215 """Registry for SSE sessions with optional distributed state.
217 This class manages server-sent event (SSE) sessions, providing methods to add,
218 remove, and query sessions. It supports multiple backend types for different
219 deployment scenarios:
221 - **Single-process deployments**: Use 'memory' backend (default)
222 - **Multi-worker deployments**: Use 'redis' or 'database' backend
223 - **Testing/development**: Use 'none' backend to disable session tracking
225 The registry maintains a local cache of transport objects while using the
226 shared backend to track session existence across workers. This enables
227 horizontal scaling while keeping transport objects process-local.
229 Attributes:
230 _sessions: Local dictionary mapping session IDs to transport objects
231 _lock: Asyncio lock for thread-safe access to _sessions
232 _cleanup_task: Background task for cleaning up expired sessions
234 Examples:
235 >>> import asyncio
236 >>> from mcpgateway.cache.session_registry import SessionRegistry
237 >>>
238 >>> class MockTransport:
239 ... async def disconnect(self):
240 ... print("Disconnected")
241 ... async def is_connected(self):
242 ... return True
243 ... async def send_message(self, msg):
244 ... print(f"Sent: {msg}")
245 >>>
246 >>> # Create registry and add session
247 >>> reg = SessionRegistry(backend='memory')
248 >>> transport = MockTransport()
249 >>> asyncio.run(reg.add_session('test123', transport))
250 >>>
251 >>> # Retrieve session
252 >>> found = asyncio.run(reg.get_session('test123'))
253 >>> found is transport
254 True
255 >>>
256 >>> # Remove session
257 >>> asyncio.run(reg.remove_session('test123'))
258 Disconnected
259 >>> asyncio.run(reg.get_session('test123')) is None
260 True
261 """
263 def __init__(
264 self,
265 backend: str = "memory",
266 redis_url: Optional[str] = None,
267 database_url: Optional[str] = None,
268 session_ttl: int = 3600, # 1 hour
269 message_ttl: int = 600, # 10 min
270 ):
271 """Initialize session registry with specified backend.
273 Args:
274 backend: Backend type. Must be one of 'memory', 'redis', 'database', or 'none'.
275 redis_url: Redis connection URL. Required when backend='redis'.
276 database_url: Database connection URL. Required when backend='database'.
277 session_ttl: Session time-to-live in seconds. Default: 3600.
278 message_ttl: Message time-to-live in seconds. Default: 600.
280 Examples:
281 >>> # Default memory backend
282 >>> reg = SessionRegistry()
283 >>> reg._backend
284 'memory'
285 >>> isinstance(reg._sessions, dict)
286 True
288 >>> # Redis backend with custom TTL
289 >>> try:
290 ... reg = SessionRegistry(
291 ... backend='redis',
292 ... redis_url='redis://localhost:6379',
293 ... session_ttl=7200
294 ... )
295 ... except ValueError:
296 ... pass # Redis may not be available
297 """
298 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl)
299 self._sessions: Dict[str, Any] = {} # Local transport cache
300 self._session_owners: Dict[str, str] = {} # Session owner email by session_id
301 self._client_capabilities: Dict[str, Dict[str, Any]] = {} # Client capabilities by session_id
302 self._respond_tasks: Dict[str, asyncio.Task] = {} # Track respond tasks for cancellation
303 self._stuck_tasks: Dict[str, asyncio.Task] = {} # Tasks that couldn't be cancelled (for monitoring)
304 self._closing_sessions: set[str] = set() # Sessions being closed - respond loop should exit
305 self._lock = asyncio.Lock()
306 self._cleanup_task: Task | None = None
307 self._stuck_task_reaper: Task | None = None # Reaper for stuck tasks
309 def register_respond_task(self, session_id: str, task: asyncio.Task) -> None:
310 """Register a respond task for later cancellation.
312 Associates an asyncio Task with a session_id so it can be cancelled
313 when the session is removed. This prevents orphaned tasks that cause
314 CPU spin loops.
316 Args:
317 session_id: Session identifier the task belongs to.
318 task: The asyncio Task to track.
319 """
320 self._respond_tasks[session_id] = task
321 logger.debug(f"Registered respond task for session {session_id}")
323 async def _cancel_respond_task(self, session_id: str, timeout: float = 5.0) -> None:
324 """Cancel and await a respond task with timeout.
326 Safely cancels the respond task associated with a session. Uses a timeout
327 to prevent hanging if the task doesn't respond to cancellation.
329 If initial cancellation times out, escalates by force-disconnecting the
330 transport to unblock the task, then retries cancellation (Finding 1 fix).
332 Args:
333 session_id: Session identifier whose task should be cancelled.
334 timeout: Maximum seconds to wait for task cancellation. Default 5.0.
335 """
336 task = self._respond_tasks.get(session_id)
337 if task is None:
338 return
340 if task.done():
341 # Task already finished - safe to remove from tracking
342 self._respond_tasks.pop(session_id, None)
343 try:
344 task.result()
345 except asyncio.CancelledError:
346 pass
347 except Exception as e:
348 logger.warning(f"Respond task for {session_id} failed with: {e}")
349 return
351 task.cancel()
352 logger.debug(f"Cancelling respond task for session {session_id}")
354 try:
355 await asyncio.wait_for(task, timeout=timeout)
356 # Cancellation successful - remove from tracking
357 self._respond_tasks.pop(session_id, None)
358 logger.debug(f"Respond task cancelled for session {session_id}")
359 except asyncio.TimeoutError:
360 # ESCALATION (Finding 1): Force-disconnect transport to unblock the task
361 logger.warning(f"Respond task cancellation timed out for {session_id}, " f"escalating with transport disconnect")
363 # Force-disconnect the transport to unblock any pending I/O
364 transport = self._sessions.get(session_id)
365 if transport and hasattr(transport, "disconnect"):
366 try:
367 await transport.disconnect()
368 logger.debug(f"Force-disconnected transport for {session_id}")
369 except Exception as e:
370 logger.warning(f"Failed to force-disconnect transport for {session_id}: {e}")
372 # Retry cancellation with shorter timeout
373 if not task.done():
374 try:
375 await asyncio.wait_for(task, timeout=2.0)
376 self._respond_tasks.pop(session_id, None)
377 logger.info(f"Respond task cancelled after escalation for {session_id}")
378 except asyncio.TimeoutError:
379 # Still stuck - move to stuck_tasks for monitoring (Finding 2 fix)
380 self._respond_tasks.pop(session_id, None)
381 self._stuck_tasks[session_id] = task
382 logger.error(f"Respond task for {session_id} still stuck after escalation, " f"moved to stuck_tasks for monitoring (total stuck: {len(self._stuck_tasks)})")
383 except asyncio.CancelledError:
384 self._respond_tasks.pop(session_id, None)
385 logger.info(f"Respond task cancelled after escalation for {session_id}")
386 except Exception as e:
387 self._respond_tasks.pop(session_id, None)
388 logger.warning(f"Error during retry cancellation for {session_id}: {e}")
389 else:
390 self._respond_tasks.pop(session_id, None)
391 logger.debug(f"Respond task completed during escalation for {session_id}")
393 except asyncio.CancelledError:
394 # Cancellation successful - remove from tracking
395 self._respond_tasks.pop(session_id, None)
396 logger.debug(f"Respond task cancelled for session {session_id}")
397 except Exception as e:
398 # Remove from tracking on unexpected error - task state unknown
399 self._respond_tasks.pop(session_id, None)
400 logger.warning(f"Error during respond task cancellation for {session_id}: {e}")
402 async def _reap_stuck_tasks(self) -> None:
403 """Periodically clean up stuck tasks that have completed.
405 This reaper runs every 30 seconds and:
406 1. Removes completed tasks from _stuck_tasks
407 2. Retries cancellation for tasks that are still running
408 3. Logs warnings for tasks that remain stuck
410 This prevents memory leaks from tasks that eventually complete after
411 being moved to _stuck_tasks during escalation.
413 Raises:
414 asyncio.CancelledError: If the task is cancelled during shutdown.
415 """
416 reap_interval = 30.0 # seconds
417 retry_timeout = 2.0 # seconds for retry cancellation
419 while True:
420 try:
421 await asyncio.sleep(reap_interval)
423 if not self._stuck_tasks:
424 continue
426 # Collect completed and still-stuck tasks
427 completed = []
428 still_stuck = []
430 for session_id, task in list(self._stuck_tasks.items()):
431 if task.done():
432 completed.append(session_id)
433 try:
434 task.result() # Consume result to avoid warnings
435 except (asyncio.CancelledError, Exception):
436 pass
437 else:
438 still_stuck.append((session_id, task))
440 # Remove completed tasks
441 for session_id in completed:
442 self._stuck_tasks.pop(session_id, None)
444 if completed:
445 logger.info(f"Reaped {len(completed)} completed stuck tasks")
447 # Retry cancellation for still-stuck tasks
448 for session_id, task in still_stuck:
449 task.cancel()
450 try:
451 await asyncio.wait_for(task, timeout=retry_timeout)
452 self._stuck_tasks.pop(session_id, None)
453 logger.info(f"Stuck task {session_id} finally cancelled during reap")
454 except asyncio.TimeoutError:
455 logger.warning(f"Task {session_id} still stuck after reap retry")
456 except asyncio.CancelledError:
457 self._stuck_tasks.pop(session_id, None)
458 logger.info(f"Stuck task {session_id} cancelled during reap")
459 except Exception as e:
460 logger.warning(f"Error during stuck task reap for {session_id}: {e}")
462 if self._stuck_tasks:
463 logger.warning(f"Stuck tasks remaining: {len(self._stuck_tasks)}")
465 except asyncio.CancelledError:
466 logger.debug("Stuck task reaper cancelled")
467 raise
468 except Exception as e:
469 logger.error(f"Error in stuck task reaper: {e}")
471 async def initialize(self) -> None:
472 """Initialize the registry with async setup.
474 This method performs asynchronous initialization tasks that cannot be done
475 in __init__. It starts background cleanup tasks and sets up pubsub
476 subscriptions for distributed backends.
478 Call this during application startup after creating the registry instance.
480 Examples:
481 >>> import asyncio
482 >>> reg = SessionRegistry(backend='memory')
483 >>> asyncio.run(reg.initialize())
484 >>> reg._cleanup_task is not None
485 True
486 >>>
487 >>> # Cleanup
488 >>> asyncio.run(reg.shutdown())
489 """
490 logger.info(f"Initializing session registry with backend: {self._backend}")
492 if self._backend == "database":
493 # Start database cleanup task
494 self._cleanup_task = asyncio.create_task(self._db_cleanup_task())
495 logger.info("Database cleanup task started")
497 elif self._backend == "redis":
498 # Get shared Redis client from factory
499 self._redis = await get_redis_client()
500 if self._redis:
501 self._pubsub = self._redis.pubsub()
502 await self._pubsub.subscribe("mcp_session_events")
503 logger.info("Session registry connected to shared Redis client")
505 elif self._backend == "none":
506 # Nothing to initialize for none backend
507 pass
509 # Memory backend needs session cleanup
510 elif self._backend == "memory":
511 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task())
512 logger.info("Memory cleanup task started")
514 # Start stuck task reaper for all backends
515 self._stuck_task_reaper = asyncio.create_task(self._reap_stuck_tasks())
516 logger.info("Stuck task reaper started")
518 async def shutdown(self) -> None:
519 """Shutdown the registry and clean up resources.
521 This method cancels background tasks and closes connections to external
522 services. Call this during application shutdown to ensure clean termination.
524 Examples:
525 >>> import asyncio
526 >>> reg = SessionRegistry()
527 >>> asyncio.run(reg.initialize())
528 >>> task_was_created = reg._cleanup_task is not None
529 >>> asyncio.run(reg.shutdown())
530 >>> # After shutdown, cleanup task should be handled (cancelled or done)
531 >>> task_was_created and (reg._cleanup_task.cancelled() or reg._cleanup_task.done())
532 True
533 """
534 logger.info("Shutting down session registry")
536 # Cancel cleanup task
537 if self._cleanup_task:
538 self._cleanup_task.cancel()
539 try:
540 await self._cleanup_task
541 except asyncio.CancelledError:
542 pass
544 # Cancel stuck task reaper
545 if self._stuck_task_reaper:
546 self._stuck_task_reaper.cancel()
547 try:
548 await self._stuck_task_reaper
549 except asyncio.CancelledError:
550 pass
552 # CRITICAL: Cancel ALL respond tasks to prevent CPU spin loops
553 if self._respond_tasks:
554 logger.info(f"Cancelling {len(self._respond_tasks)} respond tasks")
555 tasks_to_cancel = list(self._respond_tasks.values())
556 self._respond_tasks.clear()
558 for task in tasks_to_cancel:
559 if not task.done():
560 task.cancel()
562 if tasks_to_cancel:
563 try:
564 await asyncio.wait_for(asyncio.gather(*tasks_to_cancel, return_exceptions=True), timeout=10.0)
565 logger.info("All respond tasks cancelled successfully")
566 except asyncio.TimeoutError:
567 logger.warning("Timeout waiting for respond tasks to cancel")
569 # Also cancel any stuck tasks (tasks that previously couldn't be cancelled)
570 if self._stuck_tasks:
571 logger.warning(f"Attempting final cancellation of {len(self._stuck_tasks)} stuck tasks")
572 stuck_to_cancel = list(self._stuck_tasks.values())
573 self._stuck_tasks.clear()
575 for task in stuck_to_cancel:
576 if not task.done():
577 task.cancel()
579 if stuck_to_cancel:
580 try:
581 await asyncio.wait_for(asyncio.gather(*stuck_to_cancel, return_exceptions=True), timeout=5.0)
582 logger.info("Stuck tasks cancelled during shutdown")
583 except asyncio.TimeoutError:
584 logger.error("Some stuck tasks could not be cancelled during shutdown")
586 # Close Redis pubsub (but not the shared client)
587 # Use timeout to prevent blocking if pubsub doesn't close cleanly
588 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout
589 if self._backend == "redis" and getattr(self, "_pubsub", None):
590 try:
591 await asyncio.wait_for(self._pubsub.aclose(), timeout=cleanup_timeout)
592 except asyncio.TimeoutError:
593 logger.warning("Redis pubsub close timed out - proceeding anyway")
594 except Exception as e:
595 logger.error(f"Error closing Redis pubsub: {e}")
596 # Don't close self._redis - it's the shared client managed by redis_client.py
597 self._redis = None
598 self._pubsub = None
600 async def add_session(self, session_id: str, transport: SSETransport) -> None:
601 """Add a session to the registry.
603 Stores the session in both the local cache and the distributed backend
604 (if configured). For distributed backends, this notifies other workers
605 about the new session.
607 Args:
608 session_id: Unique session identifier. Should be a UUID or similar
609 unique string to avoid collisions.
610 transport: SSE transport object for this session. Must implement
611 the SSETransport interface.
613 Examples:
614 >>> import asyncio
615 >>> from mcpgateway.cache.session_registry import SessionRegistry
616 >>>
617 >>> class MockTransport:
618 ... async def disconnect(self):
619 ... print(f"Transport disconnected")
620 ... async def is_connected(self):
621 ... return True
622 >>>
623 >>> reg = SessionRegistry()
624 >>> transport = MockTransport()
625 >>> asyncio.run(reg.add_session('test-456', transport))
626 >>>
627 >>> # Found in local cache
628 >>> found = asyncio.run(reg.get_session('test-456'))
629 >>> found is transport
630 True
631 >>>
632 >>> # Remove session
633 >>> asyncio.run(reg.remove_session('test-456'))
634 Transport disconnected
635 """
636 # Skip for none backend
637 if self._backend == "none":
638 return
640 async with self._lock:
641 self._sessions[session_id] = transport
643 if self._backend == "redis":
644 # Store session marker in Redis
645 if not self._redis:
646 logger.warning(f"Redis client not initialized, skipping distributed session tracking for {session_id}")
647 return
648 try:
649 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1")
650 # Publish event to notify other workers
651 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()}))
652 except Exception as e:
653 logger.error(f"Redis error adding session {session_id}: {e}")
655 elif self._backend == "database":
656 # Store session in database
657 try:
659 def _db_add() -> None:
660 """Store session record in the database.
662 Creates a new SessionRecord entry in the database for tracking
663 distributed session state. Uses a fresh database connection from
664 the connection pool.
666 This inner function is designed to be run in a thread executor
667 to avoid blocking the async event loop during database I/O.
669 Raises:
670 Exception: Any database error is re-raised after rollback.
671 Common errors include duplicate session_id (unique constraint)
672 or database connection issues.
674 Examples:
675 >>> # This function is called internally by add_session()
676 >>> # When executed, it creates a database record:
677 >>> # SessionRecord(session_id='abc123', created_at=now())
678 """
679 db_session = next(get_db())
680 try:
681 session_record = SessionRecord(session_id=session_id)
682 db_session.add(session_record)
683 db_session.commit()
684 except Exception as ex:
685 db_session.rollback()
686 raise ex
687 finally:
688 db_session.close()
690 await asyncio.to_thread(_db_add)
691 except Exception as e:
692 logger.error(f"Database error adding session {session_id}: {e}")
694 logger.info(f"Added session: {session_id}")
696 def _session_owner_key(self, session_id: str) -> str:
697 """Return Redis key used to store session ownership.
699 Args:
700 session_id: Session identifier.
702 Returns:
703 Redis key for the session owner mapping.
704 """
705 return f"mcp:session_owner:{session_id}"
707 async def set_session_owner(self, session_id: str, owner_email: Optional[str]) -> None:
708 """Set or clear owner for a session.
710 Args:
711 session_id: Session identifier to update.
712 owner_email: Owner email to set. Passing ``None`` clears ownership.
714 Returns:
715 None.
716 """
717 # Skip for none backend
718 if self._backend == "none":
719 return
721 if owner_email:
722 self._session_owners[session_id] = owner_email
723 else:
724 self._session_owners.pop(session_id, None)
726 if self._backend == "redis":
727 if not self._redis:
728 logger.warning(f"Redis client not initialized, cannot set owner for session {session_id}")
729 return
730 try:
731 owner_key = self._session_owner_key(session_id)
732 if owner_email:
733 await self._redis.setex(owner_key, self._session_ttl, owner_email)
734 else:
735 await self._redis.delete(owner_key)
736 except Exception as e:
737 logger.error(f"Redis error setting owner for session {session_id}: {e}")
739 elif self._backend == "database":
740 try:
742 def _db_set_owner() -> None:
743 """Persist owner metadata for a session in the database backend.
745 Raises:
746 Exception: Propagates database write failures to the caller.
747 """
748 db_session = next(get_db())
749 try:
750 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
751 if not record:
752 record = SessionRecord(session_id=session_id, data=None)
753 db_session.add(record)
754 db_session.flush()
756 record_data: Dict[str, Any] = {}
757 if record.data:
758 try:
759 parsed = orjson.loads(record.data)
760 if isinstance(parsed, dict):
761 record_data = parsed
762 except Exception:
763 record_data = {}
765 if owner_email:
766 record_data["owner_email"] = owner_email
767 else:
768 record_data.pop("owner_email", None)
770 record.data = orjson.dumps(record_data).decode() if record_data else None
771 db_session.commit()
772 except Exception as ex:
773 db_session.rollback()
774 raise ex
775 finally:
776 db_session.close()
778 await asyncio.to_thread(_db_set_owner)
779 except Exception as e:
780 logger.error(f"Database error setting owner for session {session_id}: {e}")
782 async def claim_session_owner(self, session_id: str, owner_email: str) -> Optional[str]:
783 """Atomically claim ownership for a session and return the effective owner.
785 This method provides compare-and-set semantics:
786 - If a session owner already exists, return the existing owner.
787 - If no owner exists, claim ownership for ``owner_email``.
789 Args:
790 session_id: Session identifier to claim.
791 owner_email: Requesting owner email.
793 Returns:
794 Effective owner email after the claim operation, or ``None`` if owner
795 metadata could not be verified due to backend availability issues.
796 """
797 if self._backend == "none":
798 return owner_email
800 # Fast local cache path.
801 cached_owner = self._session_owners.get(session_id)
802 if cached_owner:
803 return cached_owner
805 if self._backend == "memory":
806 async with self._lock:
807 existing_owner = self._session_owners.get(session_id)
808 if existing_owner:
809 return existing_owner
810 self._session_owners[session_id] = owner_email
811 return owner_email
813 if self._backend == "redis":
814 if not self._redis:
815 logger.warning("Redis client not initialized, session owner claim unavailable for %s", session_id)
816 return None
818 owner_key = self._session_owner_key(session_id)
819 try:
820 claimed = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True)
821 if claimed:
822 self._session_owners[session_id] = owner_email
823 return owner_email
825 owner_raw = await self._redis.get(owner_key)
826 if owner_raw is not None:
827 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw)
828 if existing_owner:
829 self._session_owners[session_id] = existing_owner
830 return existing_owner
832 # Handle key expiry/race window by retrying once.
833 claimed_retry = await self._redis.set(owner_key, owner_email, ex=self._session_ttl, nx=True)
834 if claimed_retry:
835 self._session_owners[session_id] = owner_email
836 return owner_email
838 owner_raw = await self._redis.get(owner_key)
839 if owner_raw is None:
840 return None
841 existing_owner = owner_raw.decode() if isinstance(owner_raw, bytes) else str(owner_raw)
842 if existing_owner:
843 self._session_owners[session_id] = existing_owner
844 return existing_owner
845 return None
846 except Exception as e:
847 logger.error("Redis error claiming owner for session %s: %s", session_id, e)
848 return None
850 if self._backend == "database":
851 try:
853 def _db_claim_owner() -> Optional[str]:
854 """Claim owner in DB with optimistic compare-and-set retries.
856 Returns:
857 Effective owner email or ``None`` when claim cannot be verified.
858 """
859 db_session = next(get_db())
860 try:
861 for _attempt in range(3):
862 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
863 if not record:
864 owner_payload = {"owner_email": owner_email}
865 new_record = SessionRecord(session_id=session_id, data=orjson.dumps(owner_payload).decode())
866 db_session.add(new_record)
867 try:
868 db_session.commit()
869 return owner_email
870 except Exception:
871 db_session.rollback()
872 # Another writer may have inserted concurrently. Retry.
873 continue
875 record_data: Dict[str, Any] = {}
876 if record.data:
877 try:
878 parsed = orjson.loads(record.data)
879 if isinstance(parsed, dict):
880 record_data = parsed
881 except Exception:
882 record_data = {}
884 existing_owner = record_data.get("owner_email")
885 if isinstance(existing_owner, str) and existing_owner:
886 return existing_owner
888 current_data = record.data
889 updated_data = dict(record_data)
890 updated_data["owner_email"] = owner_email
891 serialized = orjson.dumps(updated_data).decode()
893 update_query = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id)
894 if current_data is None:
895 update_query = update_query.filter(SessionRecord.data.is_(None))
896 else:
897 update_query = update_query.filter(SessionRecord.data == current_data)
899 updated_rows = update_query.update({"data": serialized}, synchronize_session=False)
900 if updated_rows == 1:
901 db_session.commit()
902 return owner_email
904 db_session.rollback()
906 return None
907 finally:
908 db_session.close()
910 claimed_owner = await asyncio.to_thread(_db_claim_owner)
911 if claimed_owner:
912 self._session_owners[session_id] = claimed_owner
913 return claimed_owner
914 except Exception as e:
915 logger.error("Database error claiming owner for session %s: %s", session_id, e)
916 return None
918 return None
920 async def get_session_owner(self, session_id: str) -> Optional[str]:
921 """Get owner email for a session.
923 Args:
924 session_id: Session identifier to resolve.
926 Returns:
927 Owner email when present, otherwise ``None``.
928 """
929 owner = self._session_owners.get(session_id)
930 if owner:
931 return owner
933 if self._backend == "redis":
934 if not self._redis:
935 return None
936 try:
937 owner_raw = await self._redis.get(self._session_owner_key(session_id))
938 if owner_raw is None:
939 return None
940 if isinstance(owner_raw, bytes):
941 owner = owner_raw.decode()
942 else:
943 owner = str(owner_raw)
944 if owner:
945 self._session_owners[session_id] = owner
946 return owner or None
947 except Exception as e:
948 logger.error(f"Redis error getting owner for session {session_id}: {e}")
949 return None
951 if self._backend == "database":
952 try:
954 def _db_get_owner() -> Optional[str]:
955 db_session = next(get_db())
956 try:
957 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
958 if not record or not record.data:
959 return None
960 try:
961 data = orjson.loads(record.data)
962 except Exception:
963 return None
964 if not isinstance(data, dict):
965 return None
966 owner_email = data.get("owner_email")
967 return owner_email if isinstance(owner_email, str) and owner_email else None
968 finally:
969 db_session.close()
971 owner = await asyncio.to_thread(_db_get_owner)
972 if owner:
973 self._session_owners[session_id] = owner
974 return owner
975 except Exception as e:
976 logger.error(f"Database error getting owner for session {session_id}: {e}")
977 return None
979 return None
981 async def session_exists(self, session_id: str) -> Optional[bool]:
982 """Return whether a session marker exists.
984 Args:
985 session_id: Session identifier to resolve.
987 Returns:
988 ``True`` when session exists, ``False`` when session does not exist,
989 and ``None`` when existence cannot be verified due to backend errors.
990 """
991 if self._backend == "none":
992 return False
994 async with self._lock:
995 if session_id in self._sessions:
996 return True
998 if self._backend == "memory":
999 return False
1001 if self._backend == "redis":
1002 if not self._redis:
1003 return None
1004 try:
1005 return bool(await self._redis.exists(f"mcp:session:{session_id}"))
1006 except Exception as e:
1007 logger.error("Redis error checking existence for session %s: %s", session_id, e)
1008 return None
1010 if self._backend == "database":
1011 try:
1013 def _db_exists() -> bool:
1014 """Check whether a session record exists in the database backend.
1016 Returns:
1017 ``True`` when a matching session record exists.
1018 """
1019 db_session = next(get_db())
1020 try:
1021 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1022 return record is not None
1023 finally:
1024 db_session.close()
1026 return await asyncio.to_thread(_db_exists)
1027 except Exception as e:
1028 logger.error("Database error checking existence for session %s: %s", session_id, e)
1029 return None
1031 return False
1033 async def get_session(self, session_id: str) -> Any:
1034 """Get session transport by ID.
1036 First checks the local cache for the transport object. If not found locally
1037 but using a distributed backend, checks if the session exists on another
1038 worker.
1040 Args:
1041 session_id: Session identifier to look up.
1043 Returns:
1044 SSETransport object if found locally, None if not found or exists
1045 on another worker.
1047 Examples:
1048 >>> import asyncio
1049 >>> from mcpgateway.cache.session_registry import SessionRegistry
1050 >>>
1051 >>> class MockTransport:
1052 ... pass
1053 >>>
1054 >>> reg = SessionRegistry()
1055 >>> transport = MockTransport()
1056 >>> asyncio.run(reg.add_session('test-456', transport))
1057 >>>
1058 >>> # Found in local cache
1059 >>> found = asyncio.run(reg.get_session('test-456'))
1060 >>> found is transport
1061 True
1062 >>>
1063 >>> # Not found
1064 >>> asyncio.run(reg.get_session('nonexistent')) is None
1065 True
1066 """
1067 # Skip for none backend
1068 if self._backend == "none":
1069 return None
1071 # First check local cache
1072 async with self._lock:
1073 transport = self._sessions.get(session_id)
1074 if transport:
1075 logger.info(f"Session {session_id} exists in local cache")
1076 return transport
1078 # If not in local cache, check if it exists in shared backend
1079 if self._backend == "redis":
1080 if not self._redis:
1081 return None
1082 try:
1083 exists = await self._redis.exists(f"mcp:session:{session_id}")
1084 session_exists = bool(exists)
1085 if session_exists:
1086 logger.info(f"Session {session_id} exists in Redis but not in local cache")
1087 return None # We don't have the transport locally
1088 except Exception as e:
1089 logger.error(f"Redis error checking session {session_id}: {e}")
1090 return None
1092 elif self._backend == "database":
1093 try:
1095 def _db_check() -> bool:
1096 """Check if a session exists in the database.
1098 Queries the SessionRecord table to determine if a session with
1099 the given session_id exists. This is used when the session is not
1100 found in the local cache to check if it exists on another worker.
1102 This inner function is designed to be run in a thread executor
1103 to avoid blocking the async event loop during database queries.
1105 Returns:
1106 bool: True if the session exists in the database, False otherwise.
1108 Examples:
1109 >>> # This function is called internally by get_session()
1110 >>> # Returns True if SessionRecord with session_id exists
1111 >>> # Returns False if no matching record found
1112 """
1113 db_session = next(get_db())
1114 try:
1115 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1116 return record is not None
1117 finally:
1118 db_session.close()
1120 exists = await asyncio.to_thread(_db_check)
1121 if exists:
1122 logger.info(f"Session {session_id} exists in database but not in local cache")
1123 return None
1124 except Exception as e:
1125 logger.error(f"Database error checking session {session_id}: {e}")
1126 return None
1128 return None
1130 async def remove_session(self, session_id: str) -> None:
1131 """Remove a session from the registry.
1133 Removes the session from both local cache and distributed backend.
1134 If a transport is found locally, it will be disconnected before removal.
1135 For distributed backends, notifies other workers about the removal.
1137 Args:
1138 session_id: Session identifier to remove.
1140 Examples:
1141 >>> import asyncio
1142 >>> from mcpgateway.cache.session_registry import SessionRegistry
1143 >>>
1144 >>> class MockTransport:
1145 ... async def disconnect(self):
1146 ... print(f"Transport disconnected")
1147 ... async def is_connected(self):
1148 ... return True
1149 >>>
1150 >>> reg = SessionRegistry()
1151 >>> transport = MockTransport()
1152 >>> asyncio.run(reg.add_session('remove-test', transport))
1153 >>> asyncio.run(reg.remove_session('remove-test'))
1154 Transport disconnected
1155 >>>
1156 >>> # Session no longer exists
1157 >>> asyncio.run(reg.get_session('remove-test')) is None
1158 True
1159 """
1160 # Skip for none backend
1161 if self._backend == "none":
1162 return
1164 # Mark session as closing FIRST so respond loop can exit early
1165 # This allows the loop to exit without waiting for cancellation to complete
1166 self._closing_sessions.add(session_id)
1168 try:
1169 # CRITICAL: Cancel respond task before any cleanup
1170 # This prevents orphaned tasks that cause CPU spin loops
1171 await self._cancel_respond_task(session_id)
1173 # Clean up local transport
1174 transport = None
1175 async with self._lock:
1176 if session_id in self._sessions:
1177 transport = self._sessions.pop(session_id)
1178 self._session_owners.pop(session_id, None)
1179 # Also clean up client capabilities
1180 if session_id in self._client_capabilities:
1181 self._client_capabilities.pop(session_id)
1182 logger.debug(f"Removed capabilities for session {session_id}")
1183 finally:
1184 # Always remove from closing set
1185 self._closing_sessions.discard(session_id)
1187 # Disconnect transport if found
1188 if transport:
1189 try:
1190 await transport.disconnect()
1191 except Exception as e:
1192 logger.error(f"Error disconnecting transport for session {session_id}: {e}")
1194 # Remove from shared backend
1195 if self._backend == "redis":
1196 if not self._redis:
1197 return
1198 try:
1199 await self._redis.delete(f"mcp:session:{session_id}")
1200 await self._redis.delete(self._session_owner_key(session_id))
1201 # Notify other workers
1202 await self._redis.publish("mcp_session_events", orjson.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()}))
1203 except Exception as e:
1204 logger.error(f"Redis error removing session {session_id}: {e}")
1206 elif self._backend == "database":
1207 try:
1209 def _db_remove() -> None:
1210 """Delete session record from the database.
1212 Removes the SessionRecord entry with the specified session_id
1213 from the database. This is called when a session is being
1214 terminated or has expired.
1216 This inner function is designed to be run in a thread executor
1217 to avoid blocking the async event loop during database operations.
1219 Raises:
1220 Exception: Any database error is re-raised after rollback.
1221 This includes connection errors or constraint violations.
1223 Examples:
1224 >>> # This function is called internally by remove_session()
1225 >>> # Deletes the SessionRecord where session_id matches
1226 >>> # No error if session_id doesn't exist (idempotent)
1227 """
1228 db_session = next(get_db())
1229 try:
1230 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete()
1231 db_session.commit()
1232 except Exception as ex:
1233 db_session.rollback()
1234 raise ex
1235 finally:
1236 db_session.close()
1238 await asyncio.to_thread(_db_remove)
1239 except Exception as e:
1240 logger.error(f"Database error removing session {session_id}: {e}")
1242 logger.info(f"Removed session: {session_id}")
1244 async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None:
1245 """Broadcast a message to a session.
1247 Sends a message to the specified session. The behavior depends on the backend:
1249 - **memory**: Stores message temporarily for local delivery
1250 - **redis**: Publishes message to Redis channel for the session
1251 - **database**: Stores message in database for polling by worker with session
1252 - **none**: No operation
1254 This method is used for inter-process communication in distributed deployments.
1256 Args:
1257 session_id: Target session identifier.
1258 message: Message to broadcast. Can be a dict, list, or any JSON-serializable object.
1260 Examples:
1261 >>> import asyncio
1262 >>> from mcpgateway.cache.session_registry import SessionRegistry
1263 >>>
1264 >>> reg = SessionRegistry(backend='memory')
1265 >>> message = {'method': 'tools/list', 'id': 1}
1266 >>> asyncio.run(reg.broadcast('session-789', message))
1267 >>>
1268 >>> # Message stored for memory backend
1269 >>> reg._session_message is not None
1270 True
1271 >>> reg._session_message['session_id']
1272 'session-789'
1273 >>> orjson.loads(reg._session_message['message'])['message'] == message
1274 True
1275 """
1276 # Skip for none backend only
1277 if self._backend == "none":
1278 return
1280 def _build_payload(msg: Any) -> str:
1281 """Build a JSON payload for message broadcasting.
1283 Args:
1284 msg: Message to wrap in payload envelope.
1286 Returns:
1287 JSON-encoded string containing type, message, and timestamp.
1288 """
1289 payload = {"type": "message", "message": msg, "timestamp": time.time()}
1290 return orjson.dumps(payload).decode()
1292 if self._backend == "memory":
1293 payload_json = _build_payload(message)
1294 self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": payload_json}
1296 elif self._backend == "redis":
1297 if not self._redis:
1298 logger.warning(f"Redis client not initialized, cannot broadcast to {session_id}")
1299 return
1300 try:
1301 broadcast_payload = {
1302 "type": "message",
1303 "message": message, # Keep as original type, not pre-encoded
1304 "timestamp": time.time(),
1305 }
1306 # Single encode
1307 payload_json = orjson.dumps(broadcast_payload)
1308 await self._redis.publish(session_id, payload_json) # Single encode
1309 except Exception as e:
1310 logger.error(f"Redis error during broadcast: {e}")
1311 elif self._backend == "database":
1312 try:
1313 msg_json = _build_payload(message)
1315 def _db_add() -> None:
1316 """Store message in the database for inter-process communication.
1318 Creates a new SessionMessageRecord entry containing the session_id
1319 and serialized message. This enables message passing between
1320 different worker processes through the shared database.
1322 This inner function is designed to be run in a thread executor
1323 to avoid blocking the async event loop during database writes.
1325 Raises:
1326 Exception: Any database error is re-raised after rollback.
1327 Common errors include database connection issues or
1328 constraints violations.
1330 Examples:
1331 >>> # This function is called internally by broadcast()
1332 >>> # Creates a record like:
1333 >>> # SessionMessageRecord(
1334 >>> # session_id='abc123',
1335 >>> # message='{"method": "ping", "id": 1}',
1336 >>> # created_at=now()
1337 >>> # )
1338 """
1339 db_session = next(get_db())
1340 try:
1341 message_record = SessionMessageRecord(session_id=session_id, message=msg_json)
1342 db_session.add(message_record)
1343 db_session.commit()
1344 except Exception as ex:
1345 db_session.rollback()
1346 raise ex
1347 finally:
1348 db_session.close()
1350 await asyncio.to_thread(_db_add)
1351 except Exception as e:
1352 logger.error(f"Database error during broadcast: {e}")
1354 async def _register_session_mapping(self, session_id: str, message: Dict[str, Any], user_email: Optional[str] = None) -> None:
1355 """Register session mapping for session affinity when tools are called.
1357 This method is called on the worker that executes the request (the SSE session
1358 owner) to pre-register the mapping between a downstream session ID and the
1359 upstream MCP session pool key. This enables session affinity in multi-worker
1360 deployments.
1362 Only registers mappings for tools/call methods - list operations and other
1363 methods don't need session affinity since they don't maintain state.
1365 Args:
1366 session_id: The downstream SSE session ID.
1367 message: The MCP protocol message being broadcast.
1368 user_email: Optional user email for session isolation.
1369 """
1370 # Skip if session affinity is disabled
1371 if not settings.mcpgateway_session_affinity_enabled:
1372 return
1374 # Only register for tools/call - other methods don't need session affinity
1375 method = message.get("method")
1376 if method != "tools/call":
1377 return
1379 # Extract tool name from params
1380 params = message.get("params", {})
1381 tool_name = params.get("name")
1382 if not tool_name:
1383 return
1385 try:
1386 # Look up tool in cache to get gateway info
1387 # First-Party
1388 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
1390 tool_info = await tool_lookup_cache.get(tool_name)
1391 if not tool_info:
1392 logger.debug(f"Tool {tool_name} not found in cache, skipping session mapping registration")
1393 return
1395 # Extract gateway information
1396 gateway = tool_info.get("gateway", {})
1397 gateway_url = gateway.get("url")
1398 gateway_id = gateway.get("id")
1399 transport = gateway.get("transport")
1401 if not gateway_url or not gateway_id or not transport:
1402 logger.debug(f"Incomplete gateway info for tool {tool_name}, skipping session mapping registration")
1403 return
1405 # Register the session mapping with the pool
1406 # First-Party
1407 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
1409 pool = get_mcp_session_pool()
1410 await pool.register_session_mapping(
1411 session_id,
1412 gateway_url,
1413 gateway_id,
1414 transport,
1415 user_email,
1416 )
1418 logger.debug(f"Registered session mapping for session {session_id[:8]}... -> {gateway_url} (tool: {tool_name})")
1420 except Exception as e:
1421 # Don't fail the broadcast if session mapping registration fails
1422 logger.warning(f"Failed to register session mapping for {session_id[:8]}...: {e}")
1424 async def get_all_session_ids(self) -> list[str]:
1425 """Return a snapshot list of all known local session IDs.
1427 Returns:
1428 list[str]: A snapshot list of currently known local session IDs.
1429 """
1430 async with self._lock:
1431 return list(self._sessions.keys())
1433 def get_session_sync(self, session_id: str) -> Any:
1434 """Get session synchronously from local cache only.
1436 This is a non-blocking method that only checks the local cache,
1437 not the distributed backend. Use this when you need quick access
1438 and know the session should be local.
1440 Args:
1441 session_id: Session identifier to look up.
1443 Returns:
1444 SSETransport object if found in local cache, None otherwise.
1446 Examples:
1447 >>> from mcpgateway.cache.session_registry import SessionRegistry
1448 >>> import asyncio
1449 >>>
1450 >>> class MockTransport:
1451 ... pass
1452 >>>
1453 >>> reg = SessionRegistry()
1454 >>> transport = MockTransport()
1455 >>> asyncio.run(reg.add_session('sync-test', transport))
1456 >>>
1457 >>> # Synchronous lookup
1458 >>> found = reg.get_session_sync('sync-test')
1459 >>> found is transport
1460 True
1461 >>>
1462 >>> # Not found
1463 >>> reg.get_session_sync('nonexistent') is None
1464 True
1465 """
1466 # Skip for none backend
1467 if self._backend == "none":
1468 return None
1470 return self._sessions.get(session_id)
1472 async def respond(
1473 self,
1474 server_id: Optional[str],
1475 user: Dict[str, Any],
1476 session_id: str,
1477 ) -> None:
1478 """Process and respond to broadcast messages for a session.
1480 This method listens for messages directed to the specified session and
1481 generates appropriate responses. The listening mechanism depends on the backend:
1483 - **memory**: Checks the temporary message storage
1484 - **redis**: Subscribes to Redis pubsub channel
1485 - **database**: Polls database for new messages
1487 When a message is received and the transport exists locally, it processes
1488 the message and sends the response through the transport.
1490 Args:
1491 server_id: Optional server identifier for scoped operations.
1492 user: User information including authentication token.
1493 session_id: Session identifier to respond for.
1495 Raises:
1496 asyncio.CancelledError: When the respond task is cancelled (e.g., on session removal).
1498 Examples:
1499 >>> import asyncio
1500 >>> from mcpgateway.cache.session_registry import SessionRegistry
1501 >>>
1502 >>> # This method is typically called internally by the SSE handler
1503 >>> reg = SessionRegistry()
1504 >>> user = {'token': 'test-token'}
1505 >>> # asyncio.run(reg.respond(None, user, 'session-id'))
1506 """
1508 if self._backend == "none":
1509 pass
1511 elif self._backend == "memory":
1512 transport = self.get_session_sync(session_id)
1513 if transport and self._session_message:
1514 message_json = self._session_message.get("message")
1515 if message_json:
1516 data = orjson.loads(message_json)
1517 if isinstance(data, dict) and "message" in data:
1518 message = data["message"]
1519 else:
1520 message = data
1521 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
1522 else:
1523 logger.warning(f"Session message stored but message content is None for session {session_id}")
1525 elif self._backend == "redis":
1526 if not self._redis:
1527 logger.warning(f"Redis client not initialized, cannot respond to {session_id}")
1528 return
1529 pubsub = self._redis.pubsub()
1530 await pubsub.subscribe(session_id)
1532 # Use timeout-based polling instead of infinite listen() to allow exit checks
1533 # This is critical for allowing cancellation to work (Finding 2)
1534 poll_timeout = 1.0 # Check every second if session still exists
1536 try:
1537 while True:
1538 # Check if session still exists or is closing - exit early
1539 if session_id not in self._sessions or session_id in self._closing_sessions:
1540 logger.info(f"Session {session_id} removed or closing, exiting Redis respond loop")
1541 break
1543 # Use get_message with timeout instead of blocking listen()
1544 try:
1545 msg = await asyncio.wait_for(
1546 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), timeout=poll_timeout + 0.5 # Slightly longer to account for Redis timeout
1547 )
1548 except asyncio.TimeoutError:
1549 # No message, loop back to check session existence
1550 continue
1552 if msg is None:
1553 # CRITICAL: Sleep to prevent tight loop when get_message returns immediately
1554 # This can happen in certain Redis states after disconnects
1555 await asyncio.sleep(0.1)
1556 continue
1557 if msg["type"] != "message":
1558 # Sleep on non-message types to prevent spin in edge cases
1559 await asyncio.sleep(0.1)
1560 continue
1562 data = orjson.loads(msg["data"])
1563 message = data.get("message", {})
1564 transport = self.get_session_sync(session_id)
1565 if transport:
1566 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
1567 except asyncio.CancelledError:
1568 logger.info(f"PubSub listener for session {session_id} cancelled")
1569 raise # Re-raise to properly complete cancellation
1570 except Exception as e:
1571 logger.error(f"PubSub listener error for session {session_id}: {e}")
1572 finally:
1573 # Pubsub cleanup first - use timeouts to prevent blocking
1574 cleanup_timeout = settings.mcp_session_pool_cleanup_timeout
1575 try:
1576 await asyncio.wait_for(pubsub.unsubscribe(session_id), timeout=cleanup_timeout)
1577 except asyncio.TimeoutError:
1578 logger.debug(f"Pubsub unsubscribe timed out for session {session_id}")
1579 except Exception as e:
1580 logger.debug(f"Error unsubscribing pubsub for session {session_id}: {e}")
1581 try:
1582 try:
1583 await asyncio.wait_for(pubsub.aclose(), timeout=cleanup_timeout)
1584 except AttributeError:
1585 await asyncio.wait_for(pubsub.close(), timeout=cleanup_timeout)
1586 except asyncio.TimeoutError:
1587 logger.debug(f"Pubsub close timed out for session {session_id}")
1588 except Exception as e:
1589 logger.debug(f"Error closing pubsub for session {session_id}: {e}")
1590 logger.info(f"Cleaned up pubsub for session {session_id}")
1591 # Clean up task reference LAST (idempotent - may already be removed by _cancel_respond_task)
1592 self._respond_tasks.pop(session_id, None)
1594 elif self._backend == "database":
1596 def _db_read_session_and_message(
1597 session_id: str,
1598 ) -> tuple[SessionRecord | None, SessionMessageRecord | None]:
1599 """
1600 Check whether a session exists and retrieve its next pending message
1601 in a single database query.
1603 This function performs a LEFT OUTER JOIN between SessionRecord and
1604 SessionMessageRecord to determine:
1606 - Whether the session still exists
1607 - Whether there is a pending message for the session (FIFO order)
1609 It is used by the database-backed message polling loop to reduce
1610 database load by collapsing multiple reads into a single query.
1612 Messages are returned in FIFO order based on the message primary key.
1614 This function is designed to be run in a thread executor to avoid
1615 blocking the async event loop during database access.
1617 Args:
1618 session_id: The session identifier to look up.
1620 Returns:
1621 Tuple[SessionRecord | None, SessionMessageRecord | None]:
1623 - (None, None)
1624 The session does not exist.
1626 - (SessionRecord, None)
1627 The session exists but has no pending messages.
1629 - (SessionRecord, SessionMessageRecord)
1630 The session exists and has a pending message.
1632 Raises:
1633 Exception: Any database error is re-raised after rollback.
1635 Examples:
1636 >>> # This function is called internally by message_check_loop()
1637 >>> # Session exists and has a pending message
1638 >>> # Returns (SessionRecord, SessionMessageRecord)
1640 >>> # Session exists but has no pending messages
1641 >>> # Returns (SessionRecord, None)
1643 >>> # Session has been removed
1644 >>> # Returns (None, None)
1645 """
1646 db_session = next(get_db())
1647 try:
1648 result = (
1649 db_session.query(SessionRecord, SessionMessageRecord)
1650 .outerjoin(
1651 SessionMessageRecord,
1652 SessionMessageRecord.session_id == SessionRecord.session_id,
1653 )
1654 .filter(SessionRecord.session_id == session_id)
1655 .order_by(SessionMessageRecord.id.asc())
1656 .first()
1657 )
1658 if not result:
1659 return None, None
1660 session, message = result
1661 return session, message
1662 except Exception as ex:
1663 db_session.rollback()
1664 raise ex
1665 finally:
1666 db_session.close()
1668 def _db_remove(session_id: str, message: str) -> None:
1669 """Remove processed message from the database.
1671 Deletes a specific message record after it has been successfully
1672 processed and sent to the transport. This prevents duplicate
1673 message delivery.
1675 This inner function is designed to be run in a thread executor
1676 to avoid blocking the async event loop during database deletes.
1678 Args:
1679 session_id: The session identifier the message belongs to.
1680 message: The exact message content to remove (must match exactly).
1682 Raises:
1683 Exception: Any database error is re-raised after rollback.
1685 Examples:
1686 >>> # This function is called internally after message processing
1687 >>> # Deletes the specific SessionMessageRecord entry
1688 >>> # Log: "Removed message from mcp_messages table"
1689 """
1690 db_session = next(get_db())
1691 try:
1692 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete()
1693 db_session.commit()
1694 logger.info("Removed message from mcp_messages table")
1695 except Exception as ex:
1696 db_session.rollback()
1697 raise ex
1698 finally:
1699 db_session.close()
1701 async def message_check_loop(session_id: str) -> None:
1702 """
1703 Background task that polls the database for messages belonging to a session
1704 using adaptive polling with exponential backoff.
1706 The loop continues until the session is removed from the database.
1708 Behavior:
1709 - Starts with a fast polling interval for low-latency message delivery.
1710 - When no message is found, the polling interval increases exponentially
1711 (up to a configured maximum) to reduce database load.
1712 - When a message is received, the polling interval is immediately reset
1713 to the fast interval.
1714 - The loop exits as soon as the session no longer exists.
1716 Polling rules:
1717 - Message found → process message, reset polling interval.
1718 - No message → increase polling interval (backoff).
1719 - Session gone → stop polling immediately.
1721 Args:
1722 session_id (str): Unique identifier of the session to monitor.
1724 Raises:
1725 asyncio.CancelledError: When the polling loop is cancelled.
1727 Examples
1728 --------
1729 Adaptive backoff when no messages are present:
1731 >>> poll_interval = 0.1
1732 >>> backoff_factor = 1.5
1733 >>> max_interval = 5.0
1734 >>> poll_interval = min(poll_interval * backoff_factor, max_interval)
1735 >>> poll_interval
1736 0.15000000000000002
1738 Backoff continues until the maximum interval is reached:
1740 >>> poll_interval = 4.0
1741 >>> poll_interval = min(poll_interval * 1.5, 5.0)
1742 >>> poll_interval
1743 5.0
1745 Polling interval resets immediately when a message arrives:
1747 >>> poll_interval = 2.0
1748 >>> poll_interval = 0.1
1749 >>> poll_interval
1750 0.1
1752 Session termination stops polling:
1754 >>> session_exists = False
1755 >>> if not session_exists:
1756 ... "polling stopped"
1757 'polling stopped'
1758 """
1760 poll_interval = settings.poll_interval # start fast
1761 max_interval = settings.max_interval # cap at configured maximum
1762 backoff_factor = settings.backoff_factor
1763 try:
1764 while True:
1765 # Check if session is closing before querying DB
1766 if session_id in self._closing_sessions:
1767 logger.debug("Session %s closing, stopping poll loop early", session_id)
1768 break
1770 session, record = await asyncio.to_thread(_db_read_session_and_message, session_id)
1772 # session gone → stop polling
1773 if not session:
1774 logger.debug("Session %s no longer exists, stopping poll loop", session_id)
1775 break
1777 if record:
1778 poll_interval = settings.poll_interval # reset on activity
1780 data = orjson.loads(record.message)
1781 if isinstance(data, dict) and "message" in data:
1782 message = data["message"]
1783 else:
1784 message = data
1786 transport = self.get_session_sync(session_id)
1787 if transport:
1788 logger.info("Ready to respond")
1789 await self.generate_response(
1790 message=message,
1791 transport=transport,
1792 server_id=server_id,
1793 user=user,
1794 )
1796 await asyncio.to_thread(_db_remove, session_id, record.message)
1797 else:
1798 # no message → backoff
1799 # update polling interval with backoff factor
1800 poll_interval = min(poll_interval * backoff_factor, max_interval)
1802 await asyncio.sleep(poll_interval)
1803 except asyncio.CancelledError:
1804 logger.info(f"Message check loop cancelled for session {session_id}")
1805 raise # Re-raise to properly complete cancellation
1806 except Exception as e:
1807 logger.error(f"Message check loop error for session {session_id}: {e}")
1809 # CRITICAL: Await instead of fire-and-forget
1810 # This ensures CancelledError propagates from outer respond() task to inner loop
1811 # The outer task (registered from main.py) now runs until message_check_loop exits
1812 try:
1813 await message_check_loop(session_id)
1814 except asyncio.CancelledError:
1815 logger.info(f"Database respond cancelled for session {session_id}")
1816 raise
1817 finally:
1818 # Clean up task reference on ANY exit (normal, cancelled, or error)
1819 # Prevents stale done tasks from accumulating in _respond_tasks
1820 self._respond_tasks.pop(session_id, None)
1822 async def _refresh_redis_sessions(self) -> None:
1823 """Refresh TTLs for Redis sessions and clean up disconnected sessions.
1825 This internal method is used by the Redis backend to maintain session state.
1826 It checks all local sessions, refreshes TTLs for connected sessions, and
1827 removes disconnected ones.
1828 """
1829 if not self._redis:
1830 return
1831 try:
1832 # Check all local sessions
1833 local_transports = {}
1834 async with self._lock:
1835 local_transports = self._sessions.copy()
1837 for session_id, transport in local_transports.items():
1838 try:
1839 if await transport.is_connected():
1840 # Refresh TTL in Redis
1841 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl)
1842 else:
1843 # Remove disconnected session
1844 await self.remove_session(session_id)
1845 except Exception as e:
1846 logger.error(f"Error refreshing session {session_id}: {e}")
1848 except Exception as e:
1849 logger.error(f"Error in Redis session refresh: {e}")
1851 async def _db_cleanup_task(self) -> None:
1852 """Background task to clean up expired database sessions.
1854 Runs periodically (every 5 minutes) to remove expired sessions from the
1855 database and refresh timestamps for active sessions. This prevents the
1856 database from accumulating stale session records.
1858 The task also verifies that local sessions still exist in the database
1859 and removes them locally if they've been deleted elsewhere.
1861 Raises:
1862 asyncio.CancelledError: If the task is cancelled during shutdown.
1863 """
1864 logger.info("Starting database cleanup task")
1865 while True:
1866 try:
1867 # Clean up expired sessions every 5 minutes
1868 def _db_cleanup() -> int:
1869 """Remove expired sessions from the database.
1871 Deletes all SessionRecord entries that haven't been accessed
1872 within the session TTL period. Uses database-specific date
1873 arithmetic to calculate expiry time.
1875 This inner function is designed to be run in a thread executor
1876 to avoid blocking the async event loop during bulk deletes.
1878 Returns:
1879 int: Number of expired session records deleted.
1881 Raises:
1882 Exception: Any database error is re-raised after rollback.
1884 Examples:
1885 >>> # This function is called periodically by _db_cleanup_task()
1886 >>> # Deletes sessions older than session_ttl seconds
1887 >>> # Returns count of deleted records for logging
1888 >>> # Log: "Cleaned up 5 expired database sessions"
1889 """
1890 db_session = next(get_db())
1891 try:
1892 # Delete sessions that haven't been accessed for TTL seconds
1893 # Use Python datetime for database-agnostic expiry calculation
1894 expiry_time = datetime.now(timezone.utc) - timedelta(seconds=self._session_ttl)
1895 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete()
1896 db_session.commit()
1897 return result
1898 except Exception as ex:
1899 db_session.rollback()
1900 raise ex
1901 finally:
1902 db_session.close()
1904 deleted = await asyncio.to_thread(_db_cleanup)
1905 if deleted > 0:
1906 logger.info(f"Cleaned up {deleted} expired database sessions")
1908 # Check local sessions against database
1909 await self._cleanup_database_sessions()
1911 await asyncio.sleep(300) # Run every 5 minutes
1913 except asyncio.CancelledError:
1914 logger.info("Database cleanup task cancelled")
1915 raise
1916 except Exception as e:
1917 logger.error(f"Error in database cleanup task: {e}")
1918 await asyncio.sleep(600) # Sleep longer on error
1920 def _refresh_session_db(self, session_id: str) -> bool:
1921 """Update session's last accessed timestamp in the database.
1923 Refreshes the last_accessed field for an active session to
1924 prevent it from being cleaned up as expired. This is called
1925 periodically for all local sessions with active transports.
1927 Args:
1928 session_id: The session identifier to refresh.
1930 Returns:
1931 bool: True if the session was found and updated, False if not found.
1933 Raises:
1934 Exception: Any database error is re-raised after rollback.
1935 """
1936 db_session = next(get_db())
1937 try:
1938 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
1939 if session:
1940 session.last_accessed = func.now() # pylint: disable=not-callable
1941 db_session.commit()
1942 return True
1943 return False
1944 except Exception as ex:
1945 db_session.rollback()
1946 raise ex
1947 finally:
1948 db_session.close()
1950 async def _cleanup_database_sessions(self, max_concurrent: int = 20) -> None:
1951 """Parallelize session cleanup with bounded concurrency.
1953 Checks connection status first (fast), then refreshes connected sessions
1954 in parallel using asyncio.gather() with a semaphore to limit concurrent
1955 DB operations and prevent resource exhaustion.
1957 Args:
1958 max_concurrent: Maximum number of concurrent DB refresh operations.
1959 Defaults to 20 to balance parallelism with resource usage.
1960 """
1961 async with self._lock:
1962 local_transports = self._sessions.copy()
1964 # Check connections first (fast)
1965 connected: list[str] = []
1966 for session_id, transport in local_transports.items():
1967 try:
1968 if not await transport.is_connected():
1969 await self.remove_session(session_id)
1970 else:
1971 connected.append(session_id)
1972 except Exception as e:
1973 # Only log error, don't remove session on transient errors
1974 logger.error(f"Error checking connection for session {session_id}: {e}")
1976 # Parallel refresh of connected sessions with bounded concurrency
1977 if connected:
1978 semaphore = asyncio.Semaphore(max_concurrent)
1980 async def bounded_refresh(session_id: str) -> bool:
1981 """Refresh session with semaphore-bounded concurrency.
1983 Args:
1984 session_id: The session ID to refresh.
1986 Returns:
1987 True if refresh succeeded, False otherwise.
1988 """
1989 async with semaphore:
1990 return await asyncio.to_thread(self._refresh_session_db, session_id)
1992 refresh_tasks = [bounded_refresh(session_id) for session_id in connected]
1993 results = await asyncio.gather(*refresh_tasks, return_exceptions=True)
1995 for session_id, result in zip(connected, results):
1996 try:
1997 if isinstance(result, Exception):
1998 # Only log error, don't remove session on transient DB errors
1999 logger.error(f"Error refreshing session {session_id}: {result}")
2000 elif not result:
2001 # Session no longer in database, remove locally
2002 await self.remove_session(session_id)
2003 except Exception as e:
2004 logger.error(f"Error processing refresh result for session {session_id}: {e}")
2006 async def _memory_cleanup_task(self) -> None:
2007 """Background task to clean up disconnected sessions in memory backend.
2009 Runs periodically (every minute) to check all local sessions and remove
2010 those that are no longer connected. This prevents memory leaks from
2011 accumulating disconnected transport objects.
2013 Raises:
2014 asyncio.CancelledError: If the task is cancelled during shutdown.
2015 """
2016 logger.info("Starting memory cleanup task")
2017 while True:
2018 try:
2019 # Check all local sessions
2020 local_transports = {}
2021 async with self._lock:
2022 local_transports = self._sessions.copy()
2024 for session_id, transport in local_transports.items():
2025 try:
2026 if not await transport.is_connected():
2027 await self.remove_session(session_id)
2028 except Exception as e:
2029 logger.error(f"Error checking session {session_id}: {e}")
2030 await self.remove_session(session_id)
2032 await asyncio.sleep(60) # Run every minute
2034 except asyncio.CancelledError:
2035 logger.info("Memory cleanup task cancelled")
2036 raise
2037 except Exception as e:
2038 logger.error(f"Error in memory cleanup task: {e}")
2039 await asyncio.sleep(300) # Sleep longer on error
2041 def _get_oauth_experimental_config(self, server_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
2042 """Query OAuth configuration for a server (synchronous, run in threadpool).
2044 This method queries the database for OAuth configuration and returns
2045 RFC 9728-safe fields for advertising in MCP capabilities.
2047 Args:
2048 server_id: The server ID to query OAuth configuration for.
2050 Returns:
2051 Dict with 'oauth' key containing safe OAuth config, or None if not configured.
2052 """
2053 # First-Party
2054 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
2055 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel
2057 db = SessionLocal()
2058 try:
2059 server = db.get(DbServer, server_id)
2060 if server and getattr(server, "oauth_enabled", False) and getattr(server, "oauth_config", None):
2061 # Filter oauth_config to RFC 9728-safe fields only (never expose secrets)
2062 oauth_config = server.oauth_config
2063 safe_oauth: Dict[str, Any] = {}
2065 # Extract authorization servers
2066 if oauth_config.get("authorization_servers"):
2067 safe_oauth["authorization_servers"] = oauth_config["authorization_servers"]
2068 elif oauth_config.get("authorization_server"):
2069 safe_oauth["authorization_servers"] = [oauth_config["authorization_server"]]
2071 # Extract scopes
2072 scopes = oauth_config.get("scopes_supported") or oauth_config.get("scopes")
2073 if scopes:
2074 safe_oauth["scopes_supported"] = scopes
2076 # Add bearer methods
2077 safe_oauth["bearer_methods_supported"] = oauth_config.get("bearer_methods_supported", ["header"])
2079 if safe_oauth.get("authorization_servers"):
2080 logger.debug(f"Advertising OAuth capability for server {server_id}")
2081 return {"oauth": safe_oauth}
2082 return None
2083 finally:
2084 db.close()
2086 # Handle initialize logic
2087 async def handle_initialize_logic(self, body: Dict[str, Any], session_id: Optional[str] = None, server_id: Optional[str] = None) -> InitializeResult:
2088 """Process MCP protocol initialization request.
2090 Validates the protocol version and returns server capabilities and information.
2091 This method implements the MCP (Model Context Protocol) initialization handshake.
2093 Args:
2094 body: Request body containing protocol_version and optional client_info.
2095 Expected keys: 'protocol_version' or 'protocolVersion', 'capabilities'.
2096 session_id: Optional session ID to associate client capabilities with.
2097 server_id: Optional server ID to query OAuth configuration for RFC 9728 support.
2099 Returns:
2100 InitializeResult containing protocol version, server capabilities, and server info.
2102 Raises:
2103 HTTPException: If protocol_version is missing (400 Bad Request with MCP error code -32002).
2105 Examples:
2106 >>> import asyncio
2107 >>> from mcpgateway.cache.session_registry import SessionRegistry
2108 >>>
2109 >>> reg = SessionRegistry()
2110 >>> body = {'protocol_version': '2025-06-18'}
2111 >>> result = asyncio.run(reg.handle_initialize_logic(body))
2112 >>> result.protocol_version
2113 '2025-06-18'
2114 >>> result.server_info.name
2115 'ContextForge'
2116 >>>
2117 >>> # Missing protocol version
2118 >>> try:
2119 ... asyncio.run(reg.handle_initialize_logic({}))
2120 ... except HTTPException as e:
2121 ... e.status_code
2122 400
2123 """
2124 # First-Party
2125 from mcpgateway.observability import create_span # pylint: disable=import-outside-toplevel
2127 protocol_version = body.get("protocol_version") or body.get("protocolVersion")
2128 client_capabilities = body.get("capabilities", {})
2129 span_attributes: Dict[str, Any] = {
2130 "mcp.protocol_version": protocol_version,
2131 "mcp.session_id": session_id,
2132 "server.id": server_id,
2133 }
2135 with create_span("mcp.initialize", span_attributes):
2136 # body.get("client_info") or body.get("clientInfo", {})
2138 if not protocol_version:
2139 raise HTTPException(
2140 status_code=status.HTTP_400_BAD_REQUEST,
2141 detail="Missing protocol version",
2142 headers={"MCP-Error-Code": "-32002"},
2143 )
2145 if protocol_version != settings.protocol_version:
2146 logger.warning(f"Using non default protocol version: {protocol_version}")
2148 # Store client capabilities if session_id provided
2149 if session_id and client_capabilities:
2150 await self.store_client_capabilities(session_id, client_capabilities)
2151 logger.debug(f"Stored capabilities for session {session_id}: {client_capabilities}")
2153 # Build experimental capabilities (including OAuth if configured)
2154 experimental: Optional[Dict[str, Dict[str, Any]]] = None
2156 # Query OAuth configuration if server_id is provided
2157 if server_id:
2158 try:
2159 # Run synchronous DB query in threadpool to avoid blocking the event loop
2160 experimental = await asyncio.to_thread(self._get_oauth_experimental_config, server_id)
2161 except Exception as e:
2162 logger.warning(f"Failed to query OAuth config for server {server_id}: {e}")
2164 return InitializeResult(
2165 protocolVersion=protocol_version,
2166 capabilities=ServerCapabilities(
2167 prompts={"listChanged": True},
2168 resources={"subscribe": True, "listChanged": True},
2169 tools={"listChanged": True},
2170 logging={},
2171 completions={}, # Advertise completions capability per MCP spec
2172 experimental=experimental, # OAuth capability when configured
2173 ),
2174 serverInfo=Implementation(name=settings.app_name, version=__version__),
2175 instructions=("ContextForge providing federated tools, resources and prompts. Use /admin interface for configuration."),
2176 )
2178 async def store_client_capabilities(self, session_id: str, capabilities: Dict[str, Any]) -> None:
2179 """Store client capabilities for a session.
2181 Args:
2182 session_id: The session ID
2183 capabilities: Client capabilities dictionary from initialize request
2184 """
2185 async with self._lock:
2186 self._client_capabilities[session_id] = capabilities
2187 logger.debug(f"Stored capabilities for session {session_id}")
2189 async def get_client_capabilities(self, session_id: str) -> Optional[Dict[str, Any]]:
2190 """Get client capabilities for a session.
2192 Args:
2193 session_id: The session ID
2195 Returns:
2196 Client capabilities dictionary, or None if not found
2197 """
2198 async with self._lock:
2199 return self._client_capabilities.get(session_id)
2201 async def has_elicitation_capability(self, session_id: str) -> bool:
2202 """Check if a session has elicitation capability.
2204 Args:
2205 session_id: The session ID
2207 Returns:
2208 True if session supports elicitation, False otherwise
2209 """
2210 capabilities = await self.get_client_capabilities(session_id)
2211 if not capabilities:
2212 return False
2213 # Check if elicitation capability exists in client capabilities
2214 return bool(capabilities.get("elicitation"))
2216 async def get_elicitation_capable_sessions(self) -> list[str]:
2217 """Get list of session IDs that support elicitation.
2219 Returns:
2220 List of session IDs with elicitation capability
2221 """
2222 async with self._lock:
2223 capable_sessions = []
2224 for session_id, capabilities in self._client_capabilities.items():
2225 if capabilities.get("elicitation"):
2226 # Verify session still exists
2227 if session_id in self._sessions:
2228 capable_sessions.append(session_id)
2229 return capable_sessions
2231 async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any]) -> None:
2232 """Generate and send response for incoming MCP protocol message.
2234 Processes MCP protocol messages and generates appropriate responses based on
2235 the method. Supports various MCP methods including initialization, tool/resource/prompt
2236 listing, tool invocation, and ping.
2238 Uses loopback (127.0.0.1) for the internal RPC call to avoid issues when
2239 the gateway is behind a reverse proxy or service mesh where the client-facing
2240 URL is not reachable from the server itself.
2242 Args:
2243 message: Incoming MCP message as JSON. Must contain 'method' and 'id' fields.
2244 transport: SSE transport to send responses through.
2245 server_id: Optional server ID for scoped operations.
2246 user: User information containing authentication token.
2248 Examples:
2249 >>> import asyncio
2250 >>> from mcpgateway.cache.session_registry import SessionRegistry
2251 >>>
2252 >>> class MockTransport:
2253 ... async def send_message(self, msg):
2254 ... print(f"Response: {msg['method'] if 'method' in msg else msg.get('result', {})}")
2255 >>>
2256 >>> reg = SessionRegistry()
2257 >>> transport = MockTransport()
2258 >>> message = {"method": "ping", "id": 1}
2259 >>> user = {"token": "test-token"}
2260 >>> # asyncio.run(reg.generate_response(message, transport, None, user))
2261 >>> # Response: {}
2262 """
2263 result = {}
2265 if "method" in message and "id" in message:
2266 method = message["method"]
2267 params = message.get("params", {})
2268 params["server_id"] = server_id
2269 req_id = message["id"]
2271 rpc_input = {
2272 "jsonrpc": "2.0",
2273 "method": method,
2274 "params": params,
2275 "id": req_id,
2276 }
2277 # Get the token from the current authentication context
2278 # The user object should contain auth_token, token_teams, and is_admin from the SSE endpoint
2279 token = None
2280 is_admin = user.get("is_admin", False) # Preserve admin status from SSE endpoint
2282 try:
2283 if hasattr(user, "get") and user.get("auth_token"):
2284 token = user["auth_token"]
2285 else:
2286 # Fallback: create lightweight session token (teams resolved server-side by downstream /rpc)
2287 logger.warning("No auth token available for SSE RPC call - creating fallback session token")
2288 now = datetime.now(timezone.utc)
2289 payload = {
2290 "sub": user.get("email", "system"),
2291 "iss": settings.jwt_issuer,
2292 "aud": settings.jwt_audience,
2293 "iat": int(now.timestamp()),
2294 "jti": str(uuid.uuid4()),
2295 "token_use": "session", # nosec B105 - token type marker, not a password
2296 "user": {
2297 "email": user.get("email", "system"),
2298 "full_name": user.get("full_name", "System"),
2299 "is_admin": is_admin, # Preserve admin status for cookie-authenticated admins
2300 "auth_provider": "internal",
2301 },
2302 }
2303 # Generate token using centralized token creation
2304 token = await create_jwt_token(payload)
2306 # Pass downstream session id to /rpc for session affinity.
2307 # This is gateway-internal only; the pool strips it before contacting upstream MCP servers.
2308 if settings.mcpgateway_session_affinity_enabled:
2309 await self._register_session_mapping(transport.session_id, message, user.get("email") if hasattr(user, "get") else None)
2311 headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
2312 if settings.mcpgateway_session_affinity_enabled:
2313 headers["x-mcp-session-id"] = transport.session_id
2314 # Forward passthrough headers captured at SSE connection time (see #3640).
2315 # This ensures X-Upstream-Authorization and other client passthrough headers
2316 # reach the /rpc endpoint, which then forwards them to upstream MCP servers.
2317 # Defense-in-depth: filter via filter_loopback_skip_headers() so passthrough
2318 # can never override the gateway's internal JWT, content-type, or session/routing headers.
2319 # First-Party
2320 from mcpgateway.utils.passthrough_headers import filter_loopback_skip_headers # pylint: disable=import-outside-toplevel
2322 passthrough = user.get("_passthrough_headers") or {}
2323 if passthrough and isinstance(passthrough, dict):
2324 headers.update(filter_loopback_skip_headers(passthrough))
2325 # Use loopback for internal RPC call (consistent with other self-call sites
2326 # in mcp_session_pool.py and streamablehttp_transport.py). This avoids
2327 # failures when the client-facing URL is not reachable from the server
2328 # (e.g., behind a reverse proxy or service mesh). See #3049.
2329 rpc_url = f"{internal_loopback_base_url()}/rpc"
2331 logger.info(f"SSE RPC: Making call to {rpc_url} with method={method}, params={params}")
2333 async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": internal_loopback_verify()}) as client:
2334 logger.info(f"SSE RPC: Sending request to {rpc_url}")
2335 rpc_response = await client.post(
2336 url=rpc_url,
2337 json=rpc_input,
2338 headers=headers,
2339 )
2340 logger.info(f"SSE RPC: Got response status {rpc_response.status_code}")
2341 result = rpc_response.json()
2342 logger.info(f"SSE RPC: Response content: {result}")
2343 result = result.get("result", {})
2345 response = {"jsonrpc": "2.0", "result": result, "id": req_id}
2346 except JSONRPCError as e:
2347 logger.error(f"SSE RPC: JSON-RPC error: {e}")
2348 result = e.to_dict()
2349 response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id}
2350 except Exception as e:
2351 logger.error(f"SSE RPC: Exception during RPC call: {type(e).__name__}: {e}")
2352 logger.error(f"SSE RPC: Traceback: {traceback.format_exc()}")
2353 result = {"code": -32000, "message": "Internal error", "data": str(e)}
2354 response = {"jsonrpc": "2.0", "error": result, "id": req_id}
2356 logging.debug(f"Sending sse message:{response}")
2357 await transport.send_message(response)
2359 if message["method"] == "initialize":
2360 await transport.send_message(
2361 {
2362 "jsonrpc": "2.0",
2363 "method": "notifications/initialized",
2364 "params": {},
2365 }
2366 )
2367 notifications = [
2368 "tools/list_changed",
2369 "resources/list_changed",
2370 "prompts/list_changed",
2371 ]
2372 for notification in notifications:
2373 await transport.send_message(
2374 {
2375 "jsonrpc": "2.0",
2376 "method": f"notifications/{notification}",
2377 "params": {},
2378 }
2379 )