Coverage for mcpgateway / services / cancellation_service.py: 100%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2# mcpgateway/services/cancellation_service.py 

3"""Location: ./mcpgateway/services/cancellation_service.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Service for tracking and cancelling active tool runs. 

9 

10Provides a simple in-memory registry for run metadata and an optional async 

11cancel callback that can be invoked when a cancellation is requested. This 

12service is intentionally small and designed to be a single-process helper for 

13local run lifecycle management; the gateway remains authoritative for 

14cancellation and also broadcasts a `notifications/cancelled` JSON-RPC 

15notification to connected sessions. 

16""" 

17# Future 

18from __future__ import annotations 

19 

20# Standard 

21import asyncio 

22import json 

23import time 

24from typing import Any, Awaitable, Callable, Dict, List, Optional 

25 

26# First-Party 

27from mcpgateway.services.logging_service import LoggingService 

28from mcpgateway.utils.redis_client import get_redis_client 

29 

30logging_service = LoggingService() 

31logger = logging_service.get_logger(__name__) 

32 

33CancelCallback = Callable[[Optional[str]], Awaitable[None]] # async callback(reason) 

34 

35 

36class CancellationService: 

37 """Track active runs and allow cancellation requests. 

38 

39 Note: This is intentionally lightweight — it does not persist state and is 

40 suitable for gateway-local run tracking. The gateway will also broadcast 

41 a `notifications/cancelled` message to connected sessions to inform remote 

42 peers of the cancellation request. 

43 

44 Multi-worker deployments: When Redis is available, cancellation events are 

45 published to the "cancellation:cancel" channel to propagate across workers. 

46 """ 

47 

48 def __init__(self) -> None: 

49 """Initialize the cancellation service.""" 

50 self._runs: Dict[str, Dict[str, Any]] = {} 

51 self._lock = asyncio.Lock() 

52 self._redis = None 

53 self._pubsub_task: Optional[asyncio.Task] = None 

54 self._initialized = False 

55 

56 async def initialize(self) -> None: 

57 """Initialize Redis pubsub if available for multi-worker support.""" 

58 if self._initialized: 

59 return 

60 

61 self._initialized = True 

62 

63 try: 

64 self._redis = await get_redis_client() 

65 if self._redis: 

66 # Start listening for cancellation events from other workers 

67 self._pubsub_task = asyncio.create_task(self._listen_for_cancellations()) 

68 logger.info("CancellationService: Redis pubsub initialized for multi-worker cancellation") 

69 except Exception as e: 

70 logger.warning(f"CancellationService: Could not initialize Redis pubsub: {e}") 

71 

72 async def shutdown(self) -> None: 

73 """Shutdown Redis pubsub listener.""" 

74 if self._pubsub_task and not self._pubsub_task.done(): 

75 self._pubsub_task.cancel() 

76 try: 

77 await self._pubsub_task 

78 except asyncio.CancelledError: 

79 pass 

80 logger.info("CancellationService: Shutdown complete") 

81 

82 async def _listen_for_cancellations(self) -> None: 

83 """Listen for cancellation events from other workers via Redis pubsub. 

84 

85 Uses timeout-based polling instead of blocking listen() to allow proper 

86 cancellation handling. This prevents CPU spin loops when the task is cancelled 

87 but stuck waiting on the blocking async iterator. 

88 

89 Raises: 

90 asyncio.CancelledError: When the listener task is cancelled during shutdown. 

91 """ 

92 if not self._redis: 

93 return 

94 

95 pubsub = None 

96 try: 

97 pubsub = self._redis.pubsub() 

98 await pubsub.subscribe("cancellation:cancel") 

99 logger.info("CancellationService: Subscribed to cancellation:cancel channel") 

100 

101 # Use timeout-based polling instead of blocking listen() 

102 # This allows the task to respond to cancellation properly 

103 poll_timeout = 1.0 

104 

105 while True: 

106 try: 

107 message = await asyncio.wait_for( 

108 pubsub.get_message(ignore_subscribe_messages=True, timeout=poll_timeout), 

109 timeout=poll_timeout + 0.5, 

110 ) 

111 except asyncio.TimeoutError: 

112 # No message, continue loop to check for cancellation 

113 continue 

114 

115 if message is None: 

116 # Prevent spin if get_message returns None immediately 

117 await asyncio.sleep(0.1) 

118 continue 

119 

120 if message["type"] != "message": 

121 # Sleep on non-message types to prevent spin 

122 await asyncio.sleep(0.1) 

123 continue 

124 

125 try: 

126 data = json.loads(message["data"]) 

127 # Normalize run_id to string (handle id=0 which is valid per JSON-RPC) 

128 raw_run_id = data.get("run_id") 

129 run_id = str(raw_run_id) if raw_run_id is not None else None 

130 reason = data.get("reason") 

131 

132 if run_id is not None: 

133 # Cancel locally if we have this run (don't re-publish) 

134 await self._cancel_run_local(run_id, reason=reason) 

135 except Exception as e: 

136 logger.warning(f"Error processing cancellation message: {e}") 

137 except asyncio.CancelledError: 

138 logger.info("CancellationService: Pubsub listener cancelled") 

139 raise 

140 except Exception as e: 

141 logger.error(f"CancellationService: Pubsub listener error: {e}") 

142 finally: 

143 # Clean up pubsub on any exit 

144 if pubsub is not None: 

145 try: 

146 await pubsub.unsubscribe("cancellation:cancel") 

147 try: 

148 await pubsub.aclose() 

149 except AttributeError: 

150 await pubsub.close() 

151 except Exception as e: 

152 logger.debug(f"Error closing cancellation pubsub: {e}") 

153 

154 async def _cancel_run_local(self, run_id: str, reason: Optional[str] = None) -> bool: 

155 """Cancel a run locally without publishing to Redis (internal use). 

156 

157 Args: 

158 run_id: Unique identifier for the run to cancel. 

159 reason: Optional textual reason for the cancellation request. 

160 

161 Returns: 

162 bool: True if the run was found and cancelled, False if not found. 

163 """ 

164 async with self._lock: 

165 entry = self._runs.get(run_id) 

166 if not entry: 

167 return False 

168 if entry.get("cancelled"): 

169 return True 

170 entry["cancelled"] = True 

171 entry["cancelled_at"] = time.time() 

172 entry["cancel_reason"] = reason 

173 cancel_cb = entry.get("cancel_callback") 

174 

175 logger.info("Tool execution cancelled (from Redis): run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown")) 

176 

177 if cancel_cb: 

178 try: 

179 await cancel_cb(reason) 

180 logger.info("Cancel callback executed for %s", run_id) 

181 except Exception as e: 

182 logger.exception("Error in cancel callback for %s: %s", run_id, e) 

183 

184 return True 

185 

186 async def register_run( 

187 self, 

188 run_id: str, 

189 name: Optional[str] = None, 

190 cancel_callback: Optional[CancelCallback] = None, 

191 owner_email: Optional[str] = None, 

192 owner_team_ids: Optional[List[str]] = None, 

193 ) -> None: 

194 """Register a run for future cancellation. 

195 

196 Args: 

197 run_id: Unique run identifier (string) 

198 name: Optional friendly name for debugging/observability 

199 cancel_callback: Optional async callback called when a cancel is requested 

200 owner_email: Optional email of the user who started the run 

201 owner_team_ids: Optional list of token team IDs associated with the run owner 

202 """ 

203 async with self._lock: 

204 self._runs[run_id] = { 

205 "name": name, 

206 "registered_at": time.time(), 

207 "cancel_callback": cancel_callback, 

208 "cancelled": False, 

209 "owner_email": owner_email, 

210 "owner_team_ids": owner_team_ids or [], 

211 } 

212 logger.info("Registered run %s (%s)", run_id, name) 

213 

214 async def unregister_run(self, run_id: str) -> None: 

215 """Remove a run from tracking. 

216 

217 Args: 

218 run_id: Unique identifier for the run to unregister. 

219 """ 

220 async with self._lock: 

221 if run_id in self._runs: 

222 self._runs.pop(run_id, None) 

223 logger.info("Unregistered run %s", run_id) 

224 

225 async def cancel_run(self, run_id: str, reason: Optional[str] = None) -> bool: 

226 """Attempt to cancel a run. 

227 

228 Args: 

229 run_id: Unique identifier for the run to cancel. 

230 reason: Optional textual reason for the cancellation request. 

231 

232 Returns: 

233 bool: True if the run was found and cancellation was attempted (or already marked), 

234 False if the run was not known locally. 

235 """ 

236 cancel_cb = None 

237 entry = None 

238 

239 async with self._lock: 

240 entry = self._runs.get(run_id) 

241 if not entry: 

242 # Entry not found - will publish to Redis outside the lock 

243 pass 

244 elif entry.get("cancelled"): 

245 logger.debug("Run %s already cancelled", run_id) 

246 return True 

247 else: 

248 entry["cancelled"] = True 

249 entry["cancelled_at"] = time.time() 

250 entry["cancel_reason"] = reason 

251 cancel_cb = entry.get("cancel_callback") 

252 

253 # Handle unknown run case outside the lock 

254 if not entry: 

255 logger.info("Cancellation requested for unknown run %s (queued for remote peers)", run_id) 

256 # Publish to Redis for other workers (outside lock to avoid blocking) 

257 await self._publish_cancellation(run_id, reason) 

258 return False 

259 

260 # Log cancellation with reason and request_id for observability 

261 logger.info("Tool execution cancelled: run_id=%s, reason=%s, tool=%s", run_id, reason or "not specified", entry.get("name", "unknown")) 

262 

263 if cancel_cb: 

264 try: 

265 await cancel_cb(reason) 

266 logger.info("Cancel callback executed for %s", run_id) 

267 except Exception as e: 

268 logger.exception("Error in cancel callback for %s: %s", run_id, e) 

269 

270 # Publish to Redis for other workers 

271 await self._publish_cancellation(run_id, reason) 

272 

273 return True 

274 

275 async def _publish_cancellation(self, run_id: str, reason: Optional[str] = None) -> None: 

276 """Publish cancellation event to Redis for other workers. 

277 

278 Args: 

279 run_id: Unique identifier for the run being cancelled. 

280 reason: Optional textual reason for the cancellation. 

281 """ 

282 if not self._redis: 

283 return 

284 

285 try: 

286 message = json.dumps({"run_id": run_id, "reason": reason}) 

287 await self._redis.publish("cancellation:cancel", message) 

288 logger.debug("Published cancellation to Redis: run_id=%s", run_id) 

289 except Exception as e: 

290 logger.warning(f"Failed to publish cancellation to Redis: {e}") 

291 

292 async def get_status(self, run_id: str) -> Optional[Dict[str, Any]]: 

293 """Return the status dict for a run if known, else None. 

294 

295 Args: 

296 run_id: Unique identifier for the run to query. 

297 

298 Returns: 

299 Optional[Dict[str, Any]]: The status dictionary for the run if found, otherwise None. 

300 """ 

301 async with self._lock: 

302 return self._runs.get(run_id) 

303 

304 async def is_registered(self, run_id: str) -> bool: 

305 """Check if a run is currently registered. 

306 

307 Args: 

308 run_id: Unique identifier for the run to check. 

309 

310 Returns: 

311 bool: True if the run is registered, False otherwise. 

312 """ 

313 async with self._lock: 

314 return run_id in self._runs 

315 

316 

317# Module-level singleton for importers to use 

318cancellation_service = CancellationService()