Coverage for mcpgateway / handlers / sampling.py: 98%
88 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/handlers/sampling.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7MCP Sampling Handler Implementation.
8This module implements the sampling handler for MCP LLM interactions.
9It handles model selection, sampling preferences, and message generation.
11Examples:
12 >>> import asyncio
13 >>> from mcpgateway.common.models import ModelPreferences
14 >>> handler = SamplingHandler()
15 >>> asyncio.run(handler.initialize())
16 >>>
17 >>> # Test model selection
18 >>> prefs = ModelPreferences(
19 ... cost_priority=0.2,
20 ... speed_priority=0.3,
21 ... intelligence_priority=0.5
22 ... )
23 >>> handler._select_model(prefs)
24 'claude-3-haiku'
25 >>>
26 >>> # Test message validation
27 >>> msg = {
28 ... "role": "user",
29 ... "content": {"type": "text", "text": "Hello"}
30 ... }
31 >>> handler._validate_message(msg)
32 True
33 >>>
34 >>> # Test mock sampling
35 >>> messages = [msg]
36 >>> response = handler._mock_sample(messages)
37 >>> print(response)
38 You said: Hello
39 Here is my response...
40 >>>
41 >>> asyncio.run(handler.shutdown())
42"""
44# Standard
45from typing import Any, Dict, List
47# Third-Party
48from sqlalchemy.orm import Session
50# First-Party
51from mcpgateway.common.models import CreateMessageResult, ModelPreferences, Role, TextContent
52from mcpgateway.services.logging_service import LoggingService
54# Initialize logging service first
55logging_service = LoggingService()
56logger = logging_service.get_logger(__name__)
59class SamplingError(Exception):
60 """Base class for sampling errors."""
63class SamplingHandler:
64 """MCP sampling request handler.
66 Handles:
67 - Model selection based on preferences
68 - Message sampling requests
69 - Context management
70 - Content validation
72 Examples:
73 >>> handler = SamplingHandler()
74 >>> handler._supported_models['claude-3-haiku']
75 (0.8, 0.9, 0.7)
76 >>> len(handler._supported_models)
77 4
78 """
80 def __init__(self):
81 """Initialize sampling handler.
83 Examples:
84 >>> handler = SamplingHandler()
85 >>> isinstance(handler._supported_models, dict)
86 True
87 >>> 'claude-3-opus' in handler._supported_models
88 True
89 >>> handler._supported_models['claude-3-sonnet']
90 (0.5, 0.7, 0.9)
91 """
92 self._supported_models = {
93 # Maps model names to capabilities scores (cost, speed, intelligence)
94 "claude-3-haiku": (0.8, 0.9, 0.7),
95 "claude-3-sonnet": (0.5, 0.7, 0.9),
96 "claude-3-opus": (0.2, 0.5, 1.0),
97 "gemini-1.5-pro": (0.6, 0.8, 0.8),
98 }
100 async def initialize(self) -> None:
101 """Initialize sampling handler.
103 Examples:
104 >>> import asyncio
105 >>> handler = SamplingHandler()
106 >>> asyncio.run(handler.initialize())
107 >>> # Handler is now initialized
108 """
109 logger.info("Initializing sampling handler")
111 async def shutdown(self) -> None:
112 """Shutdown sampling handler.
114 Examples:
115 >>> import asyncio
116 >>> handler = SamplingHandler()
117 >>> asyncio.run(handler.initialize())
118 >>> asyncio.run(handler.shutdown())
119 >>> # Handler is now shut down
120 """
121 logger.info("Shutting down sampling handler")
123 async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult:
124 """Create message from sampling request.
126 Args:
127 db: Database session
128 request: Sampling request parameters
130 Returns:
131 Sampled message result
133 Raises:
134 SamplingError: If sampling fails
136 Examples:
137 >>> import asyncio
138 >>> from unittest.mock import Mock
139 >>> handler = SamplingHandler()
140 >>> db = Mock()
141 >>>
142 >>> # Test with valid request
143 >>> request = {
144 ... "messages": [{
145 ... "role": "user",
146 ... "content": {"type": "text", "text": "Hello"}
147 ... }],
148 ... "maxTokens": 100,
149 ... "modelPreferences": {
150 ... "cost_priority": 0.3,
151 ... "speed_priority": 0.3,
152 ... "intelligence_priority": 0.4
153 ... }
154 ... }
155 >>> result = asyncio.run(handler.create_message(db, request))
156 >>> result.role
157 'assistant'
158 >>> result.content.type
159 'text'
160 >>> result.stop_reason
161 'maxTokens'
162 >>>
163 >>> # Test with no messages
164 >>> bad_request = {
165 ... "messages": [],
166 ... "maxTokens": 100,
167 ... "modelPreferences": {
168 ... "cost_priority": 0.3,
169 ... "speed_priority": 0.3,
170 ... "intelligence_priority": 0.4
171 ... }
172 ... }
173 >>> try:
174 ... asyncio.run(handler.create_message(db, bad_request))
175 ... except SamplingError as e:
176 ... print(str(e))
177 No messages provided
178 >>>
179 >>> # Test with no max tokens
180 >>> bad_request = {
181 ... "messages": [{"role": "user", "content": {"type": "text", "text": "Hi"}}],
182 ... "modelPreferences": {
183 ... "cost_priority": 0.3,
184 ... "speed_priority": 0.3,
185 ... "intelligence_priority": 0.4
186 ... }
187 ... }
188 >>> try:
189 ... asyncio.run(handler.create_message(db, bad_request))
190 ... except SamplingError as e:
191 ... print(str(e))
192 Max tokens not specified
193 """
194 try:
195 # Extract request parameters
196 messages = request.get("messages", [])
197 max_tokens = request.get("maxTokens")
198 model_prefs = ModelPreferences.model_validate(request.get("modelPreferences", {}))
199 include_context = request.get("includeContext", "none")
200 request.get("metadata", {})
202 # Validate request
203 if not messages:
204 raise SamplingError("No messages provided")
205 if not max_tokens:
206 raise SamplingError("Max tokens not specified")
208 # Select model
209 model = self._select_model(model_prefs)
210 logger.info(f"Selected model: {model}")
212 # Include context if requested
213 if include_context != "none": 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 messages = await self._add_context(db, messages, include_context)
216 # Validate messages
217 for msg in messages:
218 if not self._validate_message(msg):
219 raise SamplingError(f"Invalid message format: {msg}")
221 # TODO: Implement actual model sampling - currently returns mock response # pylint: disable=fixme
222 # For now return mock response
223 response = self._mock_sample(messages=messages)
225 # Convert to result
226 return CreateMessageResult(
227 content=TextContent(type="text", text=response),
228 model=model,
229 role=Role.ASSISTANT,
230 stop_reason="maxTokens",
231 )
233 except Exception as e:
234 logger.error(f"Sampling error: {e}")
235 raise SamplingError(str(e))
237 def _select_model(self, preferences: ModelPreferences) -> str:
238 """Select model based on preferences.
240 Args:
241 preferences: Model selection preferences
243 Returns:
244 Selected model name
246 Raises:
247 SamplingError: If no suitable model found
249 Examples:
250 >>> from mcpgateway.common.models import ModelPreferences, ModelHint
251 >>> handler = SamplingHandler()
252 >>>
253 >>> # Test intelligence priority
254 >>> prefs = ModelPreferences(
255 ... cost_priority=1.0,
256 ... speed_priority=0.0,
257 ... intelligence_priority=1.0
258 ... )
259 >>> handler._select_model(prefs)
260 'claude-3-opus'
261 >>>
262 >>> # Test speed priority
263 >>> prefs = ModelPreferences(
264 ... cost_priority=0.0,
265 ... speed_priority=1.0,
266 ... intelligence_priority=0.0
267 ... )
268 >>> handler._select_model(prefs)
269 'claude-3-haiku'
270 >>>
271 >>> # Test balanced preferences
272 >>> prefs = ModelPreferences(
273 ... cost_priority=0.33,
274 ... speed_priority=0.33,
275 ... intelligence_priority=0.34
276 ... )
277 >>> model = handler._select_model(prefs)
278 >>> model in handler._supported_models
279 True
280 >>>
281 >>> # Test with model hints
282 >>> prefs = ModelPreferences(
283 ... hints=[ModelHint(name="opus")],
284 ... cost_priority=0.5,
285 ... speed_priority=0.3,
286 ... intelligence_priority=0.2
287 ... )
288 >>> handler._select_model(prefs)
289 'claude-3-opus'
290 >>>
291 >>> # Test empty supported models (should raise error)
292 >>> handler._supported_models = {}
293 >>> try:
294 ... handler._select_model(prefs)
295 ... except SamplingError as e:
296 ... print(str(e))
297 No suitable model found
298 """
299 # Check model hints first
300 if preferences.hints:
301 for hint in preferences.hints:
302 for model in self._supported_models:
303 if hint.name and hint.name in model:
304 return model
306 # Score models on preferences
307 best_score = -1
308 best_model = None
310 for model, caps in self._supported_models.items():
311 cost_score = caps[0] * (1 - preferences.cost_priority)
312 speed_score = caps[1] * preferences.speed_priority
313 intel_score = caps[2] * preferences.intelligence_priority
315 total_score = (cost_score + speed_score + intel_score) / 3
317 if total_score > best_score:
318 best_score = total_score
319 best_model = model
321 if not best_model:
322 raise SamplingError("No suitable model found")
324 return best_model
326 async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _context_type: str) -> List[Dict[str, Any]]:
327 """Add context to messages.
329 Args:
330 _db: Database session
331 messages: Message list
332 _context_type: Context inclusion type
334 Returns:
335 Messages with added context
337 Examples:
338 >>> import asyncio
339 >>> from unittest.mock import Mock
340 >>> handler = SamplingHandler()
341 >>> db = Mock()
342 >>>
343 >>> messages = [
344 ... {"role": "user", "content": {"type": "text", "text": "Hello"}},
345 ... {"role": "assistant", "content": {"type": "text", "text": "Hi there!"}}
346 ... ]
347 >>>
348 >>> # Test with 'none' context type
349 >>> result = asyncio.run(handler._add_context(db, messages, "none"))
350 >>> result == messages
351 True
352 >>>
353 >>> # Test with 'all' context type (currently returns same messages)
354 >>> result = asyncio.run(handler._add_context(db, messages, "all"))
355 >>> result == messages
356 True
357 >>> len(result)
358 2
359 """
360 # TODO: Implement context gathering based on type - currently no-op # pylint: disable=fixme
361 # For now return original messages
362 return messages
364 def _validate_message(self, message: Dict[str, Any]) -> bool:
365 """Validate message format.
367 Args:
368 message: Message to validate
370 Returns:
371 True if valid
373 Examples:
374 >>> handler = SamplingHandler()
375 >>>
376 >>> # Valid text message
377 >>> msg = {"role": "user", "content": {"type": "text", "text": "Hello"}}
378 >>> handler._validate_message(msg)
379 True
380 >>>
381 >>> # Valid assistant message
382 >>> msg = {"role": "assistant", "content": {"type": "text", "text": "Hi!"}}
383 >>> handler._validate_message(msg)
384 True
385 >>>
386 >>> # Valid image message
387 >>> msg = {
388 ... "role": "user",
389 ... "content": {
390 ... "type": "image",
391 ... "data": "base64data",
392 ... "mime_type": "image/png"
393 ... }
394 ... }
395 >>> handler._validate_message(msg)
396 True
397 >>>
398 >>> # Missing role
399 >>> msg = {"content": {"type": "text", "text": "Hello"}}
400 >>> handler._validate_message(msg)
401 False
402 >>>
403 >>> # Invalid role
404 >>> msg = {"role": "system", "content": {"type": "text", "text": "Hello"}}
405 >>> handler._validate_message(msg)
406 False
407 >>>
408 >>> # Missing content
409 >>> msg = {"role": "user"}
410 >>> handler._validate_message(msg)
411 False
412 >>>
413 >>> # Invalid content type
414 >>> msg = {"role": "user", "content": {"type": "audio"}}
415 >>> handler._validate_message(msg)
416 False
417 >>>
418 >>> # Text content not string
419 >>> msg = {"role": "user", "content": {"type": "text", "text": 123}}
420 >>> handler._validate_message(msg)
421 False
422 >>>
423 >>> # Image missing data
424 >>> msg = {"role": "user", "content": {"type": "image", "mime_type": "image/png"}}
425 >>> handler._validate_message(msg)
426 False
427 >>>
428 >>> # Invalid structure
429 >>> handler._validate_message("not a dict")
430 False
431 """
432 try:
433 # Must have role and content
434 if "role" not in message or "content" not in message or message["role"] not in ("user", "assistant"):
435 return False
437 # Content must be valid
438 content = message["content"]
439 if content.get("type") == "text":
440 if not isinstance(content.get("text"), str):
441 return False
442 elif content.get("type") == "image":
443 if not (content.get("data") and content.get("mime_type")):
444 return False
445 else:
446 return False
448 return True
450 except Exception:
451 return False
453 def _mock_sample(
454 self,
455 messages: List[Dict[str, Any]],
456 ) -> str:
457 """Mock sampling response for testing.
459 Args:
460 messages: Input messages
462 Returns:
463 Sampled response text
465 Examples:
466 >>> handler = SamplingHandler()
467 >>>
468 >>> # Single user message
469 >>> messages = [{"role": "user", "content": {"type": "text", "text": "Hello world"}}]
470 >>> handler._mock_sample(messages)
471 'You said: Hello world\\nHere is my response...'
472 >>>
473 >>> # Conversation with multiple messages
474 >>> messages = [
475 ... {"role": "user", "content": {"type": "text", "text": "Hi"}},
476 ... {"role": "assistant", "content": {"type": "text", "text": "Hello!"}},
477 ... {"role": "user", "content": {"type": "text", "text": "How are you?"}}
478 ... ]
479 >>> handler._mock_sample(messages)
480 'You said: How are you?\\nHere is my response...'
481 >>>
482 >>> # Image message
483 >>> messages = [{
484 ... "role": "user",
485 ... "content": {"type": "image", "data": "base64", "mime_type": "image/png"}
486 ... }]
487 >>> handler._mock_sample(messages)
488 'You said: I see the image you shared.\\nHere is my response...'
489 >>>
490 >>> # No user messages
491 >>> messages = [{"role": "assistant", "content": {"type": "text", "text": "Hi"}}]
492 >>> handler._mock_sample(messages)
493 "I'm not sure what to respond to."
494 >>>
495 >>> # Empty messages
496 >>> handler._mock_sample([])
497 "I'm not sure what to respond to."
498 """
499 # Extract last user message
500 last_msg = None
501 for msg in reversed(messages):
502 if msg["role"] == "user":
503 last_msg = msg
504 break
506 if not last_msg:
507 return "I'm not sure what to respond to."
509 # Get user text
510 user_text = ""
511 content = last_msg["content"]
512 if content["type"] == "text":
513 user_text = content["text"]
514 elif content["type"] == "image": 514 ↛ 518line 514 didn't jump to line 518 because the condition on line 514 was always true
515 user_text = "I see the image you shared."
517 # Generate simple response
518 return f"You said: {user_text}\nHere is my response..."