Coverage for mcpgateway / services / elicitation_service.py: 94%
133 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Elicitation service for tracking and routing elicitation requests.
4This service manages the lifecycle of MCP elicitation requests, which allow
5servers to request structured user input through connected clients.
7Per MCP specification 2025-06-18, elicitation follows a server→client request
8pattern where servers send elicitation/create requests, and clients respond
9with user input (accept/decline/cancel actions).
10"""
12# Standard
13import asyncio
14from dataclasses import dataclass, field
15import logging
16import time
17from typing import Any, Dict, Optional
18from uuid import uuid4
20# First-Party
21from mcpgateway.common.models import ElicitResult
23logger = logging.getLogger(__name__)
26@dataclass
27class PendingElicitation:
28 """Tracks a pending elicitation request awaiting client response.
30 Attributes:
31 request_id: Unique identifier for this elicitation request
32 upstream_session_id: Session that initiated the request (server)
33 downstream_session_id: Session handling the request (client)
34 created_at: Unix timestamp when request was created
35 timeout: Maximum wait time in seconds
36 message: User-facing message describing what input is needed
37 schema: JSON Schema defining expected response structure
38 future: AsyncIO future that resolves to ElicitResult when complete
39 """
41 request_id: str
42 upstream_session_id: str
43 downstream_session_id: str
44 created_at: float
45 timeout: float
46 message: str
47 schema: Dict[str, Any]
48 future: asyncio.Future = field(default_factory=asyncio.Future)
51class ElicitationService:
52 """Service for managing elicitation request lifecycle.
54 This service provides:
55 - Tracking of pending elicitation requests
56 - Response routing back to original requesters
57 - Timeout enforcement and cleanup
58 - Schema validation per MCP spec (primitive types only)
59 - Concurrency limits to prevent resource exhaustion
61 The service maintains a global registry of pending requests and ensures
62 proper cleanup through timeout enforcement and background cleanup tasks.
63 """
65 def __init__(
66 self,
67 default_timeout: int = 60,
68 max_concurrent: int = 100,
69 cleanup_interval: int = 300, # 5 minutes
70 ):
71 """Initialize the elicitation service.
73 Args:
74 default_timeout: Default timeout for elicitation requests (seconds)
75 max_concurrent: Maximum number of concurrent elicitations
76 cleanup_interval: How often to run cleanup task (seconds)
77 """
78 self.default_timeout = default_timeout
79 self.max_concurrent = max_concurrent
80 self.cleanup_interval = cleanup_interval
81 self._pending: Dict[str, PendingElicitation] = {}
82 self._cleanup_task: Optional[asyncio.Task] = None
83 logger.info(f"ElicitationService initialized: timeout={default_timeout}s, " f"max_concurrent={max_concurrent}, cleanup_interval={cleanup_interval}s")
85 async def start(self):
86 """Start background cleanup task."""
87 if self._cleanup_task is None or self._cleanup_task.done(): 87 ↛ exitline 87 didn't return from function 'start' because the condition on line 87 was always true
88 self._cleanup_task = asyncio.create_task(self._cleanup_loop())
89 logger.info("Elicitation cleanup task started")
91 async def shutdown(self):
92 """Shutdown service and cancel all pending requests."""
93 if self._cleanup_task: 93 ↛ 101line 93 didn't jump to line 101 because the condition on line 93 was always true
94 self._cleanup_task.cancel()
95 try:
96 await self._cleanup_task
97 except asyncio.CancelledError:
98 pass
100 # Cancel all pending requests
101 cancelled_count = 0
102 for elicitation in list(self._pending.values()):
103 if not elicitation.future.done(): 103 ↛ 102line 103 didn't jump to line 102 because the condition on line 103 was always true
104 elicitation.future.set_exception(RuntimeError("ElicitationService shutting down"))
105 cancelled_count += 1
107 self._pending.clear()
108 logger.info(f"ElicitationService shutdown complete (cancelled {cancelled_count} pending requests)")
110 async def create_elicitation(self, upstream_session_id: str, downstream_session_id: str, message: str, requested_schema: Dict[str, Any], timeout: Optional[float] = None) -> ElicitResult:
111 """Create and track an elicitation request.
113 This method initiates an elicitation request, validates the schema,
114 tracks the request, and awaits the client's response with timeout.
116 Args:
117 upstream_session_id: Session that initiated the request (server)
118 downstream_session_id: Session that will handle the request (client)
119 message: Message to present to user
120 requested_schema: JSON Schema for expected response
121 timeout: Optional timeout override (default: self.default_timeout)
123 Returns:
124 ElicitResult from the client containing action and optional content
126 Raises:
127 ValueError: If max concurrent limit reached or invalid schema
128 asyncio.TimeoutError: If request times out waiting for response
129 """
130 # Check concurrent limit
131 if len(self._pending) >= self.max_concurrent:
132 logger.warning(f"Max concurrent elicitations reached: {self.max_concurrent}")
133 raise ValueError(f"Maximum concurrent elicitations ({self.max_concurrent}) reached")
135 # Validate schema (primitive types only per MCP spec)
136 self._validate_schema(requested_schema)
138 # Create tracking entry
139 request_id = str(uuid4())
140 timeout_val = timeout if timeout is not None else self.default_timeout
141 future: asyncio.Future = asyncio.Future()
143 elicitation = PendingElicitation(
144 request_id=request_id,
145 upstream_session_id=upstream_session_id,
146 downstream_session_id=downstream_session_id,
147 created_at=time.time(),
148 timeout=timeout_val,
149 message=message,
150 schema=requested_schema,
151 future=future,
152 )
154 self._pending[request_id] = elicitation
155 logger.info(f"Created elicitation request {request_id}: upstream={upstream_session_id}, downstream={downstream_session_id}, timeout={timeout_val}s")
157 try:
158 # Wait for response with timeout
159 result = await asyncio.wait_for(future, timeout=timeout_val)
160 logger.info(f"Elicitation {request_id} completed: action={result.action}")
161 return result
162 except asyncio.TimeoutError:
163 logger.warning(f"Elicitation {request_id} timed out after {timeout_val}s")
164 raise
165 finally:
166 # Cleanup
167 self._pending.pop(request_id, None)
169 def complete_elicitation(self, request_id: str, result: ElicitResult) -> bool:
170 """Complete a pending elicitation with a result from the client.
172 Args:
173 request_id: ID of the elicitation request to complete
174 result: The client's response (action + optional content)
176 Returns:
177 True if request was found and completed, False otherwise
178 """
179 elicitation = self._pending.get(request_id)
180 if not elicitation:
181 logger.warning(f"Attempted to complete unknown elicitation: {request_id}")
182 return False
184 if elicitation.future.done():
185 logger.warning(f"Elicitation {request_id} already completed")
186 return False
188 elicitation.future.set_result(result)
189 logger.debug(f"Completed elicitation {request_id}: action={result.action}")
190 return True
192 def get_pending_elicitation(self, request_id: str) -> Optional[PendingElicitation]:
193 """Get a pending elicitation by ID.
195 Args:
196 request_id: The elicitation request ID to lookup
198 Returns:
199 PendingElicitation if found, None otherwise
200 """
201 return self._pending.get(request_id)
203 def get_pending_count(self) -> int:
204 """Get count of pending elicitations.
206 Returns:
207 Number of currently pending elicitation requests
208 """
209 return len(self._pending)
211 def get_pending_for_session(self, session_id: str) -> list[PendingElicitation]:
212 """Get all pending elicitations for a specific session.
214 Args:
215 session_id: Session ID to filter by (upstream or downstream)
217 Returns:
218 List of PendingElicitation objects involving this session
219 """
220 return [e for e in self._pending.values() if session_id in (e.upstream_session_id, e.downstream_session_id)]
222 async def _cleanup_loop(self):
223 """Background task to periodically clean up expired elicitations."""
224 while True:
225 try:
226 await asyncio.sleep(60) # Run every minute
227 await self._cleanup_expired()
228 except asyncio.CancelledError:
229 logger.info("Elicitation cleanup loop cancelled")
230 break
231 except Exception as e:
232 logger.error(f"Error in elicitation cleanup loop: {e}", exc_info=True)
234 async def _cleanup_expired(self):
235 """Remove expired elicitation requests that have timed out."""
236 now = time.time()
237 expired = []
239 for request_id, elicitation in self._pending.items():
240 age = now - elicitation.created_at
241 if age > elicitation.timeout: 241 ↛ 239line 241 didn't jump to line 239 because the condition on line 241 was always true
242 expired.append(request_id)
243 if not elicitation.future.done(): 243 ↛ 239line 243 didn't jump to line 239 because the condition on line 243 was always true
244 elicitation.future.set_exception(asyncio.TimeoutError(f"Elicitation expired after {age:.1f}s"))
246 for request_id in expired:
247 self._pending.pop(request_id, None)
249 if expired: 249 ↛ exitline 249 didn't return from function '_cleanup_expired' because the condition on line 249 was always true
250 logger.info(f"Cleaned up {len(expired)} expired elicitations")
252 def _validate_schema(self, schema: Dict[str, Any]):
253 """Validate that schema only contains primitive types per MCP spec.
255 MCP spec restricts elicitation schemas to flat objects with primitive properties:
256 - string (with optional format: email, uri, date, date-time)
257 - number / integer (with optional min/max)
258 - boolean
259 - enum (array of string values)
261 Complex types (nested objects, arrays, refs) are not allowed to keep
262 client implementation simple.
264 Args:
265 schema: JSON Schema object to validate
267 Raises:
268 ValueError: If schema contains complex types or invalid structure
269 """
270 if not isinstance(schema, dict):
271 raise ValueError("Schema must be an object")
273 if schema.get("type") != "object":
274 raise ValueError("Top-level schema must be type 'object'")
276 properties = schema.get("properties", {})
277 if not isinstance(properties, dict):
278 raise ValueError("Schema properties must be an object")
280 # Validate each property is primitive
281 allowed_types = {"string", "number", "integer", "boolean"}
282 allowed_formats = {"email", "uri", "date", "date-time"}
284 for prop_name, prop_schema in properties.items():
285 if not isinstance(prop_schema, dict): 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true
286 raise ValueError(f"Property '{prop_name}' schema must be an object")
288 prop_type = prop_schema.get("type")
289 if prop_type not in allowed_types:
290 raise ValueError(f"Property '{prop_name}' has invalid type '{prop_type}'. " f"Only primitive types allowed: {allowed_types}")
292 # Check for nested structures (not allowed per spec)
293 if "properties" in prop_schema or "items" in prop_schema:
294 raise ValueError(f"Property '{prop_name}' contains nested structure. " "MCP elicitation schemas must be flat.")
296 # Validate string format if present
297 if prop_type == "string" and "format" in prop_schema:
298 fmt = prop_schema["format"]
299 if fmt not in allowed_formats:
300 logger.warning(f"Property '{prop_name}' has non-standard format '{fmt}'. " f"Allowed formats: {allowed_formats}")
302 logger.debug(f"Schema validation passed: {len(properties)} properties")
305# Global singleton instance
306_elicitation_service: Optional[ElicitationService] = None
309def get_elicitation_service() -> ElicitationService:
310 """Get the global ElicitationService singleton instance.
312 Returns:
313 The global ElicitationService instance
314 """
315 global _elicitation_service # pylint: disable=global-statement
316 if _elicitation_service is None:
317 _elicitation_service = ElicitationService()
318 return _elicitation_service
321def set_elicitation_service(service: ElicitationService):
322 """Set the global ElicitationService instance.
324 This is primarily used for testing to inject mock services.
326 Args:
327 service: The ElicitationService instance to use globally
328 """
329 global _elicitation_service # pylint: disable=global-statement
330 _elicitation_service = service