Coverage for mcpgateway / handlers / sampling.py: 98%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/handlers/sampling.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7MCP Sampling Handler Implementation. 

8This module implements the sampling handler for MCP LLM interactions. 

9It handles model selection, sampling preferences, and message generation. 

10 

11Examples: 

12 >>> import asyncio 

13 >>> from mcpgateway.common.models import ModelPreferences 

14 >>> handler = SamplingHandler() 

15 >>> asyncio.run(handler.initialize()) 

16 >>> 

17 >>> # Test model selection 

18 >>> prefs = ModelPreferences( 

19 ... cost_priority=0.2, 

20 ... speed_priority=0.3, 

21 ... intelligence_priority=0.5 

22 ... ) 

23 >>> handler._select_model(prefs) 

24 'claude-3-haiku' 

25 >>> 

26 >>> # Test message validation 

27 >>> msg = { 

28 ... "role": "user", 

29 ... "content": {"type": "text", "text": "Hello"} 

30 ... } 

31 >>> handler._validate_message(msg) 

32 True 

33 >>> 

34 >>> # Test mock sampling 

35 >>> messages = [msg] 

36 >>> response = handler._mock_sample(messages) 

37 >>> print(response) 

38 You said: Hello 

39 Here is my response... 

40 >>> 

41 >>> asyncio.run(handler.shutdown()) 

42""" 

43 

44# Standard 

45from typing import Any, Dict, List 

46 

47# Third-Party 

48from sqlalchemy.orm import Session 

49 

50# First-Party 

51from mcpgateway.common.models import CreateMessageResult, ModelPreferences, Role, TextContent 

52from mcpgateway.services.logging_service import LoggingService 

53 

54# Initialize logging service first 

55logging_service = LoggingService() 

56logger = logging_service.get_logger(__name__) 

57 

58 

59class SamplingError(Exception): 

60 """Base class for sampling errors.""" 

61 

62 

63class SamplingHandler: 

64 """MCP sampling request handler. 

65 

66 Handles: 

67 - Model selection based on preferences 

68 - Message sampling requests 

69 - Context management 

70 - Content validation 

71 

72 Examples: 

73 >>> handler = SamplingHandler() 

74 >>> handler._supported_models['claude-3-haiku'] 

75 (0.8, 0.9, 0.7) 

76 >>> len(handler._supported_models) 

77 4 

78 """ 

79 

80 def __init__(self): 

81 """Initialize sampling handler. 

82 

83 Examples: 

84 >>> handler = SamplingHandler() 

85 >>> isinstance(handler._supported_models, dict) 

86 True 

87 >>> 'claude-3-opus' in handler._supported_models 

88 True 

89 >>> handler._supported_models['claude-3-sonnet'] 

90 (0.5, 0.7, 0.9) 

91 """ 

92 self._supported_models = { 

93 # Maps model names to capabilities scores (cost, speed, intelligence) 

94 "claude-3-haiku": (0.8, 0.9, 0.7), 

95 "claude-3-sonnet": (0.5, 0.7, 0.9), 

96 "claude-3-opus": (0.2, 0.5, 1.0), 

97 "gemini-1.5-pro": (0.6, 0.8, 0.8), 

98 } 

99 

100 async def initialize(self) -> None: 

101 """Initialize sampling handler. 

102 

103 Examples: 

104 >>> import asyncio 

105 >>> handler = SamplingHandler() 

106 >>> asyncio.run(handler.initialize()) 

107 >>> # Handler is now initialized 

108 """ 

109 logger.info("Initializing sampling handler") 

110 

111 async def shutdown(self) -> None: 

112 """Shutdown sampling handler. 

113 

114 Examples: 

115 >>> import asyncio 

116 >>> handler = SamplingHandler() 

117 >>> asyncio.run(handler.initialize()) 

118 >>> asyncio.run(handler.shutdown()) 

119 >>> # Handler is now shut down 

120 """ 

121 logger.info("Shutting down sampling handler") 

122 

123 async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult: 

124 """Create message from sampling request. 

125 

126 Args: 

127 db: Database session 

128 request: Sampling request parameters 

129 

130 Returns: 

131 Sampled message result 

132 

133 Raises: 

134 SamplingError: If sampling fails 

135 

136 Examples: 

137 >>> import asyncio 

138 >>> from unittest.mock import Mock 

139 >>> handler = SamplingHandler() 

140 >>> db = Mock() 

141 >>> 

142 >>> # Test with valid request 

143 >>> request = { 

144 ... "messages": [{ 

145 ... "role": "user", 

146 ... "content": {"type": "text", "text": "Hello"} 

147 ... }], 

148 ... "maxTokens": 100, 

149 ... "modelPreferences": { 

150 ... "cost_priority": 0.3, 

151 ... "speed_priority": 0.3, 

152 ... "intelligence_priority": 0.4 

153 ... } 

154 ... } 

155 >>> result = asyncio.run(handler.create_message(db, request)) 

156 >>> result.role 

157 'assistant' 

158 >>> result.content.type 

159 'text' 

160 >>> result.stop_reason 

161 'maxTokens' 

162 >>> 

163 >>> # Test with no messages 

164 >>> bad_request = { 

165 ... "messages": [], 

166 ... "maxTokens": 100, 

167 ... "modelPreferences": { 

168 ... "cost_priority": 0.3, 

169 ... "speed_priority": 0.3, 

170 ... "intelligence_priority": 0.4 

171 ... } 

172 ... } 

173 >>> try: 

174 ... asyncio.run(handler.create_message(db, bad_request)) 

175 ... except SamplingError as e: 

176 ... print(str(e)) 

177 No messages provided 

178 >>> 

179 >>> # Test with no max tokens 

180 >>> bad_request = { 

181 ... "messages": [{"role": "user", "content": {"type": "text", "text": "Hi"}}], 

182 ... "modelPreferences": { 

183 ... "cost_priority": 0.3, 

184 ... "speed_priority": 0.3, 

185 ... "intelligence_priority": 0.4 

186 ... } 

187 ... } 

188 >>> try: 

189 ... asyncio.run(handler.create_message(db, bad_request)) 

190 ... except SamplingError as e: 

191 ... print(str(e)) 

192 Max tokens not specified 

193 """ 

194 try: 

195 # Extract request parameters 

196 messages = request.get("messages", []) 

197 max_tokens = request.get("maxTokens") 

198 model_prefs = ModelPreferences.model_validate(request.get("modelPreferences", {})) 

199 include_context = request.get("includeContext", "none") 

200 request.get("metadata", {}) 

201 

202 # Validate request 

203 if not messages: 

204 raise SamplingError("No messages provided") 

205 if not max_tokens: 

206 raise SamplingError("Max tokens not specified") 

207 

208 # Select model 

209 model = self._select_model(model_prefs) 

210 logger.info(f"Selected model: {model}") 

211 

212 # Include context if requested 

213 if include_context != "none": 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 messages = await self._add_context(db, messages, include_context) 

215 

216 # Validate messages 

217 for msg in messages: 

218 if not self._validate_message(msg): 

219 raise SamplingError(f"Invalid message format: {msg}") 

220 

221 # TODO: Implement actual model sampling - currently returns mock response # pylint: disable=fixme 

222 # For now return mock response 

223 response = self._mock_sample(messages=messages) 

224 

225 # Convert to result 

226 return CreateMessageResult( 

227 content=TextContent(type="text", text=response), 

228 model=model, 

229 role=Role.ASSISTANT, 

230 stop_reason="maxTokens", 

231 ) 

232 

233 except Exception as e: 

234 logger.error(f"Sampling error: {e}") 

235 raise SamplingError(str(e)) 

236 

237 def _select_model(self, preferences: ModelPreferences) -> str: 

238 """Select model based on preferences. 

239 

240 Args: 

241 preferences: Model selection preferences 

242 

243 Returns: 

244 Selected model name 

245 

246 Raises: 

247 SamplingError: If no suitable model found 

248 

249 Examples: 

250 >>> from mcpgateway.common.models import ModelPreferences, ModelHint 

251 >>> handler = SamplingHandler() 

252 >>> 

253 >>> # Test intelligence priority 

254 >>> prefs = ModelPreferences( 

255 ... cost_priority=1.0, 

256 ... speed_priority=0.0, 

257 ... intelligence_priority=1.0 

258 ... ) 

259 >>> handler._select_model(prefs) 

260 'claude-3-opus' 

261 >>> 

262 >>> # Test speed priority 

263 >>> prefs = ModelPreferences( 

264 ... cost_priority=0.0, 

265 ... speed_priority=1.0, 

266 ... intelligence_priority=0.0 

267 ... ) 

268 >>> handler._select_model(prefs) 

269 'claude-3-haiku' 

270 >>> 

271 >>> # Test balanced preferences 

272 >>> prefs = ModelPreferences( 

273 ... cost_priority=0.33, 

274 ... speed_priority=0.33, 

275 ... intelligence_priority=0.34 

276 ... ) 

277 >>> model = handler._select_model(prefs) 

278 >>> model in handler._supported_models 

279 True 

280 >>> 

281 >>> # Test with model hints 

282 >>> prefs = ModelPreferences( 

283 ... hints=[ModelHint(name="opus")], 

284 ... cost_priority=0.5, 

285 ... speed_priority=0.3, 

286 ... intelligence_priority=0.2 

287 ... ) 

288 >>> handler._select_model(prefs) 

289 'claude-3-opus' 

290 >>> 

291 >>> # Test empty supported models (should raise error) 

292 >>> handler._supported_models = {} 

293 >>> try: 

294 ... handler._select_model(prefs) 

295 ... except SamplingError as e: 

296 ... print(str(e)) 

297 No suitable model found 

298 """ 

299 # Check model hints first 

300 if preferences.hints: 

301 for hint in preferences.hints: 

302 for model in self._supported_models: 

303 if hint.name and hint.name in model: 

304 return model 

305 

306 # Score models on preferences 

307 best_score = -1 

308 best_model = None 

309 

310 for model, caps in self._supported_models.items(): 

311 cost_score = caps[0] * (1 - preferences.cost_priority) 

312 speed_score = caps[1] * preferences.speed_priority 

313 intel_score = caps[2] * preferences.intelligence_priority 

314 

315 total_score = (cost_score + speed_score + intel_score) / 3 

316 

317 if total_score > best_score: 

318 best_score = total_score 

319 best_model = model 

320 

321 if not best_model: 

322 raise SamplingError("No suitable model found") 

323 

324 return best_model 

325 

326 async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _context_type: str) -> List[Dict[str, Any]]: 

327 """Add context to messages. 

328 

329 Args: 

330 _db: Database session 

331 messages: Message list 

332 _context_type: Context inclusion type 

333 

334 Returns: 

335 Messages with added context 

336 

337 Examples: 

338 >>> import asyncio 

339 >>> from unittest.mock import Mock 

340 >>> handler = SamplingHandler() 

341 >>> db = Mock() 

342 >>> 

343 >>> messages = [ 

344 ... {"role": "user", "content": {"type": "text", "text": "Hello"}}, 

345 ... {"role": "assistant", "content": {"type": "text", "text": "Hi there!"}} 

346 ... ] 

347 >>> 

348 >>> # Test with 'none' context type 

349 >>> result = asyncio.run(handler._add_context(db, messages, "none")) 

350 >>> result == messages 

351 True 

352 >>> 

353 >>> # Test with 'all' context type (currently returns same messages) 

354 >>> result = asyncio.run(handler._add_context(db, messages, "all")) 

355 >>> result == messages 

356 True 

357 >>> len(result) 

358 2 

359 """ 

360 # TODO: Implement context gathering based on type - currently no-op # pylint: disable=fixme 

361 # For now return original messages 

362 return messages 

363 

364 def _validate_message(self, message: Dict[str, Any]) -> bool: 

365 """Validate message format. 

366 

367 Args: 

368 message: Message to validate 

369 

370 Returns: 

371 True if valid 

372 

373 Examples: 

374 >>> handler = SamplingHandler() 

375 >>> 

376 >>> # Valid text message 

377 >>> msg = {"role": "user", "content": {"type": "text", "text": "Hello"}} 

378 >>> handler._validate_message(msg) 

379 True 

380 >>> 

381 >>> # Valid assistant message 

382 >>> msg = {"role": "assistant", "content": {"type": "text", "text": "Hi!"}} 

383 >>> handler._validate_message(msg) 

384 True 

385 >>> 

386 >>> # Valid image message 

387 >>> msg = { 

388 ... "role": "user", 

389 ... "content": { 

390 ... "type": "image", 

391 ... "data": "base64data", 

392 ... "mime_type": "image/png" 

393 ... } 

394 ... } 

395 >>> handler._validate_message(msg) 

396 True 

397 >>> 

398 >>> # Missing role 

399 >>> msg = {"content": {"type": "text", "text": "Hello"}} 

400 >>> handler._validate_message(msg) 

401 False 

402 >>> 

403 >>> # Invalid role 

404 >>> msg = {"role": "system", "content": {"type": "text", "text": "Hello"}} 

405 >>> handler._validate_message(msg) 

406 False 

407 >>> 

408 >>> # Missing content 

409 >>> msg = {"role": "user"} 

410 >>> handler._validate_message(msg) 

411 False 

412 >>> 

413 >>> # Invalid content type 

414 >>> msg = {"role": "user", "content": {"type": "audio"}} 

415 >>> handler._validate_message(msg) 

416 False 

417 >>> 

418 >>> # Text content not string 

419 >>> msg = {"role": "user", "content": {"type": "text", "text": 123}} 

420 >>> handler._validate_message(msg) 

421 False 

422 >>> 

423 >>> # Image missing data 

424 >>> msg = {"role": "user", "content": {"type": "image", "mime_type": "image/png"}} 

425 >>> handler._validate_message(msg) 

426 False 

427 >>> 

428 >>> # Invalid structure 

429 >>> handler._validate_message("not a dict") 

430 False 

431 """ 

432 try: 

433 # Must have role and content 

434 if "role" not in message or "content" not in message or message["role"] not in ("user", "assistant"): 

435 return False 

436 

437 # Content must be valid 

438 content = message["content"] 

439 if content.get("type") == "text": 

440 if not isinstance(content.get("text"), str): 

441 return False 

442 elif content.get("type") == "image": 

443 if not (content.get("data") and content.get("mime_type")): 

444 return False 

445 else: 

446 return False 

447 

448 return True 

449 

450 except Exception: 

451 return False 

452 

453 def _mock_sample( 

454 self, 

455 messages: List[Dict[str, Any]], 

456 ) -> str: 

457 """Mock sampling response for testing. 

458 

459 Args: 

460 messages: Input messages 

461 

462 Returns: 

463 Sampled response text 

464 

465 Examples: 

466 >>> handler = SamplingHandler() 

467 >>> 

468 >>> # Single user message 

469 >>> messages = [{"role": "user", "content": {"type": "text", "text": "Hello world"}}] 

470 >>> handler._mock_sample(messages) 

471 'You said: Hello world\\nHere is my response...' 

472 >>> 

473 >>> # Conversation with multiple messages 

474 >>> messages = [ 

475 ... {"role": "user", "content": {"type": "text", "text": "Hi"}}, 

476 ... {"role": "assistant", "content": {"type": "text", "text": "Hello!"}}, 

477 ... {"role": "user", "content": {"type": "text", "text": "How are you?"}} 

478 ... ] 

479 >>> handler._mock_sample(messages) 

480 'You said: How are you?\\nHere is my response...' 

481 >>> 

482 >>> # Image message 

483 >>> messages = [{ 

484 ... "role": "user", 

485 ... "content": {"type": "image", "data": "base64", "mime_type": "image/png"} 

486 ... }] 

487 >>> handler._mock_sample(messages) 

488 'You said: I see the image you shared.\\nHere is my response...' 

489 >>> 

490 >>> # No user messages 

491 >>> messages = [{"role": "assistant", "content": {"type": "text", "text": "Hi"}}] 

492 >>> handler._mock_sample(messages) 

493 "I'm not sure what to respond to." 

494 >>> 

495 >>> # Empty messages 

496 >>> handler._mock_sample([]) 

497 "I'm not sure what to respond to." 

498 """ 

499 # Extract last user message 

500 last_msg = None 

501 for msg in reversed(messages): 

502 if msg["role"] == "user": 

503 last_msg = msg 

504 break 

505 

506 if not last_msg: 

507 return "I'm not sure what to respond to." 

508 

509 # Get user text 

510 user_text = "" 

511 content = last_msg["content"] 

512 if content["type"] == "text": 

513 user_text = content["text"] 

514 elif content["type"] == "image": 514 ↛ 518line 514 didn't jump to line 518 because the condition on line 514 was always true

515 user_text = "I see the image you shared." 

516 

517 # Generate simple response 

518 return f"You said: {user_text}\nHere is my response..."