Coverage for mcpgateway / services / mcp_client_chat_service.py: 99%
845 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/mcp_client_chat_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7MCP Client Service Module.
9This module provides a comprehensive client implementation for interacting with
10MCP servers, managing LLM providers, and orchestrating conversational AI agents.
11It supports multiple transport protocols and LLM providers.
13The module consists of several key components:
14- Configuration classes for MCP servers and LLM providers
15- LLM provider factory and implementations
16- MCP client for tool management
17- Chat history manager for Redis and in-memory storage
18- Chat service for conversational interactions
19"""
21# Standard
22from datetime import datetime, timezone
23import time
24from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
25from uuid import uuid4
27# Third-Party
28import orjson
30try:
31 # Third-Party
32 from langchain_core.language_models import BaseChatModel
33 from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
34 from langchain_core.tools import BaseTool
35 from langchain_mcp_adapters.client import MultiServerMCPClient
36 from langchain_ollama import ChatOllama, OllamaLLM
37 from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI, OpenAI
38 from langgraph.prebuilt import create_react_agent
40 _LLMCHAT_AVAILABLE = True
41except ImportError:
42 # Optional dependencies for LLM chat feature not installed
43 # These are only needed if LLMCHAT_ENABLED=true
44 _LLMCHAT_AVAILABLE = False
45 BaseChatModel = None # type: ignore
46 AIMessage = None # type: ignore
47 BaseMessage = None # type: ignore
48 HumanMessage = None # type: ignore
49 BaseTool = None # type: ignore
50 MultiServerMCPClient = None # type: ignore
51 ChatOllama = None # type: ignore
52 OllamaLLM = None
53 AzureChatOpenAI = None # type: ignore
54 AzureOpenAI = None
55 ChatOpenAI = None # type: ignore
56 OpenAI = None
57 create_react_agent = None # type: ignore
59# Try to import Anthropic and Bedrock providers (they may not be installed)
60try:
61 # Third-Party
62 from langchain_anthropic import AnthropicLLM, ChatAnthropic
64 _ANTHROPIC_AVAILABLE = True
65except ImportError:
66 _ANTHROPIC_AVAILABLE = False
67 ChatAnthropic = None # type: ignore
68 AnthropicLLM = None
70try:
71 # Third-Party
72 from langchain_aws import BedrockLLM, ChatBedrock
74 _BEDROCK_AVAILABLE = True
75except ImportError:
76 _BEDROCK_AVAILABLE = False
77 ChatBedrock = None # type: ignore
78 BedrockLLM = None
80try:
81 # Third-Party
82 from langchain_ibm import ChatWatsonx, WatsonxLLM
84 _WATSONX_AVAILABLE = True
85except ImportError:
86 _WATSONX_AVAILABLE = False
87 WatsonxLLM = None # type: ignore
88 ChatWatsonx = None
90# Third-Party
91from pydantic import BaseModel, Field, field_validator, model_validator
93# First-Party
94from mcpgateway.config import settings
95from mcpgateway.services.cancellation_service import cancellation_service
96from mcpgateway.services.logging_service import LoggingService
98logging_service = LoggingService()
99logger = logging_service.get_logger(__name__)
102class MCPServerConfig(BaseModel):
103 """
104 Configuration for MCP server connection.
106 This class defines the configuration parameters required to connect to an
107 MCP (Model Context Protocol) server using various transport mechanisms.
109 Attributes:
110 url: MCP server URL for streamable_http/sse transports.
111 command: Command to run for stdio transport.
112 args: Command-line arguments for stdio command.
113 transport: Transport type (streamable_http, sse, or stdio).
114 auth_token: Authentication token for HTTP-based transports.
115 headers: Additional HTTP headers for request customization.
117 Examples:
118 >>> # HTTP-based transport
119 >>> config = MCPServerConfig(
120 ... url="https://mcp-server.example.com/mcp",
121 ... transport="streamable_http",
122 ... auth_token="secret-token"
123 ... )
124 >>> config.transport
125 'streamable_http'
127 >>> # Stdio transport (requires explicit feature flag)
128 >>> settings.mcpgateway_stdio_transport_enabled = True
129 >>> config = MCPServerConfig(
130 ... command="python",
131 ... args=["server.py"],
132 ... transport="stdio"
133 ... )
134 >>> config.command
135 'python'
137 Note:
138 The auth_token is automatically added to headers as a Bearer token
139 for HTTP-based transports.
140 """
142 url: Optional[str] = Field(None, description="MCP server URL for streamable_http/sse transports")
143 command: Optional[str] = Field(None, description="Command to run for stdio transport")
144 args: Optional[list[str]] = Field(None, description="Arguments for stdio command")
145 transport: Literal["streamable_http", "sse", "stdio"] = Field(default="streamable_http", description="Transport type for MCP connection")
146 auth_token: Optional[str] = Field(None, description="Authentication token for the server")
147 headers: Optional[Dict[str, str]] = Field(default=None, description="Additional headers for HTTP-based transports")
149 @model_validator(mode="before")
150 @classmethod
151 def add_auth_to_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
152 """
153 Automatically add authentication token to headers if provided.
155 This validator ensures that if an auth_token is provided for HTTP-based
156 transports, it's automatically added to the headers as a Bearer token.
158 Args:
159 values: Dictionary of field values before validation.
161 Returns:
162 Dict[str, Any]: Updated values with auth token in headers.
164 Examples:
165 >>> values = {
166 ... "url": "https://api.example.com",
167 ... "transport": "streamable_http",
168 ... "auth_token": "token123"
169 ... }
170 >>> result = MCPServerConfig.add_auth_to_headers(values)
171 >>> result['headers']['Authorization']
172 'Bearer token123'
173 """
174 auth_token = values.get("auth_token")
175 transport = values.get("transport")
176 headers = values.get("headers") or {}
178 if auth_token and transport in ["streamable_http", "sse"]:
179 if "Authorization" not in headers:
180 headers["Authorization"] = f"Bearer {auth_token}"
181 values["headers"] = headers
183 return values
185 @model_validator(mode="after")
186 def validate_transport_requirements(self):
187 """Validate transport-specific requirements and feature flags.
189 Returns:
190 MCPServerConfig: Validated config instance.
192 Raises:
193 ValueError: If transport requirements or feature flags are violated.
194 """
195 if self.transport in ["streamable_http", "sse"] and not self.url:
196 raise ValueError(f"URL is required for {self.transport} transport")
198 if self.transport == "stdio":
199 if not settings.mcpgateway_stdio_transport_enabled:
200 raise ValueError("stdio transport is disabled by default; set MCPGATEWAY_STDIO_TRANSPORT_ENABLED=true to enable it")
201 if not self.command:
202 raise ValueError("Command is required for stdio transport")
204 return self
206 model_config = {
207 "json_schema_extra": {
208 "examples": [
209 {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token-here"}, # nosec B105 - example placeholder
210 {"command": "python", "args": ["server.py"], "transport": "stdio"},
211 ]
212 }
213 }
216class AzureOpenAIConfig(BaseModel):
217 """
218 Configuration for Azure OpenAI provider.
220 Defines all necessary parameters to connect to and use Azure OpenAI services,
221 including API credentials, endpoints, model settings, and request parameters.
223 Attributes:
224 api_key: Azure OpenAI API authentication key.
225 azure_endpoint: Azure OpenAI service endpoint URL.
226 api_version: API version to use for requests.
227 azure_deployment: Name of the deployed model.
228 model: Model identifier for logging and tracing.
229 temperature: Sampling temperature for response generation (0.0-2.0).
230 max_tokens: Maximum number of tokens to generate.
231 timeout: Request timeout duration in seconds.
232 max_retries: Maximum number of retry attempts for failed requests.
234 Examples:
235 >>> config = AzureOpenAIConfig(
236 ... api_key="your-api-key",
237 ... azure_endpoint="https://your-resource.openai.azure.com/",
238 ... azure_deployment="gpt-4",
239 ... temperature=0.7
240 ... )
241 >>> config.model
242 'gpt-4'
243 >>> config.temperature
244 0.7
245 """
247 api_key: str = Field(..., description="Azure OpenAI API key")
248 azure_endpoint: str = Field(..., description="Azure OpenAI endpoint URL")
249 api_version: str = Field(default="2024-05-01-preview", description="Azure OpenAI API version")
250 azure_deployment: str = Field(..., description="Azure OpenAI deployment name")
251 model: str = Field(default="gpt-4", description="Model name for tracing")
252 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
253 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
254 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
255 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
257 model_config = {
258 "json_schema_extra": {
259 "example": {
260 "api_key": "your-api-key",
261 "azure_endpoint": "https://your-resource.openai.azure.com/",
262 "api_version": "2024-05-01-preview",
263 "azure_deployment": "gpt-4",
264 "model": "gpt-4",
265 "temperature": 0.7,
266 }
267 }
268 }
271class OllamaConfig(BaseModel):
272 """
273 Configuration for Ollama provider.
275 Defines parameters for connecting to a local or remote Ollama instance
276 for running open-source language models.
278 Attributes:
279 base_url: Ollama server base URL.
280 model: Name of the Ollama model to use.
281 temperature: Sampling temperature for response generation (0.0-2.0).
282 timeout: Request timeout duration in seconds.
283 num_ctx: Context window size for the model.
285 Examples:
286 >>> config = OllamaConfig(
287 ... base_url="http://localhost:11434",
288 ... model="llama2",
289 ... temperature=0.5
290 ... )
291 >>> config.model
292 'llama2'
293 >>> config.base_url
294 'http://localhost:11434'
295 """
297 base_url: str = Field(default="http://localhost:11434", description="Ollama base URL")
298 model: str = Field(default="llama2", description="Model name to use")
299 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
300 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
301 num_ctx: Optional[int] = Field(None, gt=0, description="Context window size")
303 model_config = {"json_schema_extra": {"example": {"base_url": "http://localhost:11434", "model": "llama2", "temperature": 0.7}}}
306class OpenAIConfig(BaseModel):
307 """
308 Configuration for OpenAI provider (non-Azure).
310 Defines parameters for connecting to OpenAI API (or OpenAI-compatible endpoints).
312 Attributes:
313 api_key: OpenAI API authentication key.
314 base_url: Optional base URL for OpenAI-compatible endpoints.
315 model: Model identifier (e.g., gpt-4, gpt-3.5-turbo).
316 temperature: Sampling temperature for response generation (0.0-2.0).
317 max_tokens: Maximum number of tokens to generate.
318 timeout: Request timeout duration in seconds.
319 max_retries: Maximum number of retry attempts for failed requests.
321 Examples:
322 >>> config = OpenAIConfig(
323 ... api_key="sk-...",
324 ... model="gpt-4",
325 ... temperature=0.7
326 ... )
327 >>> config.model
328 'gpt-4'
329 """
331 api_key: str = Field(..., description="OpenAI API key")
332 base_url: Optional[str] = Field(None, description="Base URL for OpenAI-compatible endpoints")
333 model: str = Field(default="gpt-4o-mini", description="Model name (e.g., gpt-4, gpt-3.5-turbo)")
334 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
335 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
336 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
337 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
338 default_headers: Optional[dict] = Field(None, description="optional default headers required by the provider")
340 model_config = {
341 "json_schema_extra": {
342 "example": {
343 "api_key": "sk-...",
344 "model": "gpt-4o-mini",
345 "temperature": 0.7,
346 }
347 }
348 }
351class AnthropicConfig(BaseModel):
352 """
353 Configuration for Anthropic Claude provider.
355 Defines parameters for connecting to Anthropic's Claude API.
357 Attributes:
358 api_key: Anthropic API authentication key.
359 model: Claude model identifier (e.g., claude-3-5-sonnet-20241022, claude-3-opus).
360 temperature: Sampling temperature for response generation (0.0-1.0).
361 max_tokens: Maximum number of tokens to generate.
362 timeout: Request timeout duration in seconds.
363 max_retries: Maximum number of retry attempts for failed requests.
365 Examples:
366 >>> config = AnthropicConfig(
367 ... api_key="sk-ant-...",
368 ... model="claude-3-5-sonnet-20241022",
369 ... temperature=0.7
370 ... )
371 >>> config.model
372 'claude-3-5-sonnet-20241022'
373 """
375 api_key: str = Field(..., description="Anthropic API key")
376 model: str = Field(default="claude-3-5-sonnet-20241022", description="Claude model name")
377 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature")
378 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate")
379 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
380 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
382 model_config = {
383 "json_schema_extra": {
384 "example": {
385 "api_key": "sk-ant-...",
386 "model": "claude-3-5-sonnet-20241022",
387 "temperature": 0.7,
388 "max_tokens": 4096,
389 }
390 }
391 }
394class AWSBedrockConfig(BaseModel):
395 """
396 Configuration for AWS Bedrock provider.
398 Defines parameters for connecting to AWS Bedrock LLM services.
400 Attributes:
401 model_id: Bedrock model identifier (e.g., anthropic.claude-v2, amazon.titan-text-express-v1).
402 region_name: AWS region name (e.g., us-east-1, us-west-2).
403 aws_access_key_id: Optional AWS access key ID (uses default credential chain if not provided).
404 aws_secret_access_key: Optional AWS secret access key.
405 aws_session_token: Optional AWS session token for temporary credentials.
406 temperature: Sampling temperature for response generation (0.0-1.0).
407 max_tokens: Maximum number of tokens to generate.
409 Examples:
410 >>> config = AWSBedrockConfig(
411 ... model_id="anthropic.claude-v2",
412 ... region_name="us-east-1",
413 ... temperature=0.7
414 ... )
415 >>> config.model_id
416 'anthropic.claude-v2'
417 """
419 model_id: str = Field(..., description="Bedrock model ID")
420 region_name: str = Field(default="us-east-1", description="AWS region name")
421 aws_access_key_id: Optional[str] = Field(None, description="AWS access key ID")
422 aws_secret_access_key: Optional[str] = Field(None, description="AWS secret access key")
423 aws_session_token: Optional[str] = Field(None, description="AWS session token")
424 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature")
425 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate")
427 model_config = {
428 "json_schema_extra": {
429 "example": {
430 "model_id": "anthropic.claude-v2",
431 "region_name": "us-east-1",
432 "temperature": 0.7,
433 "max_tokens": 4096,
434 }
435 }
436 }
439class WatsonxConfig(BaseModel):
440 """
441 Configuration for IBM watsonx.ai provider.
443 Defines parameters for connecting to IBM watsonx.ai services.
445 Attributes:
446 api_key: IBM Cloud API key for authentication.
447 url: IBM watsonx.ai service endpoint URL.
448 project_id: IBM watsonx.ai project ID for context.
449 model_id: Model identifier (e.g., ibm/granite-13b-chat-v2, meta-llama/llama-3-70b-instruct).
450 temperature: Sampling temperature for response generation (0.0-2.0).
451 max_new_tokens: Maximum number of tokens to generate.
452 min_new_tokens: Minimum number of tokens to generate.
453 decoding_method: Decoding method ('sample', 'greedy').
454 top_k: Top-K sampling parameter.
455 top_p: Top-P (nucleus) sampling parameter.
456 timeout: Request timeout duration in seconds.
458 Examples:
459 >>> config = WatsonxConfig(
460 ... api_key="your-api-key",
461 ... url="https://us-south.ml.cloud.ibm.com",
462 ... project_id="your-project-id",
463 ... model_id="ibm/granite-13b-chat-v2"
464 ... )
465 >>> config.model_id
466 'ibm/granite-13b-chat-v2'
467 """
469 api_key: str = Field(..., description="IBM Cloud API key")
470 url: str = Field(default="https://us-south.ml.cloud.ibm.com", description="watsonx.ai endpoint URL")
471 project_id: str = Field(..., description="watsonx.ai project ID")
472 model_id: str = Field(default="ibm/granite-13b-chat-v2", description="Model identifier")
473 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
474 max_new_tokens: Optional[int] = Field(default=1024, gt=0, description="Maximum tokens to generate")
475 min_new_tokens: Optional[int] = Field(default=1, gt=0, description="Minimum tokens to generate")
476 decoding_method: str = Field(default="sample", description="Decoding method (sample or greedy)")
477 top_k: Optional[int] = Field(default=50, gt=0, description="Top-K sampling")
478 top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-P sampling")
479 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
481 model_config = {
482 "json_schema_extra": {
483 "example": {
484 "api_key": "your-api-key",
485 "url": "https://us-south.ml.cloud.ibm.com",
486 "project_id": "your-project-id",
487 "model_id": "ibm/granite-13b-chat-v2",
488 "temperature": 0.7,
489 "max_new_tokens": 1024,
490 }
491 }
492 }
495class GatewayConfig(BaseModel):
496 """
497 Configuration for ContextForge internal LLM provider.
499 Allows LLM Chat to use models configured in the gateway's LLM Settings.
500 The gateway routes requests to the appropriate configured provider.
502 Attributes:
503 model: Model ID (gateway model ID or provider model ID).
504 base_url: Gateway internal API URL (defaults to self).
505 temperature: Sampling temperature for response generation.
506 max_tokens: Maximum tokens to generate.
507 timeout: Request timeout in seconds.
509 Examples:
510 >>> config = GatewayConfig(model="gpt-4o")
511 >>> config.model
512 'gpt-4o'
513 """
515 model: str = Field(..., description="Gateway model ID to use")
516 base_url: Optional[str] = Field(None, description="Gateway internal API URL (optional, defaults to self)")
517 temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
518 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
519 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
521 model_config = {
522 "json_schema_extra": {
523 "example": {
524 "model": "gpt-4o",
525 "temperature": 0.7,
526 "max_tokens": 4096,
527 }
528 }
529 }
532class LLMConfig(BaseModel):
533 """
534 Configuration for LLM provider.
536 Unified configuration class that supports multiple LLM providers through
537 a discriminated union pattern.
539 Attributes:
540 provider: Type of LLM provider (azure_openai, openai, anthropic, aws_bedrock, or ollama).
541 config: Provider-specific configuration object.
543 Examples:
544 >>> # Azure OpenAI configuration
545 >>> config = LLMConfig(
546 ... provider="azure_openai",
547 ... config=AzureOpenAIConfig(
548 ... api_key="key",
549 ... azure_endpoint="https://example.com/",
550 ... azure_deployment="gpt-4"
551 ... )
552 ... )
553 >>> config.provider
554 'azure_openai'
556 >>> # OpenAI configuration
557 >>> config = LLMConfig(
558 ... provider="openai",
559 ... config=OpenAIConfig(
560 ... api_key="sk-...",
561 ... model="gpt-4"
562 ... )
563 ... )
564 >>> config.provider
565 'openai'
567 >>> # Ollama configuration
568 >>> config = LLMConfig(
569 ... provider="ollama",
570 ... config=OllamaConfig(model="llama2")
571 ... )
572 >>> config.provider
573 'ollama'
575 >>> # Watsonx configuration
576 >>> config = LLMConfig(
577 ... provider="watsonx",
578 ... config=WatsonxConfig(
579 ... url="https://us-south.ml.cloud.ibm.com",
580 ... model_id="ibm/granite-13b-instruct-v2",
581 ... project_id="YOUR_PROJECT_ID",
582 ... api_key="YOUR_API")
583 ... )
584 >>> config.provider
585 'watsonx'
586 """
588 provider: Literal["azure_openai", "openai", "anthropic", "aws_bedrock", "ollama", "watsonx", "gateway"] = Field(..., description="LLM provider type")
589 config: Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig] = Field(..., description="Provider-specific configuration")
591 @field_validator("config", mode="before")
592 @classmethod
593 def validate_config_type(cls, v: Any, info) -> Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig]:
594 """
595 Validate and convert config dictionary to appropriate provider type.
597 Args:
598 v: Configuration value (dict or config object).
599 info: Validation context containing provider information.
601 Returns:
602 Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig]: Validated configuration object.
604 Examples:
605 >>> # Automatically converts dict to appropriate config type
606 >>> config_dict = {
607 ... "api_key": "key",
608 ... "azure_endpoint": "https://example.com/",
609 ... "azure_deployment": "gpt-4"
610 ... }
611 >>> # Used internally by Pydantic during validation
612 """
613 provider = info.data.get("provider")
615 if isinstance(v, dict):
616 if provider == "azure_openai":
617 return AzureOpenAIConfig(**v)
618 if provider == "openai":
619 return OpenAIConfig(**v)
620 if provider == "anthropic":
621 return AnthropicConfig(**v)
622 if provider == "aws_bedrock":
623 return AWSBedrockConfig(**v)
624 if provider == "ollama":
625 return OllamaConfig(**v)
626 if provider == "watsonx":
627 return WatsonxConfig(**v)
628 if provider == "gateway":
629 return GatewayConfig(**v)
631 return v
634class MCPClientConfig(BaseModel):
635 """
636 Main configuration for MCP client service.
638 Aggregates all configuration parameters required for the complete MCP client
639 service, including server connection, LLM provider, and operational settings.
641 Attributes:
642 mcp_server: MCP server connection configuration.
643 llm: LLM provider configuration.
644 chat_history_max_messages: Maximum messages to retain in chat history.
645 enable_streaming: Whether to enable streaming responses.
647 Examples:
648 >>> config = MCPClientConfig(
649 ... mcp_server=MCPServerConfig(
650 ... url="https://mcp-server.example.com/mcp",
651 ... transport="streamable_http"
652 ... ),
653 ... llm=LLMConfig(
654 ... provider="ollama",
655 ... config=OllamaConfig(model="llama2")
656 ... ),
657 ... chat_history_max_messages=100,
658 ... enable_streaming=True
659 ... )
660 >>> config.chat_history_max_messages
661 100
662 >>> config.enable_streaming
663 True
664 """
666 mcp_server: MCPServerConfig = Field(..., description="MCP server configuration")
667 llm: LLMConfig = Field(..., description="LLM provider configuration")
668 chat_history_max_messages: int = settings.llmchat_chat_history_max_messages
669 enable_streaming: bool = Field(default=True, description="Enable streaming responses")
671 model_config = {
672 "json_schema_extra": {
673 "example": {
674 "mcp_server": {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token"}, # nosec B105 - example placeholder
675 "llm": {
676 "provider": "azure_openai",
677 "config": {"api_key": "your-key", "azure_endpoint": "https://your-resource.openai.azure.com/", "azure_deployment": "gpt-4", "api_version": "2024-05-01-preview"},
678 },
679 }
680 }
681 }
684# ==================== LLM PROVIDER IMPLEMENTATIONS ====================
687class AzureOpenAIProvider:
688 """
689 Azure OpenAI provider implementation.
691 Manages connection and interaction with Azure OpenAI services.
693 Attributes:
694 config: Azure OpenAI configuration object.
696 Examples:
697 >>> config = AzureOpenAIConfig(
698 ... api_key="key",
699 ... azure_endpoint="https://example.openai.azure.com/",
700 ... azure_deployment="gpt-4"
701 ... )
702 >>> provider = AzureOpenAIProvider(config)
703 >>> provider.get_model_name()
704 'gpt-4'
706 Note:
707 The LLM instance is lazily initialized on first access for
708 improved startup performance.
709 """
711 def __init__(self, config: AzureOpenAIConfig):
712 """
713 Initialize Azure OpenAI provider.
715 Args:
716 config: Azure OpenAI configuration with API credentials and settings.
718 Examples:
719 >>> config = AzureOpenAIConfig(
720 ... api_key="key",
721 ... azure_endpoint="https://example.openai.azure.com/",
722 ... azure_deployment="gpt-4"
723 ... )
724 >>> provider = AzureOpenAIProvider(config)
725 """
726 self.config = config
727 self._llm = None
728 logger.info(f"Initializing Azure OpenAI provider with deployment: {config.azure_deployment}")
730 def get_llm(self, model_type: str = "chat") -> Union[AzureChatOpenAI, AzureOpenAI]:
731 """
732 Get Azure OpenAI LLM instance with lazy initialization.
734 Creates and caches the Azure OpenAI chat model instance on first call.
735 Subsequent calls return the cached instance.
737 Args:
738 model_type: LLM inference model type such as 'chat' model , text 'completion' model
740 Returns:
741 AzureChatOpenAI: Configured Azure OpenAI chat model.
743 Raises:
744 Exception: If LLM initialization fails (e.g., invalid credentials).
746 Examples:
747 >>> config = AzureOpenAIConfig(
748 ... api_key="key",
749 ... azure_endpoint="https://example.openai.azure.com/",
750 ... azure_deployment="gpt-4"
751 ... )
752 >>> provider = AzureOpenAIProvider(config)
753 >>> # llm = provider.get_llm() # Returns AzureChatOpenAI instance
754 """
755 if self._llm is None:
756 try:
757 if model_type == "chat":
758 self._llm = AzureChatOpenAI(
759 api_key=self.config.api_key,
760 azure_endpoint=self.config.azure_endpoint,
761 api_version=self.config.api_version,
762 azure_deployment=self.config.azure_deployment,
763 model=self.config.model,
764 temperature=self.config.temperature,
765 max_tokens=self.config.max_tokens,
766 timeout=self.config.timeout,
767 max_retries=self.config.max_retries,
768 )
769 elif model_type == "completion":
770 self._llm = AzureOpenAI(
771 api_key=self.config.api_key,
772 azure_endpoint=self.config.azure_endpoint,
773 api_version=self.config.api_version,
774 azure_deployment=self.config.azure_deployment,
775 model=self.config.model,
776 temperature=self.config.temperature,
777 max_tokens=self.config.max_tokens,
778 timeout=self.config.timeout,
779 max_retries=self.config.max_retries,
780 )
781 logger.info("Azure OpenAI LLM instance created successfully")
782 except Exception as e:
783 logger.error(f"Failed to create Azure OpenAI LLM: {e}")
784 raise
786 return self._llm
788 def get_model_name(self) -> str:
789 """
790 Get the Azure OpenAI model name.
792 Returns:
793 str: The model name configured for this provider.
795 Examples:
796 >>> config = AzureOpenAIConfig(
797 ... api_key="key",
798 ... azure_endpoint="https://example.openai.azure.com/",
799 ... azure_deployment="gpt-4",
800 ... model="gpt-4"
801 ... )
802 >>> provider = AzureOpenAIProvider(config)
803 >>> provider.get_model_name()
804 'gpt-4'
805 """
806 return self.config.model
809class OllamaProvider:
810 """
811 Ollama provider implementation.
813 Manages connection and interaction with Ollama instances for running
814 open-source language models locally or remotely.
816 Attributes:
817 config: Ollama configuration object.
819 Examples:
820 >>> config = OllamaConfig(
821 ... base_url="http://localhost:11434",
822 ... model="llama2"
823 ... )
824 >>> provider = OllamaProvider(config)
825 >>> provider.get_model_name()
826 'llama2'
828 Note:
829 Requires Ollama to be running and accessible at the configured base_url.
830 """
832 def __init__(self, config: OllamaConfig):
833 """
834 Initialize Ollama provider.
836 Args:
837 config: Ollama configuration with server URL and model settings.
839 Examples:
840 >>> config = OllamaConfig(model="llama2")
841 >>> provider = OllamaProvider(config)
842 """
843 self.config = config
844 self._llm = None
845 logger.info(f"Initializing Ollama provider with model: {config.model}")
847 def get_llm(self, model_type: str = "chat") -> Union[ChatOllama, OllamaLLM]:
848 """
849 Get Ollama LLM instance with lazy initialization.
851 Creates and caches the Ollama chat model instance on first call.
852 Subsequent calls return the cached instance.
854 Args:
855 model_type: LLM inference model type such as 'chat' model , text 'completion' model
857 Returns:
858 ChatOllama: Configured Ollama chat model.
860 Raises:
861 Exception: If LLM initialization fails (e.g., Ollama not running).
863 Examples:
864 >>> config = OllamaConfig(model="llama2")
865 >>> provider = OllamaProvider(config)
866 >>> # llm = provider.get_llm() # Returns ChatOllama instance
867 """
868 if self._llm is None:
869 try:
870 # Build model kwargs
871 model_kwargs = {}
872 if self.config.num_ctx is not None:
873 model_kwargs["num_ctx"] = self.config.num_ctx
875 if model_type == "chat":
876 self._llm = ChatOllama(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs)
877 elif model_type == "completion":
878 self._llm = OllamaLLM(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs)
879 logger.info("Ollama LLM instance created successfully")
880 except Exception as e:
881 logger.error(f"Failed to create Ollama LLM: {e}")
882 raise
884 return self._llm
886 def get_model_name(self) -> str:
887 """Get the model name.
889 Returns:
890 str: The model name
891 """
892 return self.config.model
895class OpenAIProvider:
896 """
897 OpenAI provider implementation (non-Azure).
899 Manages connection and interaction with OpenAI API or OpenAI-compatible endpoints.
901 Attributes:
902 config: OpenAI configuration object.
904 Examples:
905 >>> config = OpenAIConfig(
906 ... api_key="sk-...",
907 ... model="gpt-4"
908 ... )
909 >>> provider = OpenAIProvider(config)
910 >>> provider.get_model_name()
911 'gpt-4'
913 Note:
914 The LLM instance is lazily initialized on first access for
915 improved startup performance.
916 """
918 def __init__(self, config: OpenAIConfig):
919 """
920 Initialize OpenAI provider.
922 Args:
923 config: OpenAI configuration with API key and settings.
925 Examples:
926 >>> config = OpenAIConfig(
927 ... api_key="sk-...",
928 ... model="gpt-4"
929 ... )
930 >>> provider = OpenAIProvider(config)
931 """
932 self.config = config
933 self._llm = None
934 logger.info(f"Initializing OpenAI provider with model: {config.model}")
936 def get_llm(self, model_type="chat") -> Union[ChatOpenAI, OpenAI]:
937 """
938 Get OpenAI LLM instance with lazy initialization.
940 Creates and caches the OpenAI chat model instance on first call.
941 Subsequent calls return the cached instance.
943 Args:
944 model_type: LLM inference model type such as 'chat' model , text 'completion' model
946 Returns:
947 ChatOpenAI: Configured OpenAI chat model.
949 Raises:
950 Exception: If LLM initialization fails (e.g., invalid credentials).
952 Examples:
953 >>> config = OpenAIConfig(
954 ... api_key="sk-...",
955 ... model="gpt-4"
956 ... )
957 >>> provider = OpenAIProvider(config)
958 >>> # llm = provider.get_llm() # Returns ChatOpenAI instance
959 """
960 if self._llm is None:
961 try:
962 kwargs = {
963 "api_key": self.config.api_key,
964 "model": self.config.model,
965 "temperature": self.config.temperature,
966 "max_tokens": self.config.max_tokens,
967 "timeout": self.config.timeout,
968 "max_retries": self.config.max_retries,
969 }
971 if self.config.base_url:
972 kwargs["base_url"] = self.config.base_url
974 # add default headers if present
975 if self.config.default_headers is not None:
976 kwargs["default_headers"] = self.config.default_headers
978 if model_type == "chat":
979 self._llm = ChatOpenAI(**kwargs)
980 elif model_type == "completion":
981 self._llm = OpenAI(**kwargs)
983 logger.info("OpenAI LLM instance created successfully")
984 except Exception as e:
985 logger.error(f"Failed to create OpenAI LLM: {e}")
986 raise
988 return self._llm
990 def get_model_name(self) -> str:
991 """
992 Get the OpenAI model name.
994 Returns:
995 str: The model name configured for this provider.
997 Examples:
998 >>> config = OpenAIConfig(
999 ... api_key="sk-...",
1000 ... model="gpt-4"
1001 ... )
1002 >>> provider = OpenAIProvider(config)
1003 >>> provider.get_model_name()
1004 'gpt-4'
1005 """
1006 return self.config.model
1009class AnthropicProvider:
1010 """
1011 Anthropic Claude provider implementation.
1013 Manages connection and interaction with Anthropic's Claude API.
1015 Attributes:
1016 config: Anthropic configuration object.
1018 Examples:
1019 >>> config = AnthropicConfig( # doctest: +SKIP
1020 ... api_key="sk-ant-...",
1021 ... model="claude-3-5-sonnet-20241022"
1022 ... )
1023 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1024 >>> provider.get_model_name() # doctest: +SKIP
1025 'claude-3-5-sonnet-20241022'
1027 Note:
1028 Requires langchain-anthropic package to be installed.
1029 """
1031 def __init__(self, config: AnthropicConfig):
1032 """
1033 Initialize Anthropic provider.
1035 Args:
1036 config: Anthropic configuration with API key and settings.
1038 Raises:
1039 ImportError: If langchain-anthropic is not installed.
1041 Examples:
1042 >>> config = AnthropicConfig( # doctest: +SKIP
1043 ... api_key="sk-ant-...",
1044 ... model="claude-3-5-sonnet-20241022"
1045 ... )
1046 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1047 """
1048 if not _ANTHROPIC_AVAILABLE:
1049 raise ImportError("Anthropic provider requires langchain-anthropic package. Install it with: pip install langchain-anthropic")
1051 self.config = config
1052 self._llm = None
1053 logger.info(f"Initializing Anthropic provider with model: {config.model}")
1055 def get_llm(self, model_type: str = "chat") -> Union[ChatAnthropic, AnthropicLLM]:
1056 """
1057 Get Anthropic LLM instance with lazy initialization.
1059 Creates and caches the Anthropic chat model instance on first call.
1060 Subsequent calls return the cached instance.
1062 Args:
1063 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1065 Returns:
1066 ChatAnthropic: Configured Anthropic chat model.
1068 Raises:
1069 Exception: If LLM initialization fails (e.g., invalid credentials).
1071 Examples:
1072 >>> config = AnthropicConfig( # doctest: +SKIP
1073 ... api_key="sk-ant-...",
1074 ... model="claude-3-5-sonnet-20241022"
1075 ... )
1076 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1077 >>> # llm = provider.get_llm() # Returns ChatAnthropic instance
1078 """
1079 if self._llm is None:
1080 try:
1081 if model_type == "chat":
1082 self._llm = ChatAnthropic(
1083 api_key=self.config.api_key,
1084 model=self.config.model,
1085 temperature=self.config.temperature,
1086 max_tokens=self.config.max_tokens,
1087 timeout=self.config.timeout,
1088 max_retries=self.config.max_retries,
1089 )
1090 elif model_type == "completion":
1091 self._llm = AnthropicLLM(
1092 api_key=self.config.api_key,
1093 model=self.config.model,
1094 temperature=self.config.temperature,
1095 max_tokens=self.config.max_tokens,
1096 timeout=self.config.timeout,
1097 max_retries=self.config.max_retries,
1098 )
1099 logger.info("Anthropic LLM instance created successfully")
1100 except Exception as e:
1101 logger.error(f"Failed to create Anthropic LLM: {e}")
1102 raise
1104 return self._llm
1106 def get_model_name(self) -> str:
1107 """
1108 Get the Anthropic model name.
1110 Returns:
1111 str: The model name configured for this provider.
1113 Examples:
1114 >>> config = AnthropicConfig( # doctest: +SKIP
1115 ... api_key="sk-ant-...",
1116 ... model="claude-3-5-sonnet-20241022"
1117 ... )
1118 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1119 >>> provider.get_model_name() # doctest: +SKIP
1120 'claude-3-5-sonnet-20241022'
1121 """
1122 return self.config.model
1125class AWSBedrockProvider:
1126 """
1127 AWS Bedrock provider implementation.
1129 Manages connection and interaction with AWS Bedrock LLM services.
1131 Attributes:
1132 config: AWS Bedrock configuration object.
1134 Examples:
1135 >>> config = AWSBedrockConfig( # doctest: +SKIP
1136 ... model_id="anthropic.claude-v2",
1137 ... region_name="us-east-1"
1138 ... )
1139 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1140 >>> provider.get_model_name() # doctest: +SKIP
1141 'anthropic.claude-v2'
1143 Note:
1144 Requires langchain-aws package and boto3 to be installed.
1145 Uses AWS default credential chain if credentials not explicitly provided.
1146 """
1148 def __init__(self, config: AWSBedrockConfig):
1149 """
1150 Initialize AWS Bedrock provider.
1152 Args:
1153 config: AWS Bedrock configuration with model ID and settings.
1155 Raises:
1156 ImportError: If langchain-aws is not installed.
1158 Examples:
1159 >>> config = AWSBedrockConfig( # doctest: +SKIP
1160 ... model_id="anthropic.claude-v2",
1161 ... region_name="us-east-1"
1162 ... )
1163 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1164 """
1165 if not _BEDROCK_AVAILABLE:
1166 raise ImportError("AWS Bedrock provider requires langchain-aws package. Install it with: pip install langchain-aws boto3")
1168 self.config = config
1169 self._llm = None
1170 logger.info(f"Initializing AWS Bedrock provider with model: {config.model_id}")
1172 def get_llm(self, model_type: str = "chat") -> Union[ChatBedrock, BedrockLLM]:
1173 """
1174 Get AWS Bedrock LLM instance with lazy initialization.
1176 Creates and caches the Bedrock chat model instance on first call.
1177 Subsequent calls return the cached instance.
1179 Args:
1180 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1182 Returns:
1183 ChatBedrock: Configured AWS Bedrock chat model.
1185 Raises:
1186 Exception: If LLM initialization fails (e.g., invalid credentials, permissions).
1188 Examples:
1189 >>> config = AWSBedrockConfig( # doctest: +SKIP
1190 ... model_id="anthropic.claude-v2",
1191 ... region_name="us-east-1"
1192 ... )
1193 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1194 >>> # llm = provider.get_llm() # Returns ChatBedrock instance
1195 """
1196 if self._llm is None:
1197 try:
1198 # Build credentials dict if provided
1199 credentials_kwargs = {}
1200 if self.config.aws_access_key_id:
1201 credentials_kwargs["aws_access_key_id"] = self.config.aws_access_key_id
1202 if self.config.aws_secret_access_key:
1203 credentials_kwargs["aws_secret_access_key"] = self.config.aws_secret_access_key
1204 if self.config.aws_session_token:
1205 credentials_kwargs["aws_session_token"] = self.config.aws_session_token
1207 if model_type == "chat":
1208 self._llm = ChatBedrock(
1209 model_id=self.config.model_id,
1210 region_name=self.config.region_name,
1211 model_kwargs={
1212 "temperature": self.config.temperature,
1213 "max_tokens": self.config.max_tokens,
1214 },
1215 **credentials_kwargs,
1216 )
1217 elif model_type == "completion":
1218 self._llm = BedrockLLM(
1219 model_id=self.config.model_id,
1220 region_name=self.config.region_name,
1221 model_kwargs={
1222 "temperature": self.config.temperature,
1223 "max_tokens": self.config.max_tokens,
1224 },
1225 **credentials_kwargs,
1226 )
1227 logger.info("AWS Bedrock LLM instance created successfully")
1228 except Exception as e:
1229 logger.error(f"Failed to create AWS Bedrock LLM: {e}")
1230 raise
1232 return self._llm
1234 def get_model_name(self) -> str:
1235 """
1236 Get the AWS Bedrock model ID.
1238 Returns:
1239 str: The model ID configured for this provider.
1241 Examples:
1242 >>> config = AWSBedrockConfig( # doctest: +SKIP
1243 ... model_id="anthropic.claude-v2",
1244 ... region_name="us-east-1"
1245 ... )
1246 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1247 >>> provider.get_model_name() # doctest: +SKIP
1248 'anthropic.claude-v2'
1249 """
1250 return self.config.model_id
1253class WatsonxProvider:
1254 """
1255 IBM watsonx.ai provider implementation.
1257 Manages connection and interaction with IBM watsonx.ai services.
1259 Attributes:
1260 config: IBM watsonx.ai configuration object.
1262 Examples:
1263 >>> config = WatsonxConfig( # doctest: +SKIP
1264 ... api_key="key",
1265 ... url="https://us-south.ml.cloud.ibm.com",
1266 ... project_id="project-id",
1267 ... model_id="ibm/granite-13b-chat-v2"
1268 ... )
1269 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1270 >>> provider.get_model_name() # doctest: +SKIP
1271 'ibm/granite-13b-chat-v2'
1273 Note:
1274 Requires langchain-ibm package to be installed.
1275 """
1277 def __init__(self, config: WatsonxConfig):
1278 """
1279 Initialize IBM watsonx.ai provider.
1281 Args:
1282 config: IBM watsonx.ai configuration with credentials and settings.
1284 Raises:
1285 ImportError: If langchain-ibm is not installed.
1287 Examples:
1288 >>> config = WatsonxConfig( # doctest: +SKIP
1289 ... api_key="key",
1290 ... url="https://us-south.ml.cloud.ibm.com",
1291 ... project_id="project-id",
1292 ... model_id="ibm/granite-13b-chat-v2"
1293 ... )
1294 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1295 """
1296 if not _WATSONX_AVAILABLE:
1297 raise ImportError("IBM watsonx.ai provider requires langchain-ibm package. Install it with: pip install langchain-ibm")
1298 self.config = config
1299 self.llm = None
1300 logger.info(f"Initializing IBM watsonx.ai provider with model {config.model_id}")
1302 def get_llm(self, model_type="chat") -> Union[WatsonxLLM, ChatWatsonx]:
1303 """
1304 Get IBM watsonx.ai LLM instance with lazy initialization.
1306 Creates and caches the watsonx LLM instance on first call.
1307 Subsequent calls return the cached instance.
1309 Args:
1310 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1312 Returns:
1313 WatsonxLLM: Configured IBM watsonx.ai LLM model.
1315 Raises:
1316 Exception: If LLM initialization fails (e.g., invalid credentials).
1318 Examples:
1319 >>> config = WatsonxConfig( # doctest: +SKIP
1320 ... api_key="key",
1321 ... url="https://us-south.ml.cloud.ibm.com",
1322 ... project_id="project-id",
1323 ... model_id="ibm/granite-13b-chat-v2"
1324 ... )
1325 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1326 >>> #llm = provider.get_llm() # Returns WatsonxLLM instance
1327 """
1328 if self.llm is None:
1329 try:
1330 # Build parameters dict
1331 params = {
1332 "decoding_method": self.config.decoding_method,
1333 "temperature": self.config.temperature,
1334 "max_new_tokens": self.config.max_new_tokens,
1335 "min_new_tokens": self.config.min_new_tokens,
1336 }
1338 if self.config.top_k is not None:
1339 params["top_k"] = self.config.top_k
1340 if self.config.top_p is not None:
1341 params["top_p"] = self.config.top_p
1342 if model_type == "completion":
1343 # Initialize WatsonxLLM
1344 self.llm = WatsonxLLM(
1345 apikey=self.config.api_key,
1346 url=self.config.url,
1347 project_id=self.config.project_id,
1348 model_id=self.config.model_id,
1349 params=params,
1350 )
1351 elif model_type == "chat":
1352 # Initialize Chat WatsonxLLM
1353 self.llm = ChatWatsonx(
1354 apikey=self.config.api_key,
1355 url=self.config.url,
1356 project_id=self.config.project_id,
1357 model_id=self.config.model_id,
1358 params=params,
1359 )
1360 logger.info("IBM watsonx.ai LLM instance created successfully")
1361 except Exception as e:
1362 logger.error(f"Failed to create IBM watsonx.ai LLM: {e}")
1363 raise
1364 return self.llm
1366 def get_model_name(self) -> str:
1367 """
1368 Get the IBM watsonx.ai model ID.
1370 Returns:
1371 str: The model ID configured for this provider.
1373 Examples:
1374 >>> config = WatsonxConfig( # doctest: +SKIP
1375 ... api_key="key",
1376 ... url="https://us-south.ml.cloud.ibm.com",
1377 ... project_id="project-id",
1378 ... model_id="ibm/granite-13b-chat-v2"
1379 ... )
1380 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1381 >>> provider.get_model_name() # doctest: +SKIP
1382 'ibm/granite-13b-chat-v2'
1383 """
1384 return self.config.model_id
1387class GatewayProvider:
1388 """
1389 Gateway provider implementation for using models configured in LLM Settings.
1391 Routes LLM requests through the gateway's configured providers, allowing
1392 users to use models set up via the Admin UI's LLM Settings without needing
1393 to configure credentials in environment variables or API requests.
1395 Attributes:
1396 config: Gateway configuration with model ID.
1397 llm: Lazily initialized LLM instance.
1399 Examples:
1400 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1401 >>> provider = GatewayProvider(config) # doctest: +SKIP
1402 >>> provider.get_model_name() # doctest: +SKIP
1403 'gpt-4o'
1405 Note:
1406 Requires models to be configured via Admin UI -> Settings -> LLM Settings.
1407 """
1409 def __init__(self, config: GatewayConfig):
1410 """
1411 Initialize Gateway provider.
1413 Args:
1414 config: Gateway configuration with model ID and optional settings.
1416 Examples:
1417 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1418 >>> provider = GatewayProvider(config) # doctest: +SKIP
1419 """
1420 self.config = config
1421 self.llm = None
1422 self._model_name: Optional[str] = None
1423 self._underlying_provider = None
1424 logger.info(f"Initializing Gateway provider with model: {config.model}")
1426 def get_llm(self, model_type: str = "chat") -> Union[BaseChatModel, Any]:
1427 """
1428 Get LLM instance by looking up model from gateway's LLM Settings.
1430 Fetches the model configuration from the database, decrypts API keys,
1431 and creates the appropriate LangChain LLM instance based on provider type.
1433 Args:
1434 model_type: Type of model to return ('chat' or 'completion'). Defaults to 'chat'.
1436 Returns:
1437 Union[BaseChatModel, Any]: Configured LangChain chat or completion model instance.
1439 Raises:
1440 ValueError: If model not found or provider not enabled.
1441 ImportError: If required provider package not installed.
1443 Examples:
1444 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1445 >>> provider = GatewayProvider(config) # doctest: +SKIP
1446 >>> llm = provider.get_llm() # doctest: +SKIP
1448 Note:
1449 The LLM instance is lazily initialized and cached by model_type.
1450 """
1451 if self.llm is not None:
1452 return self.llm
1454 # Import here to avoid circular imports
1455 # First-Party
1456 from mcpgateway.db import LLMModel, LLMProvider, SessionLocal # pylint: disable=import-outside-toplevel
1457 from mcpgateway.services.llm_provider_service import decrypt_provider_config_for_runtime # pylint: disable=import-outside-toplevel
1458 from mcpgateway.utils.services_auth import decode_auth # pylint: disable=import-outside-toplevel
1460 model_id = self.config.model
1462 with SessionLocal() as db:
1463 # Try to find model by UUID first, then by model_id
1464 model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
1465 if not model:
1466 model = db.query(LLMModel).filter(LLMModel.model_id == model_id).first()
1468 if not model:
1469 raise ValueError(f"Model '{model_id}' not found in LLM Settings. Configure it via Admin UI -> Settings -> LLM Settings.")
1471 if not model.enabled:
1472 raise ValueError(f"Model '{model.model_id}' is disabled. Enable it in LLM Settings.")
1474 # Get the provider
1475 provider = db.query(LLMProvider).filter(LLMProvider.id == model.provider_id).first()
1476 if not provider:
1477 raise ValueError(f"Provider not found for model '{model.model_id}'")
1479 if not provider.enabled:
1480 raise ValueError(f"Provider '{provider.name}' is disabled. Enable it in LLM Settings.")
1482 # Get decrypted API key
1483 api_key = None
1484 if provider.api_key:
1485 auth_data = decode_auth(provider.api_key)
1486 if isinstance(auth_data, dict):
1487 api_key = auth_data.get("api_key")
1488 else:
1489 api_key = auth_data
1491 # Store model name for get_model_name()
1492 self._model_name = model.model_id
1494 # Get temperature - use config override or provider default
1495 temperature = self.config.temperature if self.config.temperature is not None else (provider.default_temperature or 0.7)
1496 max_tokens = self.config.max_tokens or model.max_output_tokens
1498 # Create appropriate LLM based on provider type
1499 provider_type = provider.provider_type.lower()
1500 config = decrypt_provider_config_for_runtime(provider.config)
1502 # Common kwargs
1503 kwargs: Dict[str, Any] = {
1504 "temperature": temperature,
1505 "timeout": self.config.timeout,
1506 }
1508 if provider_type == "openai":
1509 kwargs.update(
1510 {
1511 "api_key": api_key,
1512 "model": model.model_id,
1513 "max_tokens": max_tokens,
1514 }
1515 )
1516 if provider.api_base:
1517 kwargs["base_url"] = provider.api_base
1519 # Handle default headers
1520 if config.get("default_headers"):
1521 kwargs["default_headers"] = config["default_headers"]
1522 elif hasattr(self.config, "default_headers") and self.config.default_headers: # type: ignore
1523 kwargs["default_headers"] = self.config.default_headers
1525 if model_type == "chat":
1526 self.llm = ChatOpenAI(**kwargs)
1527 else:
1528 self.llm = OpenAI(**kwargs)
1530 elif provider_type == "azure_openai":
1531 if not provider.api_base:
1532 raise ValueError("Azure OpenAI requires base_url (azure_endpoint) to be configured")
1534 azure_deployment = config.get("azure_deployment", model.model_id)
1535 api_version = config.get("api_version", "2024-05-01-preview")
1536 max_retries = config.get("max_retries", 2)
1538 kwargs.update(
1539 {
1540 "api_key": api_key,
1541 "azure_endpoint": provider.api_base,
1542 "azure_deployment": azure_deployment,
1543 "api_version": api_version,
1544 "model": model.model_id,
1545 "max_tokens": int(max_tokens) if max_tokens is not None else None,
1546 "max_retries": max_retries,
1547 }
1548 )
1550 if model_type == "chat":
1551 self.llm = AzureChatOpenAI(**kwargs)
1552 else:
1553 self.llm = AzureOpenAI(**kwargs)
1555 elif provider_type == "anthropic":
1556 if not _ANTHROPIC_AVAILABLE:
1557 raise ImportError("Anthropic provider requires langchain-anthropic. Install with: pip install langchain-anthropic")
1559 # Anthropic uses 'model_name' instead of 'model'
1560 anthropic_kwargs = {
1561 "api_key": api_key,
1562 "model_name": model.model_id,
1563 "max_tokens": max_tokens or 4096,
1564 "temperature": temperature,
1565 "timeout": self.config.timeout,
1566 "default_request_timeout": self.config.timeout,
1567 }
1569 if model_type == "chat":
1570 self.llm = ChatAnthropic(**anthropic_kwargs)
1571 else:
1572 # Generic Anthropic completion model if needed, though mostly chat used now
1573 if AnthropicLLM:
1574 llm_kwargs = anthropic_kwargs.copy()
1575 llm_kwargs["model"] = llm_kwargs.pop("model_name")
1576 self.llm = AnthropicLLM(**llm_kwargs)
1577 else:
1578 raise ImportError("Anthropic completion model (AnthropicLLM) not available")
1580 elif provider_type == "bedrock":
1581 if not _BEDROCK_AVAILABLE:
1582 raise ImportError("AWS Bedrock provider requires langchain-aws. Install with: pip install langchain-aws boto3")
1584 region_name = config.get("region_name", "us-east-1")
1585 credentials_kwargs = {}
1586 if config.get("aws_access_key_id"):
1587 credentials_kwargs["aws_access_key_id"] = config["aws_access_key_id"]
1588 if config.get("aws_secret_access_key"):
1589 credentials_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"]
1590 if config.get("aws_session_token"):
1591 credentials_kwargs["aws_session_token"] = config["aws_session_token"]
1593 model_kwargs = {
1594 "temperature": temperature,
1595 "max_tokens": max_tokens or 4096,
1596 }
1598 if model_type == "chat":
1599 self.llm = ChatBedrock(
1600 model_id=model.model_id,
1601 region_name=region_name,
1602 model_kwargs=model_kwargs,
1603 **credentials_kwargs,
1604 )
1605 else:
1606 self.llm = BedrockLLM(
1607 model_id=model.model_id,
1608 region_name=region_name,
1609 model_kwargs=model_kwargs,
1610 **credentials_kwargs,
1611 )
1613 elif provider_type == "ollama":
1614 base_url = provider.api_base or "http://localhost:11434"
1615 num_ctx = config.get("num_ctx")
1617 # Explicitly construct kwargs to avoid generic unpacking issues with Pydantic models
1618 ollama_kwargs = {
1619 "base_url": base_url,
1620 "model": model.model_id,
1621 "temperature": temperature,
1622 "timeout": self.config.timeout,
1623 }
1624 if num_ctx:
1625 ollama_kwargs["num_ctx"] = num_ctx
1627 if model_type == "chat":
1628 self.llm = ChatOllama(**ollama_kwargs)
1629 else:
1630 self.llm = OllamaLLM(**ollama_kwargs)
1632 elif provider_type == "watsonx":
1633 if not _WATSONX_AVAILABLE:
1634 raise ImportError("IBM watsonx.ai provider requires langchain-ibm. Install with: pip install langchain-ibm")
1636 project_id = config.get("project_id")
1637 if not project_id:
1638 raise ValueError("IBM watsonx.ai requires project_id in config")
1640 url = provider.api_base or "https://us-south.ml.cloud.ibm.com"
1642 params = {
1643 "temperature": temperature,
1644 "max_new_tokens": max_tokens or 1024,
1645 "min_new_tokens": config.get("min_new_tokens", 1),
1646 "decoding_method": config.get("decoding_method", "sample"),
1647 "top_k": config.get("top_k", 50),
1648 "top_p": config.get("top_p", 1.0),
1649 }
1651 if model_type == "chat":
1652 self.llm = ChatWatsonx(
1653 apikey=api_key,
1654 url=url,
1655 project_id=project_id,
1656 model_id=model.model_id,
1657 params=params,
1658 )
1659 else:
1660 self.llm = WatsonxLLM(
1661 apikey=api_key,
1662 url=url,
1663 project_id=project_id,
1664 model_id=model.model_id,
1665 params=params,
1666 )
1668 elif provider_type == "openai_compatible":
1669 if not provider.api_base:
1670 raise ValueError("OpenAI-compatible provider requires base_url to be configured")
1672 kwargs.update(
1673 {
1674 "api_key": api_key or "no-key-required",
1675 "model": model.model_id,
1676 "base_url": provider.api_base,
1677 "max_tokens": max_tokens,
1678 }
1679 )
1681 if model_type == "chat":
1682 self.llm = ChatOpenAI(**kwargs)
1683 else:
1684 self.llm = OpenAI(**kwargs)
1686 else:
1687 raise ValueError(f"Unsupported LLM provider: {provider_type}")
1689 logger.info(f"Gateway provider created LLM instance for model: {model.model_id} via {provider_type}")
1690 return self.llm
1692 def get_model_name(self) -> str:
1693 """
1694 Get the model name.
1696 Returns:
1697 str: The model name/ID.
1699 Examples:
1700 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1701 >>> provider = GatewayProvider(config) # doctest: +SKIP
1702 >>> provider.get_model_name() # doctest: +SKIP
1703 'gpt-4o'
1704 """
1705 return self._model_name or self.config.model
1708class LLMProviderFactory:
1709 """
1710 Factory for creating LLM providers.
1712 Implements the Factory pattern to instantiate the appropriate LLM provider
1713 based on configuration, abstracting away provider-specific initialization.
1715 Examples:
1716 >>> config = LLMConfig(
1717 ... provider="ollama",
1718 ... config=OllamaConfig(model="llama2")
1719 ... )
1720 >>> provider = LLMProviderFactory.create(config)
1721 >>> provider.get_model_name()
1722 'llama2'
1724 Note:
1725 This factory supports dynamic provider registration and ensures
1726 type safety through the LLMConfig discriminated union.
1727 """
1729 @staticmethod
1730 def create(llm_config: LLMConfig) -> Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]:
1731 """
1732 Create an LLM provider based on configuration.
1734 Args:
1735 llm_config: LLM configuration specifying provider type and settings.
1737 Returns:
1738 Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: Instantiated provider.
1740 Raises:
1741 ValueError: If provider type is not supported.
1742 ImportError: If required provider package is not installed.
1744 Examples:
1745 >>> # Create Azure OpenAI provider
1746 >>> config = LLMConfig(
1747 ... provider="azure_openai",
1748 ... config=AzureOpenAIConfig(
1749 ... api_key="key",
1750 ... azure_endpoint="https://example.com/",
1751 ... azure_deployment="gpt-4"
1752 ... )
1753 ... )
1754 >>> provider = LLMProviderFactory.create(config)
1755 >>> isinstance(provider, AzureOpenAIProvider)
1756 True
1758 >>> # Create OpenAI provider
1759 >>> config = LLMConfig(
1760 ... provider="openai",
1761 ... config=OpenAIConfig(
1762 ... api_key="sk-...",
1763 ... model="gpt-4"
1764 ... )
1765 ... )
1766 >>> provider = LLMProviderFactory.create(config)
1767 >>> isinstance(provider, OpenAIProvider)
1768 True
1770 >>> # Create Ollama provider
1771 >>> config = LLMConfig(
1772 ... provider="ollama",
1773 ... config=OllamaConfig(model="llama2")
1774 ... )
1775 >>> provider = LLMProviderFactory.create(config)
1776 >>> isinstance(provider, OllamaProvider)
1777 True
1778 """
1779 provider_map = {
1780 "azure_openai": AzureOpenAIProvider,
1781 "openai": OpenAIProvider,
1782 "anthropic": AnthropicProvider,
1783 "aws_bedrock": AWSBedrockProvider,
1784 "ollama": OllamaProvider,
1785 "watsonx": WatsonxProvider,
1786 "gateway": GatewayProvider,
1787 }
1789 provider_class = provider_map.get(llm_config.provider)
1791 if not provider_class:
1792 raise ValueError(f"Unsupported LLM provider: {llm_config.provider}. Supported providers: {list(provider_map.keys())}")
1794 logger.info(f"Creating LLM provider: {llm_config.provider}")
1795 return provider_class(llm_config.config)
1798# ==================== CHAT HISTORY MANAGER ====================
1801class ChatHistoryManager:
1802 """
1803 Centralized chat history management with Redis and in-memory fallback.
1805 Provides a unified interface for storing and retrieving chat histories across
1806 multiple workers using Redis, with automatic fallback to in-memory storage
1807 when Redis is not available.
1809 This class eliminates duplication between router and service layers by
1810 providing a single source of truth for all chat history operations.
1812 Attributes:
1813 redis_client: Optional Redis async client for distributed storage.
1814 max_messages: Maximum number of messages to retain per user.
1815 ttl: Time-to-live for Redis entries in seconds.
1816 _memory_store: In-memory dict fallback when Redis unavailable.
1818 Examples:
1819 >>> import asyncio
1820 >>> # Create manager without Redis (in-memory mode)
1821 >>> manager = ChatHistoryManager(redis_client=None, max_messages=50)
1822 >>> # asyncio.run(manager.save_history("user123", [{"role": "user", "content": "Hello"}]))
1823 >>> # history = asyncio.run(manager.get_history("user123"))
1824 >>> # len(history) >= 0
1825 True
1827 Note:
1828 Thread-safe for Redis operations. In-memory mode suitable for
1829 single-worker deployments only.
1830 """
1832 def __init__(self, redis_client: Optional[Any] = None, max_messages: int = 50, ttl: int = 3600):
1833 """
1834 Initialize chat history manager.
1836 Args:
1837 redis_client: Optional Redis async client. If None, uses in-memory storage.
1838 max_messages: Maximum messages to retain per user (default: 50).
1839 ttl: Time-to-live for Redis entries in seconds (default: 3600).
1841 Examples:
1842 >>> manager = ChatHistoryManager(redis_client=None, max_messages=100)
1843 >>> manager.max_messages
1844 100
1845 >>> manager.ttl
1846 3600
1847 """
1848 self.redis_client = redis_client
1849 self.max_messages = max_messages
1850 self.ttl = ttl
1851 self._memory_store: Dict[str, List[Dict[str, str]]] = {}
1853 if redis_client:
1854 logger.info("ChatHistoryManager initialized with Redis backend")
1855 else:
1856 logger.info("ChatHistoryManager initialized with in-memory backend")
1858 def _history_key(self, user_id: str) -> str:
1859 """
1860 Generate Redis key for user's chat history.
1862 Args:
1863 user_id: User identifier.
1865 Returns:
1866 str: Redis key string.
1868 Examples:
1869 >>> manager = ChatHistoryManager()
1870 >>> manager._history_key("user123")
1871 'chat_history:user123'
1872 """
1873 return f"chat_history:{user_id}"
1875 async def get_history(self, user_id: str) -> List[Dict[str, str]]:
1876 """
1877 Retrieve chat history for a user.
1879 Fetches history from Redis if available, otherwise from in-memory store.
1881 Args:
1882 user_id: User identifier.
1884 Returns:
1885 List[Dict[str, str]]: List of message dictionaries with 'role' and 'content' keys.
1886 Returns empty list if no history exists.
1888 Examples:
1889 >>> import asyncio
1890 >>> manager = ChatHistoryManager()
1891 >>> # history = asyncio.run(manager.get_history("user123"))
1892 >>> # isinstance(history, list)
1893 True
1895 Note:
1896 Automatically handles JSON deserialization errors by returning empty list.
1897 """
1898 if self.redis_client:
1899 try:
1900 data = await self.redis_client.get(self._history_key(user_id))
1901 if not data:
1902 return []
1903 return orjson.loads(data)
1904 except orjson.JSONDecodeError:
1905 logger.warning(f"Failed to decode chat history for user {user_id}")
1906 return []
1907 except Exception as e:
1908 logger.error(f"Error retrieving chat history from Redis for user {user_id}: {e}")
1909 return []
1910 else:
1911 return self._memory_store.get(user_id, [])
1913 async def save_history(self, user_id: str, history: List[Dict[str, str]]) -> None:
1914 """
1915 Save chat history for a user.
1917 Stores history in Redis (with TTL) if available, otherwise in memory.
1918 Automatically trims history to max_messages before saving.
1920 Args:
1921 user_id: User identifier.
1922 history: List of message dictionaries to save.
1924 Examples:
1925 >>> import asyncio
1926 >>> manager = ChatHistoryManager(max_messages=50)
1927 >>> messages = [{"role": "user", "content": "Hello"}]
1928 >>> # asyncio.run(manager.save_history("user123", messages))
1930 Note:
1931 History is automatically trimmed to max_messages limit before storage.
1932 """
1933 # Trim history before saving
1934 trimmed = self._trim_messages(history)
1936 if self.redis_client:
1937 try:
1938 await self.redis_client.set(self._history_key(user_id), orjson.dumps(trimmed), ex=self.ttl)
1939 except Exception as e:
1940 logger.error(f"Error saving chat history to Redis for user {user_id}: {e}")
1941 else:
1942 self._memory_store[user_id] = trimmed
1944 async def append_message(self, user_id: str, role: str, content: str) -> None:
1945 """
1946 Append a single message to user's chat history.
1948 Convenience method that fetches current history, appends the message,
1949 trims if needed, and saves back.
1951 Args:
1952 user_id: User identifier.
1953 role: Message role ('user' or 'assistant').
1954 content: Message content text.
1956 Examples:
1957 >>> import asyncio
1958 >>> manager = ChatHistoryManager()
1959 >>> # asyncio.run(manager.append_message("user123", "user", "Hello!"))
1961 Note:
1962 This method performs a read-modify-write operation which may
1963 not be atomic in distributed environments.
1964 """
1965 history = await self.get_history(user_id)
1966 history.append({"role": role, "content": content})
1967 await self.save_history(user_id, history)
1969 async def clear_history(self, user_id: str) -> None:
1970 """
1971 Clear all chat history for a user.
1973 Deletes history from Redis or memory store.
1975 Args:
1976 user_id: User identifier.
1978 Examples:
1979 >>> import asyncio
1980 >>> manager = ChatHistoryManager()
1981 >>> # asyncio.run(manager.clear_history("user123"))
1983 Note:
1984 This operation cannot be undone.
1985 """
1986 if self.redis_client:
1987 try:
1988 await self.redis_client.delete(self._history_key(user_id))
1989 except Exception as e:
1990 logger.error(f"Error clearing chat history from Redis for user {user_id}: {e}")
1991 else:
1992 self._memory_store.pop(user_id, None)
1994 def _trim_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
1995 """
1996 Trim message list to max_messages limit.
1998 Keeps the most recent messages up to max_messages count.
2000 Args:
2001 messages: List of message dictionaries.
2003 Returns:
2004 List[Dict[str, str]]: Trimmed message list.
2006 Examples:
2007 >>> manager = ChatHistoryManager(max_messages=2)
2008 >>> messages = [
2009 ... {"role": "user", "content": "1"},
2010 ... {"role": "assistant", "content": "2"},
2011 ... {"role": "user", "content": "3"}
2012 ... ]
2013 >>> trimmed = manager._trim_messages(messages)
2014 >>> len(trimmed)
2015 2
2016 >>> trimmed[0]["content"]
2017 '2'
2018 """
2019 if len(messages) > self.max_messages:
2020 return messages[-self.max_messages :]
2021 return messages
2023 async def get_langchain_messages(self, user_id: str) -> List[BaseMessage]:
2024 """
2025 Get chat history as LangChain message objects.
2027 Converts stored history dictionaries to LangChain HumanMessage and
2028 AIMessage objects for use with LangChain agents.
2030 Args:
2031 user_id: User identifier.
2033 Returns:
2034 List[BaseMessage]: List of LangChain message objects.
2036 Examples:
2037 >>> import asyncio
2038 >>> manager = ChatHistoryManager()
2039 >>> # messages = asyncio.run(manager.get_langchain_messages("user123"))
2040 >>> # isinstance(messages, list)
2041 True
2043 Note:
2044 Returns empty list if LangChain is not available or history is empty.
2045 """
2046 if not _LLMCHAT_AVAILABLE:
2047 return []
2049 history = await self.get_history(user_id)
2050 lc_messages = []
2052 for msg in history:
2053 role = msg.get("role")
2054 content = msg.get("content", "")
2056 if role == "user":
2057 lc_messages.append(HumanMessage(content=content))
2058 elif role == "assistant":
2059 lc_messages.append(AIMessage(content=content))
2061 return lc_messages
2064# ==================== MCP CLIENT ====================
2067class MCPClient:
2068 """
2069 Manages MCP server connections and tool loading.
2071 Provides a high-level interface for connecting to MCP servers, retrieving
2072 available tools, and managing connection health. Supports multiple transport
2073 protocols including HTTP, SSE, and stdio.
2075 Attributes:
2076 config: MCP server configuration.
2078 Examples:
2079 >>> import asyncio
2080 >>> config = MCPServerConfig(
2081 ... url="https://mcp-server.example.com/mcp",
2082 ... transport="streamable_http"
2083 ... )
2084 >>> client = MCPClient(config)
2085 >>> client.is_connected
2086 False
2087 >>> # asyncio.run(client.connect())
2088 >>> # tools = asyncio.run(client.get_tools())
2090 Note:
2091 All methods are async and should be called using asyncio or within
2092 an async context.
2093 """
2095 def __init__(self, config: MCPServerConfig):
2096 """
2097 Initialize MCP client.
2099 Args:
2100 config: MCP server configuration with connection parameters.
2102 Examples:
2103 >>> config = MCPServerConfig(
2104 ... url="https://example.com/mcp",
2105 ... transport="streamable_http"
2106 ... )
2107 >>> client = MCPClient(config)
2108 >>> client.config.transport
2109 'streamable_http'
2110 """
2111 self.config = config
2112 self._client: Optional[MultiServerMCPClient] = None
2113 self._tools: Optional[List[BaseTool]] = None
2114 self._connected = False
2115 logger.info(f"MCP client initialized with transport: {config.transport}")
2117 async def connect(self) -> None:
2118 """
2119 Connect to the MCP server.
2121 Establishes connection to the configured MCP server using the specified
2122 transport protocol. Subsequent calls are no-ops if already connected.
2124 Raises:
2125 ConnectionError: If connection to MCP server fails.
2127 Examples:
2128 >>> import asyncio
2129 >>> config = MCPServerConfig(
2130 ... url="https://example.com/mcp",
2131 ... transport="streamable_http"
2132 ... )
2133 >>> client = MCPClient(config)
2134 >>> # asyncio.run(client.connect())
2135 >>> # client.is_connected -> True
2137 Note:
2138 Connection is idempotent - calling multiple times is safe.
2139 """
2140 if self._connected:
2141 logger.warning("MCP client already connected")
2142 return
2144 try:
2145 logger.info(f"Connecting to MCP server via {self.config.transport}...")
2147 # Build server configuration for MultiServerMCPClient
2148 server_config = {
2149 "transport": self.config.transport,
2150 }
2152 if self.config.transport in ["streamable_http", "sse"]:
2153 server_config["url"] = self.config.url
2154 if self.config.headers:
2155 server_config["headers"] = self.config.headers
2156 elif self.config.transport == "stdio":
2157 server_config["command"] = self.config.command
2158 if self.config.args:
2159 server_config["args"] = self.config.args
2161 if not MultiServerMCPClient:
2162 logger.error("Some dependencies are missing. Install those with: pip install '.[llmchat]'")
2164 # Create MultiServerMCPClient with single server
2165 self._client = MultiServerMCPClient({"default": server_config})
2166 self._connected = True
2167 logger.info("Successfully connected to MCP server")
2169 except Exception as e:
2170 logger.error(f"Failed to connect to MCP server: {e}")
2171 self._connected = False
2172 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
2174 async def disconnect(self) -> None:
2175 """
2176 Disconnect from the MCP server.
2178 Cleanly closes the connection and releases resources. Safe to call
2179 even if not connected.
2181 Raises:
2182 Exception: If cleanup operations fail.
2184 Examples:
2185 >>> import asyncio
2186 >>> config = MCPServerConfig(
2187 ... url="https://example.com/mcp",
2188 ... transport="streamable_http"
2189 ... )
2190 >>> client = MCPClient(config)
2191 >>> # asyncio.run(client.connect())
2192 >>> # asyncio.run(client.disconnect())
2193 >>> # client.is_connected -> False
2195 Note:
2196 Clears cached tools upon disconnection.
2197 """
2198 if not self._connected:
2199 logger.warning("MCP client not connected")
2200 return
2202 try:
2203 if self._client:
2204 # MultiServerMCPClient manages connections internally
2205 self._client = None
2207 self._connected = False
2208 self._tools = None
2209 logger.info("Disconnected from MCP server")
2211 except Exception as e:
2212 logger.error(f"Error during disconnect: {e}")
2213 raise
2215 async def get_tools(self, force_reload: bool = False) -> List[BaseTool]:
2216 """
2217 Get tools from the MCP server.
2219 Retrieves available tools from the connected MCP server. Results are
2220 cached unless force_reload is True.
2222 Args:
2223 force_reload: Force reload tools even if cached (default: False).
2225 Returns:
2226 List[BaseTool]: List of available tools from the server.
2228 Raises:
2229 ConnectionError: If not connected to MCP server.
2230 Exception: If tool loading fails.
2232 Examples:
2233 >>> import asyncio
2234 >>> config = MCPServerConfig(
2235 ... url="https://example.com/mcp",
2236 ... transport="streamable_http"
2237 ... )
2238 >>> client = MCPClient(config)
2239 >>> # asyncio.run(client.connect())
2240 >>> # tools = asyncio.run(client.get_tools())
2241 >>> # len(tools) >= 0 -> True
2243 Note:
2244 Tools are cached after first successful load for performance.
2245 """
2246 if not self._connected or not self._client:
2247 raise ConnectionError("Not connected to MCP server. Call connect() first.")
2249 if self._tools and not force_reload:
2250 logger.debug(f"Returning {len(self._tools)} cached tools")
2251 return self._tools
2253 try:
2254 logger.info("Loading tools from MCP server...")
2255 self._tools = await self._client.get_tools()
2256 logger.info(f"Successfully loaded {len(self._tools)} tools")
2257 return self._tools
2259 except Exception as e:
2260 logger.error(f"Failed to load tools: {e}")
2261 raise
2263 @property
2264 def is_connected(self) -> bool:
2265 """
2266 Check if client is connected.
2268 Returns:
2269 bool: True if connected to MCP server, False otherwise.
2271 Examples:
2272 >>> config = MCPServerConfig(
2273 ... url="https://example.com/mcp",
2274 ... transport="streamable_http"
2275 ... )
2276 >>> client = MCPClient(config)
2277 >>> client.is_connected
2278 False
2279 """
2280 return self._connected
2283# ==================== MCP CHAT SERVICE ====================
2286class MCPChatService:
2287 """
2288 Main chat service for MCP client backend.
2289 Orchestrates chat sessions with LLM and MCP server integration.
2291 Provides a high-level interface for managing conversational AI sessions
2292 that combine LLM capabilities with MCP server tools. Handles conversation
2293 history management, tool execution, and streaming responses.
2295 This service integrates:
2296 - LLM providers (Azure OpenAI, OpenAI, Anthropic, AWS Bedrock, Ollama)
2297 - MCP server tools
2298 - Centralized chat history management (Redis or in-memory)
2299 - Streaming and non-streaming response modes
2301 Attributes:
2302 config: Complete MCP client configuration.
2303 user_id: Optional user identifier for history management.
2305 Examples:
2306 >>> import asyncio
2307 >>> config = MCPClientConfig(
2308 ... mcp_server=MCPServerConfig(
2309 ... url="https://example.com/mcp",
2310 ... transport="streamable_http"
2311 ... ),
2312 ... llm=LLMConfig(
2313 ... provider="ollama",
2314 ... config=OllamaConfig(model="llama2")
2315 ... )
2316 ... )
2317 >>> service = MCPChatService(config)
2318 >>> service.is_initialized
2319 False
2320 >>> # asyncio.run(service.initialize())
2322 Note:
2323 Must call initialize() before using chat methods.
2324 """
2326 def __init__(self, config: MCPClientConfig, user_id: Optional[str] = None, redis_client: Optional[Any] = None):
2327 """
2328 Initialize MCP chat service.
2330 Args:
2331 config: Complete MCP client configuration.
2332 user_id: Optional user identifier for chat history management.
2333 redis_client: Optional Redis client for distributed history storage.
2335 Examples:
2336 >>> config = MCPClientConfig(
2337 ... mcp_server=MCPServerConfig(
2338 ... url="https://example.com/mcp",
2339 ... transport="streamable_http"
2340 ... ),
2341 ... llm=LLMConfig(
2342 ... provider="ollama",
2343 ... config=OllamaConfig(model="llama2")
2344 ... )
2345 ... )
2346 >>> service = MCPChatService(config, user_id="user123")
2347 >>> service.user_id
2348 'user123'
2349 """
2350 self.config = config
2351 self.user_id = user_id
2352 self.mcp_client = MCPClient(config.mcp_server)
2353 self.llm_provider = LLMProviderFactory.create(config.llm)
2355 # Initialize centralized chat history manager
2356 self.history_manager = ChatHistoryManager(redis_client=redis_client, max_messages=config.chat_history_max_messages, ttl=settings.llmchat_chat_history_ttl)
2358 self._agent = None
2359 self._initialized = False
2360 self._tools: List[BaseTool] = []
2362 logger.info(f"MCPChatService initialized for user: {user_id or 'anonymous'}")
2364 async def initialize(self) -> None:
2365 """
2366 Initialize the chat service.
2368 Connects to MCP server, loads tools, initializes LLM, and creates the
2369 conversational agent. Must be called before using chat functionality.
2371 Raises:
2372 ConnectionError: If MCP server connection fails.
2373 Exception: If initialization fails.
2375 Examples:
2376 >>> import asyncio
2377 >>> config = MCPClientConfig(
2378 ... mcp_server=MCPServerConfig(
2379 ... url="https://example.com/mcp",
2380 ... transport="streamable_http"
2381 ... ),
2382 ... llm=LLMConfig(
2383 ... provider="ollama",
2384 ... config=OllamaConfig(model="llama2")
2385 ... )
2386 ... )
2387 >>> service = MCPChatService(config)
2388 >>> # asyncio.run(service.initialize())
2389 >>> # service.is_initialized -> True
2391 Note:
2392 Automatically loads tools from MCP server and creates agent.
2393 """
2394 if self._initialized:
2395 logger.warning("Chat service already initialized")
2396 return
2398 try:
2399 logger.info("Initializing chat service...")
2401 # Connect to MCP server and load tools
2402 await self.mcp_client.connect()
2403 self._tools = await self.mcp_client.get_tools()
2405 # Create LLM instance
2406 llm = self.llm_provider.get_llm()
2408 # Create ReAct agent with tools
2409 self._agent = create_react_agent(llm, self._tools)
2411 self._initialized = True
2412 logger.info(f"Chat service initialized successfully with {len(self._tools)} tools")
2414 except Exception as e:
2415 logger.error(f"Failed to initialize chat service: {e}")
2416 self._initialized = False
2417 raise
2419 async def chat(self, message: str) -> str:
2420 """
2421 Send a message and get a complete response.
2423 Processes the user's message through the LLM with tool access,
2424 manages conversation history, and returns the complete response.
2426 Args:
2427 message: User's message text.
2429 Returns:
2430 str: Complete AI response text.
2432 Raises:
2433 RuntimeError: If service not initialized.
2434 ValueError: If message is empty.
2435 Exception: If processing fails.
2437 Examples:
2438 >>> import asyncio
2439 >>> # Assuming service is initialized
2440 >>> # response = asyncio.run(service.chat("Hello!"))
2441 >>> # isinstance(response, str)
2442 True
2444 Note:
2445 Automatically saves conversation history after response.
2446 """
2447 if not self._initialized or not self._agent:
2448 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2450 if not message or not message.strip():
2451 raise ValueError("Message cannot be empty")
2453 try:
2454 logger.debug("Processing chat message...")
2456 # Get conversation history from manager
2457 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2459 # Add user message
2460 user_message = HumanMessage(content=message)
2461 lc_messages.append(user_message)
2463 # Invoke agent
2464 response = await self._agent.ainvoke({"messages": lc_messages})
2466 # Extract AI response
2467 ai_message = response["messages"][-1]
2468 response_text = ai_message.content if hasattr(ai_message, "content") else str(ai_message)
2470 # Save history if user_id provided
2471 if self.user_id:
2472 await self.history_manager.append_message(self.user_id, "user", message)
2473 await self.history_manager.append_message(self.user_id, "assistant", response_text)
2475 logger.debug("Chat message processed successfully")
2476 return response_text
2478 except Exception as e:
2479 logger.error(f"Error processing chat message: {e}")
2480 raise
2482 async def chat_with_metadata(self, message: str) -> Dict[str, Any]:
2483 """
2484 Send a message and get response with metadata.
2486 Similar to chat() but collects all events and returns detailed
2487 information about tool usage and timing.
2489 Args:
2490 message: User's message text.
2492 Returns:
2493 Dict[str, Any]: Dictionary containing:
2494 - text (str): Complete response text
2495 - tool_used (bool): Whether any tools were invoked
2496 - tools (List[str]): Names of tools that were used
2497 - tool_invocations (List[dict]): Detailed tool invocation data
2498 - elapsed_ms (int): Processing time in milliseconds
2500 Raises:
2501 RuntimeError: If service not initialized.
2502 ValueError: If message is empty.
2504 Examples:
2505 >>> import asyncio
2506 >>> # Assuming service is initialized
2507 >>> # result = asyncio.run(service.chat_with_metadata("What's 2+2?"))
2508 >>> # 'text' in result and 'elapsed_ms' in result
2509 True
2511 Note:
2512 This method collects all events and returns them as a single response.
2513 """
2514 text = ""
2515 tool_invocations: list[dict[str, Any]] = []
2516 final: dict[str, Any] = {}
2518 async for ev in self.chat_events(message):
2519 t = ev.get("type")
2520 if t == "token":
2521 text += ev.get("content", "")
2522 elif t in ("tool_start", "tool_end", "tool_error"):
2523 tool_invocations.append(ev)
2524 elif t == "final":
2525 final = ev
2527 return {
2528 "text": text,
2529 "tool_used": final.get("tool_used", False),
2530 "tools": final.get("tools", []),
2531 "tool_invocations": tool_invocations,
2532 "elapsed_ms": final.get("elapsed_ms"),
2533 }
2535 async def chat_stream(self, message: str) -> AsyncGenerator[str, None]:
2536 """
2537 Send a message and stream the response.
2539 Yields response chunks as they're generated, enabling real-time display
2540 of the AI's response.
2542 Args:
2543 message: User's message text.
2545 Yields:
2546 str: Chunks of AI response text.
2548 Raises:
2549 RuntimeError: If service not initialized.
2550 Exception: If streaming fails.
2552 Examples:
2553 >>> import asyncio
2554 >>> async def stream_example():
2555 ... # Assuming service is initialized
2556 ... chunks = []
2557 ... async for chunk in service.chat_stream("Hello"):
2558 ... chunks.append(chunk)
2559 ... return ''.join(chunks)
2560 >>> # full_response = asyncio.run(stream_example())
2562 Note:
2563 Falls back to non-streaming if enable_streaming is False in config.
2564 """
2565 if not self._initialized or not self._agent:
2566 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2568 if not self.config.enable_streaming:
2569 # Fall back to non-streaming
2570 response = await self.chat(message)
2571 yield response
2572 return
2574 try:
2575 logger.debug("Processing streaming chat message...")
2577 # Get conversation history
2578 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2580 # Add user message
2581 user_message = HumanMessage(content=message)
2582 lc_messages.append(user_message)
2584 # Stream agent response
2585 full_response = ""
2586 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"):
2587 kind = event["event"]
2589 # Stream LLM tokens
2590 if kind == "on_chat_model_stream":
2591 chunk = event.get("data", {}).get("chunk")
2592 if chunk and hasattr(chunk, "content"):
2593 content = chunk.content
2594 if content:
2595 full_response += content
2596 yield content
2598 # Save history
2599 if self.user_id and full_response:
2600 await self.history_manager.append_message(self.user_id, "user", message)
2601 await self.history_manager.append_message(self.user_id, "assistant", full_response)
2603 logger.debug("Streaming chat message processed successfully")
2605 except Exception as e:
2606 logger.error(f"Error processing streaming chat message: {e}")
2607 raise
2609 async def chat_events(self, message: str) -> AsyncGenerator[Dict[str, Any], None]:
2610 """
2611 Stream structured events during chat processing.
2613 Provides granular visibility into the chat processing pipeline by yielding
2614 structured events for tokens, tool invocations, errors, and final results.
2616 Args:
2617 message: User's message text.
2619 Yields:
2620 dict: Event dictionaries with type-specific fields:
2621 - token: {"type": "token", "content": str}
2622 - tool_start: {"type": "tool_start", "id": str, "name": str,
2623 "input": Any, "start": str}
2624 - tool_end: {"type": "tool_end", "id": str, "name": str,
2625 "output": Any, "end": str}
2626 - tool_error: {"type": "tool_error", "id": str, "error": str,
2627 "time": str}
2628 - final: {"type": "final", "content": str, "tool_used": bool,
2629 "tools": List[str], "elapsed_ms": int}
2631 Raises:
2632 RuntimeError: If service not initialized.
2633 ValueError: If message is empty or whitespace only.
2635 Examples:
2636 >>> import asyncio
2637 >>> async def event_example():
2638 ... # Assuming service is initialized
2639 ... events = []
2640 ... async for event in service.chat_events("Hello"):
2641 ... events.append(event['type'])
2642 ... return events
2643 >>> # event_types = asyncio.run(event_example())
2644 >>> # 'final' in event_types -> True
2646 Note:
2647 This is the most detailed chat method, suitable for building
2648 interactive UIs or detailed logging systems.
2649 """
2650 if not self._initialized or not self._agent:
2651 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2653 # Validate message
2654 if not message or not message.strip():
2655 raise ValueError("Message cannot be empty")
2657 # Get conversation history
2658 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2660 # Append user message
2661 user_message = HumanMessage(content=message)
2662 lc_messages.append(user_message)
2664 full_response = ""
2665 start_ts = time.time()
2666 tool_runs: dict[str, dict[str, Any]] = {}
2667 # Buffer for out-of-order on_tool_end events (end arrives before start)
2668 pending_tool_ends: dict[str, dict[str, Any]] = {}
2669 pending_ttl_seconds = 30.0 # Max time to hold pending end events
2670 pending_max_size = 100 # Max number of pending end events to buffer
2671 # Track dropped run_ids for aggregated error (TTL-expired or buffer-full)
2672 dropped_tool_ends: set[str] = set()
2673 dropped_max_size = 200 # Max dropped IDs to track (prevents unbounded growth)
2674 dropped_overflow_count = 0 # Count of drops that couldn't be tracked due to full buffer
2676 def _extract_output(raw_output: Any) -> Any:
2677 """Extract output value from various LangChain output formats.
2679 Args:
2680 raw_output: The raw output from a tool execution.
2682 Returns:
2683 The extracted output value in a serializable format.
2684 """
2685 if hasattr(raw_output, "content"):
2686 return raw_output.content
2687 if hasattr(raw_output, "dict") and callable(raw_output.dict):
2688 return raw_output.dict()
2689 if not isinstance(raw_output, (str, int, float, bool, list, dict, type(None))):
2690 return str(raw_output)
2691 return raw_output
2693 def _cleanup_expired_pending(current_ts: float) -> None:
2694 """Remove expired entries from pending_tool_ends buffer and track them.
2696 Args:
2697 current_ts: Current timestamp in seconds since epoch.
2698 """
2699 nonlocal dropped_overflow_count
2700 expired = [rid for rid, data in pending_tool_ends.items() if current_ts - data.get("buffered_at", 0) > pending_ttl_seconds]
2701 for rid in expired:
2702 logger.warning(f"Pending on_tool_end for run_id {rid} expired after {pending_ttl_seconds}s (orphan event)")
2703 if len(dropped_tool_ends) < dropped_max_size:
2704 dropped_tool_ends.add(rid)
2705 else:
2706 dropped_overflow_count += 1
2707 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track expired run_id {rid} (overflow count: {dropped_overflow_count})")
2708 del pending_tool_ends[rid]
2710 try:
2711 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"):
2712 kind = event.get("event")
2713 now_iso = datetime.now(timezone.utc).isoformat()
2714 now_ts = time.time()
2716 # Periodically cleanup expired pending ends
2717 _cleanup_expired_pending(now_ts)
2719 try:
2720 if kind == "on_tool_start":
2721 run_id = str(event.get("run_id") or uuid4())
2722 name = event.get("name") or event.get("data", {}).get("name") or event.get("data", {}).get("tool")
2723 input_data = event.get("data", {}).get("input")
2725 # Filter out common metadata keys injected by LangChain/LangGraph
2726 if isinstance(input_data, dict):
2727 input_data = {k: v for k, v in input_data.items() if k not in ["runtime", "config", "run_manager", "callbacks"]}
2729 tool_runs[run_id] = {"name": name, "start": now_iso, "input": input_data}
2731 # Register run for cancellation tracking with gateway-level Cancellation service
2732 async def _noop_cancel_cb(reason: Optional[str]) -> None:
2733 """
2734 No-op cancel callback used when a run is started.
2736 Args:
2737 reason: Optional textual reason for cancellation.
2739 Returns:
2740 None
2741 """
2742 # Default no-op; kept for potential future intra-process cancellation
2743 return None
2745 # Register with cancellation service only if feature is enabled
2746 if settings.mcpgateway_tool_cancellation_enabled:
2747 try:
2748 await cancellation_service.register_run(run_id, name=name, cancel_callback=_noop_cancel_cb)
2749 except Exception:
2750 logger.exception("Failed to register run %s with CancellationService", run_id)
2752 yield {"type": "tool_start", "id": run_id, "tool": name, "input": input_data, "start": now_iso}
2754 # NOTE: Do NOT clear from dropped_tool_ends here. If an end was dropped (TTL/buffer-full)
2755 # before this start arrived, that end is permanently lost. Since tools only end once,
2756 # we won't receive another end event, so this should still be reported as an orphan.
2758 # Check if we have a buffered end event for this run_id (out-of-order reconciliation)
2759 if run_id in pending_tool_ends:
2760 buffered = pending_tool_ends.pop(run_id)
2761 tool_runs[run_id]["end"] = buffered["end_time"]
2762 tool_runs[run_id]["output"] = buffered["output"]
2763 logger.info(f"Reconciled out-of-order on_tool_end for run_id {run_id}")
2765 if tool_runs[run_id].get("output") == "":
2766 error = "Tool execution failed: Please check if the tool is accessible"
2767 yield {"type": "tool_error", "id": run_id, "tool": name, "error": error, "time": buffered["end_time"]}
2769 yield {"type": "tool_end", "id": run_id, "tool": name, "output": tool_runs[run_id].get("output"), "end": buffered["end_time"]}
2771 elif kind == "on_tool_end":
2772 run_id = str(event.get("run_id") or uuid4())
2773 output = event.get("data", {}).get("output")
2774 extracted_output = _extract_output(output)
2776 if run_id in tool_runs:
2777 # Normal case: start already received
2778 tool_runs[run_id]["end"] = now_iso
2779 tool_runs[run_id]["output"] = extracted_output
2781 if tool_runs[run_id].get("output") == "":
2782 error = "Tool execution failed: Please check if the tool is accessible"
2783 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso}
2785 yield {"type": "tool_end", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "output": tool_runs[run_id].get("output"), "end": now_iso}
2786 else:
2787 # Out-of-order: buffer the end event for later reconciliation
2788 if len(pending_tool_ends) < pending_max_size:
2789 pending_tool_ends[run_id] = {"output": extracted_output, "end_time": now_iso, "buffered_at": now_ts}
2790 logger.debug(f"Buffered out-of-order on_tool_end for run_id {run_id}, awaiting on_tool_start")
2791 else:
2792 logger.warning(f"Pending tool ends buffer full ({pending_max_size}), dropping on_tool_end for run_id {run_id}")
2793 if len(dropped_tool_ends) < dropped_max_size:
2794 dropped_tool_ends.add(run_id)
2795 else:
2796 dropped_overflow_count += 1
2797 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track run_id {run_id} (overflow count: {dropped_overflow_count})")
2799 # Unregister run from cancellation service when finished (only if feature is enabled)
2800 if settings.mcpgateway_tool_cancellation_enabled:
2801 try:
2802 await cancellation_service.unregister_run(run_id)
2803 except Exception:
2804 logger.exception("Failed to unregister run %s", run_id)
2806 elif kind == "on_tool_error":
2807 run_id = str(event.get("run_id") or uuid4())
2808 error = str(event.get("data", {}).get("error", "Unknown error"))
2810 # Clear any buffered end for this run to avoid emitting both error and end
2811 if run_id in pending_tool_ends:
2812 del pending_tool_ends[run_id]
2813 logger.debug(f"Cleared buffered on_tool_end for run_id {run_id} due to tool error")
2815 # Clear from dropped set if this run was previously dropped (prevents false orphan)
2816 dropped_tool_ends.discard(run_id)
2818 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso}
2820 # Unregister run on error (only if feature is enabled)
2821 if settings.mcpgateway_tool_cancellation_enabled:
2822 try:
2823 await cancellation_service.unregister_run(run_id)
2824 except Exception:
2825 logger.exception("Failed to unregister run %s after error", run_id)
2827 elif kind == "on_chat_model_stream":
2828 chunk = event.get("data", {}).get("chunk")
2829 if chunk and hasattr(chunk, "content"):
2830 content = chunk.content
2831 if content:
2832 full_response += content
2833 yield {"type": "token", "content": content}
2835 except Exception as event_error:
2836 logger.warning(f"Error processing event {kind}: {event_error}")
2837 continue
2839 # Emit aggregated error for any orphan/dropped tool ends
2840 # De-duplicate IDs (in case same ID was buffered and dropped in edge cases)
2841 all_orphan_ids = sorted(set(pending_tool_ends.keys()) | dropped_tool_ends)
2842 if all_orphan_ids or dropped_overflow_count > 0:
2843 buffered_count = len(pending_tool_ends)
2844 dropped_count = len(dropped_tool_ends)
2845 total_unique = len(all_orphan_ids)
2846 total_affected = total_unique + dropped_overflow_count
2847 logger.warning(
2848 f"Stream completed with {total_affected} orphan tool end(s): {buffered_count} buffered, {dropped_count} dropped (tracked), {dropped_overflow_count} dropped (untracked overflow)"
2849 )
2850 # Log full list at debug level for observability
2851 if all_orphan_ids:
2852 logger.debug(f"Full orphan run_id list: {', '.join(all_orphan_ids)}")
2853 now_iso = datetime.now(timezone.utc).isoformat()
2854 error_parts = []
2855 if buffered_count > 0:
2856 error_parts.append(f"{buffered_count} buffered")
2857 if dropped_count > 0:
2858 error_parts.append(f"{dropped_count} dropped (TTL expired or buffer full)")
2859 if dropped_overflow_count > 0:
2860 error_parts.append(f"{dropped_overflow_count} additional dropped (tracking overflow)")
2861 error_msg = f"Tool execution incomplete: {total_affected} tool end(s) received without matching start ({', '.join(error_parts)})"
2862 # Truncate to first 10 IDs in error message to avoid excessive payload
2863 if all_orphan_ids:
2864 max_display_ids = 10
2865 display_ids = all_orphan_ids[:max_display_ids]
2866 remaining = total_unique - len(display_ids)
2867 if remaining > 0:
2868 error_msg += f". Run IDs (first {max_display_ids} of {total_unique}): {', '.join(display_ids)} (+{remaining} more)"
2869 else:
2870 error_msg += f". Run IDs: {', '.join(display_ids)}"
2871 yield {
2872 "type": "tool_error",
2873 "id": str(uuid4()),
2874 "tool": None,
2875 "error": error_msg,
2876 "time": now_iso,
2877 }
2878 pending_tool_ends.clear()
2879 dropped_tool_ends.clear()
2881 # Calculate elapsed time
2882 elapsed_ms = int((time.time() - start_ts) * 1000)
2884 # Determine tool usage
2885 tools_used = list({tr["name"] for tr in tool_runs.values() if tr.get("name")})
2887 # Yield final event
2888 yield {"type": "final", "content": full_response, "tool_used": len(tools_used) > 0, "tools": tools_used, "elapsed_ms": elapsed_ms}
2890 # Save history
2891 if self.user_id and full_response:
2892 await self.history_manager.append_message(self.user_id, "user", message)
2893 await self.history_manager.append_message(self.user_id, "assistant", full_response)
2895 except Exception as e:
2896 logger.error(f"Error in chat_events: {e}")
2897 raise RuntimeError(f"Chat processing error: {e}") from e
2899 async def get_conversation_history(self) -> List[Dict[str, str]]:
2900 """
2901 Get conversation history for the current user.
2903 Returns:
2904 List[Dict[str, str]]: Conversation messages with keys:
2905 - role (str): "user" or "assistant"
2906 - content (str): Message text
2908 Examples:
2909 >>> import asyncio
2910 >>> # Assuming service is initialized with user_id
2911 >>> # history = asyncio.run(service.get_conversation_history())
2912 >>> # all('role' in msg and 'content' in msg for msg in history)
2913 True
2915 Note:
2916 Returns empty list if no user_id set or no history exists.
2917 """
2918 if not self.user_id:
2919 return []
2921 return await self.history_manager.get_history(self.user_id)
2923 async def clear_history(self) -> None:
2924 """
2925 Clear conversation history for the current user.
2927 Removes all messages from the conversation history. Useful for starting
2928 fresh conversations or managing memory usage.
2930 Examples:
2931 >>> import asyncio
2932 >>> # Assuming service is initialized with user_id
2933 >>> # asyncio.run(service.clear_history())
2934 >>> # history = asyncio.run(service.get_conversation_history())
2935 >>> # len(history) -> 0
2937 Note:
2938 This action cannot be undone. No-op if no user_id set.
2939 """
2940 if not self.user_id:
2941 return
2943 await self.history_manager.clear_history(self.user_id)
2944 logger.info(f"Conversation history cleared for user {self.user_id}")
2946 async def shutdown(self) -> None:
2947 """
2948 Shutdown the chat service and cleanup resources.
2950 Performs graceful shutdown by disconnecting from MCP server, clearing
2951 agent and history, and resetting initialization state.
2953 Raises:
2954 Exception: If cleanup operations fail.
2956 Examples:
2957 >>> import asyncio
2958 >>> config = MCPClientConfig(
2959 ... mcp_server=MCPServerConfig(
2960 ... url="https://example.com/mcp",
2961 ... transport="streamable_http"
2962 ... ),
2963 ... llm=LLMConfig(
2964 ... provider="ollama",
2965 ... config=OllamaConfig(model="llama2")
2966 ... )
2967 ... )
2968 >>> service = MCPChatService(config)
2969 >>> # asyncio.run(service.initialize())
2970 >>> # asyncio.run(service.shutdown())
2971 >>> # service.is_initialized -> False
2973 Note:
2974 Should be called when service is no longer needed to properly
2975 release resources and connections.
2976 """
2977 logger.info("Shutting down chat service...")
2979 try:
2980 # Disconnect from MCP server
2981 if self.mcp_client.is_connected:
2982 await self.mcp_client.disconnect()
2984 # Clear state
2985 self._agent = None
2986 self._initialized = False
2987 self._tools = []
2989 logger.info("Chat service shutdown complete")
2991 except Exception as e:
2992 logger.error(f"Error during shutdown: {e}")
2993 raise
2995 @property
2996 def is_initialized(self) -> bool:
2997 """
2998 Check if service is initialized.
3000 Returns:
3001 bool: True if service is initialized and ready, False otherwise.
3003 Examples:
3004 >>> config = MCPClientConfig(
3005 ... mcp_server=MCPServerConfig(
3006 ... url="https://example.com/mcp",
3007 ... transport="streamable_http"
3008 ... ),
3009 ... llm=LLMConfig(
3010 ... provider="ollama",
3011 ... config=OllamaConfig(model="llama2")
3012 ... )
3013 ... )
3014 >>> service = MCPChatService(config)
3015 >>> service.is_initialized
3016 False
3018 Note:
3019 Service must be initialized before calling chat methods.
3020 """
3021 return self._initialized
3023 async def reload_tools(self) -> int:
3024 """
3025 Reload tools from MCP server.
3027 Forces a reload of tools from the MCP server and recreates the agent
3028 with the updated tool set. Useful when MCP server tools have changed.
3030 Returns:
3031 int: Number of tools successfully loaded.
3033 Raises:
3034 RuntimeError: If service not initialized.
3035 Exception: If tool reloading or agent recreation fails.
3037 Examples:
3038 >>> import asyncio
3039 >>> # Assuming service is initialized
3040 >>> # tool_count = asyncio.run(service.reload_tools())
3041 >>> # tool_count >= 0 -> True
3043 Note:
3044 This operation recreates the agent, so it may briefly interrupt
3045 ongoing conversations. Conversation history is preserved.
3046 """
3047 if not self._initialized:
3048 raise RuntimeError("Chat service not initialized")
3050 try:
3051 logger.info("Reloading tools from MCP server...")
3052 tools = await self.mcp_client.get_tools(force_reload=True)
3054 # Recreate agent with new tools
3055 llm = self.llm_provider.get_llm()
3056 self._agent = create_react_agent(llm, tools)
3057 self._tools = tools
3059 logger.info(f"Reloaded {len(tools)} tools successfully")
3060 return len(tools)
3062 except Exception as e:
3063 logger.error(f"Failed to reload tools: {e}")
3064 raise