Coverage for mcpgateway / services / cancellation_service.py: 99%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2# mcpgateway/services/cancellation_service.py 

3"""Location: ./mcpgateway/services/cancellation_service.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Service for tracking and cancelling active tool runs. 

9 

10Provides a simple in-memory registry for run metadata and an optional async 

11cancel callback that can be invoked when a cancellation is requested. This 

12service is intentionally small and designed to be a single-process helper for 

13local run lifecycle management; the gateway remains authoritative for 

14cancellation and also broadcasts a `notifications/cancelled` JSON-RPC 

15notification to connected sessions. 

16""" 

17# Future 

18from __future__ import annotations 

19 

20# Standard 

21import asyncio 

22import json 

23import time 

24from typing import Any, Awaitable, Callable, Dict, Optional 

25 

26# First-Party 

27from mcpgateway.services.logging_service import LoggingService 

28from mcpgateway.utils.redis_client import get_redis_client 

29 

30logging_service = LoggingService() 

31logger = logging_service.get_logger(__name__) 

32 

33CancelCallback = Callable[[Optional[str]], Awaitable[None]] # async callback(reason) 

34 

35 

36class CancellationService: 

37 """Track active runs and allow cancellation requests. 

38 

39 Note: This is intentionally lightweight — it does not persist state and is 

40 suitable for gateway-local run tracking. The gateway will also broadcast 

41 a `notifications/cancelled` message to connected sessions to inform remote 

42 peers of the cancellation request. 

43 

44 Multi-worker deployments: When Redis is available, cancellation events are 

45 published to the "cancellation:cancel" channel to propagate across workers. 

46 """ 

47 

48 def __init__(self) -> None: 

49 """Initialize the cancellation service.""" 

50 self._runs: Dict[str, Dict[str, Any]] = {} 

51 self._lock = asyncio.Lock() 

52 self._redis = None 

53 self._pubsub_task: Optional[asyncio.Task] = None 

54 self._initialized = False 

55 

56 async def initialize(self) -> None: 

57 """Initialize Redis pubsub if available for multi-worker support.""" 

58 if self._initialized: 

59 return 

60 

61 self._initialized = True 

62 

63 try: 

64 self._redis = await get_redis_client() 

65 if self._redis: 

66 # Start listening for cancellation events from other workers 

67 self._pubsub_task = asyncio.create_task(self._listen_for_cancellations()) 

68 logger.info("CancellationService: Redis pubsub initialized for multi-worker cancellation") 

69 except Exception as e: 

70 logger.warning(f"CancellationService: Could not initialize Redis pubsub: {e}") 

71 

72 async def shutdown(self) -> None: 

73 """Shutdown Redis pubsub listener.""" 

74 if self._pubsub_task and not self._pubsub_task.done(): 

75 self._pubsub_task.cancel() 

76 try: 

77 await self._pubsub_task 

78 except asyncio.CancelledError: 

79 pass 

80 logger.info("CancellationService: Shutdown complete") 

81 

82 async def _listen_for_cancellations(self) -> None: 

83 """Listen for cancellation events from other workers via Redis pubsub. 

84 

85 Uses timeout-based polling instead of blocking listen() to allow proper 

86 cancellation handling. This prevents CPU spin loops when the task is cancelled 

87 but stuck waiting on the blocking async iterator. 

88 

89 Raises: 

90 asyncio.CancelledError: When the listener task is cancelled during shutdown. 

91 """ 

92 if not self._redis: 

93 return 

94 

95 pubsub = None 

96 try: 

97 pubsub = self._redis.pubsub() 

98 await pubsub.subscribe("cancellation:cancel") 

99 logger.info("CancellationService: Subscribed to cancellation:cancel channel") 

100 

101 # Use timeout-based polling instead of blocking listen() 

102 # This allows the task to respond to cancellation properly 

103 poll_timeout = 1.0 

104 

105 while True: 

106 try: 

107 message = await asyncio.wait_for( 

108 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), 

109 timeout=poll_timeout + 0.5, 

110 ) 

111 except asyncio.TimeoutError: 

112 # No message, continue loop to check for cancellation 

113 continue 

114 

115 if message is None: 

116 # Prevent spin if get_message returns None immediately 

117 await asyncio.sleep(0.1) 

118 continue 

119 

120 if message["type"] != "message": 

121 # Sleep on non-message types to prevent spin 

122 await asyncio.sleep(0.1) 

123 continue 

124 

125 try: 

126 data = json.loads(message["data"]) 

127 # Normalize run_id to string (handle id=0 which is valid per JSON-RPC) 

128 raw_run_id = data.get("run_id") 

129 run_id = str(raw_run_id) if raw_run_id is not None else None 

130 reason = data.get("reason") 

131 

132 if run_id is not None: 

133 # Cancel locally if we have this run (don't re-publish) 

134 await self._cancel_run_local(run_id, reason=reason) 

135 except Exception as e: 

136 logger.warning(f"Error processing cancellation message: {e}") 

137 except asyncio.CancelledError: 

138 logger.info("CancellationService: Pubsub listener cancelled") 

139 raise 

140 except Exception as e: 

141 logger.error(f"CancellationService: Pubsub listener error: {e}") 

142 finally: 

143 # Clean up pubsub on any exit 

144 if pubsub is not None: 

145 try: 

146 await pubsub.unsubscribe("cancellation:cancel") 

147 try: 

148 await pubsub.aclose() 

149 except AttributeError: 

150 await pubsub.close() 

151 except Exception as e: 

152 logger.debug(f"Error closing cancellation pubsub: {e}") 

153 

154 async def _cancel_run_local(self, run_id: str, reason: Optional[str] = None) -> bool: 

155 """Cancel a run locally without publishing to Redis (internal use). 

156 

157 Args: 

158 run_id: Unique identifier for the run to cancel. 

159 reason: Optional textual reason for the cancellation request. 

160 

161 Returns: 

162 bool: True if the run was found and cancelled, False if not found. 

163 """ 

164 async with self._lock: 

165 entry = self._runs.get(run_id) 

166 if not entry: 

167 return False 

168 if entry.get("cancelled"): 

169 return True 

170 entry["cancelled"] = True 

171 entry["cancelled_at"] = time.time() 

172 entry["cancel_reason"] = reason 

173 cancel_cb = entry.get("cancel_callback") 

174 

175 logger.info("Tool execution cancelled (from Redis): run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown")) 

176 

177 if cancel_cb: 177 ↛ 184line 177 didn't jump to line 184 because the condition on line 177 was always true

178 try: 

179 await cancel_cb(reason) 

180 logger.info("Cancel callback executed for %s", run_id) 

181 except Exception as e: 

182 logger.exception("Error in cancel callback for %s: %s", run_id, e) 

183 

184 return True 

185 

186 async def register_run(self, run_id: str, name: Optional[str] = None, cancel_callback: Optional[CancelCallback] = None) -> None: 

187 """Register a run for future cancellation. 

188 

189 Args: 

190 run_id: Unique run identifier (string) 

191 name: Optional friendly name for debugging/observability 

192 cancel_callback: Optional async callback called when a cancel is requested 

193 """ 

194 async with self._lock: 

195 self._runs[run_id] = {"name": name, "registered_at": time.time(), "cancel_callback": cancel_callback, "cancelled": False} 

196 logger.info("Registered run %s (%s)", run_id, name) 

197 

198 async def unregister_run(self, run_id: str) -> None: 

199 """Remove a run from tracking. 

200 

201 Args: 

202 run_id: Unique identifier for the run to unregister. 

203 """ 

204 async with self._lock: 

205 if run_id in self._runs: 

206 self._runs.pop(run_id, None) 

207 logger.info("Unregistered run %s", run_id) 

208 

209 async def cancel_run(self, run_id: str, reason: Optional[str] = None) -> bool: 

210 """Attempt to cancel a run. 

211 

212 Args: 

213 run_id: Unique identifier for the run to cancel. 

214 reason: Optional textual reason for the cancellation request. 

215 

216 Returns: 

217 bool: True if the run was found and cancellation was attempted (or already marked), 

218 False if the run was not known locally. 

219 """ 

220 cancel_cb = None 

221 entry = None 

222 

223 async with self._lock: 

224 entry = self._runs.get(run_id) 

225 if not entry: 

226 # Entry not found - will publish to Redis outside the lock 

227 pass 

228 elif entry.get("cancelled"): 

229 logger.debug("Run %s already cancelled", run_id) 

230 return True 

231 else: 

232 entry["cancelled"] = True 

233 entry["cancelled_at"] = time.time() 

234 entry["cancel_reason"] = reason 

235 cancel_cb = entry.get("cancel_callback") 

236 

237 # Handle unknown run case outside the lock 

238 if not entry: 

239 logger.info("Cancellation requested for unknown run %s (queued for remote peers)", run_id) 

240 # Publish to Redis for other workers (outside lock to avoid blocking) 

241 await self._publish_cancellation(run_id, reason) 

242 return False 

243 

244 # Log cancellation with reason and request_id for observability 

245 logger.info("Tool execution cancelled: run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown")) 

246 

247 if cancel_cb: 

248 try: 

249 await cancel_cb(reason) 

250 logger.info("Cancel callback executed for %s", run_id) 

251 except Exception as e: 

252 logger.exception("Error in cancel callback for %s: %s", run_id, e) 

253 

254 # Publish to Redis for other workers 

255 await self._publish_cancellation(run_id, reason) 

256 

257 return True 

258 

259 async def _publish_cancellation(self, run_id: str, reason: Optional[str] = None) -> None: 

260 """Publish cancellation event to Redis for other workers. 

261 

262 Args: 

263 run_id: Unique identifier for the run being cancelled. 

264 reason: Optional textual reason for the cancellation. 

265 """ 

266 if not self._redis: 

267 return 

268 

269 try: 

270 message = json.dumps({"run_id": run_id, "reason": reason}) 

271 await self._redis.publish("cancellation:cancel", message) 

272 logger.debug("Published cancellation to Redis: run_id=%s", run_id) 

273 except Exception as e: 

274 logger.warning(f"Failed to publish cancellation to Redis: {e}") 

275 

276 async def get_status(self, run_id: str) -> Optional[Dict[str, Any]]: 

277 """Return the status dict for a run if known, else None. 

278 

279 Args: 

280 run_id: Unique identifier for the run to query. 

281 

282 Returns: 

283 Optional[Dict[str, Any]]: The status dictionary for the run if found, otherwise None. 

284 """ 

285 async with self._lock: 

286 return self._runs.get(run_id) 

287 

288 async def is_registered(self, run_id: str) -> bool: 

289 """Check if a run is currently registered. 

290 

291 Args: 

292 run_id: Unique identifier for the run to check. 

293 

294 Returns: 

295 bool: True if the run is registered, False otherwise. 

296 """ 

297 async with self._lock: 

298 return run_id in self._runs 

299 

300 

301# Module-level singleton for importers to use 

302cancellation_service = CancellationService()