Coverage for mcpgateway / services / notification_service.py: 96%
185 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/notification_service.py
3Copyright 2026
4SPDX-License-Identifier: Apache-2.0
6Authors: Keval Mahajan
8Description:
9 MCP Notification Service for handling server notifications with debounced
10 gateway refresh. Provides centralized notification handling for MCP sessions
11 including tools/list_changed, resources/list_changed, and prompts/list_changed.
13 Key Features:
14 - Debounced refresh to prevent notification storms
15 - Flag merging during debounce (notifications within window merge their refresh flags)
16 - Per-gateway refresh locking to prevent concurrent refresh races
17 - Per-gateway refresh tracking with capability awareness
18 - Compatible with MCPSessionPool for pooled session notification handling
19 - Per-gateway session isolation ensures correct notification attribution
20 - Supports tools, resources, and prompts list_changed notifications
22 Capable of handling other tasks as well like cancellation, progress notifications, etc. (to be implemented here)
24Usage:
25 ```python
26 from mcpgateway.services.notification_service import NotificationService
28 # Create service instance
29 notification_service = NotificationService()
30 await notification_service.initialize()
32 # Create a message handler for a specific gateway
33 handler = notification_service.create_message_handler(gateway_id="gw-123")
35 # Pass handler to ClientSession
36 session = ClientSession(read_stream, write_stream, message_handler=handler)
37 ```
38"""
40# Future
41from __future__ import annotations
43# Standard
44import asyncio
45from dataclasses import dataclass, field
46from enum import Enum
47import time
48from typing import Any, Awaitable, Callable, Dict, Optional, Set, TYPE_CHECKING
50# Third-Party
51from mcp.shared.session import RequestResponder
52import mcp.types as mcp_types
54# First-Party
55from mcpgateway.services.logging_service import LoggingService
57if TYPE_CHECKING:
58 # First-Party
59 from mcpgateway.services.gateway_service import GatewayService
61# Type alias for message handler callback
62MessageHandlerCallback = Callable[
63 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception],
64 Awaitable[None],
65]
67# Initialize logging
68logging_service = LoggingService()
69logger = logging_service.get_logger(__name__)
72class NotificationType(Enum):
73 """Types of MCP list_changed notifications.
75 Attributes:
76 TOOLS_LIST_CHANGED: Notification for tool list changes.
77 RESOURCES_LIST_CHANGED: Notification for resource list changes.
78 PROMPTS_LIST_CHANGED: Notification for prompt list changes.
79 """
81 TOOLS_LIST_CHANGED = "notifications/tools/list_changed"
82 RESOURCES_LIST_CHANGED = "notifications/resources/list_changed"
83 PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"
86@dataclass
87class GatewayCapabilities:
88 """Tracks list_changed capabilities for a gateway.
90 Attributes:
91 tools_list_changed: Whether the gateway supports tool list changes.
92 resources_list_changed: Whether the gateway supports resource list changes.
93 prompts_list_changed: Whether the gateway supports prompt list changes.
94 """
96 tools_list_changed: bool = False
97 resources_list_changed: bool = False
98 prompts_list_changed: bool = False
101def _empty_notification_type_set() -> Set[NotificationType]:
102 """Factory function for creating an empty set of NotificationType.
104 Returns:
105 An empty set typed for NotificationType elements.
106 """
107 return set()
110@dataclass
111class PendingRefresh:
112 """Represents a pending refresh operation with debounce tracking.
114 Attributes:
115 gateway_id: The ID of the gateway to refresh.
116 enqueued_at: The timestamp when the refresh was enqueued.
117 include_resources: Whether to include resources in the refresh.
118 include_prompts: Whether to include prompts in the refresh.
119 triggered_by: The set of notification types that triggered this refresh.
120 """
122 gateway_id: str
123 enqueued_at: float = field(default_factory=time.time)
124 include_resources: bool = True
125 include_prompts: bool = True
126 # Track which notification types triggered this refresh
127 triggered_by: Set[NotificationType] = field(default_factory=_empty_notification_type_set)
130class NotificationService:
131 """Centralized service for handling MCP server notifications.
133 Provides debounced refresh triggering based on list_changed notifications
134 from MCP servers. Works with MCPSessionPool to handle notifications for
135 pooled sessions while maintaining session isolation.
137 Attributes:
138 debounce_seconds: Minimum time between refresh operations for same gateway.
139 max_queue_size: Maximum pending refreshes in the queue.
141 Example:
142 >>> service = NotificationService(debounce_seconds=5.0)
143 >>> service.debounce_seconds
144 5.0
145 >>> service._gateway_capabilities == {}
146 True
147 """
149 def __init__(
150 self,
151 debounce_seconds: float = 5.0,
152 max_queue_size: int = 100,
153 ) -> None:
154 """Initialize the NotificationService.
156 Args:
157 debounce_seconds: Minimum time between refreshes for same gateway.
158 max_queue_size: Maximum number of pending refreshes in queue.
160 Example:
161 >>> service = NotificationService(debounce_seconds=10.0, max_queue_size=50)
162 >>> service.debounce_seconds
163 10.0
164 >>> service._max_queue_size
165 50
166 """
167 self.debounce_seconds = debounce_seconds
168 self._max_queue_size = max_queue_size
170 # Track gateway capabilities for list_changed support
171 self._gateway_capabilities: Dict[str, GatewayCapabilities] = {}
173 # Debounce tracking: gateway_id -> last refresh enqueue time
174 self._last_refresh_enqueued: Dict[str, float] = {}
176 # Track pending refreshes by gateway_id to allow flag merging during debounce
177 # When a notification arrives during debounce window, we merge flags instead of dropping
178 self._pending_refresh_flags: Dict[str, PendingRefresh] = {}
180 # Pending refresh queue
181 self._refresh_queue: asyncio.Queue[PendingRefresh] = asyncio.Queue(maxsize=max_queue_size)
183 # Background worker task
184 self._worker_task: Optional[asyncio.Task[None]] = None
185 self._shutdown_event = asyncio.Event()
187 # Reference to gateway service for refresh operations (set during initialize)
188 self._gateway_service: Optional["GatewayService"] = None
190 # Metrics
191 self._notifications_received = 0
192 self._notifications_debounced = 0
193 self._refreshes_triggered = 0
194 self._refreshes_failed = 0
196 async def initialize(self, gateway_service: Optional["GatewayService"] = None) -> None:
197 """Initialize the notification service and start background worker.
199 Args:
200 gateway_service: Optional GatewayService reference for triggering refreshes.
201 Can be set later via set_gateway_service().
203 Example:
204 >>> import asyncio
205 >>> async def test():
206 ... service = NotificationService()
207 ... await service.initialize()
208 ... is_running = service._worker_task is not None
209 ... await service.shutdown()
210 ... return is_running
211 >>> asyncio.run(test())
212 True
213 """
214 if gateway_service:
215 self._gateway_service = gateway_service
217 self._shutdown_event.clear()
218 self._worker_task = asyncio.create_task(self._process_refresh_queue())
219 logger.info("NotificationService initialized with debounce=%ss", self.debounce_seconds)
221 def set_gateway_service(self, gateway_service: "GatewayService") -> None:
222 """Set the gateway service reference for refresh operations.
224 Args:
225 gateway_service: The GatewayService instance to use for refreshes.
227 Example:
228 >>> from unittest.mock import Mock
229 >>> service = NotificationService()
230 >>> mock_gateway_service = Mock()
231 >>> service.set_gateway_service(mock_gateway_service)
232 """
233 self._gateway_service = gateway_service
235 async def shutdown(self) -> None:
236 """Shutdown the notification service and cleanup resources.
238 Example:
239 >>> import asyncio
240 >>> async def test():
241 ... service = NotificationService()
242 ... await service.initialize()
243 ... await service.shutdown()
244 ... return service._worker_task is None or service._worker_task.done()
245 >>> asyncio.run(test())
246 True
247 """
248 self._shutdown_event.set()
250 if self._worker_task:
251 self._worker_task.cancel()
252 try:
253 await self._worker_task
254 except asyncio.CancelledError:
255 pass
256 self._worker_task = None
258 self._gateway_capabilities.clear()
259 self._last_refresh_enqueued.clear()
260 self._pending_refresh_flags.clear()
261 logger.info("NotificationService shutdown complete")
263 def register_gateway_capabilities(
264 self,
265 gateway_id: str,
266 capabilities: Dict[str, Any],
267 ) -> None:
268 """Register list_changed capabilities for a gateway.
270 Extracts and stores which list_changed notifications the gateway supports
271 based on server capabilities returned during initialization.
273 Args:
274 gateway_id: The gateway ID.
275 capabilities: Server capabilities dict from initialize response.
277 Example:
278 >>> service = NotificationService()
279 >>> caps = {"tools": {"listChanged": True}, "resources": {"listChanged": False}}
280 >>> service.register_gateway_capabilities("gw-1", caps)
281 >>> service.supports_list_changed("gw-1")
282 True
283 >>> service._gateway_capabilities["gw-1"].resources_list_changed
284 False
285 """
286 tools_cap: Dict[str, Any] = capabilities.get("tools", {}) if isinstance(capabilities.get("tools"), dict) else {}
287 resources_cap: Dict[str, Any] = capabilities.get("resources", {}) if isinstance(capabilities.get("resources"), dict) else {}
288 prompts_cap: Dict[str, Any] = capabilities.get("prompts", {}) if isinstance(capabilities.get("prompts"), dict) else {}
290 self._gateway_capabilities[gateway_id] = GatewayCapabilities(
291 tools_list_changed=bool(tools_cap.get("listChanged", False)),
292 resources_list_changed=bool(resources_cap.get("listChanged", False)),
293 prompts_list_changed=bool(prompts_cap.get("listChanged", False)),
294 )
296 logger.debug(
297 "Registered capabilities for gateway %s: tools=%s, resources=%s, prompts=%s",
298 gateway_id,
299 self._gateway_capabilities[gateway_id].tools_list_changed,
300 self._gateway_capabilities[gateway_id].resources_list_changed,
301 self._gateway_capabilities[gateway_id].prompts_list_changed,
302 )
304 def unregister_gateway(self, gateway_id: str) -> None:
305 """Unregister a gateway and cleanup its state.
307 Args:
308 gateway_id: The gateway ID to unregister.
310 Example:
311 >>> service = NotificationService()
312 >>> service.register_gateway_capabilities("gw-1", {"tools": {"listChanged": True}})
313 >>> service.supports_list_changed("gw-1")
314 True
315 >>> service.unregister_gateway("gw-1")
316 >>> service.supports_list_changed("gw-1")
317 False
318 """
319 self._gateway_capabilities.pop(gateway_id, None)
320 self._last_refresh_enqueued.pop(gateway_id, None)
322 def supports_list_changed(self, gateway_id: str) -> bool:
323 """Check if a gateway supports any list_changed notifications.
325 Args:
326 gateway_id: The gateway ID to check.
328 Returns:
329 True if gateway supports at least one list_changed notification type.
331 Example:
332 >>> service = NotificationService()
333 >>> caps = {"tools": {"listChanged": True}}
334 >>> service.register_gateway_capabilities("gw-1", caps)
335 >>> service.supports_list_changed("gw-1")
336 True
337 >>> service.supports_list_changed("gw-unknown")
338 False
339 """
340 caps = self._gateway_capabilities.get(gateway_id)
341 if not caps:
342 return False
343 return caps.tools_list_changed or caps.resources_list_changed or caps.prompts_list_changed
345 def create_message_handler(
346 self,
347 gateway_id: str,
348 gateway_url: Optional[str] = None,
349 ) -> MessageHandlerCallback:
350 """Create a message handler callback for a specific gateway.
352 Returns a callback suitable for passing to ClientSession's message_handler
353 parameter. The handler routes notifications to this service for processing.
355 Args:
356 gateway_id: The gateway ID this handler is for.
357 gateway_url: Optional URL for logging context.
359 Returns:
360 Async callable suitable for ClientSession message_handler.
362 Example:
363 >>> service = NotificationService()
364 >>> handler = service.create_message_handler("gw-123")
365 >>> callable(handler)
366 True
367 """
369 async def message_handler(
370 message: RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception,
371 ) -> None:
372 """Handle incoming messages from MCP server.
374 Args:
375 message: The message received from the server.
376 """
377 # Only process ServerNotification objects
378 if isinstance(message, mcp_types.ServerNotification):
379 await self._handle_notification(gateway_id, message, gateway_url)
380 elif isinstance(message, Exception):
381 logger.warning("Received exception from MCP server %s: %s", gateway_id, message)
382 # RequestResponder messages are handled by the session itself
384 return message_handler
386 async def _handle_notification(
387 self,
388 gateway_id: str,
389 notification: mcp_types.ServerNotification,
390 gateway_url: Optional[str] = None,
391 ) -> None:
392 """Process an incoming server notification.
394 Args:
395 gateway_id: The gateway ID that sent the notification.
396 notification: The notification object.
397 gateway_url: Optional URL for logging context.
398 """
399 self._notifications_received += 1
401 # Extract notification type from the notification object
402 # ServerNotification has a 'root' attribute containing the actual notification
403 notification_root = notification.root
405 # Check for list_changed notifications
406 notification_type: Optional[NotificationType] = None
408 # Match notification types - check class names since mcp.types may vary
409 root_class = type(notification_root).__name__
411 if "ToolListChangedNotification" in root_class or "ToolsListChangedNotification" in root_class:
412 notification_type = NotificationType.TOOLS_LIST_CHANGED
413 elif "ResourceListChangedNotification" in root_class or "ResourcesListChangedNotification" in root_class:
414 notification_type = NotificationType.RESOURCES_LIST_CHANGED
415 elif "PromptListChangedNotification" in root_class or "PromptsListChangedNotification" in root_class:
416 notification_type = NotificationType.PROMPTS_LIST_CHANGED
418 if notification_type:
419 logger.info(
420 "Received %s notification from gateway %s (%s)",
421 notification_type.value,
422 gateway_id,
423 gateway_url or "unknown",
424 )
425 await self._enqueue_refresh(gateway_id, notification_type)
426 else:
427 logger.info(
428 "Received notification from gateway %s: %s",
429 gateway_id,
430 root_class,
431 )
433 async def _enqueue_refresh(
434 self,
435 gateway_id: str,
436 notification_type: NotificationType,
437 ) -> None:
438 """Enqueue a refresh operation with debouncing and flag merging.
440 When notifications arrive during the debounce window, their flags are
441 merged into the pending refresh instead of being dropped. This ensures
442 that if tools/list_changed arrives after resources/list_changed within
443 the debounce window, tools will still be refreshed.
445 Args:
446 gateway_id: The gateway to refresh.
447 notification_type: The type of notification that triggered this.
448 """
449 now = time.time()
450 last_enqueued = self._last_refresh_enqueued.get(gateway_id, 0)
452 # Determine what to include based on notification type
453 include_resources = notification_type == NotificationType.RESOURCES_LIST_CHANGED
454 include_prompts = notification_type == NotificationType.PROMPTS_LIST_CHANGED
456 # For tools notification, include everything as tools are always primary
457 if notification_type == NotificationType.TOOLS_LIST_CHANGED:
458 include_resources = True
459 include_prompts = True
461 # Debounce: if within window, merge flags into pending refresh instead of dropping
462 if now - last_enqueued < self.debounce_seconds:
463 existing = self._pending_refresh_flags.get(gateway_id)
464 if existing:
465 # Merge flags - use OR to include all requested types
466 existing.include_resources = existing.include_resources or include_resources
467 existing.include_prompts = existing.include_prompts or include_prompts
468 existing.triggered_by.add(notification_type)
469 self._notifications_debounced += 1
470 logger.debug(
471 "Merged %s into pending refresh for gateway %s (resources=%s, prompts=%s)",
472 notification_type.value,
473 gateway_id,
474 existing.include_resources,
475 existing.include_prompts,
476 )
477 return
479 # No pending refresh found but within debounce - this shouldn't happen normally
480 # but can occur if the refresh was already processed. Count as debounced.
481 self._notifications_debounced += 1
482 logger.debug(
483 "Debounced refresh for gateway %s (last enqueued %.1fs ago, no pending)",
484 gateway_id,
485 now - last_enqueued,
486 )
487 return
489 # Create new pending refresh
490 pending = PendingRefresh(
491 gateway_id=gateway_id,
492 include_resources=include_resources,
493 include_prompts=include_prompts,
494 triggered_by={notification_type},
495 )
497 try:
498 self._refresh_queue.put_nowait(pending)
499 self._last_refresh_enqueued[gateway_id] = now
500 self._pending_refresh_flags[gateway_id] = pending # Track for flag merging
501 logger.info(
502 "Enqueued refresh for gateway %s (triggered by %s)",
503 gateway_id,
504 notification_type.value,
505 )
506 except asyncio.QueueFull:
507 logger.warning(
508 "Refresh queue full, dropping refresh request for gateway %s",
509 gateway_id,
510 )
512 async def _process_refresh_queue(self) -> None:
513 """Background worker that processes pending refresh operations.
515 Continuously runs until shutdown is triggered, picking up pending
516 refreshes from the queue and executing them.
518 Raises:
519 asyncio.CancelledError: If the task is cancelled during shutdown.
520 """
521 logger.info("NotificationService refresh worker started")
523 while not self._shutdown_event.is_set():
524 try:
525 # Wait for pending refresh with timeout to allow shutdown check
526 try:
527 pending = await asyncio.wait_for(
528 self._refresh_queue.get(),
529 timeout=1.0,
530 )
531 except asyncio.TimeoutError:
532 continue
534 await self._execute_refresh(pending)
535 self._refresh_queue.task_done()
537 except asyncio.CancelledError:
538 logger.debug("Refresh worker cancelled")
539 raise
540 except Exception as e:
541 logger.exception("Error in refresh worker: %s", e)
543 logger.info("NotificationService refresh worker stopped")
545 async def _execute_refresh(self, pending: PendingRefresh) -> None:
546 """Execute a refresh operation.
548 Acquires the per-gateway refresh lock to prevent concurrent refreshes
549 with manual refresh or health check auto-refresh.
551 Args:
552 pending: The pending refresh to execute.
553 """
554 # pylint: disable=protected-access
555 gateway_id = pending.gateway_id
557 # Clear pending flag tracking now that we're processing this refresh
558 self._pending_refresh_flags.pop(gateway_id, None)
560 if not self._gateway_service:
561 logger.warning(
562 "Cannot execute refresh for gateway %s: GatewayService not set",
563 gateway_id,
564 )
565 return
567 # Acquire per-gateway lock to prevent concurrent refresh with manual/auto refresh
568 lock = self._gateway_service._get_refresh_lock(gateway_id) # pyright: ignore[reportPrivateUsage]
570 # Skip if lock is already held (another refresh in progress)
571 if lock.locked():
572 logger.debug(
573 "Skipping event-driven refresh for gateway %s: lock held (refresh in progress)",
574 gateway_id,
575 )
576 self._notifications_debounced += 1
577 return
579 async with lock:
580 logger.info(
581 "Executing event-driven refresh for gateway %s (resources=%s, prompts=%s)",
582 pending.gateway_id,
583 pending.include_resources,
584 pending.include_prompts,
585 )
587 try:
588 # Use the existing refresh method (lock already held)
589 result = await self._gateway_service._refresh_gateway_tools_resources_prompts( # pyright: ignore[reportPrivateUsage]
590 gateway_id=pending.gateway_id,
591 created_via="notification_service",
592 include_resources=pending.include_resources,
593 include_prompts=pending.include_prompts,
594 )
596 self._refreshes_triggered += 1
598 if result.get("success"):
599 logger.info(
600 "Event-driven refresh completed for gateway %s: tools_added=%d, tools_removed=%d",
601 pending.gateway_id,
602 result.get("tools_added", 0),
603 result.get("tools_removed", 0),
604 )
605 else:
606 self._refreshes_failed += 1
607 logger.warning(
608 "Event-driven refresh failed for gateway %s: %s",
609 pending.gateway_id,
610 result.get("error"),
611 )
613 except Exception as e:
614 self._refreshes_failed += 1
615 logger.exception(
616 "Error during event-driven refresh for gateway %s: %s",
617 pending.gateway_id,
618 e,
619 )
621 def get_metrics(self) -> Dict[str, Any]:
622 """Return notification service metrics.
624 Returns:
625 Dict containing notification and refresh metrics.
627 Example:
628 >>> service = NotificationService()
629 >>> metrics = service.get_metrics()
630 >>> "notifications_received" in metrics
631 True
632 """
633 return {
634 "notifications_received": self._notifications_received,
635 "notifications_debounced": self._notifications_debounced,
636 "refreshes_triggered": self._refreshes_triggered,
637 "refreshes_failed": self._refreshes_failed,
638 "pending_refreshes": self._refresh_queue.qsize(),
639 "registered_gateways": len(self._gateway_capabilities),
640 "debounce_seconds": self.debounce_seconds,
641 }
644# Module-level singleton instance (initialized lazily)
645_notification_service: Optional[NotificationService] = None
648def get_notification_service() -> NotificationService:
649 """Get the global NotificationService instance.
651 Returns:
652 The global NotificationService instance.
654 Raises:
655 RuntimeError: If service has not been initialized.
657 Example:
658 >>> try:
659 ... _ = init_notification_service()
660 ... service = get_notification_service()
661 ... result = isinstance(service, NotificationService)
662 ... except RuntimeError:
663 ... result = False
664 >>> result
665 True
666 """
667 if _notification_service is None:
668 raise RuntimeError("NotificationService not initialized. Call init_notification_service() first.")
669 return _notification_service
672def init_notification_service(
673 debounce_seconds: float = 5.0,
674 max_queue_size: int = 100,
675) -> NotificationService:
676 """Initialize the global NotificationService.
678 Args:
679 debounce_seconds: Minimum time between refreshes for same gateway.
680 max_queue_size: Maximum number of pending refreshes in queue.
682 Returns:
683 The initialized NotificationService instance.
685 Example:
686 >>> service = init_notification_service(debounce_seconds=10.0)
687 >>> service.debounce_seconds
688 10.0
689 """
690 global _notification_service # pylint: disable=global-statement
691 _notification_service = NotificationService(
692 debounce_seconds=debounce_seconds,
693 max_queue_size=max_queue_size,
694 )
695 logger.info("Global NotificationService created")
696 return _notification_service
699async def close_notification_service() -> None:
700 """Close the global NotificationService.
702 Example:
703 >>> import asyncio
704 >>> async def test():
705 ... init_notification_service()
706 ... await close_notification_service()
707 ... try:
708 ... get_notification_service()
709 ... except RuntimeError:
710 ... return True
711 ... return False
712 >>> asyncio.run(test())
713 True
714 """
715 global _notification_service # pylint: disable=global-statement
716 if _notification_service is not None:
717 await _notification_service.shutdown()
718 _notification_service = None
719 logger.info("Global NotificationService closed")