Coverage for mcpgateway / services / notification_service.py: 95%
187 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/notification_service.py
3Copyright 2026
4SPDX-License-Identifier: Apache-2.0
6Authors: Keval Mahajan
8Description:
9 MCP Notification Service for handling server notifications with debounced
10 gateway refresh. Provides centralized notification handling for MCP sessions
11 including tools/list_changed, resources/list_changed, and prompts/list_changed.
13 Key Features:
14 - Debounced refresh to prevent notification storms
15 - Flag merging during debounce (notifications within window merge their refresh flags)
16 - Per-gateway refresh locking to prevent concurrent refresh races
17 - Per-gateway refresh tracking with capability awareness
18 - Compatible with MCPSessionPool for pooled session notification handling
19 - Per-gateway session isolation ensures correct notification attribution
20 - Supports tools, resources, and prompts list_changed notifications
22 Capable of handling other tasks as well like cancellation, progress notifications, etc. (to be implemented here)
24Usage:
25 ```python
26 from mcpgateway.services.notification_service import NotificationService
28 # Create service instance
29 notification_service = NotificationService()
30 await notification_service.initialize()
32 # Create a message handler for a specific gateway
33 handler = notification_service.create_message_handler(gateway_id="gw-123")
35 # Pass handler to ClientSession
36 session = ClientSession(read_stream, write_stream, message_handler=handler)
37 ```
38"""
40# Future
41from __future__ import annotations
43# Standard
44import asyncio
45from dataclasses import dataclass, field
46from enum import Enum
47import time
48from typing import Any, Awaitable, Callable, Dict, Optional, Set, TYPE_CHECKING
50# Third-Party
51from mcp.shared.session import RequestResponder
52import mcp.types as mcp_types
54# First-Party
55from mcpgateway.services.logging_service import LoggingService
57if TYPE_CHECKING: 57 ↛ 59line 57 didn't jump to line 59 because the condition on line 57 was never true
58 # First-Party
59 from mcpgateway.services.gateway_service import GatewayService
61# Type alias for message handler callback
62MessageHandlerCallback = Callable[
63 [RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception],
64 Awaitable[None],
65]
67# Initialize logging
68logging_service = LoggingService()
69logger = logging_service.get_logger(__name__)
72class NotificationType(Enum):
73 """Types of MCP list_changed notifications.
75 Attributes:
76 TOOLS_LIST_CHANGED: Notification for tool list changes.
77 RESOURCES_LIST_CHANGED: Notification for resource list changes.
78 PROMPTS_LIST_CHANGED: Notification for prompt list changes.
79 """
81 TOOLS_LIST_CHANGED = "notifications/tools/list_changed"
82 RESOURCES_LIST_CHANGED = "notifications/resources/list_changed"
83 PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"
86@dataclass
87class GatewayCapabilities:
88 """Tracks list_changed capabilities for a gateway.
90 Attributes:
91 tools_list_changed: Whether the gateway supports tool list changes.
92 resources_list_changed: Whether the gateway supports resource list changes.
93 prompts_list_changed: Whether the gateway supports prompt list changes.
94 """
96 tools_list_changed: bool = False
97 resources_list_changed: bool = False
98 prompts_list_changed: bool = False
101def _empty_notification_type_set() -> Set[NotificationType]:
102 """Factory function for creating an empty set of NotificationType.
104 Returns:
105 An empty set typed for NotificationType elements.
106 """
107 return set()
110@dataclass
111class PendingRefresh:
112 """Represents a pending refresh operation with debounce tracking.
114 Attributes:
115 gateway_id: The ID of the gateway to refresh.
116 enqueued_at: The timestamp when the refresh was enqueued.
117 include_resources: Whether to include resources in the refresh.
118 include_prompts: Whether to include prompts in the refresh.
119 triggered_by: The set of notification types that triggered this refresh.
120 """
122 gateway_id: str
123 enqueued_at: float = field(default_factory=time.time)
124 include_resources: bool = True
125 include_prompts: bool = True
126 # Track which notification types triggered this refresh
127 triggered_by: Set[NotificationType] = field(default_factory=_empty_notification_type_set)
130class NotificationService:
131 """Centralized service for handling MCP server notifications.
133 Provides debounced refresh triggering based on list_changed notifications
134 from MCP servers. Works with MCPSessionPool to handle notifications for
135 pooled sessions while maintaining session isolation.
137 Attributes:
138 debounce_seconds: Minimum time between refresh operations for same gateway.
139 max_queue_size: Maximum pending refreshes in the queue.
141 Example:
142 >>> service = NotificationService(debounce_seconds=5.0)
143 >>> service.debounce_seconds
144 5.0
145 >>> service._gateway_capabilities == {}
146 True
147 """
149 def __init__(
150 self,
151 debounce_seconds: float = 5.0,
152 max_queue_size: int = 100,
153 ) -> None:
154 """Initialize the NotificationService.
156 Args:
157 debounce_seconds: Minimum time between refreshes for same gateway.
158 max_queue_size: Maximum number of pending refreshes in queue.
160 Example:
161 >>> service = NotificationService(debounce_seconds=10.0, max_queue_size=50)
162 >>> service.debounce_seconds
163 10.0
164 >>> service._max_queue_size
165 50
166 """
167 self.debounce_seconds = debounce_seconds
168 self._max_queue_size = max_queue_size
170 # Track gateway capabilities for list_changed support
171 self._gateway_capabilities: Dict[str, GatewayCapabilities] = {}
173 # Debounce tracking: gateway_id -> last refresh enqueue time
174 self._last_refresh_enqueued: Dict[str, float] = {}
176 # Track pending refreshes by gateway_id to allow flag merging during debounce
177 # When a notification arrives during debounce window, we merge flags instead of dropping
178 self._pending_refresh_flags: Dict[str, PendingRefresh] = {}
180 # Pending refresh queue
181 self._refresh_queue: asyncio.Queue[PendingRefresh] = asyncio.Queue(maxsize=max_queue_size)
183 # Background worker task
184 self._worker_task: Optional[asyncio.Task[None]] = None
185 self._shutdown_event = asyncio.Event()
187 # Reference to gateway service for refresh operations (set during initialize)
188 self._gateway_service: Optional["GatewayService"] = None
190 # Metrics
191 self._notifications_received = 0
192 self._notifications_debounced = 0
193 self._refreshes_triggered = 0
194 self._refreshes_failed = 0
196 async def initialize(self, gateway_service: Optional["GatewayService"] = None) -> None:
197 """Initialize the notification service and start background worker.
199 Args:
200 gateway_service: Optional GatewayService reference for triggering refreshes.
201 Can be set later via set_gateway_service().
203 Example:
204 >>> import asyncio
205 >>> async def test():
206 ... service = NotificationService()
207 ... await service.initialize()
208 ... is_running = service._worker_task is not None
209 ... await service.shutdown()
210 ... return is_running
211 >>> asyncio.run(test())
212 True
213 """
214 if gateway_service: 214 ↛ 215line 214 didn't jump to line 215 because the condition on line 214 was never true
215 self._gateway_service = gateway_service
217 self._shutdown_event.clear()
218 self._worker_task = asyncio.create_task(self._process_refresh_queue())
219 logger.info("NotificationService initialized with debounce=%ss", self.debounce_seconds)
221 def set_gateway_service(self, gateway_service: "GatewayService") -> None:
222 """Set the gateway service reference for refresh operations.
224 Args:
225 gateway_service: The GatewayService instance to use for refreshes.
227 Example:
228 >>> from unittest.mock import Mock
229 >>> service = NotificationService()
230 >>> mock_gateway_service = Mock()
231 >>> service.set_gateway_service(mock_gateway_service)
232 """
233 self._gateway_service = gateway_service
235 async def shutdown(self) -> None:
236 """Shutdown the notification service and cleanup resources.
238 Example:
239 >>> import asyncio
240 >>> async def test():
241 ... service = NotificationService()
242 ... await service.initialize()
243 ... await service.shutdown()
244 ... return service._worker_task is None or service._worker_task.done()
245 >>> asyncio.run(test())
246 True
247 """
248 self._shutdown_event.set()
250 if self._worker_task:
251 self._worker_task.cancel()
252 try:
253 await self._worker_task
254 except asyncio.CancelledError:
255 pass
256 self._worker_task = None
258 self._gateway_capabilities.clear()
259 self._last_refresh_enqueued.clear()
260 self._pending_refresh_flags.clear()
261 logger.info("NotificationService shutdown complete")
263 def register_gateway_capabilities(
264 self,
265 gateway_id: str,
266 capabilities: Dict[str, Any],
267 ) -> None:
268 """Register list_changed capabilities for a gateway.
270 Extracts and stores which list_changed notifications the gateway supports
271 based on server capabilities returned during initialization.
273 Args:
274 gateway_id: The gateway ID.
275 capabilities: Server capabilities dict from initialize response.
277 Example:
278 >>> service = NotificationService()
279 >>> caps = {"tools": {"listChanged": True}, "resources": {"listChanged": False}}
280 >>> service.register_gateway_capabilities("gw-1", caps)
281 >>> service.supports_list_changed("gw-1")
282 True
283 >>> service._gateway_capabilities["gw-1"].resources_list_changed
284 False
285 """
286 tools_cap: Dict[str, Any] = capabilities.get("tools", {}) if isinstance(capabilities.get("tools"), dict) else {}
287 resources_cap: Dict[str, Any] = capabilities.get("resources", {}) if isinstance(capabilities.get("resources"), dict) else {}
288 prompts_cap: Dict[str, Any] = capabilities.get("prompts", {}) if isinstance(capabilities.get("prompts"), dict) else {}
290 self._gateway_capabilities[gateway_id] = GatewayCapabilities(
291 tools_list_changed=bool(tools_cap.get("listChanged", False)),
292 resources_list_changed=bool(resources_cap.get("listChanged", False)),
293 prompts_list_changed=bool(prompts_cap.get("listChanged", False)),
294 )
296 logger.debug(
297 "Registered capabilities for gateway %s: tools=%s, resources=%s, prompts=%s",
298 gateway_id,
299 self._gateway_capabilities[gateway_id].tools_list_changed,
300 self._gateway_capabilities[gateway_id].resources_list_changed,
301 self._gateway_capabilities[gateway_id].prompts_list_changed,
302 )
304 def unregister_gateway(self, gateway_id: str) -> None:
305 """Unregister a gateway and cleanup its state.
307 Args:
308 gateway_id: The gateway ID to unregister.
310 Example:
311 >>> service = NotificationService()
312 >>> service.register_gateway_capabilities("gw-1", {"tools": {"listChanged": True}})
313 >>> service.supports_list_changed("gw-1")
314 True
315 >>> service.unregister_gateway("gw-1")
316 >>> service.supports_list_changed("gw-1")
317 False
318 """
319 self._gateway_capabilities.pop(gateway_id, None)
320 self._last_refresh_enqueued.pop(gateway_id, None)
322 def supports_list_changed(self, gateway_id: str) -> bool:
323 """Check if a gateway supports any list_changed notifications.
325 Args:
326 gateway_id: The gateway ID to check.
328 Returns:
329 True if gateway supports at least one list_changed notification type.
331 Example:
332 >>> service = NotificationService()
333 >>> caps = {"tools": {"listChanged": True}}
334 >>> service.register_gateway_capabilities("gw-1", caps)
335 >>> service.supports_list_changed("gw-1")
336 True
337 >>> service.supports_list_changed("gw-unknown")
338 False
339 """
340 caps = self._gateway_capabilities.get(gateway_id)
341 if not caps:
342 return False
343 return caps.tools_list_changed or caps.resources_list_changed or caps.prompts_list_changed
345 def create_message_handler(
346 self,
347 gateway_id: str,
348 gateway_url: Optional[str] = None,
349 ) -> MessageHandlerCallback:
350 """Create a message handler callback for a specific gateway.
352 Returns a callback suitable for passing to ClientSession's message_handler
353 parameter. The handler routes notifications to this service for processing.
355 Args:
356 gateway_id: The gateway ID this handler is for.
357 gateway_url: Optional URL for logging context.
359 Returns:
360 Async callable suitable for ClientSession message_handler.
362 Example:
363 >>> service = NotificationService()
364 >>> handler = service.create_message_handler("gw-123")
365 >>> callable(handler)
366 True
367 """
369 async def message_handler(
370 message: RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult] | mcp_types.ServerNotification | Exception,
371 ) -> None:
372 """Handle incoming messages from MCP server.
374 Args:
375 message: The message received from the server.
376 """
377 # Only process ServerNotification objects
378 if isinstance(message, mcp_types.ServerNotification):
379 await self._handle_notification(gateway_id, message, gateway_url)
380 elif isinstance(message, Exception):
381 logger.warning("Received exception from MCP server %s: %s", gateway_id, message)
382 # RequestResponder messages are handled by the session itself
384 return message_handler
386 async def _handle_notification(
387 self,
388 gateway_id: str,
389 notification: mcp_types.ServerNotification,
390 gateway_url: Optional[str] = None,
391 ) -> None:
392 """Process an incoming server notification.
394 Args:
395 gateway_id: The gateway ID that sent the notification.
396 notification: The notification object.
397 gateway_url: Optional URL for logging context.
398 """
399 self._notifications_received += 1
401 # Extract notification type from the notification object
402 # ServerNotification has a 'root' attribute containing the actual notification
403 notification_root = notification.root
405 # Check for list_changed notifications
406 notification_type: Optional[NotificationType] = None
408 # Match notification types - check class names since mcp.types may vary
409 root_class = type(notification_root).__name__
411 if "ToolListChangedNotification" in root_class or "ToolsListChangedNotification" in root_class:
412 notification_type = NotificationType.TOOLS_LIST_CHANGED
413 elif "ResourceListChangedNotification" in root_class or "ResourcesListChangedNotification" in root_class:
414 notification_type = NotificationType.RESOURCES_LIST_CHANGED
415 elif "PromptListChangedNotification" in root_class or "PromptsListChangedNotification" in root_class:
416 notification_type = NotificationType.PROMPTS_LIST_CHANGED
418 if notification_type:
419 logger.info(
420 "Received %s notification from gateway %s (%s)",
421 notification_type.value,
422 gateway_id,
423 gateway_url or "unknown",
424 )
425 await self._enqueue_refresh(gateway_id, notification_type)
426 else:
427 logger.info(
428 "Received notification from gateway %s: %s",
429 gateway_id,
430 root_class,
431 )
433 async def _enqueue_refresh(
434 self,
435 gateway_id: str,
436 notification_type: NotificationType,
437 ) -> None:
438 """Enqueue a refresh operation with debouncing and flag merging.
440 When notifications arrive during the debounce window, their flags are
441 merged into the pending refresh instead of being dropped. This ensures
442 that if tools/list_changed arrives after resources/list_changed within
443 the debounce window, tools will still be refreshed.
445 Args:
446 gateway_id: The gateway to refresh.
447 notification_type: The type of notification that triggered this.
448 """
449 now = time.time()
450 last_enqueued = self._last_refresh_enqueued.get(gateway_id, 0)
452 # Determine what to include based on notification type
453 include_resources = notification_type == NotificationType.RESOURCES_LIST_CHANGED
454 include_prompts = notification_type == NotificationType.PROMPTS_LIST_CHANGED
456 # For tools notification, include everything as tools are always primary
457 if notification_type == NotificationType.TOOLS_LIST_CHANGED:
458 include_resources = True
459 include_prompts = True
461 # Debounce: if within window, merge flags into pending refresh instead of dropping
462 if now - last_enqueued < self.debounce_seconds:
463 existing = self._pending_refresh_flags.get(gateway_id)
464 if existing: 464 ↛ 481line 464 didn't jump to line 481 because the condition on line 464 was always true
465 # Merge flags - use OR to include all requested types
466 existing.include_resources = existing.include_resources or include_resources
467 existing.include_prompts = existing.include_prompts or include_prompts
468 existing.triggered_by.add(notification_type)
469 self._notifications_debounced += 1
470 logger.debug(
471 "Merged %s into pending refresh for gateway %s (resources=%s, prompts=%s)",
472 notification_type.value,
473 gateway_id,
474 existing.include_resources,
475 existing.include_prompts,
476 )
477 return
479 # No pending refresh found but within debounce - this shouldn't happen normally
480 # but can occur if the refresh was already processed. Count as debounced.
481 self._notifications_debounced += 1
482 logger.debug(
483 "Debounced refresh for gateway %s (last enqueued %.1fs ago, no pending)",
484 gateway_id,
485 now - last_enqueued,
486 )
487 return
489 # Create new pending refresh
490 pending = PendingRefresh(
491 gateway_id=gateway_id,
492 include_resources=include_resources,
493 include_prompts=include_prompts,
494 triggered_by={notification_type},
495 )
497 try:
498 self._refresh_queue.put_nowait(pending)
499 self._last_refresh_enqueued[gateway_id] = now
500 self._pending_refresh_flags[gateway_id] = pending # Track for flag merging
501 logger.info(
502 "Enqueued refresh for gateway %s (triggered by %s)",
503 gateway_id,
504 notification_type.value,
505 )
506 except asyncio.QueueFull:
507 logger.warning(
508 "Refresh queue full, dropping refresh request for gateway %s",
509 gateway_id,
510 )
512 async def _process_refresh_queue(self) -> None:
513 """Background worker that processes pending refresh operations.
515 Continuously runs until shutdown is triggered, picking up pending
516 refreshes from the queue and executing them.
517 """
518 logger.info("NotificationService refresh worker started")
520 while not self._shutdown_event.is_set(): 520 ↛ 540line 520 didn't jump to line 540 because the condition on line 520 was always true
521 try:
522 # Wait for pending refresh with timeout to allow shutdown check
523 try:
524 pending = await asyncio.wait_for(
525 self._refresh_queue.get(),
526 timeout=1.0,
527 )
528 except asyncio.TimeoutError:
529 continue
531 await self._execute_refresh(pending)
532 self._refresh_queue.task_done()
534 except asyncio.CancelledError:
535 logger.debug("Refresh worker cancelled")
536 break
537 except Exception as e:
538 logger.exception("Error in refresh worker: %s", e)
540 logger.info("NotificationService refresh worker stopped")
542 async def _execute_refresh(self, pending: PendingRefresh) -> None:
543 """Execute a refresh operation.
545 Acquires the per-gateway refresh lock to prevent concurrent refreshes
546 with manual refresh or health check auto-refresh.
548 Args:
549 pending: The pending refresh to execute.
550 """
551 # pylint: disable=protected-access
552 gateway_id = pending.gateway_id
554 # Clear pending flag tracking now that we're processing this refresh
555 self._pending_refresh_flags.pop(gateway_id, None)
557 if not self._gateway_service:
558 logger.warning(
559 "Cannot execute refresh for gateway %s: GatewayService not set",
560 gateway_id,
561 )
562 return
564 # Acquire per-gateway lock to prevent concurrent refresh with manual/auto refresh
565 lock = self._gateway_service._get_refresh_lock(gateway_id) # pyright: ignore[reportPrivateUsage]
567 # Skip if lock is already held (another refresh in progress)
568 if lock.locked():
569 logger.debug(
570 "Skipping event-driven refresh for gateway %s: lock held (refresh in progress)",
571 gateway_id,
572 )
573 self._notifications_debounced += 1
574 return
576 async with lock:
577 logger.info(
578 "Executing event-driven refresh for gateway %s (resources=%s, prompts=%s)",
579 pending.gateway_id,
580 pending.include_resources,
581 pending.include_prompts,
582 )
584 try:
585 # Use the existing refresh method (lock already held)
586 result = await self._gateway_service._refresh_gateway_tools_resources_prompts( # pyright: ignore[reportPrivateUsage]
587 gateway_id=pending.gateway_id,
588 created_via="notification_service",
589 include_resources=pending.include_resources,
590 include_prompts=pending.include_prompts,
591 )
593 self._refreshes_triggered += 1
595 if result.get("success"):
596 logger.info(
597 "Event-driven refresh completed for gateway %s: tools_added=%d, tools_removed=%d",
598 pending.gateway_id,
599 result.get("tools_added", 0),
600 result.get("tools_removed", 0),
601 )
602 else:
603 self._refreshes_failed += 1
604 logger.warning(
605 "Event-driven refresh failed for gateway %s: %s",
606 pending.gateway_id,
607 result.get("error"),
608 )
610 except Exception as e:
611 self._refreshes_failed += 1
612 logger.exception(
613 "Error during event-driven refresh for gateway %s: %s",
614 pending.gateway_id,
615 e,
616 )
618 def get_metrics(self) -> Dict[str, Any]:
619 """Return notification service metrics.
621 Returns:
622 Dict containing notification and refresh metrics.
624 Example:
625 >>> service = NotificationService()
626 >>> metrics = service.get_metrics()
627 >>> "notifications_received" in metrics
628 True
629 """
630 return {
631 "notifications_received": self._notifications_received,
632 "notifications_debounced": self._notifications_debounced,
633 "refreshes_triggered": self._refreshes_triggered,
634 "refreshes_failed": self._refreshes_failed,
635 "pending_refreshes": self._refresh_queue.qsize(),
636 "registered_gateways": len(self._gateway_capabilities),
637 "debounce_seconds": self.debounce_seconds,
638 }
641# Module-level singleton instance (initialized lazily)
642_notification_service: Optional[NotificationService] = None
645def get_notification_service() -> NotificationService:
646 """Get the global NotificationService instance.
648 Returns:
649 The global NotificationService instance.
651 Raises:
652 RuntimeError: If service has not been initialized.
654 Example:
655 >>> try:
656 ... _ = init_notification_service()
657 ... service = get_notification_service()
658 ... result = isinstance(service, NotificationService)
659 ... except RuntimeError:
660 ... result = False
661 >>> result
662 True
663 """
664 if _notification_service is None:
665 raise RuntimeError("NotificationService not initialized. Call init_notification_service() first.")
666 return _notification_service
669def init_notification_service(
670 debounce_seconds: float = 5.0,
671 max_queue_size: int = 100,
672) -> NotificationService:
673 """Initialize the global NotificationService.
675 Args:
676 debounce_seconds: Minimum time between refreshes for same gateway.
677 max_queue_size: Maximum number of pending refreshes in queue.
679 Returns:
680 The initialized NotificationService instance.
682 Example:
683 >>> service = init_notification_service(debounce_seconds=10.0)
684 >>> service.debounce_seconds
685 10.0
686 """
687 global _notification_service # pylint: disable=global-statement
688 _notification_service = NotificationService(
689 debounce_seconds=debounce_seconds,
690 max_queue_size=max_queue_size,
691 )
692 logger.info("Global NotificationService created")
693 return _notification_service
696async def close_notification_service() -> None:
697 """Close the global NotificationService.
699 Example:
700 >>> import asyncio
701 >>> async def test():
702 ... init_notification_service()
703 ... await close_notification_service()
704 ... try:
705 ... get_notification_service()
706 ... except RuntimeError:
707 ... return True
708 ... return False
709 >>> asyncio.run(test())
710 True
711 """
712 global _notification_service # pylint: disable=global-statement
713 if _notification_service is not None:
714 await _notification_service.shutdown()
715 _notification_service = None
716 logger.info("Global NotificationService closed")