Coverage for mcpgateway / services / cancellation_service.py: 100%
150 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2# mcpgateway/services/cancellation_service.py
3"""Location: ./mcpgateway/services/cancellation_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Service for tracking and cancelling active tool runs.
10Provides a simple in-memory registry for run metadata and an optional async
11cancel callback that can be invoked when a cancellation is requested. This
12service is intentionally small and designed to be a single-process helper for
13local run lifecycle management; the gateway remains authoritative for
14cancellation and also broadcasts a `notifications/cancelled` JSON-RPC
15notification to connected sessions.
16"""
17# Future
18from __future__ import annotations
20# Standard
21import asyncio
22import json
23import time
24from typing import Any, Awaitable, Callable, Dict, List, Optional
26# First-Party
27from mcpgateway.services.logging_service import LoggingService
28from mcpgateway.utils.redis_client import get_redis_client
30logging_service = LoggingService()
31logger = logging_service.get_logger(__name__)
33CancelCallback = Callable[[Optional[str]], Awaitable[None]] # async callback(reason)
36class CancellationService:
37 """Track active runs and allow cancellation requests.
39 Note: This is intentionally lightweight — it does not persist state and is
40 suitable for gateway-local run tracking. The gateway will also broadcast
41 a `notifications/cancelled` message to connected sessions to inform remote
42 peers of the cancellation request.
44 Multi-worker deployments: When Redis is available, cancellation events are
45 published to the "cancellation:cancel" channel to propagate across workers.
46 """
48 def __init__(self) -> None:
49 """Initialize the cancellation service."""
50 self._runs: Dict[str, Dict[str, Any]] = {}
51 self._lock = asyncio.Lock()
52 self._redis = None
53 self._pubsub_task: Optional[asyncio.Task] = None
54 self._initialized = False
56 async def initialize(self) -> None:
57 """Initialize Redis pubsub if available for multi-worker support."""
58 if self._initialized:
59 return
61 self._initialized = True
63 try:
64 self._redis = await get_redis_client()
65 if self._redis:
66 # Start listening for cancellation events from other workers
67 self._pubsub_task = asyncio.create_task(self._listen_for_cancellations())
68 logger.info("CancellationService: Redis pubsub initialized for multi-worker cancellation")
69 except Exception as e:
70 logger.warning(f"CancellationService: Could not initialize Redis pubsub: {e}")
72 async def shutdown(self) -> None:
73 """Shutdown Redis pubsub listener."""
74 if self._pubsub_task and not self._pubsub_task.done():
75 self._pubsub_task.cancel()
76 try:
77 await self._pubsub_task
78 except asyncio.CancelledError:
79 pass
80 logger.info("CancellationService: Shutdown complete")
82 async def _listen_for_cancellations(self) -> None:
83 """Listen for cancellation events from other workers via Redis pubsub.
85 Uses timeout-based polling instead of blocking listen() to allow proper
86 cancellation handling. This prevents CPU spin loops when the task is cancelled
87 but stuck waiting on the blocking async iterator.
89 Raises:
90 asyncio.CancelledError: When the listener task is cancelled during shutdown.
91 """
92 if not self._redis:
93 return
95 pubsub = None
96 try:
97 pubsub = self._redis.pubsub()
98 await pubsub.subscribe("cancellation:cancel")
99 logger.info("CancellationService: Subscribed to cancellation:cancel channel")
101 # Use timeout-based polling instead of blocking listen()
102 # This allows the task to respond to cancellation properly
103 poll_timeout = 1.0
105 while True:
106 try:
107 message = await asyncio.wait_for(
108 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout),
109 timeout=poll_timeout + 0.5,
110 )
111 except asyncio.TimeoutError:
112 # No message, continue loop to check for cancellation
113 continue
115 if message is None:
116 # Prevent spin if get_message returns None immediately
117 await asyncio.sleep(0.1)
118 continue
120 if message["type"] != "message":
121 # Sleep on non-message types to prevent spin
122 await asyncio.sleep(0.1)
123 continue
125 try:
126 data = json.loads(message["data"])
127 # Normalize run_id to string (handle id=0 which is valid per JSON-RPC)
128 raw_run_id = data.get("run_id")
129 run_id = str(raw_run_id) if raw_run_id is not None else None
130 reason = data.get("reason")
132 if run_id is not None:
133 # Cancel locally if we have this run (don't re-publish)
134 await self._cancel_run_local(run_id, reason=reason)
135 except Exception as e:
136 logger.warning(f"Error processing cancellation message: {e}")
137 except asyncio.CancelledError:
138 logger.info("CancellationService: Pubsub listener cancelled")
139 raise
140 except Exception as e:
141 logger.error(f"CancellationService: Pubsub listener error: {e}")
142 finally:
143 # Clean up pubsub on any exit
144 if pubsub is not None:
145 try:
146 await pubsub.unsubscribe("cancellation:cancel")
147 try:
148 await pubsub.aclose()
149 except AttributeError:
150 await pubsub.close()
151 except Exception as e:
152 logger.debug(f"Error closing cancellation pubsub: {e}")
154 async def _cancel_run_local(self, run_id: str, reason: Optional[str] = None) -> bool:
155 """Cancel a run locally without publishing to Redis (internal use).
157 Args:
158 run_id: Unique identifier for the run to cancel.
159 reason: Optional textual reason for the cancellation request.
161 Returns:
162 bool: True if the run was found and cancelled, False if not found.
163 """
164 async with self._lock:
165 entry = self._runs.get(run_id)
166 if not entry:
167 return False
168 if entry.get("cancelled"):
169 return True
170 entry["cancelled"] = True
171 entry["cancelled_at"] = time.time()
172 entry["cancel_reason"] = reason
173 cancel_cb = entry.get("cancel_callback")
175 logger.info("Tool execution cancelled (from Redis): run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown"))
177 if cancel_cb:
178 try:
179 await cancel_cb(reason)
180 logger.info("Cancel callback executed for %s", run_id)
181 except Exception as e:
182 logger.exception("Error in cancel callback for %s: %s", run_id, e)
184 return True
186 async def register_run(
187 self,
188 run_id: str,
189 name: Optional[str] = None,
190 cancel_callback: Optional[CancelCallback] = None,
191 owner_email: Optional[str] = None,
192 owner_team_ids: Optional[List[str]] = None,
193 ) -> None:
194 """Register a run for future cancellation.
196 Args:
197 run_id: Unique run identifier (string)
198 name: Optional friendly name for debugging/observability
199 cancel_callback: Optional async callback called when a cancel is requested
200 owner_email: Optional email of the user who started the run
201 owner_team_ids: Optional list of token team IDs associated with the run owner
202 """
203 async with self._lock:
204 self._runs[run_id] = {
205 "name": name,
206 "registered_at": time.time(),
207 "cancel_callback": cancel_callback,
208 "cancelled": False,
209 "owner_email": owner_email,
210 "owner_team_ids": owner_team_ids or [],
211 }
212 logger.info("Registered run %s (%s)", run_id, name)
214 async def unregister_run(self, run_id: str) -> None:
215 """Remove a run from tracking.
217 Args:
218 run_id: Unique identifier for the run to unregister.
219 """
220 async with self._lock:
221 if run_id in self._runs:
222 self._runs.pop(run_id, None)
223 logger.info("Unregistered run %s", run_id)
225 async def cancel_run(self, run_id: str, reason: Optional[str] = None) -> bool:
226 """Attempt to cancel a run.
228 Args:
229 run_id: Unique identifier for the run to cancel.
230 reason: Optional textual reason for the cancellation request.
232 Returns:
233 bool: True if the run was found and cancellation was attempted (or already marked),
234 False if the run was not known locally.
235 """
236 cancel_cb = None
237 entry = None
239 async with self._lock:
240 entry = self._runs.get(run_id)
241 if not entry:
242 # Entry not found - will publish to Redis outside the lock
243 pass
244 elif entry.get("cancelled"):
245 logger.debug("Run %s already cancelled", run_id)
246 return True
247 else:
248 entry["cancelled"] = True
249 entry["cancelled_at"] = time.time()
250 entry["cancel_reason"] = reason
251 cancel_cb = entry.get("cancel_callback")
253 # Handle unknown run case outside the lock
254 if not entry:
255 logger.info("Cancellation requested for unknown run %s (queued for remote peers)", run_id)
256 # Publish to Redis for other workers (outside lock to avoid blocking)
257 await self._publish_cancellation(run_id, reason)
258 return False
260 # Log cancellation with reason and request_id for observability
261 logger.info("Tool execution cancelled: run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown"))
263 if cancel_cb:
264 try:
265 await cancel_cb(reason)
266 logger.info("Cancel callback executed for %s", run_id)
267 except Exception as e:
268 logger.exception("Error in cancel callback for %s: %s", run_id, e)
270 # Publish to Redis for other workers
271 await self._publish_cancellation(run_id, reason)
273 return True
275 async def _publish_cancellation(self, run_id: str, reason: Optional[str] = None) -> None:
276 """Publish cancellation event to Redis for other workers.
278 Args:
279 run_id: Unique identifier for the run being cancelled.
280 reason: Optional textual reason for the cancellation.
281 """
282 if not self._redis:
283 return
285 try:
286 message = json.dumps({"run_id": run_id, "reason": reason})
287 await self._redis.publish("cancellation:cancel", message)
288 logger.debug("Published cancellation to Redis: run_id=%s", run_id)
289 except Exception as e:
290 logger.warning(f"Failed to publish cancellation to Redis: {e}")
292 async def get_status(self, run_id: str) -> Optional[Dict[str, Any]]:
293 """Return the status dict for a run if known, else None.
295 Args:
296 run_id: Unique identifier for the run to query.
298 Returns:
299 Optional[Dict[str, Any]]: The status dictionary for the run if found, otherwise None.
300 """
301 async with self._lock:
302 return self._runs.get(run_id)
304 async def is_registered(self, run_id: str) -> bool:
305 """Check if a run is currently registered.
307 Args:
308 run_id: Unique identifier for the run to check.
310 Returns:
311 bool: True if the run is registered, False otherwise.
312 """
313 async with self._lock:
314 return run_id in self._runs
317# Module-level singleton for importers to use
318cancellation_service = CancellationService()