Coverage for mcpgateway / services / cancellation_service.py: 100%
150 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2# mcpgateway/services/cancellation_service.py
3"""Location: ./mcpgateway/services/cancellation_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Service for tracking and cancelling active tool runs.
10Provides a simple in-memory registry for run metadata and an optional async
11cancel callback that can be invoked when a cancellation is requested. This
12service is intentionally small and designed to be a single-process helper for
13local run lifecycle management; the gateway remains authoritative for
14cancellation and also broadcasts a `notifications/cancelled` JSON-RPC
15notification to connected sessions.
16"""
18# Future
19from __future__ import annotations
21# Standard
22import asyncio
23import json
24import time
25from typing import Any, Awaitable, Callable, Dict, List, Optional
27# First-Party
28from mcpgateway.services.logging_service import LoggingService
29from mcpgateway.utils.redis_client import get_redis_client
31logging_service = LoggingService()
32logger = logging_service.get_logger(__name__)
34CancelCallback = Callable[[Optional[str]], Awaitable[None]] # async callback(reason)
37class CancellationService:
38 """Track active runs and allow cancellation requests.
40 Note: This is intentionally lightweight — it does not persist state and is
41 suitable for gateway-local run tracking. The gateway will also broadcast
42 a `notifications/cancelled` message to connected sessions to inform remote
43 peers of the cancellation request.
45 Multi-worker deployments: When Redis is available, cancellation events are
46 published to the "cancellation:cancel" channel to propagate across workers.
47 """
49 def __init__(self) -> None:
50 """Initialize the cancellation service."""
51 self._runs: Dict[str, Dict[str, Any]] = {}
52 self._lock = asyncio.Lock()
53 self._redis = None
54 self._pubsub_task: Optional[asyncio.Task] = None
55 self._initialized = False
57 async def initialize(self) -> None:
58 """Initialize Redis pubsub if available for multi-worker support."""
59 if self._initialized:
60 return
62 self._initialized = True
64 try:
65 self._redis = await get_redis_client()
66 if self._redis:
67 # Start listening for cancellation events from other workers
68 self._pubsub_task = asyncio.create_task(self._listen_for_cancellations())
69 logger.info("CancellationService: Redis pubsub initialized for multi-worker cancellation")
70 except Exception as e:
71 logger.warning(f"CancellationService: Could not initialize Redis pubsub: {e}")
73 async def shutdown(self) -> None:
74 """Shutdown Redis pubsub listener."""
75 if self._pubsub_task and not self._pubsub_task.done():
76 self._pubsub_task.cancel()
77 try:
78 await self._pubsub_task
79 except asyncio.CancelledError:
80 pass
81 logger.info("CancellationService: Shutdown complete")
83 async def _listen_for_cancellations(self) -> None:
84 """Listen for cancellation events from other workers via Redis pubsub.
86 Uses timeout-based polling instead of blocking listen() to allow proper
87 cancellation handling. This prevents CPU spin loops when the task is cancelled
88 but stuck waiting on the blocking async iterator.
90 Raises:
91 asyncio.CancelledError: When the listener task is cancelled during shutdown.
92 """
93 if not self._redis:
94 return
96 pubsub = None
97 try:
98 pubsub = self._redis.pubsub()
99 await pubsub.subscribe("cancellation:cancel")
100 logger.info("CancellationService: Subscribed to cancellation:cancel channel")
102 # Use timeout-based polling instead of blocking listen()
103 # This allows the task to respond to cancellation properly
104 poll_timeout = 1.0
106 while True:
107 try:
108 message = await asyncio.wait_for(
109 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout),
110 timeout=poll_timeout + 0.5,
111 )
112 except asyncio.TimeoutError:
113 # No message, continue loop to check for cancellation
114 continue
116 if message is None:
117 # Prevent spin if get_message returns None immediately
118 await asyncio.sleep(0.1)
119 continue
121 if message["type"] != "message":
122 # Sleep on non-message types to prevent spin
123 await asyncio.sleep(0.1)
124 continue
126 try:
127 data = json.loads(message["data"])
128 # Normalize run_id to string (handle id=0 which is valid per JSON-RPC)
129 raw_run_id = data.get("run_id")
130 run_id = str(raw_run_id) if raw_run_id is not None else None
131 reason = data.get("reason")
133 if run_id is not None:
134 # Cancel locally if we have this run (don't re-publish)
135 await self._cancel_run_local(run_id, reason=reason)
136 except Exception as e:
137 logger.warning(f"Error processing cancellation message: {e}")
138 except asyncio.CancelledError:
139 logger.info("CancellationService: Pubsub listener cancelled")
140 raise
141 except Exception as e:
142 logger.error(f"CancellationService: Pubsub listener error: {e}")
143 finally:
144 # Clean up pubsub on any exit
145 if pubsub is not None:
146 try:
147 await pubsub.unsubscribe("cancellation:cancel")
148 try:
149 await pubsub.aclose()
150 except AttributeError:
151 await pubsub.close()
152 except Exception as e:
153 logger.debug(f"Error closing cancellation pubsub: {e}")
155 async def _cancel_run_local(self, run_id: str, reason: Optional[str] = None) -> bool:
156 """Cancel a run locally without publishing to Redis (internal use).
158 Args:
159 run_id: Unique identifier for the run to cancel.
160 reason: Optional textual reason for the cancellation request.
162 Returns:
163 bool: True if the run was found and cancelled, False if not found.
164 """
165 async with self._lock:
166 entry = self._runs.get(run_id)
167 if not entry:
168 return False
169 if entry.get("cancelled"):
170 return True
171 entry["cancelled"] = True
172 entry["cancelled_at"] = time.time()
173 entry["cancel_reason"] = reason
174 cancel_cb = entry.get("cancel_callback")
176 logger.info("Tool execution cancelled (from Redis): run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown"))
178 if cancel_cb:
179 try:
180 await cancel_cb(reason)
181 logger.info("Cancel callback executed for %s", run_id)
182 except Exception as e:
183 logger.exception("Error in cancel callback for %s: %s", run_id, e)
185 return True
187 async def register_run(
188 self,
189 run_id: str,
190 name: Optional[str] = None,
191 cancel_callback: Optional[CancelCallback] = None,
192 owner_email: Optional[str] = None,
193 owner_team_ids: Optional[List[str]] = None,
194 ) -> None:
195 """Register a run for future cancellation.
197 Args:
198 run_id: Unique run identifier (string)
199 name: Optional friendly name for debugging/observability
200 cancel_callback: Optional async callback called when a cancel is requested
201 owner_email: Optional email of the user who started the run
202 owner_team_ids: Optional list of token team IDs associated with the run owner
203 """
204 async with self._lock:
205 self._runs[run_id] = {
206 "name": name,
207 "registered_at": time.time(),
208 "cancel_callback": cancel_callback,
209 "cancelled": False,
210 "owner_email": owner_email,
211 "owner_team_ids": owner_team_ids or [],
212 }
213 logger.info("Registered run %s (%s)", run_id, name)
215 async def unregister_run(self, run_id: str) -> None:
216 """Remove a run from tracking.
218 Args:
219 run_id: Unique identifier for the run to unregister.
220 """
221 async with self._lock:
222 if run_id in self._runs:
223 self._runs.pop(run_id, None)
224 logger.info("Unregistered run %s", run_id)
226 async def cancel_run(self, run_id: str, reason: Optional[str] = None) -> bool:
227 """Attempt to cancel a run.
229 Args:
230 run_id: Unique identifier for the run to cancel.
231 reason: Optional textual reason for the cancellation request.
233 Returns:
234 bool: True if the run was found and cancellation was attempted (or already marked),
235 False if the run was not known locally.
236 """
237 cancel_cb = None
238 entry = None
240 async with self._lock:
241 entry = self._runs.get(run_id)
242 if not entry:
243 # Entry not found - will publish to Redis outside the lock
244 pass
245 elif entry.get("cancelled"):
246 logger.debug("Run %s already cancelled", run_id)
247 return True
248 else:
249 entry["cancelled"] = True
250 entry["cancelled_at"] = time.time()
251 entry["cancel_reason"] = reason
252 cancel_cb = entry.get("cancel_callback")
254 # Handle unknown run case outside the lock
255 if not entry:
256 logger.info("Cancellation requested for unknown run %s (queued for remote peers)", run_id)
257 # Publish to Redis for other workers (outside lock to avoid blocking)
258 await self._publish_cancellation(run_id, reason)
259 return False
261 # Log cancellation with reason and request_id for observability
262 logger.info("Tool execution cancelled: run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown"))
264 if cancel_cb:
265 try:
266 await cancel_cb(reason)
267 logger.info("Cancel callback executed for %s", run_id)
268 except Exception as e:
269 logger.exception("Error in cancel callback for %s: %s", run_id, e)
271 # Publish to Redis for other workers
272 await self._publish_cancellation(run_id, reason)
274 return True
276 async def _publish_cancellation(self, run_id: str, reason: Optional[str] = None) -> None:
277 """Publish cancellation event to Redis for other workers.
279 Args:
280 run_id: Unique identifier for the run being cancelled.
281 reason: Optional textual reason for the cancellation.
282 """
283 if not self._redis:
284 return
286 try:
287 message = json.dumps({"run_id": run_id, "reason": reason})
288 await self._redis.publish("cancellation:cancel", message)
289 logger.debug("Published cancellation to Redis: run_id=%s", run_id)
290 except Exception as e:
291 logger.warning(f"Failed to publish cancellation to Redis: {e}")
293 async def get_status(self, run_id: str) -> Optional[Dict[str, Any]]:
294 """Return the status dict for a run if known, else None.
296 Args:
297 run_id: Unique identifier for the run to query.
299 Returns:
300 Optional[Dict[str, Any]]: The status dictionary for the run if found, otherwise None.
301 """
302 async with self._lock:
303 return self._runs.get(run_id)
305 async def is_registered(self, run_id: str) -> bool:
306 """Check if a run is currently registered.
308 Args:
309 run_id: Unique identifier for the run to check.
311 Returns:
312 bool: True if the run is registered, False otherwise.
313 """
314 async with self._lock:
315 return run_id in self._runs
318# Module-level singleton for importers to use
319cancellation_service = CancellationService()