Coverage for mcpgateway / services / mcp_client_chat_service.py: 97%
848 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/mcp_client_chat_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7MCP Client Service Module.
9This module provides a comprehensive client implementation for interacting with
10MCP servers, managing LLM providers, and orchestrating conversational AI agents.
11It supports multiple transport protocols and LLM providers.
13The module consists of several key components:
14- Configuration classes for MCP servers and LLM providers
15- LLM provider factory and implementations
16- MCP client for tool management
17- Chat history manager for Redis and in-memory storage
18- Chat service for conversational interactions
19"""
21# Standard
22from datetime import datetime, timezone
23import time
24from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
25from uuid import uuid4
27# Third-Party
28import orjson
30try:
31 # Third-Party
32 from langchain_core.language_models import BaseChatModel
33 from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
34 from langchain_core.tools import BaseTool
35 from langchain_mcp_adapters.client import MultiServerMCPClient
36 from langchain_ollama import ChatOllama, OllamaLLM
37 from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI, OpenAI
38 from langgraph.prebuilt import create_react_agent
40 _LLMCHAT_AVAILABLE = True
41except ImportError:
42 # Optional dependencies for LLM chat feature not installed
43 # These are only needed if LLMCHAT_ENABLED=true
44 _LLMCHAT_AVAILABLE = False
45 BaseChatModel = None # type: ignore
46 AIMessage = None # type: ignore
47 BaseMessage = None # type: ignore
48 HumanMessage = None # type: ignore
49 BaseTool = None # type: ignore
50 MultiServerMCPClient = None # type: ignore
51 ChatOllama = None # type: ignore
52 OllamaLLM = None
53 AzureChatOpenAI = None # type: ignore
54 AzureOpenAI = None
55 ChatOpenAI = None # type: ignore
56 OpenAI = None
57 create_react_agent = None # type: ignore
59# Try to import Anthropic and Bedrock providers (they may not be installed)
60try:
61 # Third-Party
62 from langchain_anthropic import AnthropicLLM, ChatAnthropic
64 _ANTHROPIC_AVAILABLE = True
65except ImportError:
66 _ANTHROPIC_AVAILABLE = False
67 ChatAnthropic = None # type: ignore
68 AnthropicLLM = None
70try:
71 # Third-Party
72 from langchain_aws import BedrockLLM, ChatBedrock
74 _BEDROCK_AVAILABLE = True
75except ImportError:
76 _BEDROCK_AVAILABLE = False
77 ChatBedrock = None # type: ignore
78 BedrockLLM = None
80try:
81 # Third-Party
82 from langchain_ibm import ChatWatsonx, WatsonxLLM
84 _WATSONX_AVAILABLE = True
85except ImportError:
86 _WATSONX_AVAILABLE = False
87 WatsonxLLM = None # type: ignore
88 ChatWatsonx = None
90# Third-Party
91from pydantic import BaseModel, Field, field_validator, model_validator
93# First-Party
94from mcpgateway.config import settings
95from mcpgateway.services.cancellation_service import cancellation_service
96from mcpgateway.services.logging_service import LoggingService
98logging_service = LoggingService()
99logger = logging_service.get_logger(__name__)
102class MCPServerConfig(BaseModel):
103 """
104 Configuration for MCP server connection.
106 This class defines the configuration parameters required to connect to an
107 MCP (Model Context Protocol) server using various transport mechanisms.
109 Attributes:
110 url: MCP server URL for streamable_http/sse transports.
111 command: Command to run for stdio transport.
112 args: Command-line arguments for stdio command.
113 transport: Transport type (streamable_http, sse, or stdio).
114 auth_token: Authentication token for HTTP-based transports.
115 headers: Additional HTTP headers for request customization.
117 Examples:
118 >>> # HTTP-based transport
119 >>> config = MCPServerConfig(
120 ... url="https://mcp-server.example.com/mcp",
121 ... transport="streamable_http",
122 ... auth_token="secret-token"
123 ... )
124 >>> config.transport
125 'streamable_http'
127 >>> # Stdio transport
128 >>> config = MCPServerConfig(
129 ... command="python",
130 ... args=["server.py"],
131 ... transport="stdio"
132 ... )
133 >>> config.command
134 'python'
136 Note:
137 The auth_token is automatically added to headers as a Bearer token
138 for HTTP-based transports.
139 """
141 url: Optional[str] = Field(None, description="MCP server URL for streamable_http/sse transports")
142 command: Optional[str] = Field(None, description="Command to run for stdio transport")
143 args: Optional[list[str]] = Field(None, description="Arguments for stdio command")
144 transport: Literal["streamable_http", "sse", "stdio"] = Field(default="streamable_http", description="Transport type for MCP connection")
145 auth_token: Optional[str] = Field(None, description="Authentication token for the server")
146 headers: Optional[Dict[str, str]] = Field(default=None, description="Additional headers for HTTP-based transports")
148 @model_validator(mode="before")
149 @classmethod
150 def add_auth_to_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
151 """
152 Automatically add authentication token to headers if provided.
154 This validator ensures that if an auth_token is provided for HTTP-based
155 transports, it's automatically added to the headers as a Bearer token.
157 Args:
158 values: Dictionary of field values before validation.
160 Returns:
161 Dict[str, Any]: Updated values with auth token in headers.
163 Examples:
164 >>> values = {
165 ... "url": "https://api.example.com",
166 ... "transport": "streamable_http",
167 ... "auth_token": "token123"
168 ... }
169 >>> result = MCPServerConfig.add_auth_to_headers(values)
170 >>> result['headers']['Authorization']
171 'Bearer token123'
172 """
173 auth_token = values.get("auth_token")
174 transport = values.get("transport")
175 headers = values.get("headers") or {}
177 if auth_token and transport in ["streamable_http", "sse"]:
178 if "Authorization" not in headers:
179 headers["Authorization"] = f"Bearer {auth_token}"
180 values["headers"] = headers
182 return values
184 @field_validator("url")
185 @classmethod
186 def validate_url_for_transport(cls, v: Optional[str], info) -> Optional[str]:
187 """
188 Validate that URL is provided for HTTP-based transports.
190 Args:
191 v: The URL value to validate.
192 info: Validation context containing other field values.
194 Returns:
195 Optional[str]: The validated URL.
197 Raises:
198 ValueError: If URL is missing for streamable_http or sse transport.
200 Examples:
201 >>> # Valid case
202 >>> MCPServerConfig(
203 ... url="https://example.com",
204 ... transport="streamable_http"
205 ... ).url
206 'https://example.com'
207 """
208 transport = info.data.get("transport")
209 if transport in ["streamable_http", "sse"] and not v:
210 raise ValueError(f"URL is required for {transport} transport")
211 return v
213 @field_validator("command")
214 @classmethod
215 def validate_command_for_stdio(cls, v: Optional[str], info) -> Optional[str]:
216 """
217 Validate that command is provided for stdio transport.
219 Args:
220 v: The command value to validate.
221 info: Validation context containing other field values.
223 Returns:
224 Optional[str]: The validated command.
226 Raises:
227 ValueError: If command is missing for stdio transport.
229 Examples:
230 >>> config = MCPServerConfig(
231 ... command="python",
232 ... args=["server.py"],
233 ... transport="stdio"
234 ... )
235 >>> config.command
236 'python'
237 """
238 transport = info.data.get("transport")
239 if transport == "stdio" and not v:
240 raise ValueError("Command is required for stdio transport")
241 return v
243 model_config = {
244 "json_schema_extra": {
245 "examples": [
246 {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token-here"}, # nosec B105 - example placeholder
247 {"command": "python", "args": ["server.py"], "transport": "stdio"},
248 ]
249 }
250 }
253class AzureOpenAIConfig(BaseModel):
254 """
255 Configuration for Azure OpenAI provider.
257 Defines all necessary parameters to connect to and use Azure OpenAI services,
258 including API credentials, endpoints, model settings, and request parameters.
260 Attributes:
261 api_key: Azure OpenAI API authentication key.
262 azure_endpoint: Azure OpenAI service endpoint URL.
263 api_version: API version to use for requests.
264 azure_deployment: Name of the deployed model.
265 model: Model identifier for logging and tracing.
266 temperature: Sampling temperature for response generation (0.0-2.0).
267 max_tokens: Maximum number of tokens to generate.
268 timeout: Request timeout duration in seconds.
269 max_retries: Maximum number of retry attempts for failed requests.
271 Examples:
272 >>> config = AzureOpenAIConfig(
273 ... api_key="your-api-key",
274 ... azure_endpoint="https://your-resource.openai.azure.com/",
275 ... azure_deployment="gpt-4",
276 ... temperature=0.7
277 ... )
278 >>> config.model
279 'gpt-4'
280 >>> config.temperature
281 0.7
282 """
284 api_key: str = Field(..., description="Azure OpenAI API key")
285 azure_endpoint: str = Field(..., description="Azure OpenAI endpoint URL")
286 api_version: str = Field(default="2024-05-01-preview", description="Azure OpenAI API version")
287 azure_deployment: str = Field(..., description="Azure OpenAI deployment name")
288 model: str = Field(default="gpt-4", description="Model name for tracing")
289 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
290 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
291 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
292 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
294 model_config = {
295 "json_schema_extra": {
296 "example": {
297 "api_key": "your-api-key",
298 "azure_endpoint": "https://your-resource.openai.azure.com/",
299 "api_version": "2024-05-01-preview",
300 "azure_deployment": "gpt-4",
301 "model": "gpt-4",
302 "temperature": 0.7,
303 }
304 }
305 }
308class OllamaConfig(BaseModel):
309 """
310 Configuration for Ollama provider.
312 Defines parameters for connecting to a local or remote Ollama instance
313 for running open-source language models.
315 Attributes:
316 base_url: Ollama server base URL.
317 model: Name of the Ollama model to use.
318 temperature: Sampling temperature for response generation (0.0-2.0).
319 timeout: Request timeout duration in seconds.
320 num_ctx: Context window size for the model.
322 Examples:
323 >>> config = OllamaConfig(
324 ... base_url="http://localhost:11434",
325 ... model="llama2",
326 ... temperature=0.5
327 ... )
328 >>> config.model
329 'llama2'
330 >>> config.base_url
331 'http://localhost:11434'
332 """
334 base_url: str = Field(default="http://localhost:11434", description="Ollama base URL")
335 model: str = Field(default="llama2", description="Model name to use")
336 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
337 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
338 num_ctx: Optional[int] = Field(None, gt=0, description="Context window size")
340 model_config = {"json_schema_extra": {"example": {"base_url": "http://localhost:11434", "model": "llama2", "temperature": 0.7}}}
343class OpenAIConfig(BaseModel):
344 """
345 Configuration for OpenAI provider (non-Azure).
347 Defines parameters for connecting to OpenAI API (or OpenAI-compatible endpoints).
349 Attributes:
350 api_key: OpenAI API authentication key.
351 base_url: Optional base URL for OpenAI-compatible endpoints.
352 model: Model identifier (e.g., gpt-4, gpt-3.5-turbo).
353 temperature: Sampling temperature for response generation (0.0-2.0).
354 max_tokens: Maximum number of tokens to generate.
355 timeout: Request timeout duration in seconds.
356 max_retries: Maximum number of retry attempts for failed requests.
358 Examples:
359 >>> config = OpenAIConfig(
360 ... api_key="sk-...",
361 ... model="gpt-4",
362 ... temperature=0.7
363 ... )
364 >>> config.model
365 'gpt-4'
366 """
368 api_key: str = Field(..., description="OpenAI API key")
369 base_url: Optional[str] = Field(None, description="Base URL for OpenAI-compatible endpoints")
370 model: str = Field(default="gpt-4o-mini", description="Model name (e.g., gpt-4, gpt-3.5-turbo)")
371 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
372 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
373 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
374 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
375 default_headers: Optional[dict] = Field(None, description="optional default headers required by the provider")
377 model_config = {
378 "json_schema_extra": {
379 "example": {
380 "api_key": "sk-...",
381 "model": "gpt-4o-mini",
382 "temperature": 0.7,
383 }
384 }
385 }
388class AnthropicConfig(BaseModel):
389 """
390 Configuration for Anthropic Claude provider.
392 Defines parameters for connecting to Anthropic's Claude API.
394 Attributes:
395 api_key: Anthropic API authentication key.
396 model: Claude model identifier (e.g., claude-3-5-sonnet-20241022, claude-3-opus).
397 temperature: Sampling temperature for response generation (0.0-1.0).
398 max_tokens: Maximum number of tokens to generate.
399 timeout: Request timeout duration in seconds.
400 max_retries: Maximum number of retry attempts for failed requests.
402 Examples:
403 >>> config = AnthropicConfig(
404 ... api_key="sk-ant-...",
405 ... model="claude-3-5-sonnet-20241022",
406 ... temperature=0.7
407 ... )
408 >>> config.model
409 'claude-3-5-sonnet-20241022'
410 """
412 api_key: str = Field(..., description="Anthropic API key")
413 model: str = Field(default="claude-3-5-sonnet-20241022", description="Claude model name")
414 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature")
415 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate")
416 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
417 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
419 model_config = {
420 "json_schema_extra": {
421 "example": {
422 "api_key": "sk-ant-...",
423 "model": "claude-3-5-sonnet-20241022",
424 "temperature": 0.7,
425 "max_tokens": 4096,
426 }
427 }
428 }
431class AWSBedrockConfig(BaseModel):
432 """
433 Configuration for AWS Bedrock provider.
435 Defines parameters for connecting to AWS Bedrock LLM services.
437 Attributes:
438 model_id: Bedrock model identifier (e.g., anthropic.claude-v2, amazon.titan-text-express-v1).
439 region_name: AWS region name (e.g., us-east-1, us-west-2).
440 aws_access_key_id: Optional AWS access key ID (uses default credential chain if not provided).
441 aws_secret_access_key: Optional AWS secret access key.
442 aws_session_token: Optional AWS session token for temporary credentials.
443 temperature: Sampling temperature for response generation (0.0-1.0).
444 max_tokens: Maximum number of tokens to generate.
446 Examples:
447 >>> config = AWSBedrockConfig(
448 ... model_id="anthropic.claude-v2",
449 ... region_name="us-east-1",
450 ... temperature=0.7
451 ... )
452 >>> config.model_id
453 'anthropic.claude-v2'
454 """
456 model_id: str = Field(..., description="Bedrock model ID")
457 region_name: str = Field(default="us-east-1", description="AWS region name")
458 aws_access_key_id: Optional[str] = Field(None, description="AWS access key ID")
459 aws_secret_access_key: Optional[str] = Field(None, description="AWS secret access key")
460 aws_session_token: Optional[str] = Field(None, description="AWS session token")
461 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature")
462 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate")
464 model_config = {
465 "json_schema_extra": {
466 "example": {
467 "model_id": "anthropic.claude-v2",
468 "region_name": "us-east-1",
469 "temperature": 0.7,
470 "max_tokens": 4096,
471 }
472 }
473 }
476class WatsonxConfig(BaseModel):
477 """
478 Configuration for IBM watsonx.ai provider.
480 Defines parameters for connecting to IBM watsonx.ai services.
482 Attributes:
483 api_key: IBM Cloud API key for authentication.
484 url: IBM watsonx.ai service endpoint URL.
485 project_id: IBM watsonx.ai project ID for context.
486 model_id: Model identifier (e.g., ibm/granite-13b-chat-v2, meta-llama/llama-3-70b-instruct).
487 temperature: Sampling temperature for response generation (0.0-2.0).
488 max_new_tokens: Maximum number of tokens to generate.
489 min_new_tokens: Minimum number of tokens to generate.
490 decoding_method: Decoding method ('sample', 'greedy').
491 top_k: Top-K sampling parameter.
492 top_p: Top-P (nucleus) sampling parameter.
493 timeout: Request timeout duration in seconds.
495 Examples:
496 >>> config = WatsonxConfig(
497 ... api_key="your-api-key",
498 ... url="https://us-south.ml.cloud.ibm.com",
499 ... project_id="your-project-id",
500 ... model_id="ibm/granite-13b-chat-v2"
501 ... )
502 >>> config.model_id
503 'ibm/granite-13b-chat-v2'
504 """
506 api_key: str = Field(..., description="IBM Cloud API key")
507 url: str = Field(default="https://us-south.ml.cloud.ibm.com", description="watsonx.ai endpoint URL")
508 project_id: str = Field(..., description="watsonx.ai project ID")
509 model_id: str = Field(default="ibm/granite-13b-chat-v2", description="Model identifier")
510 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
511 max_new_tokens: Optional[int] = Field(default=1024, gt=0, description="Maximum tokens to generate")
512 min_new_tokens: Optional[int] = Field(default=1, gt=0, description="Minimum tokens to generate")
513 decoding_method: str = Field(default="sample", description="Decoding method (sample or greedy)")
514 top_k: Optional[int] = Field(default=50, gt=0, description="Top-K sampling")
515 top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-P sampling")
516 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
518 model_config = {
519 "json_schema_extra": {
520 "example": {
521 "api_key": "your-api-key",
522 "url": "https://us-south.ml.cloud.ibm.com",
523 "project_id": "your-project-id",
524 "model_id": "ibm/granite-13b-chat-v2",
525 "temperature": 0.7,
526 "max_new_tokens": 1024,
527 }
528 }
529 }
532class GatewayConfig(BaseModel):
533 """
534 Configuration for MCP Gateway internal LLM provider.
536 Allows LLM Chat to use models configured in the gateway's LLM Settings.
537 The gateway routes requests to the appropriate configured provider.
539 Attributes:
540 model: Model ID (gateway model ID or provider model ID).
541 base_url: Gateway internal API URL (defaults to self).
542 temperature: Sampling temperature for response generation.
543 max_tokens: Maximum tokens to generate.
544 timeout: Request timeout in seconds.
546 Examples:
547 >>> config = GatewayConfig(model="gpt-4o")
548 >>> config.model
549 'gpt-4o'
550 """
552 model: str = Field(..., description="Gateway model ID to use")
553 base_url: Optional[str] = Field(None, description="Gateway internal API URL (optional, defaults to self)")
554 temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
555 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
556 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
558 model_config = {
559 "json_schema_extra": {
560 "example": {
561 "model": "gpt-4o",
562 "temperature": 0.7,
563 "max_tokens": 4096,
564 }
565 }
566 }
569class LLMConfig(BaseModel):
570 """
571 Configuration for LLM provider.
573 Unified configuration class that supports multiple LLM providers through
574 a discriminated union pattern.
576 Attributes:
577 provider: Type of LLM provider (azure_openai, openai, anthropic, aws_bedrock, or ollama).
578 config: Provider-specific configuration object.
580 Examples:
581 >>> # Azure OpenAI configuration
582 >>> config = LLMConfig(
583 ... provider="azure_openai",
584 ... config=AzureOpenAIConfig(
585 ... api_key="key",
586 ... azure_endpoint="https://example.com/",
587 ... azure_deployment="gpt-4"
588 ... )
589 ... )
590 >>> config.provider
591 'azure_openai'
593 >>> # OpenAI configuration
594 >>> config = LLMConfig(
595 ... provider="openai",
596 ... config=OpenAIConfig(
597 ... api_key="sk-...",
598 ... model="gpt-4"
599 ... )
600 ... )
601 >>> config.provider
602 'openai'
604 >>> # Ollama configuration
605 >>> config = LLMConfig(
606 ... provider="ollama",
607 ... config=OllamaConfig(model="llama2")
608 ... )
609 >>> config.provider
610 'ollama'
612 >>> # Watsonx configuration
613 >>> config = LLMConfig(
614 ... provider="watsonx",
615 ... config=WatsonxConfig(
616 ... url="https://us-south.ml.cloud.ibm.com",
617 ... model_id="ibm/granite-13b-instruct-v2",
618 ... project_id="YOUR_PROJECT_ID",
619 ... api_key="YOUR_API")
620 ... )
621 >>> config.provider
622 'watsonx'
623 """
625 provider: Literal["azure_openai", "openai", "anthropic", "aws_bedrock", "ollama", "watsonx", "gateway"] = Field(..., description="LLM provider type")
626 config: Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig] = Field(..., description="Provider-specific configuration")
628 @field_validator("config", mode="before")
629 @classmethod
630 def validate_config_type(cls, v: Any, info) -> Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig]:
631 """
632 Validate and convert config dictionary to appropriate provider type.
634 Args:
635 v: Configuration value (dict or config object).
636 info: Validation context containing provider information.
638 Returns:
639 Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig]: Validated configuration object.
641 Examples:
642 >>> # Automatically converts dict to appropriate config type
643 >>> config_dict = {
644 ... "api_key": "key",
645 ... "azure_endpoint": "https://example.com/",
646 ... "azure_deployment": "gpt-4"
647 ... }
648 >>> # Used internally by Pydantic during validation
649 """
650 provider = info.data.get("provider")
652 if isinstance(v, dict):
653 if provider == "azure_openai":
654 return AzureOpenAIConfig(**v)
655 if provider == "openai":
656 return OpenAIConfig(**v)
657 if provider == "anthropic":
658 return AnthropicConfig(**v)
659 if provider == "aws_bedrock":
660 return AWSBedrockConfig(**v)
661 if provider == "ollama":
662 return OllamaConfig(**v)
663 if provider == "watsonx":
664 return WatsonxConfig(**v)
665 if provider == "gateway": 665 ↛ 668line 665 didn't jump to line 668 because the condition on line 665 was always true
666 return GatewayConfig(**v)
668 return v
671class MCPClientConfig(BaseModel):
672 """
673 Main configuration for MCP client service.
675 Aggregates all configuration parameters required for the complete MCP client
676 service, including server connection, LLM provider, and operational settings.
678 Attributes:
679 mcp_server: MCP server connection configuration.
680 llm: LLM provider configuration.
681 chat_history_max_messages: Maximum messages to retain in chat history.
682 enable_streaming: Whether to enable streaming responses.
684 Examples:
685 >>> config = MCPClientConfig(
686 ... mcp_server=MCPServerConfig(
687 ... url="https://mcp-server.example.com/mcp",
688 ... transport="streamable_http"
689 ... ),
690 ... llm=LLMConfig(
691 ... provider="ollama",
692 ... config=OllamaConfig(model="llama2")
693 ... ),
694 ... chat_history_max_messages=100,
695 ... enable_streaming=True
696 ... )
697 >>> config.chat_history_max_messages
698 100
699 >>> config.enable_streaming
700 True
701 """
703 mcp_server: MCPServerConfig = Field(..., description="MCP server configuration")
704 llm: LLMConfig = Field(..., description="LLM provider configuration")
705 chat_history_max_messages: int = settings.llmchat_chat_history_max_messages
706 enable_streaming: bool = Field(default=True, description="Enable streaming responses")
708 model_config = {
709 "json_schema_extra": {
710 "example": {
711 "mcp_server": {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token"}, # nosec B105 - example placeholder
712 "llm": {
713 "provider": "azure_openai",
714 "config": {"api_key": "your-key", "azure_endpoint": "https://your-resource.openai.azure.com/", "azure_deployment": "gpt-4", "api_version": "2024-05-01-preview"},
715 },
716 }
717 }
718 }
721# ==================== LLM PROVIDER IMPLEMENTATIONS ====================
724class AzureOpenAIProvider:
725 """
726 Azure OpenAI provider implementation.
728 Manages connection and interaction with Azure OpenAI services.
730 Attributes:
731 config: Azure OpenAI configuration object.
733 Examples:
734 >>> config = AzureOpenAIConfig(
735 ... api_key="key",
736 ... azure_endpoint="https://example.openai.azure.com/",
737 ... azure_deployment="gpt-4"
738 ... )
739 >>> provider = AzureOpenAIProvider(config)
740 >>> provider.get_model_name()
741 'gpt-4'
743 Note:
744 The LLM instance is lazily initialized on first access for
745 improved startup performance.
746 """
748 def __init__(self, config: AzureOpenAIConfig):
749 """
750 Initialize Azure OpenAI provider.
752 Args:
753 config: Azure OpenAI configuration with API credentials and settings.
755 Examples:
756 >>> config = AzureOpenAIConfig(
757 ... api_key="key",
758 ... azure_endpoint="https://example.openai.azure.com/",
759 ... azure_deployment="gpt-4"
760 ... )
761 >>> provider = AzureOpenAIProvider(config)
762 """
763 self.config = config
764 self._llm = None
765 logger.info(f"Initializing Azure OpenAI provider with deployment: {config.azure_deployment}")
767 def get_llm(self, model_type: str = "chat") -> Union[AzureChatOpenAI, AzureOpenAI]:
768 """
769 Get Azure OpenAI LLM instance with lazy initialization.
771 Creates and caches the Azure OpenAI chat model instance on first call.
772 Subsequent calls return the cached instance.
774 Args:
775 model_type: LLM inference model type such as 'chat' model , text 'completion' model
777 Returns:
778 AzureChatOpenAI: Configured Azure OpenAI chat model.
780 Raises:
781 Exception: If LLM initialization fails (e.g., invalid credentials).
783 Examples:
784 >>> config = AzureOpenAIConfig(
785 ... api_key="key",
786 ... azure_endpoint="https://example.openai.azure.com/",
787 ... azure_deployment="gpt-4"
788 ... )
789 >>> provider = AzureOpenAIProvider(config)
790 >>> # llm = provider.get_llm() # Returns AzureChatOpenAI instance
791 """
792 if self._llm is None:
793 try:
794 if model_type == "chat":
795 self._llm = AzureChatOpenAI(
796 api_key=self.config.api_key,
797 azure_endpoint=self.config.azure_endpoint,
798 api_version=self.config.api_version,
799 azure_deployment=self.config.azure_deployment,
800 model=self.config.model,
801 temperature=self.config.temperature,
802 max_tokens=self.config.max_tokens,
803 timeout=self.config.timeout,
804 max_retries=self.config.max_retries,
805 )
806 elif model_type == "completion": 806 ↛ 818line 806 didn't jump to line 818 because the condition on line 806 was always true
807 self._llm = AzureOpenAI(
808 api_key=self.config.api_key,
809 azure_endpoint=self.config.azure_endpoint,
810 api_version=self.config.api_version,
811 azure_deployment=self.config.azure_deployment,
812 model=self.config.model,
813 temperature=self.config.temperature,
814 max_tokens=self.config.max_tokens,
815 timeout=self.config.timeout,
816 max_retries=self.config.max_retries,
817 )
818 logger.info("Azure OpenAI LLM instance created successfully")
819 except Exception as e:
820 logger.error(f"Failed to create Azure OpenAI LLM: {e}")
821 raise
823 return self._llm
825 def get_model_name(self) -> str:
826 """
827 Get the Azure OpenAI model name.
829 Returns:
830 str: The model name configured for this provider.
832 Examples:
833 >>> config = AzureOpenAIConfig(
834 ... api_key="key",
835 ... azure_endpoint="https://example.openai.azure.com/",
836 ... azure_deployment="gpt-4",
837 ... model="gpt-4"
838 ... )
839 >>> provider = AzureOpenAIProvider(config)
840 >>> provider.get_model_name()
841 'gpt-4'
842 """
843 return self.config.model
846class OllamaProvider:
847 """
848 Ollama provider implementation.
850 Manages connection and interaction with Ollama instances for running
851 open-source language models locally or remotely.
853 Attributes:
854 config: Ollama configuration object.
856 Examples:
857 >>> config = OllamaConfig(
858 ... base_url="http://localhost:11434",
859 ... model="llama2"
860 ... )
861 >>> provider = OllamaProvider(config)
862 >>> provider.get_model_name()
863 'llama2'
865 Note:
866 Requires Ollama to be running and accessible at the configured base_url.
867 """
869 def __init__(self, config: OllamaConfig):
870 """
871 Initialize Ollama provider.
873 Args:
874 config: Ollama configuration with server URL and model settings.
876 Examples:
877 >>> config = OllamaConfig(model="llama2")
878 >>> provider = OllamaProvider(config)
879 """
880 self.config = config
881 self._llm = None
882 logger.info(f"Initializing Ollama provider with model: {config.model}")
884 def get_llm(self, model_type: str = "chat") -> Union[ChatOllama, OllamaLLM]:
885 """
886 Get Ollama LLM instance with lazy initialization.
888 Creates and caches the Ollama chat model instance on first call.
889 Subsequent calls return the cached instance.
891 Args:
892 model_type: LLM inference model type such as 'chat' model , text 'completion' model
894 Returns:
895 ChatOllama: Configured Ollama chat model.
897 Raises:
898 Exception: If LLM initialization fails (e.g., Ollama not running).
900 Examples:
901 >>> config = OllamaConfig(model="llama2")
902 >>> provider = OllamaProvider(config)
903 >>> # llm = provider.get_llm() # Returns ChatOllama instance
904 """
905 if self._llm is None: 905 ↛ 921line 905 didn't jump to line 921 because the condition on line 905 was always true
906 try:
907 # Build model kwargs
908 model_kwargs = {}
909 if self.config.num_ctx is not None:
910 model_kwargs["num_ctx"] = self.config.num_ctx
912 if model_type == "chat":
913 self._llm = ChatOllama(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs)
914 elif model_type == "completion": 914 ↛ 916line 914 didn't jump to line 916 because the condition on line 914 was always true
915 self._llm = OllamaLLM(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs)
916 logger.info("Ollama LLM instance created successfully")
917 except Exception as e:
918 logger.error(f"Failed to create Ollama LLM: {e}")
919 raise
921 return self._llm
923 def get_model_name(self) -> str:
924 """Get the model name.
926 Returns:
927 str: The model name
928 """
929 return self.config.model
932class OpenAIProvider:
933 """
934 OpenAI provider implementation (non-Azure).
936 Manages connection and interaction with OpenAI API or OpenAI-compatible endpoints.
938 Attributes:
939 config: OpenAI configuration object.
941 Examples:
942 >>> config = OpenAIConfig(
943 ... api_key="sk-...",
944 ... model="gpt-4"
945 ... )
946 >>> provider = OpenAIProvider(config)
947 >>> provider.get_model_name()
948 'gpt-4'
950 Note:
951 The LLM instance is lazily initialized on first access for
952 improved startup performance.
953 """
955 def __init__(self, config: OpenAIConfig):
956 """
957 Initialize OpenAI provider.
959 Args:
960 config: OpenAI configuration with API key and settings.
962 Examples:
963 >>> config = OpenAIConfig(
964 ... api_key="sk-...",
965 ... model="gpt-4"
966 ... )
967 >>> provider = OpenAIProvider(config)
968 """
969 self.config = config
970 self._llm = None
971 logger.info(f"Initializing OpenAI provider with model: {config.model}")
973 def get_llm(self, model_type="chat") -> Union[ChatOpenAI, OpenAI]:
974 """
975 Get OpenAI LLM instance with lazy initialization.
977 Creates and caches the OpenAI chat model instance on first call.
978 Subsequent calls return the cached instance.
980 Args:
981 model_type: LLM inference model type such as 'chat' model , text 'completion' model
983 Returns:
984 ChatOpenAI: Configured OpenAI chat model.
986 Raises:
987 Exception: If LLM initialization fails (e.g., invalid credentials).
989 Examples:
990 >>> config = OpenAIConfig(
991 ... api_key="sk-...",
992 ... model="gpt-4"
993 ... )
994 >>> provider = OpenAIProvider(config)
995 >>> # llm = provider.get_llm() # Returns ChatOpenAI instance
996 """
997 if self._llm is None: 997 ↛ 1025line 997 didn't jump to line 1025 because the condition on line 997 was always true
998 try:
999 kwargs = {
1000 "api_key": self.config.api_key,
1001 "model": self.config.model,
1002 "temperature": self.config.temperature,
1003 "max_tokens": self.config.max_tokens,
1004 "timeout": self.config.timeout,
1005 "max_retries": self.config.max_retries,
1006 }
1008 if self.config.base_url:
1009 kwargs["base_url"] = self.config.base_url
1011 # add default headers if present
1012 if self.config.default_headers is not None:
1013 kwargs["default_headers"] = self.config.default_headers
1015 if model_type == "chat":
1016 self._llm = ChatOpenAI(**kwargs)
1017 elif model_type == "completion": 1017 ↛ 1020line 1017 didn't jump to line 1020 because the condition on line 1017 was always true
1018 self._llm = OpenAI(**kwargs)
1020 logger.info("OpenAI LLM instance created successfully")
1021 except Exception as e:
1022 logger.error(f"Failed to create OpenAI LLM: {e}")
1023 raise
1025 return self._llm
1027 def get_model_name(self) -> str:
1028 """
1029 Get the OpenAI model name.
1031 Returns:
1032 str: The model name configured for this provider.
1034 Examples:
1035 >>> config = OpenAIConfig(
1036 ... api_key="sk-...",
1037 ... model="gpt-4"
1038 ... )
1039 >>> provider = OpenAIProvider(config)
1040 >>> provider.get_model_name()
1041 'gpt-4'
1042 """
1043 return self.config.model
1046class AnthropicProvider:
1047 """
1048 Anthropic Claude provider implementation.
1050 Manages connection and interaction with Anthropic's Claude API.
1052 Attributes:
1053 config: Anthropic configuration object.
1055 Examples:
1056 >>> config = AnthropicConfig( # doctest: +SKIP
1057 ... api_key="sk-ant-...",
1058 ... model="claude-3-5-sonnet-20241022"
1059 ... )
1060 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1061 >>> provider.get_model_name() # doctest: +SKIP
1062 'claude-3-5-sonnet-20241022'
1064 Note:
1065 Requires langchain-anthropic package to be installed.
1066 """
1068 def __init__(self, config: AnthropicConfig):
1069 """
1070 Initialize Anthropic provider.
1072 Args:
1073 config: Anthropic configuration with API key and settings.
1075 Raises:
1076 ImportError: If langchain-anthropic is not installed.
1078 Examples:
1079 >>> config = AnthropicConfig( # doctest: +SKIP
1080 ... api_key="sk-ant-...",
1081 ... model="claude-3-5-sonnet-20241022"
1082 ... )
1083 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1084 """
1085 if not _ANTHROPIC_AVAILABLE:
1086 raise ImportError("Anthropic provider requires langchain-anthropic package. Install it with: pip install langchain-anthropic")
1088 self.config = config
1089 self._llm = None
1090 logger.info(f"Initializing Anthropic provider with model: {config.model}")
1092 def get_llm(self, model_type: str = "chat") -> Union[ChatAnthropic, AnthropicLLM]:
1093 """
1094 Get Anthropic LLM instance with lazy initialization.
1096 Creates and caches the Anthropic chat model instance on first call.
1097 Subsequent calls return the cached instance.
1099 Args:
1100 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1102 Returns:
1103 ChatAnthropic: Configured Anthropic chat model.
1105 Raises:
1106 Exception: If LLM initialization fails (e.g., invalid credentials).
1108 Examples:
1109 >>> config = AnthropicConfig( # doctest: +SKIP
1110 ... api_key="sk-ant-...",
1111 ... model="claude-3-5-sonnet-20241022"
1112 ... )
1113 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1114 >>> # llm = provider.get_llm() # Returns ChatAnthropic instance
1115 """
1116 if self._llm is None: 1116 ↛ 1141line 1116 didn't jump to line 1141 because the condition on line 1116 was always true
1117 try:
1118 if model_type == "chat":
1119 self._llm = ChatAnthropic(
1120 api_key=self.config.api_key,
1121 model=self.config.model,
1122 temperature=self.config.temperature,
1123 max_tokens=self.config.max_tokens,
1124 timeout=self.config.timeout,
1125 max_retries=self.config.max_retries,
1126 )
1127 elif model_type == "completion": 1127 ↛ 1136line 1127 didn't jump to line 1136 because the condition on line 1127 was always true
1128 self._llm = AnthropicLLM(
1129 api_key=self.config.api_key,
1130 model=self.config.model,
1131 temperature=self.config.temperature,
1132 max_tokens=self.config.max_tokens,
1133 timeout=self.config.timeout,
1134 max_retries=self.config.max_retries,
1135 )
1136 logger.info("Anthropic LLM instance created successfully")
1137 except Exception as e:
1138 logger.error(f"Failed to create Anthropic LLM: {e}")
1139 raise
1141 return self._llm
1143 def get_model_name(self) -> str:
1144 """
1145 Get the Anthropic model name.
1147 Returns:
1148 str: The model name configured for this provider.
1150 Examples:
1151 >>> config = AnthropicConfig( # doctest: +SKIP
1152 ... api_key="sk-ant-...",
1153 ... model="claude-3-5-sonnet-20241022"
1154 ... )
1155 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1156 >>> provider.get_model_name() # doctest: +SKIP
1157 'claude-3-5-sonnet-20241022'
1158 """
1159 return self.config.model
1162class AWSBedrockProvider:
1163 """
1164 AWS Bedrock provider implementation.
1166 Manages connection and interaction with AWS Bedrock LLM services.
1168 Attributes:
1169 config: AWS Bedrock configuration object.
1171 Examples:
1172 >>> config = AWSBedrockConfig( # doctest: +SKIP
1173 ... model_id="anthropic.claude-v2",
1174 ... region_name="us-east-1"
1175 ... )
1176 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1177 >>> provider.get_model_name() # doctest: +SKIP
1178 'anthropic.claude-v2'
1180 Note:
1181 Requires langchain-aws package and boto3 to be installed.
1182 Uses AWS default credential chain if credentials not explicitly provided.
1183 """
1185 def __init__(self, config: AWSBedrockConfig):
1186 """
1187 Initialize AWS Bedrock provider.
1189 Args:
1190 config: AWS Bedrock configuration with model ID and settings.
1192 Raises:
1193 ImportError: If langchain-aws is not installed.
1195 Examples:
1196 >>> config = AWSBedrockConfig( # doctest: +SKIP
1197 ... model_id="anthropic.claude-v2",
1198 ... region_name="us-east-1"
1199 ... )
1200 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1201 """
1202 if not _BEDROCK_AVAILABLE:
1203 raise ImportError("AWS Bedrock provider requires langchain-aws package. Install it with: pip install langchain-aws boto3")
1205 self.config = config
1206 self._llm = None
1207 logger.info(f"Initializing AWS Bedrock provider with model: {config.model_id}")
1209 def get_llm(self, model_type: str = "chat") -> Union[ChatBedrock, BedrockLLM]:
1210 """
1211 Get AWS Bedrock LLM instance with lazy initialization.
1213 Creates and caches the Bedrock chat model instance on first call.
1214 Subsequent calls return the cached instance.
1216 Args:
1217 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1219 Returns:
1220 ChatBedrock: Configured AWS Bedrock chat model.
1222 Raises:
1223 Exception: If LLM initialization fails (e.g., invalid credentials, permissions).
1225 Examples:
1226 >>> config = AWSBedrockConfig( # doctest: +SKIP
1227 ... model_id="anthropic.claude-v2",
1228 ... region_name="us-east-1"
1229 ... )
1230 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1231 >>> # llm = provider.get_llm() # Returns ChatBedrock instance
1232 """
1233 if self._llm is None: 1233 ↛ 1269line 1233 didn't jump to line 1269 because the condition on line 1233 was always true
1234 try:
1235 # Build credentials dict if provided
1236 credentials_kwargs = {}
1237 if self.config.aws_access_key_id:
1238 credentials_kwargs["aws_access_key_id"] = self.config.aws_access_key_id
1239 if self.config.aws_secret_access_key:
1240 credentials_kwargs["aws_secret_access_key"] = self.config.aws_secret_access_key
1241 if self.config.aws_session_token:
1242 credentials_kwargs["aws_session_token"] = self.config.aws_session_token
1244 if model_type == "chat":
1245 self._llm = ChatBedrock(
1246 model_id=self.config.model_id,
1247 region_name=self.config.region_name,
1248 model_kwargs={
1249 "temperature": self.config.temperature,
1250 "max_tokens": self.config.max_tokens,
1251 },
1252 **credentials_kwargs,
1253 )
1254 elif model_type == "completion": 1254 ↛ 1264line 1254 didn't jump to line 1264 because the condition on line 1254 was always true
1255 self._llm = BedrockLLM(
1256 model_id=self.config.model_id,
1257 region_name=self.config.region_name,
1258 model_kwargs={
1259 "temperature": self.config.temperature,
1260 "max_tokens": self.config.max_tokens,
1261 },
1262 **credentials_kwargs,
1263 )
1264 logger.info("AWS Bedrock LLM instance created successfully")
1265 except Exception as e:
1266 logger.error(f"Failed to create AWS Bedrock LLM: {e}")
1267 raise
1269 return self._llm
1271 def get_model_name(self) -> str:
1272 """
1273 Get the AWS Bedrock model ID.
1275 Returns:
1276 str: The model ID configured for this provider.
1278 Examples:
1279 >>> config = AWSBedrockConfig( # doctest: +SKIP
1280 ... model_id="anthropic.claude-v2",
1281 ... region_name="us-east-1"
1282 ... )
1283 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1284 >>> provider.get_model_name() # doctest: +SKIP
1285 'anthropic.claude-v2'
1286 """
1287 return self.config.model_id
1290class WatsonxProvider:
1291 """
1292 IBM watsonx.ai provider implementation.
1294 Manages connection and interaction with IBM watsonx.ai services.
1296 Attributes:
1297 config: IBM watsonx.ai configuration object.
1299 Examples:
1300 >>> config = WatsonxConfig( # doctest: +SKIP
1301 ... api_key="key",
1302 ... url="https://us-south.ml.cloud.ibm.com",
1303 ... project_id="project-id",
1304 ... model_id="ibm/granite-13b-chat-v2"
1305 ... )
1306 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1307 >>> provider.get_model_name() # doctest: +SKIP
1308 'ibm/granite-13b-chat-v2'
1310 Note:
1311 Requires langchain-ibm package to be installed.
1312 """
1314 def __init__(self, config: WatsonxConfig):
1315 """
1316 Initialize IBM watsonx.ai provider.
1318 Args:
1319 config: IBM watsonx.ai configuration with credentials and settings.
1321 Raises:
1322 ImportError: If langchain-ibm is not installed.
1324 Examples:
1325 >>> config = WatsonxConfig( # doctest: +SKIP
1326 ... api_key="key",
1327 ... url="https://us-south.ml.cloud.ibm.com",
1328 ... project_id="project-id",
1329 ... model_id="ibm/granite-13b-chat-v2"
1330 ... )
1331 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1332 """
1333 if not _WATSONX_AVAILABLE:
1334 raise ImportError("IBM watsonx.ai provider requires langchain-ibm package. Install it with: pip install langchain-ibm")
1335 self.config = config
1336 self.llm = None
1337 logger.info(f"Initializing IBM watsonx.ai provider with model {config.model_id}")
1339 def get_llm(self, model_type="chat") -> Union[WatsonxLLM, ChatWatsonx]:
1340 """
1341 Get IBM watsonx.ai LLM instance with lazy initialization.
1343 Creates and caches the watsonx LLM instance on first call.
1344 Subsequent calls return the cached instance.
1346 Args:
1347 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1349 Returns:
1350 WatsonxLLM: Configured IBM watsonx.ai LLM model.
1352 Raises:
1353 Exception: If LLM initialization fails (e.g., invalid credentials).
1355 Examples:
1356 >>> config = WatsonxConfig( # doctest: +SKIP
1357 ... api_key="key",
1358 ... url="https://us-south.ml.cloud.ibm.com",
1359 ... project_id="project-id",
1360 ... model_id="ibm/granite-13b-chat-v2"
1361 ... )
1362 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1363 >>> #llm = provider.get_llm() # Returns WatsonxLLM instance
1364 """
1365 if self.llm is None: 1365 ↛ 1401line 1365 didn't jump to line 1401 because the condition on line 1365 was always true
1366 try:
1367 # Build parameters dict
1368 params = {
1369 "decoding_method": self.config.decoding_method,
1370 "temperature": self.config.temperature,
1371 "max_new_tokens": self.config.max_new_tokens,
1372 "min_new_tokens": self.config.min_new_tokens,
1373 }
1375 if self.config.top_k is not None: 1375 ↛ 1377line 1375 didn't jump to line 1377 because the condition on line 1375 was always true
1376 params["top_k"] = self.config.top_k
1377 if self.config.top_p is not None: 1377 ↛ 1379line 1377 didn't jump to line 1379 because the condition on line 1377 was always true
1378 params["top_p"] = self.config.top_p
1379 if model_type == "completion":
1380 # Initialize WatsonxLLM
1381 self.llm = WatsonxLLM(
1382 apikey=self.config.api_key,
1383 url=self.config.url,
1384 project_id=self.config.project_id,
1385 model_id=self.config.model_id,
1386 params=params,
1387 )
1388 elif model_type == "chat": 1388 ↛ 1397line 1388 didn't jump to line 1397 because the condition on line 1388 was always true
1389 # Initialize Chat WatsonxLLM
1390 self.llm = ChatWatsonx(
1391 apikey=self.config.api_key,
1392 url=self.config.url,
1393 project_id=self.config.project_id,
1394 model_id=self.config.model_id,
1395 params=params,
1396 )
1397 logger.info("IBM watsonx.ai LLM instance created successfully")
1398 except Exception as e:
1399 logger.error(f"Failed to create IBM watsonx.ai LLM: {e}")
1400 raise
1401 return self.llm
1403 def get_model_name(self) -> str:
1404 """
1405 Get the IBM watsonx.ai model ID.
1407 Returns:
1408 str: The model ID configured for this provider.
1410 Examples:
1411 >>> config = WatsonxConfig( # doctest: +SKIP
1412 ... api_key="key",
1413 ... url="https://us-south.ml.cloud.ibm.com",
1414 ... project_id="project-id",
1415 ... model_id="ibm/granite-13b-chat-v2"
1416 ... )
1417 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1418 >>> provider.get_model_name() # doctest: +SKIP
1419 'ibm/granite-13b-chat-v2'
1420 """
1421 return self.config.model_id
1424class GatewayProvider:
1425 """
1426 Gateway provider implementation for using models configured in LLM Settings.
1428 Routes LLM requests through the gateway's configured providers, allowing
1429 users to use models set up via the Admin UI's LLM Settings without needing
1430 to configure credentials in environment variables or API requests.
1432 Attributes:
1433 config: Gateway configuration with model ID.
1434 llm: Lazily initialized LLM instance.
1436 Examples:
1437 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1438 >>> provider = GatewayProvider(config) # doctest: +SKIP
1439 >>> provider.get_model_name() # doctest: +SKIP
1440 'gpt-4o'
1442 Note:
1443 Requires models to be configured via Admin UI -> Settings -> LLM Settings.
1444 """
1446 def __init__(self, config: GatewayConfig):
1447 """
1448 Initialize Gateway provider.
1450 Args:
1451 config: Gateway configuration with model ID and optional settings.
1453 Examples:
1454 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1455 >>> provider = GatewayProvider(config) # doctest: +SKIP
1456 """
1457 self.config = config
1458 self.llm = None
1459 self._model_name: Optional[str] = None
1460 self._underlying_provider = None
1461 logger.info(f"Initializing Gateway provider with model: {config.model}")
1463 def get_llm(self, model_type: str = "chat") -> Union[BaseChatModel, Any]:
1464 """
1465 Get LLM instance by looking up model from gateway's LLM Settings.
1467 Fetches the model configuration from the database, decrypts API keys,
1468 and creates the appropriate LangChain LLM instance based on provider type.
1470 Args:
1471 model_type: Type of model to return ('chat' or 'completion'). Defaults to 'chat'.
1473 Returns:
1474 Union[BaseChatModel, Any]: Configured LangChain chat or completion model instance.
1476 Raises:
1477 ValueError: If model not found or provider not enabled.
1478 ImportError: If required provider package not installed.
1480 Examples:
1481 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1482 >>> provider = GatewayProvider(config) # doctest: +SKIP
1483 >>> llm = provider.get_llm() # doctest: +SKIP
1485 Note:
1486 The LLM instance is lazily initialized and cached by model_type.
1487 """
1488 if self.llm is not None:
1489 return self.llm
1491 # Import here to avoid circular imports
1492 # First-Party
1493 from mcpgateway.db import LLMModel, LLMProvider, SessionLocal # pylint: disable=import-outside-toplevel
1494 from mcpgateway.utils.services_auth import decode_auth # pylint: disable=import-outside-toplevel
1496 model_id = self.config.model
1498 with SessionLocal() as db:
1499 # Try to find model by UUID first, then by model_id
1500 model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
1501 if not model:
1502 model = db.query(LLMModel).filter(LLMModel.model_id == model_id).first()
1504 if not model:
1505 raise ValueError(f"Model '{model_id}' not found in LLM Settings. Configure it via Admin UI -> Settings -> LLM Settings.")
1507 if not model.enabled:
1508 raise ValueError(f"Model '{model.model_id}' is disabled. Enable it in LLM Settings.")
1510 # Get the provider
1511 provider = db.query(LLMProvider).filter(LLMProvider.id == model.provider_id).first()
1512 if not provider:
1513 raise ValueError(f"Provider not found for model '{model.model_id}'")
1515 if not provider.enabled:
1516 raise ValueError(f"Provider '{provider.name}' is disabled. Enable it in LLM Settings.")
1518 # Get decrypted API key
1519 api_key = None
1520 if provider.api_key:
1521 auth_data = decode_auth(provider.api_key)
1522 if isinstance(auth_data, dict):
1523 api_key = auth_data.get("api_key")
1524 else:
1525 api_key = auth_data
1527 # Store model name for get_model_name()
1528 self._model_name = model.model_id
1530 # Get temperature - use config override or provider default
1531 temperature = self.config.temperature if self.config.temperature is not None else (provider.default_temperature or 0.7)
1532 max_tokens = self.config.max_tokens or model.max_output_tokens
1534 # Create appropriate LLM based on provider type
1535 provider_type = provider.provider_type.lower()
1536 config = provider.config or {}
1538 # Common kwargs
1539 kwargs: Dict[str, Any] = {
1540 "temperature": temperature,
1541 "timeout": self.config.timeout,
1542 }
1544 if provider_type == "openai":
1545 kwargs.update(
1546 {
1547 "api_key": api_key,
1548 "model": model.model_id,
1549 "max_tokens": max_tokens,
1550 }
1551 )
1552 if provider.api_base:
1553 kwargs["base_url"] = provider.api_base
1555 # Handle default headers
1556 if config.get("default_headers"):
1557 kwargs["default_headers"] = config["default_headers"]
1558 elif hasattr(self.config, "default_headers") and self.config.default_headers: # type: ignore 1558 ↛ 1559line 1558 didn't jump to line 1559 because the condition on line 1558 was never true
1559 kwargs["default_headers"] = self.config.default_headers
1561 if model_type == "chat":
1562 self.llm = ChatOpenAI(**kwargs)
1563 else:
1564 self.llm = OpenAI(**kwargs)
1566 elif provider_type == "azure_openai":
1567 if not provider.api_base:
1568 raise ValueError("Azure OpenAI requires base_url (azure_endpoint) to be configured")
1570 azure_deployment = config.get("azure_deployment", model.model_id)
1571 api_version = config.get("api_version", "2024-05-01-preview")
1572 max_retries = config.get("max_retries", 2)
1574 kwargs.update(
1575 {
1576 "api_key": api_key,
1577 "azure_endpoint": provider.api_base,
1578 "azure_deployment": azure_deployment,
1579 "api_version": api_version,
1580 "model": model.model_id,
1581 "max_tokens": int(max_tokens) if max_tokens is not None else None,
1582 "max_retries": max_retries,
1583 }
1584 )
1586 if model_type == "chat":
1587 self.llm = AzureChatOpenAI(**kwargs)
1588 else:
1589 self.llm = AzureOpenAI(**kwargs)
1591 elif provider_type == "anthropic":
1592 if not _ANTHROPIC_AVAILABLE:
1593 raise ImportError("Anthropic provider requires langchain-anthropic. Install with: pip install langchain-anthropic")
1595 # Anthropic uses 'model_name' instead of 'model'
1596 anthropic_kwargs = {
1597 "api_key": api_key,
1598 "model_name": model.model_id,
1599 "max_tokens": max_tokens or 4096,
1600 "temperature": temperature,
1601 "timeout": self.config.timeout,
1602 "default_request_timeout": self.config.timeout,
1603 }
1605 if model_type == "chat":
1606 self.llm = ChatAnthropic(**anthropic_kwargs)
1607 else:
1608 # Generic Anthropic completion model if needed, though mostly chat used now
1609 if AnthropicLLM:
1610 llm_kwargs = anthropic_kwargs.copy()
1611 llm_kwargs["model"] = llm_kwargs.pop("model_name")
1612 self.llm = AnthropicLLM(**llm_kwargs)
1613 else:
1614 raise ImportError("Anthropic completion model (AnthropicLLM) not available")
1616 elif provider_type == "bedrock":
1617 if not _BEDROCK_AVAILABLE:
1618 raise ImportError("AWS Bedrock provider requires langchain-aws. Install with: pip install langchain-aws boto3")
1620 region_name = config.get("region_name", "us-east-1")
1621 credentials_kwargs = {}
1622 if config.get("aws_access_key_id"):
1623 credentials_kwargs["aws_access_key_id"] = config["aws_access_key_id"]
1624 if config.get("aws_secret_access_key"):
1625 credentials_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"]
1626 if config.get("aws_session_token"):
1627 credentials_kwargs["aws_session_token"] = config["aws_session_token"]
1629 model_kwargs = {
1630 "temperature": temperature,
1631 "max_tokens": max_tokens or 4096,
1632 }
1634 if model_type == "chat":
1635 self.llm = ChatBedrock(
1636 model_id=model.model_id,
1637 region_name=region_name,
1638 model_kwargs=model_kwargs,
1639 **credentials_kwargs,
1640 )
1641 else:
1642 self.llm = BedrockLLM(
1643 model_id=model.model_id,
1644 region_name=region_name,
1645 model_kwargs=model_kwargs,
1646 **credentials_kwargs,
1647 )
1649 elif provider_type == "ollama":
1650 base_url = provider.api_base or "http://localhost:11434"
1651 num_ctx = config.get("num_ctx")
1653 # Explicitly construct kwargs to avoid generic unpacking issues with Pydantic models
1654 ollama_kwargs = {
1655 "base_url": base_url,
1656 "model": model.model_id,
1657 "temperature": temperature,
1658 "timeout": self.config.timeout,
1659 }
1660 if num_ctx: 1660 ↛ 1663line 1660 didn't jump to line 1663 because the condition on line 1660 was always true
1661 ollama_kwargs["num_ctx"] = num_ctx
1663 if model_type == "chat":
1664 self.llm = ChatOllama(**ollama_kwargs)
1665 else:
1666 self.llm = OllamaLLM(**ollama_kwargs)
1668 elif provider_type == "watsonx":
1669 if not _WATSONX_AVAILABLE:
1670 raise ImportError("IBM watsonx.ai provider requires langchain-ibm. Install with: pip install langchain-ibm")
1672 project_id = config.get("project_id")
1673 if not project_id:
1674 raise ValueError("IBM watsonx.ai requires project_id in config")
1676 url = provider.api_base or "https://us-south.ml.cloud.ibm.com"
1678 params = {
1679 "temperature": temperature,
1680 "max_new_tokens": max_tokens or 1024,
1681 "min_new_tokens": config.get("min_new_tokens", 1),
1682 "decoding_method": config.get("decoding_method", "sample"),
1683 "top_k": config.get("top_k", 50),
1684 "top_p": config.get("top_p", 1.0),
1685 }
1687 if model_type == "chat":
1688 self.llm = ChatWatsonx(
1689 apikey=api_key,
1690 url=url,
1691 project_id=project_id,
1692 model_id=model.model_id,
1693 params=params,
1694 )
1695 else:
1696 self.llm = WatsonxLLM(
1697 apikey=api_key,
1698 url=url,
1699 project_id=project_id,
1700 model_id=model.model_id,
1701 params=params,
1702 )
1704 elif provider_type == "openai_compatible":
1705 if not provider.api_base:
1706 raise ValueError("OpenAI-compatible provider requires base_url to be configured")
1708 kwargs.update(
1709 {
1710 "api_key": api_key or "no-key-required",
1711 "model": model.model_id,
1712 "base_url": provider.api_base,
1713 "max_tokens": max_tokens,
1714 }
1715 )
1717 if model_type == "chat":
1718 self.llm = ChatOpenAI(**kwargs)
1719 else:
1720 self.llm = OpenAI(**kwargs)
1722 else:
1723 raise ValueError(f"Unsupported LLM provider: {provider_type}")
1725 logger.info(f"Gateway provider created LLM instance for model: {model.model_id} via {provider_type}")
1726 return self.llm
1728 def get_model_name(self) -> str:
1729 """
1730 Get the model name.
1732 Returns:
1733 str: The model name/ID.
1735 Examples:
1736 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1737 >>> provider = GatewayProvider(config) # doctest: +SKIP
1738 >>> provider.get_model_name() # doctest: +SKIP
1739 'gpt-4o'
1740 """
1741 return self._model_name or self.config.model
1744class LLMProviderFactory:
1745 """
1746 Factory for creating LLM providers.
1748 Implements the Factory pattern to instantiate the appropriate LLM provider
1749 based on configuration, abstracting away provider-specific initialization.
1751 Examples:
1752 >>> config = LLMConfig(
1753 ... provider="ollama",
1754 ... config=OllamaConfig(model="llama2")
1755 ... )
1756 >>> provider = LLMProviderFactory.create(config)
1757 >>> provider.get_model_name()
1758 'llama2'
1760 Note:
1761 This factory supports dynamic provider registration and ensures
1762 type safety through the LLMConfig discriminated union.
1763 """
1765 @staticmethod
1766 def create(llm_config: LLMConfig) -> Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]:
1767 """
1768 Create an LLM provider based on configuration.
1770 Args:
1771 llm_config: LLM configuration specifying provider type and settings.
1773 Returns:
1774 Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: Instantiated provider.
1776 Raises:
1777 ValueError: If provider type is not supported.
1778 ImportError: If required provider package is not installed.
1780 Examples:
1781 >>> # Create Azure OpenAI provider
1782 >>> config = LLMConfig(
1783 ... provider="azure_openai",
1784 ... config=AzureOpenAIConfig(
1785 ... api_key="key",
1786 ... azure_endpoint="https://example.com/",
1787 ... azure_deployment="gpt-4"
1788 ... )
1789 ... )
1790 >>> provider = LLMProviderFactory.create(config)
1791 >>> isinstance(provider, AzureOpenAIProvider)
1792 True
1794 >>> # Create OpenAI provider
1795 >>> config = LLMConfig(
1796 ... provider="openai",
1797 ... config=OpenAIConfig(
1798 ... api_key="sk-...",
1799 ... model="gpt-4"
1800 ... )
1801 ... )
1802 >>> provider = LLMProviderFactory.create(config)
1803 >>> isinstance(provider, OpenAIProvider)
1804 True
1806 >>> # Create Ollama provider
1807 >>> config = LLMConfig(
1808 ... provider="ollama",
1809 ... config=OllamaConfig(model="llama2")
1810 ... )
1811 >>> provider = LLMProviderFactory.create(config)
1812 >>> isinstance(provider, OllamaProvider)
1813 True
1814 """
1815 provider_map = {
1816 "azure_openai": AzureOpenAIProvider,
1817 "openai": OpenAIProvider,
1818 "anthropic": AnthropicProvider,
1819 "aws_bedrock": AWSBedrockProvider,
1820 "ollama": OllamaProvider,
1821 "watsonx": WatsonxProvider,
1822 "gateway": GatewayProvider,
1823 }
1825 provider_class = provider_map.get(llm_config.provider)
1827 if not provider_class:
1828 raise ValueError(f"Unsupported LLM provider: {llm_config.provider}. Supported providers: {list(provider_map.keys())}")
1830 logger.info(f"Creating LLM provider: {llm_config.provider}")
1831 return provider_class(llm_config.config)
1834# ==================== CHAT HISTORY MANAGER ====================
1837class ChatHistoryManager:
1838 """
1839 Centralized chat history management with Redis and in-memory fallback.
1841 Provides a unified interface for storing and retrieving chat histories across
1842 multiple workers using Redis, with automatic fallback to in-memory storage
1843 when Redis is not available.
1845 This class eliminates duplication between router and service layers by
1846 providing a single source of truth for all chat history operations.
1848 Attributes:
1849 redis_client: Optional Redis async client for distributed storage.
1850 max_messages: Maximum number of messages to retain per user.
1851 ttl: Time-to-live for Redis entries in seconds.
1852 _memory_store: In-memory dict fallback when Redis unavailable.
1854 Examples:
1855 >>> import asyncio
1856 >>> # Create manager without Redis (in-memory mode)
1857 >>> manager = ChatHistoryManager(redis_client=None, max_messages=50)
1858 >>> # asyncio.run(manager.save_history("user123", [{"role": "user", "content": "Hello"}]))
1859 >>> # history = asyncio.run(manager.get_history("user123"))
1860 >>> # len(history) >= 0
1861 True
1863 Note:
1864 Thread-safe for Redis operations. In-memory mode suitable for
1865 single-worker deployments only.
1866 """
1868 def __init__(self, redis_client: Optional[Any] = None, max_messages: int = 50, ttl: int = 3600):
1869 """
1870 Initialize chat history manager.
1872 Args:
1873 redis_client: Optional Redis async client. If None, uses in-memory storage.
1874 max_messages: Maximum messages to retain per user (default: 50).
1875 ttl: Time-to-live for Redis entries in seconds (default: 3600).
1877 Examples:
1878 >>> manager = ChatHistoryManager(redis_client=None, max_messages=100)
1879 >>> manager.max_messages
1880 100
1881 >>> manager.ttl
1882 3600
1883 """
1884 self.redis_client = redis_client
1885 self.max_messages = max_messages
1886 self.ttl = ttl
1887 self._memory_store: Dict[str, List[Dict[str, str]]] = {}
1889 if redis_client:
1890 logger.info("ChatHistoryManager initialized with Redis backend")
1891 else:
1892 logger.info("ChatHistoryManager initialized with in-memory backend")
1894 def _history_key(self, user_id: str) -> str:
1895 """
1896 Generate Redis key for user's chat history.
1898 Args:
1899 user_id: User identifier.
1901 Returns:
1902 str: Redis key string.
1904 Examples:
1905 >>> manager = ChatHistoryManager()
1906 >>> manager._history_key("user123")
1907 'chat_history:user123'
1908 """
1909 return f"chat_history:{user_id}"
1911 async def get_history(self, user_id: str) -> List[Dict[str, str]]:
1912 """
1913 Retrieve chat history for a user.
1915 Fetches history from Redis if available, otherwise from in-memory store.
1917 Args:
1918 user_id: User identifier.
1920 Returns:
1921 List[Dict[str, str]]: List of message dictionaries with 'role' and 'content' keys.
1922 Returns empty list if no history exists.
1924 Examples:
1925 >>> import asyncio
1926 >>> manager = ChatHistoryManager()
1927 >>> # history = asyncio.run(manager.get_history("user123"))
1928 >>> # isinstance(history, list)
1929 True
1931 Note:
1932 Automatically handles JSON deserialization errors by returning empty list.
1933 """
1934 if self.redis_client:
1935 try:
1936 data = await self.redis_client.get(self._history_key(user_id))
1937 if not data:
1938 return []
1939 return orjson.loads(data)
1940 except orjson.JSONDecodeError:
1941 logger.warning(f"Failed to decode chat history for user {user_id}")
1942 return []
1943 except Exception as e:
1944 logger.error(f"Error retrieving chat history from Redis for user {user_id}: {e}")
1945 return []
1946 else:
1947 return self._memory_store.get(user_id, [])
1949 async def save_history(self, user_id: str, history: List[Dict[str, str]]) -> None:
1950 """
1951 Save chat history for a user.
1953 Stores history in Redis (with TTL) if available, otherwise in memory.
1954 Automatically trims history to max_messages before saving.
1956 Args:
1957 user_id: User identifier.
1958 history: List of message dictionaries to save.
1960 Examples:
1961 >>> import asyncio
1962 >>> manager = ChatHistoryManager(max_messages=50)
1963 >>> messages = [{"role": "user", "content": "Hello"}]
1964 >>> # asyncio.run(manager.save_history("user123", messages))
1966 Note:
1967 History is automatically trimmed to max_messages limit before storage.
1968 """
1969 # Trim history before saving
1970 trimmed = self._trim_messages(history)
1972 if self.redis_client:
1973 try:
1974 await self.redis_client.set(self._history_key(user_id), orjson.dumps(trimmed), ex=self.ttl)
1975 except Exception as e:
1976 logger.error(f"Error saving chat history to Redis for user {user_id}: {e}")
1977 else:
1978 self._memory_store[user_id] = trimmed
1980 async def append_message(self, user_id: str, role: str, content: str) -> None:
1981 """
1982 Append a single message to user's chat history.
1984 Convenience method that fetches current history, appends the message,
1985 trims if needed, and saves back.
1987 Args:
1988 user_id: User identifier.
1989 role: Message role ('user' or 'assistant').
1990 content: Message content text.
1992 Examples:
1993 >>> import asyncio
1994 >>> manager = ChatHistoryManager()
1995 >>> # asyncio.run(manager.append_message("user123", "user", "Hello!"))
1997 Note:
1998 This method performs a read-modify-write operation which may
1999 not be atomic in distributed environments.
2000 """
2001 history = await self.get_history(user_id)
2002 history.append({"role": role, "content": content})
2003 await self.save_history(user_id, history)
2005 async def clear_history(self, user_id: str) -> None:
2006 """
2007 Clear all chat history for a user.
2009 Deletes history from Redis or memory store.
2011 Args:
2012 user_id: User identifier.
2014 Examples:
2015 >>> import asyncio
2016 >>> manager = ChatHistoryManager()
2017 >>> # asyncio.run(manager.clear_history("user123"))
2019 Note:
2020 This operation cannot be undone.
2021 """
2022 if self.redis_client:
2023 try:
2024 await self.redis_client.delete(self._history_key(user_id))
2025 except Exception as e:
2026 logger.error(f"Error clearing chat history from Redis for user {user_id}: {e}")
2027 else:
2028 self._memory_store.pop(user_id, None)
2030 def _trim_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
2031 """
2032 Trim message list to max_messages limit.
2034 Keeps the most recent messages up to max_messages count.
2036 Args:
2037 messages: List of message dictionaries.
2039 Returns:
2040 List[Dict[str, str]]: Trimmed message list.
2042 Examples:
2043 >>> manager = ChatHistoryManager(max_messages=2)
2044 >>> messages = [
2045 ... {"role": "user", "content": "1"},
2046 ... {"role": "assistant", "content": "2"},
2047 ... {"role": "user", "content": "3"}
2048 ... ]
2049 >>> trimmed = manager._trim_messages(messages)
2050 >>> len(trimmed)
2051 2
2052 >>> trimmed[0]["content"]
2053 '2'
2054 """
2055 if len(messages) > self.max_messages:
2056 return messages[-self.max_messages :]
2057 return messages
2059 async def get_langchain_messages(self, user_id: str) -> List[BaseMessage]:
2060 """
2061 Get chat history as LangChain message objects.
2063 Converts stored history dictionaries to LangChain HumanMessage and
2064 AIMessage objects for use with LangChain agents.
2066 Args:
2067 user_id: User identifier.
2069 Returns:
2070 List[BaseMessage]: List of LangChain message objects.
2072 Examples:
2073 >>> import asyncio
2074 >>> manager = ChatHistoryManager()
2075 >>> # messages = asyncio.run(manager.get_langchain_messages("user123"))
2076 >>> # isinstance(messages, list)
2077 True
2079 Note:
2080 Returns empty list if LangChain is not available or history is empty.
2081 """
2082 if not _LLMCHAT_AVAILABLE:
2083 return []
2085 history = await self.get_history(user_id)
2086 lc_messages = []
2088 for msg in history:
2089 role = msg.get("role")
2090 content = msg.get("content", "")
2092 if role == "user":
2093 lc_messages.append(HumanMessage(content=content))
2094 elif role == "assistant": 2094 ↛ 2088line 2094 didn't jump to line 2088 because the condition on line 2094 was always true
2095 lc_messages.append(AIMessage(content=content))
2097 return lc_messages
2100# ==================== MCP CLIENT ====================
2103class MCPClient:
2104 """
2105 Manages MCP server connections and tool loading.
2107 Provides a high-level interface for connecting to MCP servers, retrieving
2108 available tools, and managing connection health. Supports multiple transport
2109 protocols including HTTP, SSE, and stdio.
2111 Attributes:
2112 config: MCP server configuration.
2114 Examples:
2115 >>> import asyncio
2116 >>> config = MCPServerConfig(
2117 ... url="https://mcp-server.example.com/mcp",
2118 ... transport="streamable_http"
2119 ... )
2120 >>> client = MCPClient(config)
2121 >>> client.is_connected
2122 False
2123 >>> # asyncio.run(client.connect())
2124 >>> # tools = asyncio.run(client.get_tools())
2126 Note:
2127 All methods are async and should be called using asyncio or within
2128 an async context.
2129 """
2131 def __init__(self, config: MCPServerConfig):
2132 """
2133 Initialize MCP client.
2135 Args:
2136 config: MCP server configuration with connection parameters.
2138 Examples:
2139 >>> config = MCPServerConfig(
2140 ... url="https://example.com/mcp",
2141 ... transport="streamable_http"
2142 ... )
2143 >>> client = MCPClient(config)
2144 >>> client.config.transport
2145 'streamable_http'
2146 """
2147 self.config = config
2148 self._client: Optional[MultiServerMCPClient] = None
2149 self._tools: Optional[List[BaseTool]] = None
2150 self._connected = False
2151 logger.info(f"MCP client initialized with transport: {config.transport}")
2153 async def connect(self) -> None:
2154 """
2155 Connect to the MCP server.
2157 Establishes connection to the configured MCP server using the specified
2158 transport protocol. Subsequent calls are no-ops if already connected.
2160 Raises:
2161 ConnectionError: If connection to MCP server fails.
2163 Examples:
2164 >>> import asyncio
2165 >>> config = MCPServerConfig(
2166 ... url="https://example.com/mcp",
2167 ... transport="streamable_http"
2168 ... )
2169 >>> client = MCPClient(config)
2170 >>> # asyncio.run(client.connect())
2171 >>> # client.is_connected -> True
2173 Note:
2174 Connection is idempotent - calling multiple times is safe.
2175 """
2176 if self._connected:
2177 logger.warning("MCP client already connected")
2178 return
2180 try:
2181 logger.info(f"Connecting to MCP server via {self.config.transport}...")
2183 # Build server configuration for MultiServerMCPClient
2184 server_config = {
2185 "transport": self.config.transport,
2186 }
2188 if self.config.transport in ["streamable_http", "sse"]:
2189 server_config["url"] = self.config.url
2190 if self.config.headers:
2191 server_config["headers"] = self.config.headers
2192 elif self.config.transport == "stdio": 2192 ↛ 2197line 2192 didn't jump to line 2197 because the condition on line 2192 was always true
2193 server_config["command"] = self.config.command
2194 if self.config.args: 2194 ↛ 2197line 2194 didn't jump to line 2197 because the condition on line 2194 was always true
2195 server_config["args"] = self.config.args
2197 if not MultiServerMCPClient:
2198 logger.error("Some dependencies are missing. Install those with: pip install '.[llmchat]'")
2200 # Create MultiServerMCPClient with single server
2201 self._client = MultiServerMCPClient({"default": server_config})
2202 self._connected = True
2203 logger.info("Successfully connected to MCP server")
2205 except Exception as e:
2206 logger.error(f"Failed to connect to MCP server: {e}")
2207 self._connected = False
2208 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
2210 async def disconnect(self) -> None:
2211 """
2212 Disconnect from the MCP server.
2214 Cleanly closes the connection and releases resources. Safe to call
2215 even if not connected.
2217 Raises:
2218 Exception: If cleanup operations fail.
2220 Examples:
2221 >>> import asyncio
2222 >>> config = MCPServerConfig(
2223 ... url="https://example.com/mcp",
2224 ... transport="streamable_http"
2225 ... )
2226 >>> client = MCPClient(config)
2227 >>> # asyncio.run(client.connect())
2228 >>> # asyncio.run(client.disconnect())
2229 >>> # client.is_connected -> False
2231 Note:
2232 Clears cached tools upon disconnection.
2233 """
2234 if not self._connected:
2235 logger.warning("MCP client not connected")
2236 return
2238 try:
2239 if self._client: 2239 ↛ 2243line 2239 didn't jump to line 2243 because the condition on line 2239 was always true
2240 # MultiServerMCPClient manages connections internally
2241 self._client = None
2243 self._connected = False
2244 self._tools = None
2245 logger.info("Disconnected from MCP server")
2247 except Exception as e:
2248 logger.error(f"Error during disconnect: {e}")
2249 raise
2251 async def get_tools(self, force_reload: bool = False) -> List[BaseTool]:
2252 """
2253 Get tools from the MCP server.
2255 Retrieves available tools from the connected MCP server. Results are
2256 cached unless force_reload is True.
2258 Args:
2259 force_reload: Force reload tools even if cached (default: False).
2261 Returns:
2262 List[BaseTool]: List of available tools from the server.
2264 Raises:
2265 ConnectionError: If not connected to MCP server.
2266 Exception: If tool loading fails.
2268 Examples:
2269 >>> import asyncio
2270 >>> config = MCPServerConfig(
2271 ... url="https://example.com/mcp",
2272 ... transport="streamable_http"
2273 ... )
2274 >>> client = MCPClient(config)
2275 >>> # asyncio.run(client.connect())
2276 >>> # tools = asyncio.run(client.get_tools())
2277 >>> # len(tools) >= 0 -> True
2279 Note:
2280 Tools are cached after first successful load for performance.
2281 """
2282 if not self._connected or not self._client:
2283 raise ConnectionError("Not connected to MCP server. Call connect() first.")
2285 if self._tools and not force_reload:
2286 logger.debug(f"Returning {len(self._tools)} cached tools")
2287 return self._tools
2289 try:
2290 logger.info("Loading tools from MCP server...")
2291 self._tools = await self._client.get_tools()
2292 logger.info(f"Successfully loaded {len(self._tools)} tools")
2293 return self._tools
2295 except Exception as e:
2296 logger.error(f"Failed to load tools: {e}")
2297 raise
2299 @property
2300 def is_connected(self) -> bool:
2301 """
2302 Check if client is connected.
2304 Returns:
2305 bool: True if connected to MCP server, False otherwise.
2307 Examples:
2308 >>> config = MCPServerConfig(
2309 ... url="https://example.com/mcp",
2310 ... transport="streamable_http"
2311 ... )
2312 >>> client = MCPClient(config)
2313 >>> client.is_connected
2314 False
2315 """
2316 return self._connected
2319# ==================== MCP CHAT SERVICE ====================
2322class MCPChatService:
2323 """
2324 Main chat service for MCP client backend.
2325 Orchestrates chat sessions with LLM and MCP server integration.
2327 Provides a high-level interface for managing conversational AI sessions
2328 that combine LLM capabilities with MCP server tools. Handles conversation
2329 history management, tool execution, and streaming responses.
2331 This service integrates:
2332 - LLM providers (Azure OpenAI, OpenAI, Anthropic, AWS Bedrock, Ollama)
2333 - MCP server tools
2334 - Centralized chat history management (Redis or in-memory)
2335 - Streaming and non-streaming response modes
2337 Attributes:
2338 config: Complete MCP client configuration.
2339 user_id: Optional user identifier for history management.
2341 Examples:
2342 >>> import asyncio
2343 >>> config = MCPClientConfig(
2344 ... mcp_server=MCPServerConfig(
2345 ... url="https://example.com/mcp",
2346 ... transport="streamable_http"
2347 ... ),
2348 ... llm=LLMConfig(
2349 ... provider="ollama",
2350 ... config=OllamaConfig(model="llama2")
2351 ... )
2352 ... )
2353 >>> service = MCPChatService(config)
2354 >>> service.is_initialized
2355 False
2356 >>> # asyncio.run(service.initialize())
2358 Note:
2359 Must call initialize() before using chat methods.
2360 """
2362 def __init__(self, config: MCPClientConfig, user_id: Optional[str] = None, redis_client: Optional[Any] = None):
2363 """
2364 Initialize MCP chat service.
2366 Args:
2367 config: Complete MCP client configuration.
2368 user_id: Optional user identifier for chat history management.
2369 redis_client: Optional Redis client for distributed history storage.
2371 Examples:
2372 >>> config = MCPClientConfig(
2373 ... mcp_server=MCPServerConfig(
2374 ... url="https://example.com/mcp",
2375 ... transport="streamable_http"
2376 ... ),
2377 ... llm=LLMConfig(
2378 ... provider="ollama",
2379 ... config=OllamaConfig(model="llama2")
2380 ... )
2381 ... )
2382 >>> service = MCPChatService(config, user_id="user123")
2383 >>> service.user_id
2384 'user123'
2385 """
2386 self.config = config
2387 self.user_id = user_id
2388 self.mcp_client = MCPClient(config.mcp_server)
2389 self.llm_provider = LLMProviderFactory.create(config.llm)
2391 # Initialize centralized chat history manager
2392 self.history_manager = ChatHistoryManager(redis_client=redis_client, max_messages=config.chat_history_max_messages, ttl=settings.llmchat_chat_history_ttl)
2394 self._agent = None
2395 self._initialized = False
2396 self._tools: List[BaseTool] = []
2398 logger.info(f"MCPChatService initialized for user: {user_id or 'anonymous'}")
2400 async def initialize(self) -> None:
2401 """
2402 Initialize the chat service.
2404 Connects to MCP server, loads tools, initializes LLM, and creates the
2405 conversational agent. Must be called before using chat functionality.
2407 Raises:
2408 ConnectionError: If MCP server connection fails.
2409 Exception: If initialization fails.
2411 Examples:
2412 >>> import asyncio
2413 >>> config = MCPClientConfig(
2414 ... mcp_server=MCPServerConfig(
2415 ... url="https://example.com/mcp",
2416 ... transport="streamable_http"
2417 ... ),
2418 ... llm=LLMConfig(
2419 ... provider="ollama",
2420 ... config=OllamaConfig(model="llama2")
2421 ... )
2422 ... )
2423 >>> service = MCPChatService(config)
2424 >>> # asyncio.run(service.initialize())
2425 >>> # service.is_initialized -> True
2427 Note:
2428 Automatically loads tools from MCP server and creates agent.
2429 """
2430 if self._initialized:
2431 logger.warning("Chat service already initialized")
2432 return
2434 try:
2435 logger.info("Initializing chat service...")
2437 # Connect to MCP server and load tools
2438 await self.mcp_client.connect()
2439 self._tools = await self.mcp_client.get_tools()
2441 # Create LLM instance
2442 llm = self.llm_provider.get_llm()
2444 # Create ReAct agent with tools
2445 self._agent = create_react_agent(llm, self._tools)
2447 self._initialized = True
2448 logger.info(f"Chat service initialized successfully with {len(self._tools)} tools")
2450 except Exception as e:
2451 logger.error(f"Failed to initialize chat service: {e}")
2452 self._initialized = False
2453 raise
2455 async def chat(self, message: str) -> str:
2456 """
2457 Send a message and get a complete response.
2459 Processes the user's message through the LLM with tool access,
2460 manages conversation history, and returns the complete response.
2462 Args:
2463 message: User's message text.
2465 Returns:
2466 str: Complete AI response text.
2468 Raises:
2469 RuntimeError: If service not initialized.
2470 ValueError: If message is empty.
2471 Exception: If processing fails.
2473 Examples:
2474 >>> import asyncio
2475 >>> # Assuming service is initialized
2476 >>> # response = asyncio.run(service.chat("Hello!"))
2477 >>> # isinstance(response, str)
2478 True
2480 Note:
2481 Automatically saves conversation history after response.
2482 """
2483 if not self._initialized or not self._agent:
2484 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2486 if not message or not message.strip():
2487 raise ValueError("Message cannot be empty")
2489 try:
2490 logger.debug("Processing chat message...")
2492 # Get conversation history from manager
2493 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2495 # Add user message
2496 user_message = HumanMessage(content=message)
2497 lc_messages.append(user_message)
2499 # Invoke agent
2500 response = await self._agent.ainvoke({"messages": lc_messages})
2502 # Extract AI response
2503 ai_message = response["messages"][-1]
2504 response_text = ai_message.content if hasattr(ai_message, "content") else str(ai_message)
2506 # Save history if user_id provided
2507 if self.user_id:
2508 await self.history_manager.append_message(self.user_id, "user", message)
2509 await self.history_manager.append_message(self.user_id, "assistant", response_text)
2511 logger.debug("Chat message processed successfully")
2512 return response_text
2514 except Exception as e:
2515 logger.error(f"Error processing chat message: {e}")
2516 raise
2518 async def chat_with_metadata(self, message: str) -> Dict[str, Any]:
2519 """
2520 Send a message and get response with metadata.
2522 Similar to chat() but collects all events and returns detailed
2523 information about tool usage and timing.
2525 Args:
2526 message: User's message text.
2528 Returns:
2529 Dict[str, Any]: Dictionary containing:
2530 - text (str): Complete response text
2531 - tool_used (bool): Whether any tools were invoked
2532 - tools (List[str]): Names of tools that were used
2533 - tool_invocations (List[dict]): Detailed tool invocation data
2534 - elapsed_ms (int): Processing time in milliseconds
2536 Raises:
2537 RuntimeError: If service not initialized.
2538 ValueError: If message is empty.
2540 Examples:
2541 >>> import asyncio
2542 >>> # Assuming service is initialized
2543 >>> # result = asyncio.run(service.chat_with_metadata("What's 2+2?"))
2544 >>> # 'text' in result and 'elapsed_ms' in result
2545 True
2547 Note:
2548 This method collects all events and returns them as a single response.
2549 """
2550 text = ""
2551 tool_invocations: list[dict[str, Any]] = []
2552 final: dict[str, Any] = {}
2554 async for ev in self.chat_events(message):
2555 t = ev.get("type")
2556 if t == "token":
2557 text += ev.get("content", "")
2558 elif t in ("tool_start", "tool_end", "tool_error"):
2559 tool_invocations.append(ev)
2560 elif t == "final": 2560 ↛ 2554line 2560 didn't jump to line 2554 because the condition on line 2560 was always true
2561 final = ev
2563 return {
2564 "text": text,
2565 "tool_used": final.get("tool_used", False),
2566 "tools": final.get("tools", []),
2567 "tool_invocations": tool_invocations,
2568 "elapsed_ms": final.get("elapsed_ms"),
2569 }
2571 async def chat_stream(self, message: str) -> AsyncGenerator[str, None]:
2572 """
2573 Send a message and stream the response.
2575 Yields response chunks as they're generated, enabling real-time display
2576 of the AI's response.
2578 Args:
2579 message: User's message text.
2581 Yields:
2582 str: Chunks of AI response text.
2584 Raises:
2585 RuntimeError: If service not initialized.
2586 Exception: If streaming fails.
2588 Examples:
2589 >>> import asyncio
2590 >>> async def stream_example():
2591 ... # Assuming service is initialized
2592 ... chunks = []
2593 ... async for chunk in service.chat_stream("Hello"):
2594 ... chunks.append(chunk)
2595 ... return ''.join(chunks)
2596 >>> # full_response = asyncio.run(stream_example())
2598 Note:
2599 Falls back to non-streaming if enable_streaming is False in config.
2600 """
2601 if not self._initialized or not self._agent:
2602 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2604 if not self.config.enable_streaming:
2605 # Fall back to non-streaming
2606 response = await self.chat(message)
2607 yield response
2608 return
2610 try:
2611 logger.debug("Processing streaming chat message...")
2613 # Get conversation history
2614 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2616 # Add user message
2617 user_message = HumanMessage(content=message)
2618 lc_messages.append(user_message)
2620 # Stream agent response
2621 full_response = ""
2622 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"):
2623 kind = event["event"]
2625 # Stream LLM tokens
2626 if kind == "on_chat_model_stream": 2626 ↛ 2622line 2626 didn't jump to line 2622 because the condition on line 2626 was always true
2627 chunk = event.get("data", {}).get("chunk")
2628 if chunk and hasattr(chunk, "content"): 2628 ↛ 2622line 2628 didn't jump to line 2622 because the condition on line 2628 was always true
2629 content = chunk.content
2630 if content: 2630 ↛ 2622line 2630 didn't jump to line 2622 because the condition on line 2630 was always true
2631 full_response += content
2632 yield content
2634 # Save history
2635 if self.user_id and full_response: 2635 ↛ 2639line 2635 didn't jump to line 2639 because the condition on line 2635 was always true
2636 await self.history_manager.append_message(self.user_id, "user", message)
2637 await self.history_manager.append_message(self.user_id, "assistant", full_response)
2639 logger.debug("Streaming chat message processed successfully")
2641 except Exception as e:
2642 logger.error(f"Error processing streaming chat message: {e}")
2643 raise
2645 async def chat_events(self, message: str) -> AsyncGenerator[Dict[str, Any], None]:
2646 """
2647 Stream structured events during chat processing.
2649 Provides granular visibility into the chat processing pipeline by yielding
2650 structured events for tokens, tool invocations, errors, and final results.
2652 Args:
2653 message: User's message text.
2655 Yields:
2656 dict: Event dictionaries with type-specific fields:
2657 - token: {"type": "token", "content": str}
2658 - tool_start: {"type": "tool_start", "id": str, "name": str,
2659 "input": Any, "start": str}
2660 - tool_end: {"type": "tool_end", "id": str, "name": str,
2661 "output": Any, "end": str}
2662 - tool_error: {"type": "tool_error", "id": str, "error": str,
2663 "time": str}
2664 - final: {"type": "final", "content": str, "tool_used": bool,
2665 "tools": List[str], "elapsed_ms": int}
2667 Raises:
2668 RuntimeError: If service not initialized.
2669 ValueError: If message is empty or whitespace only.
2671 Examples:
2672 >>> import asyncio
2673 >>> async def event_example():
2674 ... # Assuming service is initialized
2675 ... events = []
2676 ... async for event in service.chat_events("Hello"):
2677 ... events.append(event['type'])
2678 ... return events
2679 >>> # event_types = asyncio.run(event_example())
2680 >>> # 'final' in event_types -> True
2682 Note:
2683 This is the most detailed chat method, suitable for building
2684 interactive UIs or detailed logging systems.
2685 """
2686 if not self._initialized or not self._agent:
2687 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2689 # Validate message
2690 if not message or not message.strip():
2691 raise ValueError("Message cannot be empty")
2693 # Get conversation history
2694 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2696 # Append user message
2697 user_message = HumanMessage(content=message)
2698 lc_messages.append(user_message)
2700 full_response = ""
2701 start_ts = time.time()
2702 tool_runs: dict[str, dict[str, Any]] = {}
2703 # Buffer for out-of-order on_tool_end events (end arrives before start)
2704 pending_tool_ends: dict[str, dict[str, Any]] = {}
2705 pending_ttl_seconds = 30.0 # Max time to hold pending end events
2706 pending_max_size = 100 # Max number of pending end events to buffer
2707 # Track dropped run_ids for aggregated error (TTL-expired or buffer-full)
2708 dropped_tool_ends: set[str] = set()
2709 dropped_max_size = 200 # Max dropped IDs to track (prevents unbounded growth)
2710 dropped_overflow_count = 0 # Count of drops that couldn't be tracked due to full buffer
2712 def _extract_output(raw_output: Any) -> Any:
2713 """Extract output value from various LangChain output formats.
2715 Args:
2716 raw_output: The raw output from a tool execution.
2718 Returns:
2719 The extracted output value in a serializable format.
2720 """
2721 if hasattr(raw_output, "content"):
2722 return raw_output.content
2723 if hasattr(raw_output, "dict") and callable(raw_output.dict):
2724 return raw_output.dict()
2725 if not isinstance(raw_output, (str, int, float, bool, list, dict, type(None))):
2726 return str(raw_output)
2727 return raw_output
2729 def _cleanup_expired_pending(current_ts: float) -> None:
2730 """Remove expired entries from pending_tool_ends buffer and track them.
2732 Args:
2733 current_ts: Current timestamp in seconds since epoch.
2734 """
2735 nonlocal dropped_overflow_count
2736 expired = [rid for rid, data in pending_tool_ends.items() if current_ts - data.get("buffered_at", 0) > pending_ttl_seconds]
2737 for rid in expired:
2738 logger.warning(f"Pending on_tool_end for run_id {rid} expired after {pending_ttl_seconds}s (orphan event)")
2739 if len(dropped_tool_ends) < dropped_max_size:
2740 dropped_tool_ends.add(rid)
2741 else:
2742 dropped_overflow_count += 1
2743 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track expired run_id {rid} (overflow count: {dropped_overflow_count})")
2744 del pending_tool_ends[rid]
2746 try:
2747 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"):
2748 kind = event.get("event")
2749 now_iso = datetime.now(timezone.utc).isoformat()
2750 now_ts = time.time()
2752 # Periodically cleanup expired pending ends
2753 _cleanup_expired_pending(now_ts)
2755 try:
2756 if kind == "on_tool_start":
2757 run_id = str(event.get("run_id") or uuid4())
2758 name = event.get("name") or event.get("data", {}).get("name") or event.get("data", {}).get("tool")
2759 input_data = event.get("data", {}).get("input")
2761 # Filter out common metadata keys injected by LangChain/LangGraph
2762 if isinstance(input_data, dict): 2762 ↛ 2765line 2762 didn't jump to line 2765 because the condition on line 2762 was always true
2763 input_data = {k: v for k, v in input_data.items() if k not in ["runtime", "config", "run_manager", "callbacks"]}
2765 tool_runs[run_id] = {"name": name, "start": now_iso, "input": input_data}
2767 # Register run for cancellation tracking with gateway-level Cancellation service
2768 async def _noop_cancel_cb(reason: Optional[str]) -> None:
2769 """
2770 No-op cancel callback used when a run is started.
2772 Args:
2773 reason: Optional textual reason for cancellation.
2775 Returns:
2776 None
2777 """
2778 # Default no-op; kept for potential future intra-process cancellation
2779 return None
2781 # Register with cancellation service only if feature is enabled
2782 if settings.mcpgateway_tool_cancellation_enabled:
2783 try:
2784 await cancellation_service.register_run(run_id, name=name, cancel_callback=_noop_cancel_cb)
2785 except Exception:
2786 logger.exception("Failed to register run %s with CancellationService", run_id)
2788 yield {"type": "tool_start", "id": run_id, "tool": name, "input": input_data, "start": now_iso}
2790 # NOTE: Do NOT clear from dropped_tool_ends here. If an end was dropped (TTL/buffer-full)
2791 # before this start arrived, that end is permanently lost. Since tools only end once,
2792 # we won't receive another end event, so this should still be reported as an orphan.
2794 # Check if we have a buffered end event for this run_id (out-of-order reconciliation)
2795 if run_id in pending_tool_ends:
2796 buffered = pending_tool_ends.pop(run_id)
2797 tool_runs[run_id]["end"] = buffered["end_time"]
2798 tool_runs[run_id]["output"] = buffered["output"]
2799 logger.info(f"Reconciled out-of-order on_tool_end for run_id {run_id}")
2801 if tool_runs[run_id].get("output") == "":
2802 error = "Tool execution failed: Please check if the tool is accessible"
2803 yield {"type": "tool_error", "id": run_id, "tool": name, "error": error, "time": buffered["end_time"]}
2805 yield {"type": "tool_end", "id": run_id, "tool": name, "output": tool_runs[run_id].get("output"), "end": buffered["end_time"]}
2807 elif kind == "on_tool_end":
2808 run_id = str(event.get("run_id") or uuid4())
2809 output = event.get("data", {}).get("output")
2810 extracted_output = _extract_output(output)
2812 if run_id in tool_runs:
2813 # Normal case: start already received
2814 tool_runs[run_id]["end"] = now_iso
2815 tool_runs[run_id]["output"] = extracted_output
2817 if tool_runs[run_id].get("output") == "":
2818 error = "Tool execution failed: Please check if the tool is accessible"
2819 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso}
2821 yield {"type": "tool_end", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "output": tool_runs[run_id].get("output"), "end": now_iso}
2822 else:
2823 # Out-of-order: buffer the end event for later reconciliation
2824 if len(pending_tool_ends) < pending_max_size:
2825 pending_tool_ends[run_id] = {"output": extracted_output, "end_time": now_iso, "buffered_at": now_ts}
2826 logger.debug(f"Buffered out-of-order on_tool_end for run_id {run_id}, awaiting on_tool_start")
2827 else:
2828 logger.warning(f"Pending tool ends buffer full ({pending_max_size}), dropping on_tool_end for run_id {run_id}")
2829 if len(dropped_tool_ends) < dropped_max_size:
2830 dropped_tool_ends.add(run_id)
2831 else:
2832 dropped_overflow_count += 1
2833 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track run_id {run_id} (overflow count: {dropped_overflow_count})")
2835 # Unregister run from cancellation service when finished (only if feature is enabled)
2836 if settings.mcpgateway_tool_cancellation_enabled:
2837 try:
2838 await cancellation_service.unregister_run(run_id)
2839 except Exception:
2840 logger.exception("Failed to unregister run %s", run_id)
2842 elif kind == "on_tool_error":
2843 run_id = str(event.get("run_id") or uuid4())
2844 error = str(event.get("data", {}).get("error", "Unknown error"))
2846 # Clear any buffered end for this run to avoid emitting both error and end
2847 if run_id in pending_tool_ends:
2848 del pending_tool_ends[run_id]
2849 logger.debug(f"Cleared buffered on_tool_end for run_id {run_id} due to tool error")
2851 # Clear from dropped set if this run was previously dropped (prevents false orphan)
2852 dropped_tool_ends.discard(run_id)
2854 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso}
2856 # Unregister run on error (only if feature is enabled)
2857 if settings.mcpgateway_tool_cancellation_enabled: 2857 ↛ 2747line 2857 didn't jump to line 2747 because the condition on line 2857 was always true
2858 try:
2859 await cancellation_service.unregister_run(run_id)
2860 except Exception:
2861 logger.exception("Failed to unregister run %s after error", run_id)
2863 elif kind == "on_chat_model_stream": 2863 ↛ 2747line 2863 didn't jump to line 2747 because the condition on line 2863 was always true
2864 chunk = event.get("data", {}).get("chunk")
2865 if chunk and hasattr(chunk, "content"): 2865 ↛ 2747line 2865 didn't jump to line 2747 because the condition on line 2865 was always true
2866 content = chunk.content
2867 if content: 2867 ↛ 2747line 2867 didn't jump to line 2747 because the condition on line 2867 was always true
2868 full_response += content
2869 yield {"type": "token", "content": content}
2871 except Exception as event_error:
2872 logger.warning(f"Error processing event {kind}: {event_error}")
2873 continue
2875 # Emit aggregated error for any orphan/dropped tool ends
2876 # De-duplicate IDs (in case same ID was buffered and dropped in edge cases)
2877 all_orphan_ids = sorted(set(pending_tool_ends.keys()) | dropped_tool_ends)
2878 if all_orphan_ids or dropped_overflow_count > 0:
2879 buffered_count = len(pending_tool_ends)
2880 dropped_count = len(dropped_tool_ends)
2881 total_unique = len(all_orphan_ids)
2882 total_affected = total_unique + dropped_overflow_count
2883 logger.warning(
2884 f"Stream completed with {total_affected} orphan tool end(s): {buffered_count} buffered, {dropped_count} dropped (tracked), {dropped_overflow_count} dropped (untracked overflow)"
2885 )
2886 # Log full list at debug level for observability
2887 if all_orphan_ids: 2887 ↛ 2889line 2887 didn't jump to line 2889 because the condition on line 2887 was always true
2888 logger.debug(f"Full orphan run_id list: {', '.join(all_orphan_ids)}")
2889 now_iso = datetime.now(timezone.utc).isoformat()
2890 error_parts = []
2891 if buffered_count > 0:
2892 error_parts.append(f"{buffered_count} buffered")
2893 if dropped_count > 0:
2894 error_parts.append(f"{dropped_count} dropped (TTL expired or buffer full)")
2895 if dropped_overflow_count > 0:
2896 error_parts.append(f"{dropped_overflow_count} additional dropped (tracking overflow)")
2897 error_msg = f"Tool execution incomplete: {total_affected} tool end(s) received without matching start ({', '.join(error_parts)})"
2898 # Truncate to first 10 IDs in error message to avoid excessive payload
2899 if all_orphan_ids: 2899 ↛ 2907line 2899 didn't jump to line 2907 because the condition on line 2899 was always true
2900 max_display_ids = 10
2901 display_ids = all_orphan_ids[:max_display_ids]
2902 remaining = total_unique - len(display_ids)
2903 if remaining > 0:
2904 error_msg += f". Run IDs (first {max_display_ids} of {total_unique}): {', '.join(display_ids)} (+{remaining} more)"
2905 else:
2906 error_msg += f". Run IDs: {', '.join(display_ids)}"
2907 yield {
2908 "type": "tool_error",
2909 "id": str(uuid4()),
2910 "tool": None,
2911 "error": error_msg,
2912 "time": now_iso,
2913 }
2914 pending_tool_ends.clear()
2915 dropped_tool_ends.clear()
2917 # Calculate elapsed time
2918 elapsed_ms = int((time.time() - start_ts) * 1000)
2920 # Determine tool usage
2921 tools_used = list({tr["name"] for tr in tool_runs.values() if tr.get("name")})
2923 # Yield final event
2924 yield {"type": "final", "content": full_response, "tool_used": len(tools_used) > 0, "tools": tools_used, "elapsed_ms": elapsed_ms}
2926 # Save history
2927 if self.user_id and full_response:
2928 await self.history_manager.append_message(self.user_id, "user", message)
2929 await self.history_manager.append_message(self.user_id, "assistant", full_response)
2931 except Exception as e:
2932 logger.error(f"Error in chat_events: {e}")
2933 raise RuntimeError(f"Chat processing error: {e}") from e
2935 async def get_conversation_history(self) -> List[Dict[str, str]]:
2936 """
2937 Get conversation history for the current user.
2939 Returns:
2940 List[Dict[str, str]]: Conversation messages with keys:
2941 - role (str): "user" or "assistant"
2942 - content (str): Message text
2944 Examples:
2945 >>> import asyncio
2946 >>> # Assuming service is initialized with user_id
2947 >>> # history = asyncio.run(service.get_conversation_history())
2948 >>> # all('role' in msg and 'content' in msg for msg in history)
2949 True
2951 Note:
2952 Returns empty list if no user_id set or no history exists.
2953 """
2954 if not self.user_id:
2955 return []
2957 return await self.history_manager.get_history(self.user_id)
2959 async def clear_history(self) -> None:
2960 """
2961 Clear conversation history for the current user.
2963 Removes all messages from the conversation history. Useful for starting
2964 fresh conversations or managing memory usage.
2966 Examples:
2967 >>> import asyncio
2968 >>> # Assuming service is initialized with user_id
2969 >>> # asyncio.run(service.clear_history())
2970 >>> # history = asyncio.run(service.get_conversation_history())
2971 >>> # len(history) -> 0
2973 Note:
2974 This action cannot be undone. No-op if no user_id set.
2975 """
2976 if not self.user_id:
2977 return
2979 await self.history_manager.clear_history(self.user_id)
2980 logger.info(f"Conversation history cleared for user {self.user_id}")
2982 async def shutdown(self) -> None:
2983 """
2984 Shutdown the chat service and cleanup resources.
2986 Performs graceful shutdown by disconnecting from MCP server, clearing
2987 agent and history, and resetting initialization state.
2989 Raises:
2990 Exception: If cleanup operations fail.
2992 Examples:
2993 >>> import asyncio
2994 >>> config = MCPClientConfig(
2995 ... mcp_server=MCPServerConfig(
2996 ... url="https://example.com/mcp",
2997 ... transport="streamable_http"
2998 ... ),
2999 ... llm=LLMConfig(
3000 ... provider="ollama",
3001 ... config=OllamaConfig(model="llama2")
3002 ... )
3003 ... )
3004 >>> service = MCPChatService(config)
3005 >>> # asyncio.run(service.initialize())
3006 >>> # asyncio.run(service.shutdown())
3007 >>> # service.is_initialized -> False
3009 Note:
3010 Should be called when service is no longer needed to properly
3011 release resources and connections.
3012 """
3013 logger.info("Shutting down chat service...")
3015 try:
3016 # Disconnect from MCP server
3017 if self.mcp_client.is_connected: 3017 ↛ 3021line 3017 didn't jump to line 3021 because the condition on line 3017 was always true
3018 await self.mcp_client.disconnect()
3020 # Clear state
3021 self._agent = None
3022 self._initialized = False
3023 self._tools = []
3025 logger.info("Chat service shutdown complete")
3027 except Exception as e:
3028 logger.error(f"Error during shutdown: {e}")
3029 raise
3031 @property
3032 def is_initialized(self) -> bool:
3033 """
3034 Check if service is initialized.
3036 Returns:
3037 bool: True if service is initialized and ready, False otherwise.
3039 Examples:
3040 >>> config = MCPClientConfig(
3041 ... mcp_server=MCPServerConfig(
3042 ... url="https://example.com/mcp",
3043 ... transport="streamable_http"
3044 ... ),
3045 ... llm=LLMConfig(
3046 ... provider="ollama",
3047 ... config=OllamaConfig(model="llama2")
3048 ... )
3049 ... )
3050 >>> service = MCPChatService(config)
3051 >>> service.is_initialized
3052 False
3054 Note:
3055 Service must be initialized before calling chat methods.
3056 """
3057 return self._initialized
3059 async def reload_tools(self) -> int:
3060 """
3061 Reload tools from MCP server.
3063 Forces a reload of tools from the MCP server and recreates the agent
3064 with the updated tool set. Useful when MCP server tools have changed.
3066 Returns:
3067 int: Number of tools successfully loaded.
3069 Raises:
3070 RuntimeError: If service not initialized.
3071 Exception: If tool reloading or agent recreation fails.
3073 Examples:
3074 >>> import asyncio
3075 >>> # Assuming service is initialized
3076 >>> # tool_count = asyncio.run(service.reload_tools())
3077 >>> # tool_count >= 0 -> True
3079 Note:
3080 This operation recreates the agent, so it may briefly interrupt
3081 ongoing conversations. Conversation history is preserved.
3082 """
3083 if not self._initialized:
3084 raise RuntimeError("Chat service not initialized")
3086 try:
3087 logger.info("Reloading tools from MCP server...")
3088 tools = await self.mcp_client.get_tools(force_reload=True)
3090 # Recreate agent with new tools
3091 llm = self.llm_provider.get_llm()
3092 self._agent = create_react_agent(llm, tools)
3093 self._tools = tools
3095 logger.info(f"Reloaded {len(tools)} tools successfully")
3096 return len(tools)
3098 except Exception as e:
3099 logger.error(f"Failed to reload tools: {e}")
3100 raise