Coverage for mcpgateway / services / elicitation_service.py: 94%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Elicitation service for tracking and routing elicitation requests. 

3 

4This service manages the lifecycle of MCP elicitation requests, which allow 

5servers to request structured user input through connected clients. 

6 

7Per MCP specification 2025-06-18, elicitation follows a server→client request 

8pattern where servers send elicitation/create requests, and clients respond 

9with user input (accept/decline/cancel actions). 

10""" 

11 

12# Standard 

13import asyncio 

14from dataclasses import dataclass, field 

15import logging 

16import time 

17from typing import Any, Dict, Optional 

18from uuid import uuid4 

19 

20# First-Party 

21from mcpgateway.common.models import ElicitResult 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26@dataclass 

27class PendingElicitation: 

28 """Tracks a pending elicitation request awaiting client response. 

29 

30 Attributes: 

31 request_id: Unique identifier for this elicitation request 

32 upstream_session_id: Session that initiated the request (server) 

33 downstream_session_id: Session handling the request (client) 

34 created_at: Unix timestamp when request was created 

35 timeout: Maximum wait time in seconds 

36 message: User-facing message describing what input is needed 

37 schema: JSON Schema defining expected response structure 

38 future: AsyncIO future that resolves to ElicitResult when complete 

39 """ 

40 

41 request_id: str 

42 upstream_session_id: str 

43 downstream_session_id: str 

44 created_at: float 

45 timeout: float 

46 message: str 

47 schema: Dict[str, Any] 

48 future: asyncio.Future = field(default_factory=asyncio.Future) 

49 

50 

51class ElicitationService: 

52 """Service for managing elicitation request lifecycle. 

53 

54 This service provides: 

55 - Tracking of pending elicitation requests 

56 - Response routing back to original requesters 

57 - Timeout enforcement and cleanup 

58 - Schema validation per MCP spec (primitive types only) 

59 - Concurrency limits to prevent resource exhaustion 

60 

61 The service maintains a global registry of pending requests and ensures 

62 proper cleanup through timeout enforcement and background cleanup tasks. 

63 """ 

64 

65 def __init__( 

66 self, 

67 default_timeout: int = 60, 

68 max_concurrent: int = 100, 

69 cleanup_interval: int = 300, # 5 minutes 

70 ): 

71 """Initialize the elicitation service. 

72 

73 Args: 

74 default_timeout: Default timeout for elicitation requests (seconds) 

75 max_concurrent: Maximum number of concurrent elicitations 

76 cleanup_interval: How often to run cleanup task (seconds) 

77 """ 

78 self.default_timeout = default_timeout 

79 self.max_concurrent = max_concurrent 

80 self.cleanup_interval = cleanup_interval 

81 self._pending: Dict[str, PendingElicitation] = {} 

82 self._cleanup_task: Optional[asyncio.Task] = None 

83 logger.info(f"ElicitationService initialized: timeout={default_timeout}s, " f"max_concurrent={max_concurrent}, cleanup_interval={cleanup_interval}s") 

84 

85 async def start(self): 

86 """Start background cleanup task.""" 

87 if self._cleanup_task is None or self._cleanup_task.done(): 87 ↛ exitline 87 didn't return from function 'start' because the condition on line 87 was always true

88 self._cleanup_task = asyncio.create_task(self._cleanup_loop()) 

89 logger.info("Elicitation cleanup task started") 

90 

91 async def shutdown(self): 

92 """Shutdown service and cancel all pending requests.""" 

93 if self._cleanup_task: 93 ↛ 101line 93 didn't jump to line 101 because the condition on line 93 was always true

94 self._cleanup_task.cancel() 

95 try: 

96 await self._cleanup_task 

97 except asyncio.CancelledError: 

98 pass 

99 

100 # Cancel all pending requests 

101 cancelled_count = 0 

102 for elicitation in list(self._pending.values()): 

103 if not elicitation.future.done(): 103 ↛ 102line 103 didn't jump to line 102 because the condition on line 103 was always true

104 elicitation.future.set_exception(RuntimeError("ElicitationService shutting down")) 

105 cancelled_count += 1 

106 

107 self._pending.clear() 

108 logger.info(f"ElicitationService shutdown complete (cancelled {cancelled_count} pending requests)") 

109 

110 async def create_elicitation(self, upstream_session_id: str, downstream_session_id: str, message: str, requested_schema: Dict[str, Any], timeout: Optional[float] = None) -> ElicitResult: 

111 """Create and track an elicitation request. 

112 

113 This method initiates an elicitation request, validates the schema, 

114 tracks the request, and awaits the client's response with timeout. 

115 

116 Args: 

117 upstream_session_id: Session that initiated the request (server) 

118 downstream_session_id: Session that will handle the request (client) 

119 message: Message to present to user 

120 requested_schema: JSON Schema for expected response 

121 timeout: Optional timeout override (default: self.default_timeout) 

122 

123 Returns: 

124 ElicitResult from the client containing action and optional content 

125 

126 Raises: 

127 ValueError: If max concurrent limit reached or invalid schema 

128 asyncio.TimeoutError: If request times out waiting for response 

129 """ 

130 # Check concurrent limit 

131 if len(self._pending) >= self.max_concurrent: 

132 logger.warning(f"Max concurrent elicitations reached: {self.max_concurrent}") 

133 raise ValueError(f"Maximum concurrent elicitations ({self.max_concurrent}) reached") 

134 

135 # Validate schema (primitive types only per MCP spec) 

136 self._validate_schema(requested_schema) 

137 

138 # Create tracking entry 

139 request_id = str(uuid4()) 

140 timeout_val = timeout if timeout is not None else self.default_timeout 

141 future: asyncio.Future = asyncio.Future() 

142 

143 elicitation = PendingElicitation( 

144 request_id=request_id, 

145 upstream_session_id=upstream_session_id, 

146 downstream_session_id=downstream_session_id, 

147 created_at=time.time(), 

148 timeout=timeout_val, 

149 message=message, 

150 schema=requested_schema, 

151 future=future, 

152 ) 

153 

154 self._pending[request_id] = elicitation 

155 logger.info(f"Created elicitation request {request_id}: upstream={upstream_session_id}, downstream={downstream_session_id}, timeout={timeout_val}s") 

156 

157 try: 

158 # Wait for response with timeout 

159 result = await asyncio.wait_for(future, timeout=timeout_val) 

160 logger.info(f"Elicitation {request_id} completed: action={result.action}") 

161 return result 

162 except asyncio.TimeoutError: 

163 logger.warning(f"Elicitation {request_id} timed out after {timeout_val}s") 

164 raise 

165 finally: 

166 # Cleanup 

167 self._pending.pop(request_id, None) 

168 

169 def complete_elicitation(self, request_id: str, result: ElicitResult) -> bool: 

170 """Complete a pending elicitation with a result from the client. 

171 

172 Args: 

173 request_id: ID of the elicitation request to complete 

174 result: The client's response (action + optional content) 

175 

176 Returns: 

177 True if request was found and completed, False otherwise 

178 """ 

179 elicitation = self._pending.get(request_id) 

180 if not elicitation: 

181 logger.warning(f"Attempted to complete unknown elicitation: {request_id}") 

182 return False 

183 

184 if elicitation.future.done(): 

185 logger.warning(f"Elicitation {request_id} already completed") 

186 return False 

187 

188 elicitation.future.set_result(result) 

189 logger.debug(f"Completed elicitation {request_id}: action={result.action}") 

190 return True 

191 

192 def get_pending_elicitation(self, request_id: str) -> Optional[PendingElicitation]: 

193 """Get a pending elicitation by ID. 

194 

195 Args: 

196 request_id: The elicitation request ID to lookup 

197 

198 Returns: 

199 PendingElicitation if found, None otherwise 

200 """ 

201 return self._pending.get(request_id) 

202 

203 def get_pending_count(self) -> int: 

204 """Get count of pending elicitations. 

205 

206 Returns: 

207 Number of currently pending elicitation requests 

208 """ 

209 return len(self._pending) 

210 

211 def get_pending_for_session(self, session_id: str) -> list[PendingElicitation]: 

212 """Get all pending elicitations for a specific session. 

213 

214 Args: 

215 session_id: Session ID to filter by (upstream or downstream) 

216 

217 Returns: 

218 List of PendingElicitation objects involving this session 

219 """ 

220 return [e for e in self._pending.values() if session_id in (e.upstream_session_id, e.downstream_session_id)] 

221 

222 async def _cleanup_loop(self): 

223 """Background task to periodically clean up expired elicitations.""" 

224 while True: 

225 try: 

226 await asyncio.sleep(60) # Run every minute 

227 await self._cleanup_expired() 

228 except asyncio.CancelledError: 

229 logger.info("Elicitation cleanup loop cancelled") 

230 break 

231 except Exception as e: 

232 logger.error(f"Error in elicitation cleanup loop: {e}", exc_info=True) 

233 

234 async def _cleanup_expired(self): 

235 """Remove expired elicitation requests that have timed out.""" 

236 now = time.time() 

237 expired = [] 

238 

239 for request_id, elicitation in self._pending.items(): 

240 age = now - elicitation.created_at 

241 if age > elicitation.timeout: 241 ↛ 239line 241 didn't jump to line 239 because the condition on line 241 was always true

242 expired.append(request_id) 

243 if not elicitation.future.done(): 243 ↛ 239line 243 didn't jump to line 239 because the condition on line 243 was always true

244 elicitation.future.set_exception(asyncio.TimeoutError(f"Elicitation expired after {age:.1f}s")) 

245 

246 for request_id in expired: 

247 self._pending.pop(request_id, None) 

248 

249 if expired: 249 ↛ exitline 249 didn't return from function '_cleanup_expired' because the condition on line 249 was always true

250 logger.info(f"Cleaned up {len(expired)} expired elicitations") 

251 

252 def _validate_schema(self, schema: Dict[str, Any]): 

253 """Validate that schema only contains primitive types per MCP spec. 

254 

255 MCP spec restricts elicitation schemas to flat objects with primitive properties: 

256 - string (with optional format: email, uri, date, date-time) 

257 - number / integer (with optional min/max) 

258 - boolean 

259 - enum (array of string values) 

260 

261 Complex types (nested objects, arrays, refs) are not allowed to keep 

262 client implementation simple. 

263 

264 Args: 

265 schema: JSON Schema object to validate 

266 

267 Raises: 

268 ValueError: If schema contains complex types or invalid structure 

269 """ 

270 if not isinstance(schema, dict): 

271 raise ValueError("Schema must be an object") 

272 

273 if schema.get("type") != "object": 

274 raise ValueError("Top-level schema must be type 'object'") 

275 

276 properties = schema.get("properties", {}) 

277 if not isinstance(properties, dict): 

278 raise ValueError("Schema properties must be an object") 

279 

280 # Validate each property is primitive 

281 allowed_types = {"string", "number", "integer", "boolean"} 

282 allowed_formats = {"email", "uri", "date", "date-time"} 

283 

284 for prop_name, prop_schema in properties.items(): 

285 if not isinstance(prop_schema, dict): 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true

286 raise ValueError(f"Property '{prop_name}' schema must be an object") 

287 

288 prop_type = prop_schema.get("type") 

289 if prop_type not in allowed_types: 

290 raise ValueError(f"Property '{prop_name}' has invalid type '{prop_type}'. " f"Only primitive types allowed: {allowed_types}") 

291 

292 # Check for nested structures (not allowed per spec) 

293 if "properties" in prop_schema or "items" in prop_schema: 

294 raise ValueError(f"Property '{prop_name}' contains nested structure. " "MCP elicitation schemas must be flat.") 

295 

296 # Validate string format if present 

297 if prop_type == "string" and "format" in prop_schema: 

298 fmt = prop_schema["format"] 

299 if fmt not in allowed_formats: 

300 logger.warning(f"Property '{prop_name}' has non-standard format '{fmt}'. " f"Allowed formats: {allowed_formats}") 

301 

302 logger.debug(f"Schema validation passed: {len(properties)} properties") 

303 

304 

305# Global singleton instance 

306_elicitation_service: Optional[ElicitationService] = None 

307 

308 

309def get_elicitation_service() -> ElicitationService: 

310 """Get the global ElicitationService singleton instance. 

311 

312 Returns: 

313 The global ElicitationService instance 

314 """ 

315 global _elicitation_service # pylint: disable=global-statement 

316 if _elicitation_service is None: 

317 _elicitation_service = ElicitationService() 

318 return _elicitation_service 

319 

320 

321def set_elicitation_service(service: ElicitationService): 

322 """Set the global ElicitationService instance. 

323 

324 This is primarily used for testing to inject mock services. 

325 

326 Args: 

327 service: The ElicitationService instance to use globally 

328 """ 

329 global _elicitation_service # pylint: disable=global-statement 

330 _elicitation_service = service