Coverage for mcpgateway / services / elicitation_service.py: 97%
133 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Elicitation service for tracking and routing elicitation requests.
4This service manages the lifecycle of MCP elicitation requests, which allow
5servers to request structured user input through connected clients.
7Per MCP specification 2025-06-18, elicitation follows a server→client request
8pattern where servers send elicitation/create requests, and clients respond
9with user input (accept/decline/cancel actions).
10"""
12# Standard
13import asyncio
14from dataclasses import dataclass, field
15import logging
16import time
17from typing import Any, Dict, Optional
18from uuid import uuid4
20# First-Party
21from mcpgateway.common.models import ElicitResult
23logger = logging.getLogger(__name__)
26@dataclass
27class PendingElicitation:
28 """Tracks a pending elicitation request awaiting client response.
30 Attributes:
31 request_id: Unique identifier for this elicitation request
32 upstream_session_id: Session that initiated the request (server)
33 downstream_session_id: Session handling the request (client)
34 created_at: Unix timestamp when request was created
35 timeout: Maximum wait time in seconds
36 message: User-facing message describing what input is needed
37 schema: JSON Schema defining expected response structure
38 future: AsyncIO future that resolves to ElicitResult when complete
39 """
41 request_id: str
42 upstream_session_id: str
43 downstream_session_id: str
44 created_at: float
45 timeout: float
46 message: str
47 schema: Dict[str, Any]
48 future: asyncio.Future = field(default_factory=asyncio.Future)
51class ElicitationService:
52 """Service for managing elicitation request lifecycle.
54 This service provides:
55 - Tracking of pending elicitation requests
56 - Response routing back to original requesters
57 - Timeout enforcement and cleanup
58 - Schema validation per MCP spec (primitive types only)
59 - Concurrency limits to prevent resource exhaustion
61 The service maintains a global registry of pending requests and ensures
62 proper cleanup through timeout enforcement and background cleanup tasks.
63 """
65 def __init__(
66 self,
67 default_timeout: int = 60,
68 max_concurrent: int = 100,
69 cleanup_interval: int = 300, # 5 minutes
70 ):
71 """Initialize the elicitation service.
73 Args:
74 default_timeout: Default timeout for elicitation requests (seconds)
75 max_concurrent: Maximum number of concurrent elicitations
76 cleanup_interval: How often to run cleanup task (seconds)
77 """
78 self.default_timeout = default_timeout
79 self.max_concurrent = max_concurrent
80 self.cleanup_interval = cleanup_interval
81 self._pending: Dict[str, PendingElicitation] = {}
82 self._cleanup_task: Optional[asyncio.Task] = None
83 logger.info(f"ElicitationService initialized: timeout={default_timeout}s, " f"max_concurrent={max_concurrent}, cleanup_interval={cleanup_interval}s")
85 async def start(self):
86 """Start background cleanup task."""
87 if self._cleanup_task is None or self._cleanup_task.done():
88 self._cleanup_task = asyncio.create_task(self._cleanup_loop())
89 logger.info("Elicitation cleanup task started")
91 async def shutdown(self):
92 """Shutdown service and cancel all pending requests."""
93 if self._cleanup_task:
94 self._cleanup_task.cancel()
95 try:
96 await self._cleanup_task
97 except asyncio.CancelledError:
98 pass
100 # Cancel all pending requests
101 cancelled_count = 0
102 for elicitation in list(self._pending.values()):
103 if not elicitation.future.done():
104 elicitation.future.set_exception(RuntimeError("ElicitationService shutting down"))
105 cancelled_count += 1
107 self._pending.clear()
108 logger.info(f"ElicitationService shutdown complete (cancelled {cancelled_count} pending requests)")
110 async def create_elicitation(self, upstream_session_id: str, downstream_session_id: str, message: str, requested_schema: Dict[str, Any], timeout: Optional[float] = None) -> ElicitResult:
111 """Create and track an elicitation request.
113 This method initiates an elicitation request, validates the schema,
114 tracks the request, and awaits the client's response with timeout.
116 Args:
117 upstream_session_id: Session that initiated the request (server)
118 downstream_session_id: Session that will handle the request (client)
119 message: Message to present to user
120 requested_schema: JSON Schema for expected response
121 timeout: Optional timeout override (default: self.default_timeout)
123 Returns:
124 ElicitResult from the client containing action and optional content
126 Raises:
127 ValueError: If max concurrent limit reached or invalid schema
128 asyncio.TimeoutError: If request times out waiting for response
129 """
130 # Check concurrent limit
131 if len(self._pending) >= self.max_concurrent:
132 logger.warning(f"Max concurrent elicitations reached: {self.max_concurrent}")
133 raise ValueError(f"Maximum concurrent elicitations ({self.max_concurrent}) reached")
135 # Validate schema (primitive types only per MCP spec)
136 self._validate_schema(requested_schema)
138 # Create tracking entry
139 request_id = str(uuid4())
140 timeout_val = timeout if timeout is not None else self.default_timeout
141 future: asyncio.Future = asyncio.Future()
143 elicitation = PendingElicitation(
144 request_id=request_id,
145 upstream_session_id=upstream_session_id,
146 downstream_session_id=downstream_session_id,
147 created_at=time.time(),
148 timeout=timeout_val,
149 message=message,
150 schema=requested_schema,
151 future=future,
152 )
154 self._pending[request_id] = elicitation
155 logger.info(f"Created elicitation request {request_id}: upstream={upstream_session_id}, downstream={downstream_session_id}, timeout={timeout_val}s")
157 try:
158 # Wait for response with timeout
159 result = await asyncio.wait_for(future, timeout=timeout_val)
160 logger.info(f"Elicitation {request_id} completed: action={result.action}")
161 return result
162 except asyncio.TimeoutError:
163 logger.warning(f"Elicitation {request_id} timed out after {timeout_val}s")
164 raise
165 finally:
166 # Cleanup
167 self._pending.pop(request_id, None)
169 def complete_elicitation(self, request_id: str, result: ElicitResult) -> bool:
170 """Complete a pending elicitation with a result from the client.
172 Args:
173 request_id: ID of the elicitation request to complete
174 result: The client's response (action + optional content)
176 Returns:
177 True if request was found and completed, False otherwise
178 """
179 elicitation = self._pending.get(request_id)
180 if not elicitation:
181 logger.warning(f"Attempted to complete unknown elicitation: {request_id}")
182 return False
184 if elicitation.future.done():
185 logger.warning(f"Elicitation {request_id} already completed")
186 return False
188 elicitation.future.set_result(result)
189 logger.debug(f"Completed elicitation {request_id}: action={result.action}")
190 return True
192 def get_pending_elicitation(self, request_id: str) -> Optional[PendingElicitation]:
193 """Get a pending elicitation by ID.
195 Args:
196 request_id: The elicitation request ID to lookup
198 Returns:
199 PendingElicitation if found, None otherwise
200 """
201 return self._pending.get(request_id)
203 def get_pending_count(self) -> int:
204 """Get count of pending elicitations.
206 Returns:
207 Number of currently pending elicitation requests
208 """
209 return len(self._pending)
211 def get_pending_for_session(self, session_id: str) -> list[PendingElicitation]:
212 """Get all pending elicitations for a specific session.
214 Args:
215 session_id: Session ID to filter by (upstream or downstream)
217 Returns:
218 List of PendingElicitation objects involving this session
219 """
220 return [e for e in self._pending.values() if session_id in (e.upstream_session_id, e.downstream_session_id)]
222 async def _cleanup_loop(self):
223 """Background task to periodically clean up expired elicitations.
225 Raises:
226 asyncio.CancelledError: If the task is cancelled during shutdown.
227 """
228 while True:
229 try:
230 await asyncio.sleep(60) # Run every minute
231 await self._cleanup_expired()
232 except asyncio.CancelledError:
233 logger.info("Elicitation cleanup loop cancelled")
234 raise
235 except Exception as e:
236 logger.error(f"Error in elicitation cleanup loop: {e}", exc_info=True)
238 async def _cleanup_expired(self):
239 """Remove expired elicitation requests that have timed out."""
240 now = time.time()
241 expired = []
243 for request_id, elicitation in self._pending.items():
244 age = now - elicitation.created_at
245 if age > elicitation.timeout:
246 expired.append(request_id)
247 if not elicitation.future.done():
248 elicitation.future.set_exception(asyncio.TimeoutError(f"Elicitation expired after {age:.1f}s"))
250 for request_id in expired:
251 self._pending.pop(request_id, None)
253 if expired:
254 logger.info(f"Cleaned up {len(expired)} expired elicitations")
256 def _validate_schema(self, schema: Dict[str, Any]):
257 """Validate that schema only contains primitive types per MCP spec.
259 MCP spec restricts elicitation schemas to flat objects with primitive properties:
260 - string (with optional format: email, uri, date, date-time)
261 - number / integer (with optional min/max)
262 - boolean
263 - enum (array of string values)
265 Complex types (nested objects, arrays, refs) are not allowed to keep
266 client implementation simple.
268 Args:
269 schema: JSON Schema object to validate
271 Raises:
272 ValueError: If schema contains complex types or invalid structure
273 """
274 if not isinstance(schema, dict):
275 raise ValueError("Schema must be an object")
277 if schema.get("type") != "object":
278 raise ValueError("Top-level schema must be type 'object'")
280 properties = schema.get("properties", {})
281 if not isinstance(properties, dict):
282 raise ValueError("Schema properties must be an object")
284 # Validate each property is primitive
285 allowed_types = {"string", "number", "integer", "boolean"}
286 allowed_formats = {"email", "uri", "date", "date-time"}
288 for prop_name, prop_schema in properties.items():
289 if not isinstance(prop_schema, dict):
290 raise ValueError(f"Property '{prop_name}' schema must be an object")
292 prop_type = prop_schema.get("type")
293 if prop_type not in allowed_types:
294 raise ValueError(f"Property '{prop_name}' has invalid type '{prop_type}'. " f"Only primitive types allowed: {allowed_types}")
296 # Check for nested structures (not allowed per spec)
297 if "properties" in prop_schema or "items" in prop_schema:
298 raise ValueError(f"Property '{prop_name}' contains nested structure. " "MCP elicitation schemas must be flat.")
300 # Validate string format if present
301 if prop_type == "string" and "format" in prop_schema:
302 fmt = prop_schema["format"]
303 if fmt not in allowed_formats:
304 logger.warning(f"Property '{prop_name}' has non-standard format '{fmt}'. " f"Allowed formats: {allowed_formats}")
306 logger.debug(f"Schema validation passed: {len(properties)} properties")
309# Global singleton instance
310_elicitation_service: Optional[ElicitationService] = None
313def get_elicitation_service() -> ElicitationService:
314 """Get the global ElicitationService singleton instance.
316 Returns:
317 The global ElicitationService instance
318 """
319 global _elicitation_service # pylint: disable=global-statement
320 if _elicitation_service is None:
321 _elicitation_service = ElicitationService()
322 return _elicitation_service
325def set_elicitation_service(service: ElicitationService):
326 """Set the global ElicitationService instance.
328 This is primarily used for testing to inject mock services.
330 Args:
331 service: The ElicitationService instance to use globally
332 """
333 global _elicitation_service # pylint: disable=global-statement
334 _elicitation_service = service