Coverage for mcpgateway / services / elicitation_service.py: 97%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Elicitation service for tracking and routing elicitation requests. 

3 

4This service manages the lifecycle of MCP elicitation requests, which allow 

5servers to request structured user input through connected clients. 

6 

7Per MCP specification 2025-06-18, elicitation follows a server→client request 

8pattern where servers send elicitation/create requests, and clients respond 

9with user input (accept/decline/cancel actions). 

10""" 

11 

12# Standard 

13import asyncio 

14from dataclasses import dataclass, field 

15import logging 

16import time 

17from typing import Any, Dict, Optional 

18from uuid import uuid4 

19 

20# First-Party 

21from mcpgateway.common.models import ElicitResult 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26@dataclass 

27class PendingElicitation: 

28 """Tracks a pending elicitation request awaiting client response. 

29 

30 Attributes: 

31 request_id: Unique identifier for this elicitation request 

32 upstream_session_id: Session that initiated the request (server) 

33 downstream_session_id: Session handling the request (client) 

34 created_at: Unix timestamp when request was created 

35 timeout: Maximum wait time in seconds 

36 message: User-facing message describing what input is needed 

37 schema: JSON Schema defining expected response structure 

38 future: AsyncIO future that resolves to ElicitResult when complete 

39 """ 

40 

41 request_id: str 

42 upstream_session_id: str 

43 downstream_session_id: str 

44 created_at: float 

45 timeout: float 

46 message: str 

47 schema: Dict[str, Any] 

48 future: asyncio.Future = field(default_factory=asyncio.Future) 

49 

50 

51class ElicitationService: 

52 """Service for managing elicitation request lifecycle. 

53 

54 This service provides: 

55 - Tracking of pending elicitation requests 

56 - Response routing back to original requesters 

57 - Timeout enforcement and cleanup 

58 - Schema validation per MCP spec (primitive types only) 

59 - Concurrency limits to prevent resource exhaustion 

60 

61 The service maintains a global registry of pending requests and ensures 

62 proper cleanup through timeout enforcement and background cleanup tasks. 

63 """ 

64 

65 def __init__( 

66 self, 

67 default_timeout: int = 60, 

68 max_concurrent: int = 100, 

69 cleanup_interval: int = 300, # 5 minutes 

70 ): 

71 """Initialize the elicitation service. 

72 

73 Args: 

74 default_timeout: Default timeout for elicitation requests (seconds) 

75 max_concurrent: Maximum number of concurrent elicitations 

76 cleanup_interval: How often to run cleanup task (seconds) 

77 """ 

78 self.default_timeout = default_timeout 

79 self.max_concurrent = max_concurrent 

80 self.cleanup_interval = cleanup_interval 

81 self._pending: Dict[str, PendingElicitation] = {} 

82 self._cleanup_task: Optional[asyncio.Task] = None 

83 logger.info(f"ElicitationService initialized: timeout={default_timeout}s, " f"max_concurrent={max_concurrent}, cleanup_interval={cleanup_interval}s") 

84 

85 async def start(self): 

86 """Start background cleanup task.""" 

87 if self._cleanup_task is None or self._cleanup_task.done(): 

88 self._cleanup_task = asyncio.create_task(self._cleanup_loop()) 

89 logger.info("Elicitation cleanup task started") 

90 

91 async def shutdown(self): 

92 """Shutdown service and cancel all pending requests.""" 

93 if self._cleanup_task: 

94 self._cleanup_task.cancel() 

95 try: 

96 await self._cleanup_task 

97 except asyncio.CancelledError: 

98 pass 

99 

100 # Cancel all pending requests 

101 cancelled_count = 0 

102 for elicitation in list(self._pending.values()): 

103 if not elicitation.future.done(): 

104 elicitation.future.set_exception(RuntimeError("ElicitationService shutting down")) 

105 cancelled_count += 1 

106 

107 self._pending.clear() 

108 logger.info(f"ElicitationService shutdown complete (cancelled {cancelled_count} pending requests)") 

109 

110 async def create_elicitation(self, upstream_session_id: str, downstream_session_id: str, message: str, requested_schema: Dict[str, Any], timeout: Optional[float] = None) -> ElicitResult: 

111 """Create and track an elicitation request. 

112 

113 This method initiates an elicitation request, validates the schema, 

114 tracks the request, and awaits the client's response with timeout. 

115 

116 Args: 

117 upstream_session_id: Session that initiated the request (server) 

118 downstream_session_id: Session that will handle the request (client) 

119 message: Message to present to user 

120 requested_schema: JSON Schema for expected response 

121 timeout: Optional timeout override (default: self.default_timeout) 

122 

123 Returns: 

124 ElicitResult from the client containing action and optional content 

125 

126 Raises: 

127 ValueError: If max concurrent limit reached or invalid schema 

128 asyncio.TimeoutError: If request times out waiting for response 

129 """ 

130 # Check concurrent limit 

131 if len(self._pending) >= self.max_concurrent: 

132 logger.warning(f"Max concurrent elicitations reached: {self.max_concurrent}") 

133 raise ValueError(f"Maximum concurrent elicitations ({self.max_concurrent}) reached") 

134 

135 # Validate schema (primitive types only per MCP spec) 

136 self._validate_schema(requested_schema) 

137 

138 # Create tracking entry 

139 request_id = str(uuid4()) 

140 timeout_val = timeout if timeout is not None else self.default_timeout 

141 future: asyncio.Future = asyncio.Future() 

142 

143 elicitation = PendingElicitation( 

144 request_id=request_id, 

145 upstream_session_id=upstream_session_id, 

146 downstream_session_id=downstream_session_id, 

147 created_at=time.time(), 

148 timeout=timeout_val, 

149 message=message, 

150 schema=requested_schema, 

151 future=future, 

152 ) 

153 

154 self._pending[request_id] = elicitation 

155 logger.info(f"Created elicitation request {request_id}: upstream={upstream_session_id}, downstream={downstream_session_id}, timeout={timeout_val}s") 

156 

157 try: 

158 # Wait for response with timeout 

159 result = await asyncio.wait_for(future, timeout=timeout_val) 

160 logger.info(f"Elicitation {request_id} completed: action={result.action}") 

161 return result 

162 except asyncio.TimeoutError: 

163 logger.warning(f"Elicitation {request_id} timed out after {timeout_val}s") 

164 raise 

165 finally: 

166 # Cleanup 

167 self._pending.pop(request_id, None) 

168 

169 def complete_elicitation(self, request_id: str, result: ElicitResult) -> bool: 

170 """Complete a pending elicitation with a result from the client. 

171 

172 Args: 

173 request_id: ID of the elicitation request to complete 

174 result: The client's response (action + optional content) 

175 

176 Returns: 

177 True if request was found and completed, False otherwise 

178 """ 

179 elicitation = self._pending.get(request_id) 

180 if not elicitation: 

181 logger.warning(f"Attempted to complete unknown elicitation: {request_id}") 

182 return False 

183 

184 if elicitation.future.done(): 

185 logger.warning(f"Elicitation {request_id} already completed") 

186 return False 

187 

188 elicitation.future.set_result(result) 

189 logger.debug(f"Completed elicitation {request_id}: action={result.action}") 

190 return True 

191 

192 def get_pending_elicitation(self, request_id: str) -> Optional[PendingElicitation]: 

193 """Get a pending elicitation by ID. 

194 

195 Args: 

196 request_id: The elicitation request ID to lookup 

197 

198 Returns: 

199 PendingElicitation if found, None otherwise 

200 """ 

201 return self._pending.get(request_id) 

202 

203 def get_pending_count(self) -> int: 

204 """Get count of pending elicitations. 

205 

206 Returns: 

207 Number of currently pending elicitation requests 

208 """ 

209 return len(self._pending) 

210 

211 def get_pending_for_session(self, session_id: str) -> list[PendingElicitation]: 

212 """Get all pending elicitations for a specific session. 

213 

214 Args: 

215 session_id: Session ID to filter by (upstream or downstream) 

216 

217 Returns: 

218 List of PendingElicitation objects involving this session 

219 """ 

220 return [e for e in self._pending.values() if session_id in (e.upstream_session_id, e.downstream_session_id)] 

221 

222 async def _cleanup_loop(self): 

223 """Background task to periodically clean up expired elicitations. 

224 

225 Raises: 

226 asyncio.CancelledError: If the task is cancelled during shutdown. 

227 """ 

228 while True: 

229 try: 

230 await asyncio.sleep(60) # Run every minute 

231 await self._cleanup_expired() 

232 except asyncio.CancelledError: 

233 logger.info("Elicitation cleanup loop cancelled") 

234 raise 

235 except Exception as e: 

236 logger.error(f"Error in elicitation cleanup loop: {e}", exc_info=True) 

237 

238 async def _cleanup_expired(self): 

239 """Remove expired elicitation requests that have timed out.""" 

240 now = time.time() 

241 expired = [] 

242 

243 for request_id, elicitation in self._pending.items(): 

244 age = now - elicitation.created_at 

245 if age > elicitation.timeout: 

246 expired.append(request_id) 

247 if not elicitation.future.done(): 

248 elicitation.future.set_exception(asyncio.TimeoutError(f"Elicitation expired after {age:.1f}s")) 

249 

250 for request_id in expired: 

251 self._pending.pop(request_id, None) 

252 

253 if expired: 

254 logger.info(f"Cleaned up {len(expired)} expired elicitations") 

255 

256 def _validate_schema(self, schema: Dict[str, Any]): 

257 """Validate that schema only contains primitive types per MCP spec. 

258 

259 MCP spec restricts elicitation schemas to flat objects with primitive properties: 

260 - string (with optional format: email, uri, date, date-time) 

261 - number / integer (with optional min/max) 

262 - boolean 

263 - enum (array of string values) 

264 

265 Complex types (nested objects, arrays, refs) are not allowed to keep 

266 client implementation simple. 

267 

268 Args: 

269 schema: JSON Schema object to validate 

270 

271 Raises: 

272 ValueError: If schema contains complex types or invalid structure 

273 """ 

274 if not isinstance(schema, dict): 

275 raise ValueError("Schema must be an object") 

276 

277 if schema.get("type") != "object": 

278 raise ValueError("Top-level schema must be type 'object'") 

279 

280 properties = schema.get("properties", {}) 

281 if not isinstance(properties, dict): 

282 raise ValueError("Schema properties must be an object") 

283 

284 # Validate each property is primitive 

285 allowed_types = {"string", "number", "integer", "boolean"} 

286 allowed_formats = {"email", "uri", "date", "date-time"} 

287 

288 for prop_name, prop_schema in properties.items(): 

289 if not isinstance(prop_schema, dict): 

290 raise ValueError(f"Property '{prop_name}' schema must be an object") 

291 

292 prop_type = prop_schema.get("type") 

293 if prop_type not in allowed_types: 

294 raise ValueError(f"Property '{prop_name}' has invalid type '{prop_type}'. " f"Only primitive types allowed: {allowed_types}") 

295 

296 # Check for nested structures (not allowed per spec) 

297 if "properties" in prop_schema or "items" in prop_schema: 

298 raise ValueError(f"Property '{prop_name}' contains nested structure. " "MCP elicitation schemas must be flat.") 

299 

300 # Validate string format if present 

301 if prop_type == "string" and "format" in prop_schema: 

302 fmt = prop_schema["format"] 

303 if fmt not in allowed_formats: 

304 logger.warning(f"Property '{prop_name}' has non-standard format '{fmt}'. " f"Allowed formats: {allowed_formats}") 

305 

306 logger.debug(f"Schema validation passed: {len(properties)} properties") 

307 

308 

309# Global singleton instance 

310_elicitation_service: Optional[ElicitationService] = None 

311 

312 

313def get_elicitation_service() -> ElicitationService: 

314 """Get the global ElicitationService singleton instance. 

315 

316 Returns: 

317 The global ElicitationService instance 

318 """ 

319 global _elicitation_service # pylint: disable=global-statement 

320 if _elicitation_service is None: 

321 _elicitation_service = ElicitationService() 

322 return _elicitation_service 

323 

324 

325def set_elicitation_service(service: ElicitationService): 

326 """Set the global ElicitationService instance. 

327 

328 This is primarily used for testing to inject mock services. 

329 

330 Args: 

331 service: The ElicitationService instance to use globally 

332 """ 

333 global _elicitation_service # pylint: disable=global-statement 

334 _elicitation_service = service