Coverage for mcpgateway / services / cancellation_service.py: 100%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2# mcpgateway/services/cancellation_service.py 

3"""Location: ./mcpgateway/services/cancellation_service.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Service for tracking and cancelling active tool runs. 

9 

10Provides a simple in-memory registry for run metadata and an optional async 

11cancel callback that can be invoked when a cancellation is requested. This 

12service is intentionally small and designed to be a single-process helper for 

13local run lifecycle management; the gateway remains authoritative for 

14cancellation and also broadcasts a `notifications/cancelled` JSON-RPC 

15notification to connected sessions. 

16""" 

17 

18# Future 

19from __future__ import annotations 

20 

21# Standard 

22import asyncio 

23import json 

24import time 

25from typing import Any, Awaitable, Callable, Dict, List, Optional 

26 

27# First-Party 

28from mcpgateway.services.logging_service import LoggingService 

29from mcpgateway.utils.redis_client import get_redis_client 

30 

31logging_service = LoggingService() 

32logger = logging_service.get_logger(__name__) 

33 

34CancelCallback = Callable[[Optional[str]], Awaitable[None]] # async callback(reason) 

35 

36 

37class CancellationService: 

38 """Track active runs and allow cancellation requests. 

39 

40 Note: This is intentionally lightweight — it does not persist state and is 

41 suitable for gateway-local run tracking. The gateway will also broadcast 

42 a `notifications/cancelled` message to connected sessions to inform remote 

43 peers of the cancellation request. 

44 

45 Multi-worker deployments: When Redis is available, cancellation events are 

46 published to the "cancellation:cancel" channel to propagate across workers. 

47 """ 

48 

49 def __init__(self) -> None: 

50 """Initialize the cancellation service.""" 

51 self._runs: Dict[str, Dict[str, Any]] = {} 

52 self._lock = asyncio.Lock() 

53 self._redis = None 

54 self._pubsub_task: Optional[asyncio.Task] = None 

55 self._initialized = False 

56 

57 async def initialize(self) -> None: 

58 """Initialize Redis pubsub if available for multi-worker support.""" 

59 if self._initialized: 

60 return 

61 

62 self._initialized = True 

63 

64 try: 

65 self._redis = await get_redis_client() 

66 if self._redis: 

67 # Start listening for cancellation events from other workers 

68 self._pubsub_task = asyncio.create_task(self._listen_for_cancellations()) 

69 logger.info("CancellationService: Redis pubsub initialized for multi-worker cancellation") 

70 except Exception as e: 

71 logger.warning(f"CancellationService: Could not initialize Redis pubsub: {e}") 

72 

73 async def shutdown(self) -> None: 

74 """Shutdown Redis pubsub listener.""" 

75 if self._pubsub_task and not self._pubsub_task.done(): 

76 self._pubsub_task.cancel() 

77 try: 

78 await self._pubsub_task 

79 except asyncio.CancelledError: 

80 pass 

81 logger.info("CancellationService: Shutdown complete") 

82 

83 async def _listen_for_cancellations(self) -> None: 

84 """Listen for cancellation events from other workers via Redis pubsub. 

85 

86 Uses timeout-based polling instead of blocking listen() to allow proper 

87 cancellation handling. This prevents CPU spin loops when the task is cancelled 

88 but stuck waiting on the blocking async iterator. 

89 

90 Raises: 

91 asyncio.CancelledError: When the listener task is cancelled during shutdown. 

92 """ 

93 if not self._redis: 

94 return 

95 

96 pubsub = None 

97 try: 

98 pubsub = self._redis.pubsub() 

99 await pubsub.subscribe("cancellation:cancel") 

100 logger.info("CancellationService: Subscribed to cancellation:cancel channel") 

101 

102 # Use timeout-based polling instead of blocking listen() 

103 # This allows the task to respond to cancellation properly 

104 poll_timeout = 1.0 

105 

106 while True: 

107 try: 

108 message = await asyncio.wait_for( 

109 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), 

110 timeout=poll_timeout + 0.5, 

111 ) 

112 except asyncio.TimeoutError: 

113 # No message, continue loop to check for cancellation 

114 continue 

115 

116 if message is None: 

117 # Prevent spin if get_message returns None immediately 

118 await asyncio.sleep(0.1) 

119 continue 

120 

121 if message["type"] != "message": 

122 # Sleep on non-message types to prevent spin 

123 await asyncio.sleep(0.1) 

124 continue 

125 

126 try: 

127 data = json.loads(message["data"]) 

128 # Normalize run_id to string (handle id=0 which is valid per JSON-RPC) 

129 raw_run_id = data.get("run_id") 

130 run_id = str(raw_run_id) if raw_run_id is not None else None 

131 reason = data.get("reason") 

132 

133 if run_id is not None: 

134 # Cancel locally if we have this run (don't re-publish) 

135 await self._cancel_run_local(run_id, reason=reason) 

136 except Exception as e: 

137 logger.warning(f"Error processing cancellation message: {e}") 

138 except asyncio.CancelledError: 

139 logger.info("CancellationService: Pubsub listener cancelled") 

140 raise 

141 except Exception as e: 

142 logger.error(f"CancellationService: Pubsub listener error: {e}") 

143 finally: 

144 # Clean up pubsub on any exit 

145 if pubsub is not None: 

146 try: 

147 await pubsub.unsubscribe("cancellation:cancel") 

148 try: 

149 await pubsub.aclose() 

150 except AttributeError: 

151 await pubsub.close() 

152 except Exception as e: 

153 logger.debug(f"Error closing cancellation pubsub: {e}") 

154 

155 async def _cancel_run_local(self, run_id: str, reason: Optional[str] = None) -> bool: 

156 """Cancel a run locally without publishing to Redis (internal use). 

157 

158 Args: 

159 run_id: Unique identifier for the run to cancel. 

160 reason: Optional textual reason for the cancellation request. 

161 

162 Returns: 

163 bool: True if the run was found and cancelled, False if not found. 

164 """ 

165 async with self._lock: 

166 entry = self._runs.get(run_id) 

167 if not entry: 

168 return False 

169 if entry.get("cancelled"): 

170 return True 

171 entry["cancelled"] = True 

172 entry["cancelled_at"] = time.time() 

173 entry["cancel_reason"] = reason 

174 cancel_cb = entry.get("cancel_callback") 

175 

176 logger.info("Tool execution cancelled (from Redis): run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown")) 

177 

178 if cancel_cb: 

179 try: 

180 await cancel_cb(reason) 

181 logger.info("Cancel callback executed for %s", run_id) 

182 except Exception as e: 

183 logger.exception("Error in cancel callback for %s: %s", run_id, e) 

184 

185 return True 

186 

187 async def register_run( 

188 self, 

189 run_id: str, 

190 name: Optional[str] = None, 

191 cancel_callback: Optional[CancelCallback] = None, 

192 owner_email: Optional[str] = None, 

193 owner_team_ids: Optional[List[str]] = None, 

194 ) -> None: 

195 """Register a run for future cancellation. 

196 

197 Args: 

198 run_id: Unique run identifier (string) 

199 name: Optional friendly name for debugging/observability 

200 cancel_callback: Optional async callback called when a cancel is requested 

201 owner_email: Optional email of the user who started the run 

202 owner_team_ids: Optional list of token team IDs associated with the run owner 

203 """ 

204 async with self._lock: 

205 self._runs[run_id] = { 

206 "name": name, 

207 "registered_at": time.time(), 

208 "cancel_callback": cancel_callback, 

209 "cancelled": False, 

210 "owner_email": owner_email, 

211 "owner_team_ids": owner_team_ids or [], 

212 } 

213 logger.info("Registered run %s (%s)", run_id, name) 

214 

215 async def unregister_run(self, run_id: str) -> None: 

216 """Remove a run from tracking. 

217 

218 Args: 

219 run_id: Unique identifier for the run to unregister. 

220 """ 

221 async with self._lock: 

222 if run_id in self._runs: 

223 self._runs.pop(run_id, None) 

224 logger.info("Unregistered run %s", run_id) 

225 

226 async def cancel_run(self, run_id: str, reason: Optional[str] = None) -> bool: 

227 """Attempt to cancel a run. 

228 

229 Args: 

230 run_id: Unique identifier for the run to cancel. 

231 reason: Optional textual reason for the cancellation request. 

232 

233 Returns: 

234 bool: True if the run was found and cancellation was attempted (or already marked), 

235 False if the run was not known locally. 

236 """ 

237 cancel_cb = None 

238 entry = None 

239 

240 async with self._lock: 

241 entry = self._runs.get(run_id) 

242 if not entry: 

243 # Entry not found - will publish to Redis outside the lock 

244 pass 

245 elif entry.get("cancelled"): 

246 logger.debug("Run %s already cancelled", run_id) 

247 return True 

248 else: 

249 entry["cancelled"] = True 

250 entry["cancelled_at"] = time.time() 

251 entry["cancel_reason"] = reason 

252 cancel_cb = entry.get("cancel_callback") 

253 

254 # Handle unknown run case outside the lock 

255 if not entry: 

256 logger.info("Cancellation requested for unknown run %s (queued for remote peers)", run_id) 

257 # Publish to Redis for other workers (outside lock to avoid blocking) 

258 await self._publish_cancellation(run_id, reason) 

259 return False 

260 

261 # Log cancellation with reason and request_id for observability 

262 logger.info("Tool execution cancelled: run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown")) 

263 

264 if cancel_cb: 

265 try: 

266 await cancel_cb(reason) 

267 logger.info("Cancel callback executed for %s", run_id) 

268 except Exception as e: 

269 logger.exception("Error in cancel callback for %s: %s", run_id, e) 

270 

271 # Publish to Redis for other workers 

272 await self._publish_cancellation(run_id, reason) 

273 

274 return True 

275 

276 async def _publish_cancellation(self, run_id: str, reason: Optional[str] = None) -> None: 

277 """Publish cancellation event to Redis for other workers. 

278 

279 Args: 

280 run_id: Unique identifier for the run being cancelled. 

281 reason: Optional textual reason for the cancellation. 

282 """ 

283 if not self._redis: 

284 return 

285 

286 try: 

287 message = json.dumps({"run_id": run_id, "reason": reason}) 

288 await self._redis.publish("cancellation:cancel", message) 

289 logger.debug("Published cancellation to Redis: run_id=%s", run_id) 

290 except Exception as e: 

291 logger.warning(f"Failed to publish cancellation to Redis: {e}") 

292 

293 async def get_status(self, run_id: str) -> Optional[Dict[str, Any]]: 

294 """Return the status dict for a run if known, else None. 

295 

296 Args: 

297 run_id: Unique identifier for the run to query. 

298 

299 Returns: 

300 Optional[Dict[str, Any]]: The status dictionary for the run if found, otherwise None. 

301 """ 

302 async with self._lock: 

303 return self._runs.get(run_id) 

304 

305 async def is_registered(self, run_id: str) -> bool: 

306 """Check if a run is currently registered. 

307 

308 Args: 

309 run_id: Unique identifier for the run to check. 

310 

311 Returns: 

312 bool: True if the run is registered, False otherwise. 

313 """ 

314 async with self._lock: 

315 return run_id in self._runs 

316 

317 

318# Module-level singleton for importers to use 

319cancellation_service = CancellationService()