Coverage for mcpgateway / services / cancellation_service.py: 99%
150 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2# mcpgateway/services/cancellation_service.py
3"""Location: ./mcpgateway/services/cancellation_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Service for tracking and cancelling active tool runs.
10Provides a simple in-memory registry for run metadata and an optional async
11cancel callback that can be invoked when a cancellation is requested. This
12service is intentionally small and designed to be a single-process helper for
13local run lifecycle management; the gateway remains authoritative for
14cancellation and also broadcasts a `notifications/cancelled` JSON-RPC
15notification to connected sessions.
16"""
17# Future
18from __future__ import annotations
20# Standard
21import asyncio
22import json
23import time
24from typing import Any, Awaitable, Callable, Dict, Optional
26# First-Party
27from mcpgateway.services.logging_service import LoggingService
28from mcpgateway.utils.redis_client import get_redis_client
30logging_service = LoggingService()
31logger = logging_service.get_logger(__name__)
33CancelCallback = Callable[[Optional[str]], Awaitable[None]] # async callback(reason)
36class CancellationService:
37 """Track active runs and allow cancellation requests.
39 Note: This is intentionally lightweight — it does not persist state and is
40 suitable for gateway-local run tracking. The gateway will also broadcast
41 a `notifications/cancelled` message to connected sessions to inform remote
42 peers of the cancellation request.
44 Multi-worker deployments: When Redis is available, cancellation events are
45 published to the "cancellation:cancel" channel to propagate across workers.
46 """
48 def __init__(self) -> None:
49 """Initialize the cancellation service."""
50 self._runs: Dict[str, Dict[str, Any]] = {}
51 self._lock = asyncio.Lock()
52 self._redis = None
53 self._pubsub_task: Optional[asyncio.Task] = None
54 self._initialized = False
56 async def initialize(self) -> None:
57 """Initialize Redis pubsub if available for multi-worker support."""
58 if self._initialized:
59 return
61 self._initialized = True
63 try:
64 self._redis = await get_redis_client()
65 if self._redis:
66 # Start listening for cancellation events from other workers
67 self._pubsub_task = asyncio.create_task(self._listen_for_cancellations())
68 logger.info("CancellationService: Redis pubsub initialized for multi-worker cancellation")
69 except Exception as e:
70 logger.warning(f"CancellationService: Could not initialize Redis pubsub: {e}")
72 async def shutdown(self) -> None:
73 """Shutdown Redis pubsub listener."""
74 if self._pubsub_task and not self._pubsub_task.done():
75 self._pubsub_task.cancel()
76 try:
77 await self._pubsub_task
78 except asyncio.CancelledError:
79 pass
80 logger.info("CancellationService: Shutdown complete")
82 async def _listen_for_cancellations(self) -> None:
83 """Listen for cancellation events from other workers via Redis pubsub.
85 Uses timeout-based polling instead of blocking listen() to allow proper
86 cancellation handling. This prevents CPU spin loops when the task is cancelled
87 but stuck waiting on the blocking async iterator.
89 Raises:
90 asyncio.CancelledError: When the listener task is cancelled during shutdown.
91 """
92 if not self._redis:
93 return
95 pubsub = None
96 try:
97 pubsub = self._redis.pubsub()
98 await pubsub.subscribe("cancellation:cancel")
99 logger.info("CancellationService: Subscribed to cancellation:cancel channel")
101 # Use timeout-based polling instead of blocking listen()
102 # This allows the task to respond to cancellation properly
103 poll_timeout = 1.0
105 while True:
106 try:
107 message = await asyncio.wait_for(
108 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout),
109 timeout=poll_timeout + 0.5,
110 )
111 except asyncio.TimeoutError:
112 # No message, continue loop to check for cancellation
113 continue
115 if message is None:
116 # Prevent spin if get_message returns None immediately
117 await asyncio.sleep(0.1)
118 continue
120 if message["type"] != "message":
121 # Sleep on non-message types to prevent spin
122 await asyncio.sleep(0.1)
123 continue
125 try:
126 data = json.loads(message["data"])
127 # Normalize run_id to string (handle id=0 which is valid per JSON-RPC)
128 raw_run_id = data.get("run_id")
129 run_id = str(raw_run_id) if raw_run_id is not None else None
130 reason = data.get("reason")
132 if run_id is not None:
133 # Cancel locally if we have this run (don't re-publish)
134 await self._cancel_run_local(run_id, reason=reason)
135 except Exception as e:
136 logger.warning(f"Error processing cancellation message: {e}")
137 except asyncio.CancelledError:
138 logger.info("CancellationService: Pubsub listener cancelled")
139 raise
140 except Exception as e:
141 logger.error(f"CancellationService: Pubsub listener error: {e}")
142 finally:
143 # Clean up pubsub on any exit
144 if pubsub is not None:
145 try:
146 await pubsub.unsubscribe("cancellation:cancel")
147 try:
148 await pubsub.aclose()
149 except AttributeError:
150 await pubsub.close()
151 except Exception as e:
152 logger.debug(f"Error closing cancellation pubsub: {e}")
154 async def _cancel_run_local(self, run_id: str, reason: Optional[str] = None) -> bool:
155 """Cancel a run locally without publishing to Redis (internal use).
157 Args:
158 run_id: Unique identifier for the run to cancel.
159 reason: Optional textual reason for the cancellation request.
161 Returns:
162 bool: True if the run was found and cancelled, False if not found.
163 """
164 async with self._lock:
165 entry = self._runs.get(run_id)
166 if not entry:
167 return False
168 if entry.get("cancelled"):
169 return True
170 entry["cancelled"] = True
171 entry["cancelled_at"] = time.time()
172 entry["cancel_reason"] = reason
173 cancel_cb = entry.get("cancel_callback")
175 logger.info("Tool execution cancelled (from Redis): run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown"))
177 if cancel_cb: 177 ↛ 184line 177 didn't jump to line 184 because the condition on line 177 was always true
178 try:
179 await cancel_cb(reason)
180 logger.info("Cancel callback executed for %s", run_id)
181 except Exception as e:
182 logger.exception("Error in cancel callback for %s: %s", run_id, e)
184 return True
186 async def register_run(self, run_id: str, name: Optional[str] = None, cancel_callback: Optional[CancelCallback] = None) -> None:
187 """Register a run for future cancellation.
189 Args:
190 run_id: Unique run identifier (string)
191 name: Optional friendly name for debugging/observability
192 cancel_callback: Optional async callback called when a cancel is requested
193 """
194 async with self._lock:
195 self._runs[run_id] = {"name": name, "registered_at": time.time(), "cancel_callback": cancel_callback, "cancelled": False}
196 logger.info("Registered run %s (%s)", run_id, name)
198 async def unregister_run(self, run_id: str) -> None:
199 """Remove a run from tracking.
201 Args:
202 run_id: Unique identifier for the run to unregister.
203 """
204 async with self._lock:
205 if run_id in self._runs:
206 self._runs.pop(run_id, None)
207 logger.info("Unregistered run %s", run_id)
209 async def cancel_run(self, run_id: str, reason: Optional[str] = None) -> bool:
210 """Attempt to cancel a run.
212 Args:
213 run_id: Unique identifier for the run to cancel.
214 reason: Optional textual reason for the cancellation request.
216 Returns:
217 bool: True if the run was found and cancellation was attempted (or already marked),
218 False if the run was not known locally.
219 """
220 cancel_cb = None
221 entry = None
223 async with self._lock:
224 entry = self._runs.get(run_id)
225 if not entry:
226 # Entry not found - will publish to Redis outside the lock
227 pass
228 elif entry.get("cancelled"):
229 logger.debug("Run %s already cancelled", run_id)
230 return True
231 else:
232 entry["cancelled"] = True
233 entry["cancelled_at"] = time.time()
234 entry["cancel_reason"] = reason
235 cancel_cb = entry.get("cancel_callback")
237 # Handle unknown run case outside the lock
238 if not entry:
239 logger.info("Cancellation requested for unknown run %s (queued for remote peers)", run_id)
240 # Publish to Redis for other workers (outside lock to avoid blocking)
241 await self._publish_cancellation(run_id, reason)
242 return False
244 # Log cancellation with reason and request_id for observability
245 logger.info("Tool execution cancelled: run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown"))
247 if cancel_cb:
248 try:
249 await cancel_cb(reason)
250 logger.info("Cancel callback executed for %s", run_id)
251 except Exception as e:
252 logger.exception("Error in cancel callback for %s: %s", run_id, e)
254 # Publish to Redis for other workers
255 await self._publish_cancellation(run_id, reason)
257 return True
259 async def _publish_cancellation(self, run_id: str, reason: Optional[str] = None) -> None:
260 """Publish cancellation event to Redis for other workers.
262 Args:
263 run_id: Unique identifier for the run being cancelled.
264 reason: Optional textual reason for the cancellation.
265 """
266 if not self._redis:
267 return
269 try:
270 message = json.dumps({"run_id": run_id, "reason": reason})
271 await self._redis.publish("cancellation:cancel", message)
272 logger.debug("Published cancellation to Redis: run_id=%s", run_id)
273 except Exception as e:
274 logger.warning(f"Failed to publish cancellation to Redis: {e}")
276 async def get_status(self, run_id: str) -> Optional[Dict[str, Any]]:
277 """Return the status dict for a run if known, else None.
279 Args:
280 run_id: Unique identifier for the run to query.
282 Returns:
283 Optional[Dict[str, Any]]: The status dictionary for the run if found, otherwise None.
284 """
285 async with self._lock:
286 return self._runs.get(run_id)
288 async def is_registered(self, run_id: str) -> bool:
289 """Check if a run is currently registered.
291 Args:
292 run_id: Unique identifier for the run to check.
294 Returns:
295 bool: True if the run is registered, False otherwise.
296 """
297 async with self._lock:
298 return run_id in self._runs
301# Module-level singleton for importers to use
302cancellation_service = CancellationService()