Coverage for mcpgateway / services / mcp_client_chat_service.py: 98%
888 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/mcp_client_chat_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7MCP Client Service Module.
9This module provides a comprehensive client implementation for interacting with
10MCP servers, managing LLM providers, and orchestrating conversational AI agents.
11It supports multiple transport protocols and LLM providers.
13The module consists of several key components:
14- Configuration classes for MCP servers and LLM providers
15- LLM provider factory and implementations
16- MCP client for tool management
17- Chat history manager for Redis and in-memory storage
18- Chat service for conversational interactions
19"""
21# Standard
22from datetime import datetime, timezone
23import time
24from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
25from uuid import uuid4
27# Third-Party
28import orjson
30try:
31 # Third-Party
32 from langchain_core.language_models import BaseChatModel
33 from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
34 from langchain_core.tools import BaseTool
35 from langchain_mcp_adapters.client import MultiServerMCPClient
36 from langchain_ollama import ChatOllama, OllamaLLM
37 from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI, OpenAI
38 from langgraph.prebuilt import create_react_agent
40 _LLMCHAT_AVAILABLE = True
41except ImportError:
42 # Optional dependencies for LLM chat feature not installed
43 # These are only needed if LLMCHAT_ENABLED=true
44 _LLMCHAT_AVAILABLE = False
45 BaseChatModel = None # type: ignore
46 AIMessage = None # type: ignore
47 BaseMessage = None # type: ignore
48 HumanMessage = None # type: ignore
49 BaseTool = None # type: ignore
50 MultiServerMCPClient = None # type: ignore
51 ChatOllama = None # type: ignore
52 OllamaLLM = None
53 AzureChatOpenAI = None # type: ignore
54 AzureOpenAI = None
55 ChatOpenAI = None # type: ignore
56 OpenAI = None
57 create_react_agent = None # type: ignore
59# Try to import Anthropic and Bedrock providers (they may not be installed)
60try:
61 # Third-Party
62 from langchain_anthropic import AnthropicLLM, ChatAnthropic
64 _ANTHROPIC_AVAILABLE = True
65except ImportError:
66 _ANTHROPIC_AVAILABLE = False
67 ChatAnthropic = None # type: ignore
68 AnthropicLLM = None
70try:
71 # Third-Party
72 from langchain_aws import BedrockLLM, ChatBedrock
74 _BEDROCK_AVAILABLE = True
75except ImportError:
76 _BEDROCK_AVAILABLE = False
77 ChatBedrock = None # type: ignore
78 BedrockLLM = None
80try:
81 # Third-Party
82 from langchain_ibm import ChatWatsonx, WatsonxLLM
84 _WATSONX_AVAILABLE = True
85except ImportError:
86 _WATSONX_AVAILABLE = False
87 WatsonxLLM = None # type: ignore
88 ChatWatsonx = None
90# Third-Party
91from pydantic import BaseModel, Field, field_validator, model_validator
93# First-Party
94from mcpgateway.common.validators import SecurityValidator
95from mcpgateway.config import settings
96from mcpgateway.observability import create_span, set_span_attribute
97from mcpgateway.services.cancellation_service import cancellation_service
98from mcpgateway.services.logging_service import LoggingService
99from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload
101logging_service = LoggingService()
102logger = logging_service.get_logger(__name__)
105def _llm_system_name(service: "MCPChatService") -> str:
106 """Return a stable provider/system label for trace attributes.
108 Args:
109 service: Chat service instance holding the active LLM provider.
111 Returns:
112 Lowercase provider name for GenAI trace attributes.
113 """
114 provider_name = type(service.llm_provider).__name__.replace("Provider", "")
115 return provider_name.lower() or "unknown"
118def _set_usage_attributes(span: Any, ai_message: Any) -> None:
119 """Attach token usage metadata to a span when available.
121 Args:
122 span: Active span object to enrich.
123 ai_message: Provider response object that may expose ``usage_metadata``.
124 """
125 usage = getattr(ai_message, "usage_metadata", None)
126 if not isinstance(usage, dict):
127 return
129 input_tokens = usage.get("input_tokens")
130 output_tokens = usage.get("output_tokens")
131 total_tokens = usage.get("total_tokens")
132 if input_tokens is not None:
133 set_span_attribute(span, "gen_ai.usage.prompt_tokens", input_tokens)
134 if output_tokens is not None:
135 set_span_attribute(span, "gen_ai.usage.completion_tokens", output_tokens)
136 if total_tokens is not None:
137 set_span_attribute(span, "gen_ai.usage.total_tokens", total_tokens)
140class ChatProcessingError(RuntimeError):
141 """Recoverable error wrapping tool, parsing, or model failures during chat streaming."""
144class MCPServerConfig(BaseModel):
145 """
146 Configuration for MCP server connection.
148 This class defines the configuration parameters required to connect to an
149 MCP (Model Context Protocol) server using various transport mechanisms.
151 Attributes:
152 url: MCP server URL for streamable_http/sse transports.
153 command: Command to run for stdio transport.
154 args: Command-line arguments for stdio command.
155 transport: Transport type (streamable_http, sse, or stdio).
156 auth_token: Authentication token for HTTP-based transports.
157 headers: Additional HTTP headers for request customization.
159 Examples:
160 >>> # HTTP-based transport
161 >>> config = MCPServerConfig(
162 ... url="https://mcp-server.example.com/mcp",
163 ... transport="streamable_http",
164 ... auth_token="secret-token"
165 ... )
166 >>> config.transport
167 'streamable_http'
169 >>> # Stdio transport (requires explicit feature flag)
170 >>> settings.mcpgateway_stdio_transport_enabled = True
171 >>> config = MCPServerConfig(
172 ... command="python",
173 ... args=["server.py"],
174 ... transport="stdio"
175 ... )
176 >>> config.command
177 'python'
179 Note:
180 The auth_token is automatically added to headers as a Bearer token
181 for HTTP-based transports.
182 """
184 url: Optional[str] = Field(None, description="MCP server URL for streamable_http/sse transports")
185 command: Optional[str] = Field(None, description="Command to run for stdio transport")
186 args: Optional[list[str]] = Field(None, description="Arguments for stdio command")
187 transport: Literal["streamable_http", "sse", "stdio"] = Field(default="streamable_http", description="Transport type for MCP connection")
188 auth_token: Optional[str] = Field(None, description="Authentication token for the server")
189 headers: Optional[Dict[str, str]] = Field(default=None, description="Additional headers for HTTP-based transports")
191 @model_validator(mode="before")
192 @classmethod
193 def add_auth_to_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
194 """
195 Automatically add authentication token to headers if provided.
197 This validator ensures that if an auth_token is provided for HTTP-based
198 transports, it's automatically added to the headers as a Bearer token.
200 Args:
201 values: Dictionary of field values before validation.
203 Returns:
204 Dict[str, Any]: Updated values with auth token in headers.
206 Examples:
207 >>> values = {
208 ... "url": "https://api.example.com",
209 ... "transport": "streamable_http",
210 ... "auth_token": "token123"
211 ... }
212 >>> result = MCPServerConfig.add_auth_to_headers(values)
213 >>> result['headers']['Authorization']
214 'Bearer token123'
215 """
216 auth_token = values.get("auth_token")
217 transport = values.get("transport")
218 headers = values.get("headers") or {}
220 if auth_token and transport in ["streamable_http", "sse"]:
221 if "Authorization" not in headers:
222 headers["Authorization"] = f"Bearer {auth_token}"
223 values["headers"] = headers
225 return values
227 @model_validator(mode="after")
228 def validate_transport_requirements(self):
229 """Validate transport-specific requirements and feature flags.
231 Returns:
232 MCPServerConfig: Validated config instance.
234 Raises:
235 ValueError: If transport requirements or feature flags are violated.
236 """
237 if self.transport in ["streamable_http", "sse"] and not self.url:
238 raise ValueError(f"URL is required for {self.transport} transport")
240 if self.transport == "stdio":
241 if not settings.mcpgateway_stdio_transport_enabled:
242 raise ValueError("stdio transport is disabled by default; set MCPGATEWAY_STDIO_TRANSPORT_ENABLED=true to enable it")
243 if not self.command:
244 raise ValueError("Command is required for stdio transport")
246 return self
248 model_config = {
249 "json_schema_extra": {
250 "examples": [
251 {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token-here"}, # nosec B105 - example placeholder
252 {"command": "python", "args": ["server.py"], "transport": "stdio"},
253 ]
254 }
255 }
258class AzureOpenAIConfig(BaseModel):
259 """
260 Configuration for Azure OpenAI provider.
262 Defines all necessary parameters to connect to and use Azure OpenAI services,
263 including API credentials, endpoints, model settings, and request parameters.
265 Attributes:
266 api_key: Azure OpenAI API authentication key.
267 azure_endpoint: Azure OpenAI service endpoint URL.
268 api_version: API version to use for requests.
269 azure_deployment: Name of the deployed model.
270 model: Model identifier for logging and tracing.
271 temperature: Sampling temperature for response generation (0.0-2.0).
272 max_tokens: Maximum number of tokens to generate.
273 timeout: Request timeout duration in seconds.
274 max_retries: Maximum number of retry attempts for failed requests.
276 Examples:
277 >>> config = AzureOpenAIConfig(
278 ... api_key="your-api-key",
279 ... azure_endpoint="https://your-resource.openai.azure.com/",
280 ... azure_deployment="gpt-4",
281 ... temperature=0.7
282 ... )
283 >>> config.model
284 'gpt-4'
285 >>> config.temperature
286 0.7
287 """
289 api_key: str = Field(..., description="Azure OpenAI API key")
290 azure_endpoint: str = Field(..., description="Azure OpenAI endpoint URL")
291 api_version: str = Field(default="2024-05-01-preview", description="Azure OpenAI API version")
292 azure_deployment: str = Field(..., description="Azure OpenAI deployment name")
293 model: str = Field(default="gpt-4", description="Model name for tracing")
294 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
295 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
296 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
297 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
299 model_config = {
300 "json_schema_extra": {
301 "example": {
302 "api_key": "your-api-key",
303 "azure_endpoint": "https://your-resource.openai.azure.com/",
304 "api_version": "2024-05-01-preview",
305 "azure_deployment": "gpt-4",
306 "model": "gpt-4",
307 "temperature": 0.7,
308 }
309 }
310 }
313class OllamaConfig(BaseModel):
314 """
315 Configuration for Ollama provider.
317 Defines parameters for connecting to a local or remote Ollama instance
318 for running open-source language models.
320 Attributes:
321 base_url: Ollama server base URL.
322 model: Name of the Ollama model to use.
323 temperature: Sampling temperature for response generation (0.0-2.0).
324 timeout: Request timeout duration in seconds.
325 num_ctx: Context window size for the model.
327 Examples:
328 >>> config = OllamaConfig(
329 ... base_url="http://localhost:11434",
330 ... model="llama2",
331 ... temperature=0.5
332 ... )
333 >>> config.model
334 'llama2'
335 >>> config.base_url
336 'http://localhost:11434'
337 """
339 base_url: str = Field(default="http://localhost:11434", description="Ollama base URL")
340 model: str = Field(default="llama2", description="Model name to use")
341 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
342 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
343 num_ctx: Optional[int] = Field(None, gt=0, description="Context window size")
345 model_config = {"json_schema_extra": {"example": {"base_url": "http://localhost:11434", "model": "llama2", "temperature": 0.7}}}
348class OpenAIConfig(BaseModel):
349 """
350 Configuration for OpenAI provider (non-Azure).
352 Defines parameters for connecting to OpenAI API (or OpenAI-compatible endpoints).
354 Attributes:
355 api_key: OpenAI API authentication key.
356 base_url: Optional base URL for OpenAI-compatible endpoints.
357 model: Model identifier (e.g., gpt-4, gpt-3.5-turbo).
358 temperature: Sampling temperature for response generation (0.0-2.0).
359 max_tokens: Maximum number of tokens to generate.
360 timeout: Request timeout duration in seconds.
361 max_retries: Maximum number of retry attempts for failed requests.
363 Examples:
364 >>> config = OpenAIConfig(
365 ... api_key="sk-...",
366 ... model="gpt-4",
367 ... temperature=0.7
368 ... )
369 >>> config.model
370 'gpt-4'
371 """
373 api_key: str = Field(..., description="OpenAI API key")
374 base_url: Optional[str] = Field(None, description="Base URL for OpenAI-compatible endpoints")
375 model: str = Field(default="gpt-4o-mini", description="Model name (e.g., gpt-4, gpt-3.5-turbo)")
376 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
377 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
378 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
379 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
380 default_headers: Optional[dict] = Field(None, description="optional default headers required by the provider")
382 model_config = {
383 "json_schema_extra": {
384 "example": {
385 "api_key": "sk-...",
386 "model": "gpt-4o-mini",
387 "temperature": 0.7,
388 }
389 }
390 }
393class AnthropicConfig(BaseModel):
394 """
395 Configuration for Anthropic Claude provider.
397 Defines parameters for connecting to Anthropic's Claude API.
399 Attributes:
400 api_key: Anthropic API authentication key.
401 model: Claude model identifier (e.g., claude-3-5-sonnet-20241022, claude-3-opus).
402 temperature: Sampling temperature for response generation (0.0-1.0).
403 max_tokens: Maximum number of tokens to generate.
404 timeout: Request timeout duration in seconds.
405 max_retries: Maximum number of retry attempts for failed requests.
407 Examples:
408 >>> config = AnthropicConfig(
409 ... api_key="sk-ant-...",
410 ... model="claude-3-5-sonnet-20241022",
411 ... temperature=0.7
412 ... )
413 >>> config.model
414 'claude-3-5-sonnet-20241022'
415 """
417 api_key: str = Field(..., description="Anthropic API key")
418 model: str = Field(default="claude-3-5-sonnet-20241022", description="Claude model name")
419 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature")
420 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate")
421 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
422 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries")
424 model_config = {
425 "json_schema_extra": {
426 "example": {
427 "api_key": "sk-ant-...",
428 "model": "claude-3-5-sonnet-20241022",
429 "temperature": 0.7,
430 "max_tokens": 4096,
431 }
432 }
433 }
436class AWSBedrockConfig(BaseModel):
437 """
438 Configuration for AWS Bedrock provider.
440 Defines parameters for connecting to AWS Bedrock LLM services.
442 Attributes:
443 model_id: Bedrock model identifier (e.g., anthropic.claude-v2, amazon.titan-text-express-v1).
444 region_name: AWS region name (e.g., us-east-1, us-west-2).
445 aws_access_key_id: Optional AWS access key ID (uses default credential chain if not provided).
446 aws_secret_access_key: Optional AWS secret access key.
447 aws_session_token: Optional AWS session token for temporary credentials.
448 temperature: Sampling temperature for response generation (0.0-1.0).
449 max_tokens: Maximum number of tokens to generate.
451 Examples:
452 >>> config = AWSBedrockConfig(
453 ... model_id="anthropic.claude-v2",
454 ... region_name="us-east-1",
455 ... temperature=0.7
456 ... )
457 >>> config.model_id
458 'anthropic.claude-v2'
459 """
461 model_id: str = Field(..., description="Bedrock model ID")
462 region_name: str = Field(default="us-east-1", description="AWS region name")
463 aws_access_key_id: Optional[str] = Field(None, description="AWS access key ID")
464 aws_secret_access_key: Optional[str] = Field(None, description="AWS secret access key")
465 aws_session_token: Optional[str] = Field(None, description="AWS session token")
466 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature")
467 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate")
469 model_config = {
470 "json_schema_extra": {
471 "example": {
472 "model_id": "anthropic.claude-v2",
473 "region_name": "us-east-1",
474 "temperature": 0.7,
475 "max_tokens": 4096,
476 }
477 }
478 }
481class WatsonxConfig(BaseModel):
482 """
483 Configuration for IBM watsonx.ai provider.
485 Defines parameters for connecting to IBM watsonx.ai services.
487 Attributes:
488 api_key: IBM Cloud API key for authentication.
489 url: IBM watsonx.ai service endpoint URL.
490 project_id: IBM watsonx.ai project ID for context.
491 model_id: Model identifier (e.g., ibm/granite-13b-chat-v2, meta-llama/llama-3-70b-instruct).
492 temperature: Sampling temperature for response generation (0.0-2.0).
493 max_new_tokens: Maximum number of tokens to generate.
494 min_new_tokens: Minimum number of tokens to generate.
495 decoding_method: Decoding method ('sample', 'greedy').
496 top_k: Top-K sampling parameter.
497 top_p: Top-P (nucleus) sampling parameter.
498 timeout: Request timeout duration in seconds.
500 Examples:
501 >>> config = WatsonxConfig(
502 ... api_key="your-api-key",
503 ... url="https://us-south.ml.cloud.ibm.com",
504 ... project_id="your-project-id",
505 ... model_id="ibm/granite-13b-chat-v2"
506 ... )
507 >>> config.model_id
508 'ibm/granite-13b-chat-v2'
509 """
511 api_key: str = Field(..., description="IBM Cloud API key")
512 url: str = Field(default="https://us-south.ml.cloud.ibm.com", description="watsonx.ai endpoint URL")
513 project_id: str = Field(..., description="watsonx.ai project ID")
514 model_id: str = Field(default="ibm/granite-13b-chat-v2", description="Model identifier")
515 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
516 max_new_tokens: Optional[int] = Field(default=1024, gt=0, description="Maximum tokens to generate")
517 min_new_tokens: Optional[int] = Field(default=1, gt=0, description="Minimum tokens to generate")
518 decoding_method: str = Field(default="sample", description="Decoding method (sample or greedy)")
519 top_k: Optional[int] = Field(default=50, gt=0, description="Top-K sampling")
520 top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-P sampling")
521 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
523 model_config = {
524 "json_schema_extra": {
525 "example": {
526 "api_key": "your-api-key",
527 "url": "https://us-south.ml.cloud.ibm.com",
528 "project_id": "your-project-id",
529 "model_id": "ibm/granite-13b-chat-v2",
530 "temperature": 0.7,
531 "max_new_tokens": 1024,
532 }
533 }
534 }
537class GatewayConfig(BaseModel):
538 """
539 Configuration for ContextForge internal LLM provider.
541 Allows LLM Chat to use models configured in the gateway's LLM Settings.
542 The gateway routes requests to the appropriate configured provider.
544 Attributes:
545 model: Model ID (gateway model ID or provider model ID).
546 base_url: Gateway internal API URL (defaults to self).
547 temperature: Sampling temperature for response generation.
548 max_tokens: Maximum tokens to generate.
549 timeout: Request timeout in seconds.
551 Examples:
552 >>> config = GatewayConfig(model="gpt-4o")
553 >>> config.model
554 'gpt-4o'
555 """
557 model: str = Field(..., description="Gateway model ID to use")
558 base_url: Optional[str] = Field(None, description="Gateway internal API URL (optional, defaults to self)")
559 temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
560 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
561 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds")
563 model_config = {
564 "json_schema_extra": {
565 "example": {
566 "model": "gpt-4o",
567 "temperature": 0.7,
568 "max_tokens": 4096,
569 }
570 }
571 }
574class LLMConfig(BaseModel):
575 """
576 Configuration for LLM provider.
578 Unified configuration class that supports multiple LLM providers through
579 a discriminated union pattern.
581 Attributes:
582 provider: Type of LLM provider (azure_openai, openai, anthropic, aws_bedrock, or ollama).
583 config: Provider-specific configuration object.
585 Examples:
586 >>> # Azure OpenAI configuration
587 >>> config = LLMConfig(
588 ... provider="azure_openai",
589 ... config=AzureOpenAIConfig(
590 ... api_key="key",
591 ... azure_endpoint="https://example.com/",
592 ... azure_deployment="gpt-4"
593 ... )
594 ... )
595 >>> config.provider
596 'azure_openai'
598 >>> # OpenAI configuration
599 >>> config = LLMConfig(
600 ... provider="openai",
601 ... config=OpenAIConfig(
602 ... api_key="sk-...",
603 ... model="gpt-4"
604 ... )
605 ... )
606 >>> config.provider
607 'openai'
609 >>> # Ollama configuration
610 >>> config = LLMConfig(
611 ... provider="ollama",
612 ... config=OllamaConfig(model="llama2")
613 ... )
614 >>> config.provider
615 'ollama'
617 >>> # Watsonx configuration
618 >>> config = LLMConfig(
619 ... provider="watsonx",
620 ... config=WatsonxConfig(
621 ... url="https://us-south.ml.cloud.ibm.com",
622 ... model_id="ibm/granite-13b-instruct-v2",
623 ... project_id="YOUR_PROJECT_ID",
624 ... api_key="YOUR_API")
625 ... )
626 >>> config.provider
627 'watsonx'
628 """
630 provider: Literal["azure_openai", "openai", "anthropic", "aws_bedrock", "ollama", "watsonx", "gateway"] = Field(..., description="LLM provider type")
631 config: Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig] = Field(..., description="Provider-specific configuration")
633 @field_validator("config", mode="before")
634 @classmethod
635 def validate_config_type(cls, v: Any, info) -> Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig]:
636 """
637 Validate and convert config dictionary to appropriate provider type.
639 Args:
640 v: Configuration value (dict or config object).
641 info: Validation context containing provider information.
643 Returns:
644 Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig]: Validated configuration object.
646 Examples:
647 >>> # Automatically converts dict to appropriate config type
648 >>> config_dict = {
649 ... "api_key": "key",
650 ... "azure_endpoint": "https://example.com/",
651 ... "azure_deployment": "gpt-4"
652 ... }
653 >>> # Used internally by Pydantic during validation
654 """
655 provider = info.data.get("provider")
657 if isinstance(v, dict):
658 if provider == "azure_openai":
659 return AzureOpenAIConfig(**v)
660 if provider == "openai":
661 return OpenAIConfig(**v)
662 if provider == "anthropic":
663 return AnthropicConfig(**v)
664 if provider == "aws_bedrock":
665 return AWSBedrockConfig(**v)
666 if provider == "ollama":
667 return OllamaConfig(**v)
668 if provider == "watsonx":
669 return WatsonxConfig(**v)
670 if provider == "gateway":
671 return GatewayConfig(**v)
673 return v
676class MCPClientConfig(BaseModel):
677 """
678 Main configuration for MCP client service.
680 Aggregates all configuration parameters required for the complete MCP client
681 service, including server connection, LLM provider, and operational settings.
683 Attributes:
684 mcp_server: MCP server connection configuration.
685 llm: LLM provider configuration.
686 chat_history_max_messages: Maximum messages to retain in chat history.
687 enable_streaming: Whether to enable streaming responses.
689 Examples:
690 >>> config = MCPClientConfig(
691 ... mcp_server=MCPServerConfig(
692 ... url="https://mcp-server.example.com/mcp",
693 ... transport="streamable_http"
694 ... ),
695 ... llm=LLMConfig(
696 ... provider="ollama",
697 ... config=OllamaConfig(model="llama2")
698 ... ),
699 ... chat_history_max_messages=100,
700 ... enable_streaming=True
701 ... )
702 >>> config.chat_history_max_messages
703 100
704 >>> config.enable_streaming
705 True
706 """
708 mcp_server: MCPServerConfig = Field(..., description="MCP server configuration")
709 llm: LLMConfig = Field(..., description="LLM provider configuration")
710 chat_history_max_messages: int = settings.llmchat_chat_history_max_messages
711 enable_streaming: bool = Field(default=True, description="Enable streaming responses")
713 model_config = {
714 "json_schema_extra": {
715 "example": {
716 "mcp_server": {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token"}, # nosec B105 - example placeholder
717 "llm": {
718 "provider": "azure_openai",
719 "config": {"api_key": "your-key", "azure_endpoint": "https://your-resource.openai.azure.com/", "azure_deployment": "gpt-4", "api_version": "2024-05-01-preview"},
720 },
721 }
722 }
723 }
726# ==================== LLM PROVIDER IMPLEMENTATIONS ====================
729class AzureOpenAIProvider:
730 """
731 Azure OpenAI provider implementation.
733 Manages connection and interaction with Azure OpenAI services.
735 Attributes:
736 config: Azure OpenAI configuration object.
738 Examples:
739 >>> config = AzureOpenAIConfig(
740 ... api_key="key",
741 ... azure_endpoint="https://example.openai.azure.com/",
742 ... azure_deployment="gpt-4"
743 ... )
744 >>> provider = AzureOpenAIProvider(config)
745 >>> provider.get_model_name()
746 'gpt-4'
748 Note:
749 The LLM instance is lazily initialized on first access for
750 improved startup performance.
751 """
753 def __init__(self, config: AzureOpenAIConfig):
754 """
755 Initialize Azure OpenAI provider.
757 Args:
758 config: Azure OpenAI configuration with API credentials and settings.
760 Examples:
761 >>> config = AzureOpenAIConfig(
762 ... api_key="key",
763 ... azure_endpoint="https://example.openai.azure.com/",
764 ... azure_deployment="gpt-4"
765 ... )
766 >>> provider = AzureOpenAIProvider(config)
767 """
768 self.config = config
769 self._llm = None
770 logger.info(f"Initializing Azure OpenAI provider with deployment: {config.azure_deployment}")
772 def get_llm(self, model_type: str = "chat") -> Union[AzureChatOpenAI, AzureOpenAI]:
773 """
774 Get Azure OpenAI LLM instance with lazy initialization.
776 Creates and caches the Azure OpenAI chat model instance on first call.
777 Subsequent calls return the cached instance.
779 Args:
780 model_type: LLM inference model type such as 'chat' model , text 'completion' model
782 Returns:
783 AzureChatOpenAI: Configured Azure OpenAI chat model.
785 Raises:
786 Exception: If LLM initialization fails (e.g., invalid credentials).
788 Examples:
789 >>> config = AzureOpenAIConfig(
790 ... api_key="key",
791 ... azure_endpoint="https://example.openai.azure.com/",
792 ... azure_deployment="gpt-4"
793 ... )
794 >>> provider = AzureOpenAIProvider(config)
795 >>> # llm = provider.get_llm() # Returns AzureChatOpenAI instance
796 """
797 if self._llm is None:
798 try:
799 if model_type == "chat":
800 self._llm = AzureChatOpenAI(
801 api_key=self.config.api_key,
802 azure_endpoint=self.config.azure_endpoint,
803 api_version=self.config.api_version,
804 azure_deployment=self.config.azure_deployment,
805 model=self.config.model,
806 temperature=self.config.temperature,
807 max_tokens=self.config.max_tokens,
808 timeout=self.config.timeout,
809 max_retries=self.config.max_retries,
810 )
811 elif model_type == "completion":
812 self._llm = AzureOpenAI(
813 api_key=self.config.api_key,
814 azure_endpoint=self.config.azure_endpoint,
815 api_version=self.config.api_version,
816 azure_deployment=self.config.azure_deployment,
817 model=self.config.model,
818 temperature=self.config.temperature,
819 max_tokens=self.config.max_tokens,
820 timeout=self.config.timeout,
821 max_retries=self.config.max_retries,
822 )
823 logger.info("Azure OpenAI LLM instance created successfully")
824 except Exception as e:
825 logger.error(f"Failed to create Azure OpenAI LLM: {e}")
826 raise
828 return self._llm
830 def get_model_name(self) -> str:
831 """
832 Get the Azure OpenAI model name.
834 Returns:
835 str: The model name configured for this provider.
837 Examples:
838 >>> config = AzureOpenAIConfig(
839 ... api_key="key",
840 ... azure_endpoint="https://example.openai.azure.com/",
841 ... azure_deployment="gpt-4",
842 ... model="gpt-4"
843 ... )
844 >>> provider = AzureOpenAIProvider(config)
845 >>> provider.get_model_name()
846 'gpt-4'
847 """
848 return self.config.model
851class OllamaProvider:
852 """
853 Ollama provider implementation.
855 Manages connection and interaction with Ollama instances for running
856 open-source language models locally or remotely.
858 Attributes:
859 config: Ollama configuration object.
861 Examples:
862 >>> config = OllamaConfig(
863 ... base_url="http://localhost:11434",
864 ... model="llama2"
865 ... )
866 >>> provider = OllamaProvider(config)
867 >>> provider.get_model_name()
868 'llama2'
870 Note:
871 Requires Ollama to be running and accessible at the configured base_url.
872 """
874 def __init__(self, config: OllamaConfig):
875 """
876 Initialize Ollama provider.
878 Args:
879 config: Ollama configuration with server URL and model settings.
881 Examples:
882 >>> config = OllamaConfig(model="llama2")
883 >>> provider = OllamaProvider(config)
884 """
885 self.config = config
886 self._llm = None
887 logger.info(f"Initializing Ollama provider with model: {config.model}")
889 def get_llm(self, model_type: str = "chat") -> Union[ChatOllama, OllamaLLM]:
890 """
891 Get Ollama LLM instance with lazy initialization.
893 Creates and caches the Ollama chat model instance on first call.
894 Subsequent calls return the cached instance.
896 Args:
897 model_type: LLM inference model type such as 'chat' model , text 'completion' model
899 Returns:
900 ChatOllama: Configured Ollama chat model.
902 Raises:
903 Exception: If LLM initialization fails (e.g., Ollama not running).
905 Examples:
906 >>> config = OllamaConfig(model="llama2")
907 >>> provider = OllamaProvider(config)
908 >>> # llm = provider.get_llm() # Returns ChatOllama instance
909 """
910 if self._llm is None:
911 try:
912 # Build model kwargs
913 model_kwargs = {}
914 if self.config.num_ctx is not None:
915 model_kwargs["num_ctx"] = self.config.num_ctx
917 if model_type == "chat":
918 self._llm = ChatOllama(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs)
919 elif model_type == "completion":
920 self._llm = OllamaLLM(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs)
921 logger.info("Ollama LLM instance created successfully")
922 except Exception as e:
923 logger.error(f"Failed to create Ollama LLM: {e}")
924 raise
926 return self._llm
928 def get_model_name(self) -> str:
929 """Get the model name.
931 Returns:
932 str: The model name
933 """
934 return self.config.model
937class OpenAIProvider:
938 """
939 OpenAI provider implementation (non-Azure).
941 Manages connection and interaction with OpenAI API or OpenAI-compatible endpoints.
943 Attributes:
944 config: OpenAI configuration object.
946 Examples:
947 >>> config = OpenAIConfig(
948 ... api_key="sk-...",
949 ... model="gpt-4"
950 ... )
951 >>> provider = OpenAIProvider(config)
952 >>> provider.get_model_name()
953 'gpt-4'
955 Note:
956 The LLM instance is lazily initialized on first access for
957 improved startup performance.
958 """
960 def __init__(self, config: OpenAIConfig):
961 """
962 Initialize OpenAI provider.
964 Args:
965 config: OpenAI configuration with API key and settings.
967 Examples:
968 >>> config = OpenAIConfig(
969 ... api_key="sk-...",
970 ... model="gpt-4"
971 ... )
972 >>> provider = OpenAIProvider(config)
973 """
974 self.config = config
975 self._llm = None
976 logger.info(f"Initializing OpenAI provider with model: {config.model}")
978 def get_llm(self, model_type="chat") -> Union[ChatOpenAI, OpenAI]:
979 """
980 Get OpenAI LLM instance with lazy initialization.
982 Creates and caches the OpenAI chat model instance on first call.
983 Subsequent calls return the cached instance.
985 Args:
986 model_type: LLM inference model type such as 'chat' model , text 'completion' model
988 Returns:
989 ChatOpenAI: Configured OpenAI chat model.
991 Raises:
992 Exception: If LLM initialization fails (e.g., invalid credentials).
994 Examples:
995 >>> config = OpenAIConfig(
996 ... api_key="sk-...",
997 ... model="gpt-4"
998 ... )
999 >>> provider = OpenAIProvider(config)
1000 >>> # llm = provider.get_llm() # Returns ChatOpenAI instance
1001 """
1002 if self._llm is None:
1003 try:
1004 kwargs = {
1005 "api_key": self.config.api_key,
1006 "model": self.config.model,
1007 "temperature": self.config.temperature,
1008 "max_tokens": self.config.max_tokens,
1009 "timeout": self.config.timeout,
1010 "max_retries": self.config.max_retries,
1011 }
1013 if self.config.base_url:
1014 kwargs["base_url"] = self.config.base_url
1016 # add default headers if present
1017 if self.config.default_headers is not None:
1018 kwargs["default_headers"] = self.config.default_headers
1020 if model_type == "chat":
1021 self._llm = ChatOpenAI(**kwargs)
1022 elif model_type == "completion":
1023 self._llm = OpenAI(**kwargs)
1025 logger.info("OpenAI LLM instance created successfully")
1026 except Exception as e:
1027 logger.error(f"Failed to create OpenAI LLM: {e}")
1028 raise
1030 return self._llm
1032 def get_model_name(self) -> str:
1033 """
1034 Get the OpenAI model name.
1036 Returns:
1037 str: The model name configured for this provider.
1039 Examples:
1040 >>> config = OpenAIConfig(
1041 ... api_key="sk-...",
1042 ... model="gpt-4"
1043 ... )
1044 >>> provider = OpenAIProvider(config)
1045 >>> provider.get_model_name()
1046 'gpt-4'
1047 """
1048 return self.config.model
1051class AnthropicProvider:
1052 """
1053 Anthropic Claude provider implementation.
1055 Manages connection and interaction with Anthropic's Claude API.
1057 Attributes:
1058 config: Anthropic configuration object.
1060 Examples:
1061 >>> config = AnthropicConfig( # doctest: +SKIP
1062 ... api_key="sk-ant-...",
1063 ... model="claude-3-5-sonnet-20241022"
1064 ... )
1065 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1066 >>> provider.get_model_name() # doctest: +SKIP
1067 'claude-3-5-sonnet-20241022'
1069 Note:
1070 Requires langchain-anthropic package to be installed.
1071 """
1073 def __init__(self, config: AnthropicConfig):
1074 """
1075 Initialize Anthropic provider.
1077 Args:
1078 config: Anthropic configuration with API key and settings.
1080 Raises:
1081 ImportError: If langchain-anthropic is not installed.
1083 Examples:
1084 >>> config = AnthropicConfig( # doctest: +SKIP
1085 ... api_key="sk-ant-...",
1086 ... model="claude-3-5-sonnet-20241022"
1087 ... )
1088 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1089 """
1090 if not _ANTHROPIC_AVAILABLE:
1091 raise ImportError("Anthropic provider requires langchain-anthropic package. Install it with: pip install langchain-anthropic")
1093 self.config = config
1094 self._llm = None
1095 logger.info(f"Initializing Anthropic provider with model: {config.model}")
1097 def get_llm(self, model_type: str = "chat") -> Union[ChatAnthropic, AnthropicLLM]:
1098 """
1099 Get Anthropic LLM instance with lazy initialization.
1101 Creates and caches the Anthropic chat model instance on first call.
1102 Subsequent calls return the cached instance.
1104 Args:
1105 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1107 Returns:
1108 ChatAnthropic: Configured Anthropic chat model.
1110 Raises:
1111 Exception: If LLM initialization fails (e.g., invalid credentials).
1113 Examples:
1114 >>> config = AnthropicConfig( # doctest: +SKIP
1115 ... api_key="sk-ant-...",
1116 ... model="claude-3-5-sonnet-20241022"
1117 ... )
1118 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1119 >>> # llm = provider.get_llm() # Returns ChatAnthropic instance
1120 """
1121 if self._llm is None:
1122 try:
1123 if model_type == "chat":
1124 self._llm = ChatAnthropic(
1125 api_key=self.config.api_key,
1126 model=self.config.model,
1127 temperature=self.config.temperature,
1128 max_tokens=self.config.max_tokens,
1129 timeout=self.config.timeout,
1130 max_retries=self.config.max_retries,
1131 )
1132 elif model_type == "completion":
1133 self._llm = AnthropicLLM(
1134 api_key=self.config.api_key,
1135 model=self.config.model,
1136 temperature=self.config.temperature,
1137 max_tokens=self.config.max_tokens,
1138 timeout=self.config.timeout,
1139 max_retries=self.config.max_retries,
1140 )
1141 logger.info("Anthropic LLM instance created successfully")
1142 except Exception as e:
1143 logger.error(f"Failed to create Anthropic LLM: {e}")
1144 raise
1146 return self._llm
1148 def get_model_name(self) -> str:
1149 """
1150 Get the Anthropic model name.
1152 Returns:
1153 str: The model name configured for this provider.
1155 Examples:
1156 >>> config = AnthropicConfig( # doctest: +SKIP
1157 ... api_key="sk-ant-...",
1158 ... model="claude-3-5-sonnet-20241022"
1159 ... )
1160 >>> provider = AnthropicProvider(config) # doctest: +SKIP
1161 >>> provider.get_model_name() # doctest: +SKIP
1162 'claude-3-5-sonnet-20241022'
1163 """
1164 return self.config.model
1167class AWSBedrockProvider:
1168 """
1169 AWS Bedrock provider implementation.
1171 Manages connection and interaction with AWS Bedrock LLM services.
1173 Attributes:
1174 config: AWS Bedrock configuration object.
1176 Examples:
1177 >>> config = AWSBedrockConfig( # doctest: +SKIP
1178 ... model_id="anthropic.claude-v2",
1179 ... region_name="us-east-1"
1180 ... )
1181 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1182 >>> provider.get_model_name() # doctest: +SKIP
1183 'anthropic.claude-v2'
1185 Note:
1186 Requires langchain-aws package and boto3 to be installed.
1187 Uses AWS default credential chain if credentials not explicitly provided.
1188 """
1190 def __init__(self, config: AWSBedrockConfig):
1191 """
1192 Initialize AWS Bedrock provider.
1194 Args:
1195 config: AWS Bedrock configuration with model ID and settings.
1197 Raises:
1198 ImportError: If langchain-aws is not installed.
1200 Examples:
1201 >>> config = AWSBedrockConfig( # doctest: +SKIP
1202 ... model_id="anthropic.claude-v2",
1203 ... region_name="us-east-1"
1204 ... )
1205 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1206 """
1207 if not _BEDROCK_AVAILABLE:
1208 raise ImportError("AWS Bedrock provider requires langchain-aws package. Install it with: pip install langchain-aws boto3")
1210 self.config = config
1211 self._llm = None
1212 logger.info(f"Initializing AWS Bedrock provider with model: {config.model_id}")
1214 def get_llm(self, model_type: str = "chat") -> Union[ChatBedrock, BedrockLLM]:
1215 """
1216 Get AWS Bedrock LLM instance with lazy initialization.
1218 Creates and caches the Bedrock chat model instance on first call.
1219 Subsequent calls return the cached instance.
1221 Args:
1222 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1224 Returns:
1225 ChatBedrock: Configured AWS Bedrock chat model.
1227 Raises:
1228 Exception: If LLM initialization fails (e.g., invalid credentials, permissions).
1230 Examples:
1231 >>> config = AWSBedrockConfig( # doctest: +SKIP
1232 ... model_id="anthropic.claude-v2",
1233 ... region_name="us-east-1"
1234 ... )
1235 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1236 >>> # llm = provider.get_llm() # Returns ChatBedrock instance
1237 """
1238 if self._llm is None:
1239 try:
1240 # Build credentials dict if provided
1241 credentials_kwargs = {}
1242 if self.config.aws_access_key_id:
1243 credentials_kwargs["aws_access_key_id"] = self.config.aws_access_key_id
1244 if self.config.aws_secret_access_key:
1245 credentials_kwargs["aws_secret_access_key"] = self.config.aws_secret_access_key
1246 if self.config.aws_session_token:
1247 credentials_kwargs["aws_session_token"] = self.config.aws_session_token
1249 if model_type == "chat":
1250 self._llm = ChatBedrock(
1251 model_id=self.config.model_id,
1252 region_name=self.config.region_name,
1253 model_kwargs={
1254 "temperature": self.config.temperature,
1255 "max_tokens": self.config.max_tokens,
1256 },
1257 **credentials_kwargs,
1258 )
1259 elif model_type == "completion":
1260 self._llm = BedrockLLM(
1261 model_id=self.config.model_id,
1262 region_name=self.config.region_name,
1263 model_kwargs={
1264 "temperature": self.config.temperature,
1265 "max_tokens": self.config.max_tokens,
1266 },
1267 **credentials_kwargs,
1268 )
1269 logger.info("AWS Bedrock LLM instance created successfully")
1270 except Exception as e:
1271 logger.error(f"Failed to create AWS Bedrock LLM: {e}")
1272 raise
1274 return self._llm
1276 def get_model_name(self) -> str:
1277 """
1278 Get the AWS Bedrock model ID.
1280 Returns:
1281 str: The model ID configured for this provider.
1283 Examples:
1284 >>> config = AWSBedrockConfig( # doctest: +SKIP
1285 ... model_id="anthropic.claude-v2",
1286 ... region_name="us-east-1"
1287 ... )
1288 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP
1289 >>> provider.get_model_name() # doctest: +SKIP
1290 'anthropic.claude-v2'
1291 """
1292 return self.config.model_id
1295class WatsonxProvider:
1296 """
1297 IBM watsonx.ai provider implementation.
1299 Manages connection and interaction with IBM watsonx.ai services.
1301 Attributes:
1302 config: IBM watsonx.ai configuration object.
1304 Examples:
1305 >>> config = WatsonxConfig( # doctest: +SKIP
1306 ... api_key="key",
1307 ... url="https://us-south.ml.cloud.ibm.com",
1308 ... project_id="project-id",
1309 ... model_id="ibm/granite-13b-chat-v2"
1310 ... )
1311 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1312 >>> provider.get_model_name() # doctest: +SKIP
1313 'ibm/granite-13b-chat-v2'
1315 Note:
1316 Requires langchain-ibm package to be installed.
1317 """
1319 def __init__(self, config: WatsonxConfig):
1320 """
1321 Initialize IBM watsonx.ai provider.
1323 Args:
1324 config: IBM watsonx.ai configuration with credentials and settings.
1326 Raises:
1327 ImportError: If langchain-ibm is not installed.
1329 Examples:
1330 >>> config = WatsonxConfig( # doctest: +SKIP
1331 ... api_key="key",
1332 ... url="https://us-south.ml.cloud.ibm.com",
1333 ... project_id="project-id",
1334 ... model_id="ibm/granite-13b-chat-v2"
1335 ... )
1336 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1337 """
1338 if not _WATSONX_AVAILABLE:
1339 raise ImportError("IBM watsonx.ai provider requires langchain-ibm package. Install it with: pip install langchain-ibm")
1340 self.config = config
1341 self.llm = None
1342 logger.info(f"Initializing IBM watsonx.ai provider with model {config.model_id}")
1344 def get_llm(self, model_type="chat") -> Union[WatsonxLLM, ChatWatsonx]:
1345 """
1346 Get IBM watsonx.ai LLM instance with lazy initialization.
1348 Creates and caches the watsonx LLM instance on first call.
1349 Subsequent calls return the cached instance.
1351 Args:
1352 model_type: LLM inference model type such as 'chat' model , text 'completion' model
1354 Returns:
1355 WatsonxLLM: Configured IBM watsonx.ai LLM model.
1357 Raises:
1358 Exception: If LLM initialization fails (e.g., invalid credentials).
1360 Examples:
1361 >>> config = WatsonxConfig( # doctest: +SKIP
1362 ... api_key="key",
1363 ... url="https://us-south.ml.cloud.ibm.com",
1364 ... project_id="project-id",
1365 ... model_id="ibm/granite-13b-chat-v2"
1366 ... )
1367 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1368 >>> #llm = provider.get_llm() # Returns WatsonxLLM instance
1369 """
1370 if self.llm is None:
1371 try:
1372 # Build parameters dict
1373 params = {
1374 "decoding_method": self.config.decoding_method,
1375 "temperature": self.config.temperature,
1376 "max_new_tokens": self.config.max_new_tokens,
1377 "min_new_tokens": self.config.min_new_tokens,
1378 }
1380 if self.config.top_k is not None:
1381 params["top_k"] = self.config.top_k
1382 if self.config.top_p is not None:
1383 params["top_p"] = self.config.top_p
1384 if model_type == "completion":
1385 # Initialize WatsonxLLM
1386 self.llm = WatsonxLLM(
1387 apikey=self.config.api_key,
1388 url=self.config.url,
1389 project_id=self.config.project_id,
1390 model_id=self.config.model_id,
1391 params=params,
1392 )
1393 elif model_type == "chat":
1394 # Initialize Chat WatsonxLLM
1395 self.llm = ChatWatsonx(
1396 apikey=self.config.api_key,
1397 url=self.config.url,
1398 project_id=self.config.project_id,
1399 model_id=self.config.model_id,
1400 params=params,
1401 )
1402 logger.info("IBM watsonx.ai LLM instance created successfully")
1403 except Exception as e:
1404 logger.error(f"Failed to create IBM watsonx.ai LLM: {e}")
1405 raise
1406 return self.llm
1408 def get_model_name(self) -> str:
1409 """
1410 Get the IBM watsonx.ai model ID.
1412 Returns:
1413 str: The model ID configured for this provider.
1415 Examples:
1416 >>> config = WatsonxConfig( # doctest: +SKIP
1417 ... api_key="key",
1418 ... url="https://us-south.ml.cloud.ibm.com",
1419 ... project_id="project-id",
1420 ... model_id="ibm/granite-13b-chat-v2"
1421 ... )
1422 >>> provider = WatsonxProvider(config) # doctest: +SKIP
1423 >>> provider.get_model_name() # doctest: +SKIP
1424 'ibm/granite-13b-chat-v2'
1425 """
1426 return self.config.model_id
1429class GatewayProvider:
1430 """
1431 Gateway provider implementation for using models configured in LLM Settings.
1433 Routes LLM requests through the gateway's configured providers, allowing
1434 users to use models set up via the Admin UI's LLM Settings without needing
1435 to configure credentials in environment variables or API requests.
1437 Attributes:
1438 config: Gateway configuration with model ID.
1439 llm: Lazily initialized LLM instance.
1441 Examples:
1442 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1443 >>> provider = GatewayProvider(config) # doctest: +SKIP
1444 >>> provider.get_model_name() # doctest: +SKIP
1445 'gpt-4o'
1447 Note:
1448 Requires models to be configured via Admin UI -> Settings -> LLM Settings.
1449 """
1451 def __init__(self, config: GatewayConfig):
1452 """
1453 Initialize Gateway provider.
1455 Args:
1456 config: Gateway configuration with model ID and optional settings.
1458 Examples:
1459 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1460 >>> provider = GatewayProvider(config) # doctest: +SKIP
1461 """
1462 self.config = config
1463 self.llm = None
1464 self._model_name: Optional[str] = None
1465 self._underlying_provider = None
1466 logger.info(f"Initializing Gateway provider with model: {config.model}")
1468 def get_llm(self, model_type: str = "chat") -> Union[BaseChatModel, Any]:
1469 """
1470 Get LLM instance by looking up model from gateway's LLM Settings.
1472 Fetches the model configuration from the database, decrypts API keys,
1473 and creates the appropriate LangChain LLM instance based on provider type.
1475 Args:
1476 model_type: Type of model to return ('chat' or 'completion'). Defaults to 'chat'.
1478 Returns:
1479 Union[BaseChatModel, Any]: Configured LangChain chat or completion model instance.
1481 Raises:
1482 ValueError: If model not found or provider not enabled.
1483 ImportError: If required provider package not installed.
1485 Examples:
1486 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1487 >>> provider = GatewayProvider(config) # doctest: +SKIP
1488 >>> llm = provider.get_llm() # doctest: +SKIP
1490 Note:
1491 The LLM instance is lazily initialized and cached by model_type.
1492 """
1493 if self.llm is not None:
1494 return self.llm
1496 # Import here to avoid circular imports
1497 # First-Party
1498 from mcpgateway.db import LLMModel, LLMProvider, SessionLocal # pylint: disable=import-outside-toplevel
1499 from mcpgateway.services.llm_provider_service import decrypt_provider_config_for_runtime # pylint: disable=import-outside-toplevel
1500 from mcpgateway.utils.services_auth import decode_auth # pylint: disable=import-outside-toplevel
1502 model_id = self.config.model
1504 with SessionLocal() as db:
1505 # Try to find model by UUID first, then by model_id
1506 model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
1507 if not model:
1508 model = db.query(LLMModel).filter(LLMModel.model_id == model_id).first()
1510 if not model:
1511 raise ValueError(f"Model '{model_id}' not found in LLM Settings. Configure it via Admin UI -> Settings -> LLM Settings.")
1513 if not model.enabled:
1514 raise ValueError(f"Model '{model.model_id}' is disabled. Enable it in LLM Settings.")
1516 # Get the provider
1517 provider = db.query(LLMProvider).filter(LLMProvider.id == model.provider_id).first()
1518 if not provider:
1519 raise ValueError(f"Provider not found for model '{model.model_id}'")
1521 if not provider.enabled:
1522 raise ValueError(f"Provider '{provider.name}' is disabled. Enable it in LLM Settings.")
1524 # Get decrypted API key
1525 api_key = None
1526 if provider.api_key:
1527 auth_data = decode_auth(provider.api_key)
1528 if isinstance(auth_data, dict):
1529 api_key = auth_data.get("api_key")
1530 else:
1531 api_key = auth_data
1533 # Store model name for get_model_name()
1534 self._model_name = model.model_id
1536 # Get temperature - use config override or provider default
1537 temperature = self.config.temperature if self.config.temperature is not None else (provider.default_temperature or 0.7)
1538 max_tokens = self.config.max_tokens or model.max_output_tokens
1540 # Create appropriate LLM based on provider type
1541 provider_type = provider.provider_type.lower()
1542 config = decrypt_provider_config_for_runtime(provider.config)
1544 # Common kwargs
1545 kwargs: Dict[str, Any] = {
1546 "temperature": temperature,
1547 "timeout": self.config.timeout,
1548 }
1550 if provider_type == "openai":
1551 kwargs.update(
1552 {
1553 "api_key": api_key,
1554 "model": model.model_id,
1555 "max_tokens": max_tokens,
1556 }
1557 )
1558 if provider.api_base:
1559 kwargs["base_url"] = provider.api_base
1561 # Handle default headers
1562 if config.get("default_headers"):
1563 kwargs["default_headers"] = config["default_headers"]
1564 elif hasattr(self.config, "default_headers") and self.config.default_headers: # type: ignore
1565 kwargs["default_headers"] = self.config.default_headers
1567 if model_type == "chat":
1568 self.llm = ChatOpenAI(**kwargs)
1569 else:
1570 self.llm = OpenAI(**kwargs)
1572 elif provider_type == "azure_openai":
1573 if not provider.api_base:
1574 raise ValueError("Azure OpenAI requires base_url (azure_endpoint) to be configured")
1576 azure_deployment = config.get("azure_deployment", model.model_id)
1577 api_version = config.get("api_version", "2024-05-01-preview")
1578 max_retries = config.get("max_retries", 2)
1580 kwargs.update(
1581 {
1582 "api_key": api_key,
1583 "azure_endpoint": provider.api_base,
1584 "azure_deployment": azure_deployment,
1585 "api_version": api_version,
1586 "model": model.model_id,
1587 "max_tokens": int(max_tokens) if max_tokens is not None else None,
1588 "max_retries": max_retries,
1589 }
1590 )
1592 if model_type == "chat":
1593 self.llm = AzureChatOpenAI(**kwargs)
1594 else:
1595 self.llm = AzureOpenAI(**kwargs)
1597 elif provider_type == "anthropic":
1598 if not _ANTHROPIC_AVAILABLE:
1599 raise ImportError("Anthropic provider requires langchain-anthropic. Install with: pip install langchain-anthropic")
1601 # Anthropic uses 'model_name' instead of 'model'
1602 anthropic_kwargs = {
1603 "api_key": api_key,
1604 "model_name": model.model_id,
1605 "max_tokens": max_tokens or 4096,
1606 "temperature": temperature,
1607 "timeout": self.config.timeout,
1608 "default_request_timeout": self.config.timeout,
1609 }
1611 if model_type == "chat":
1612 self.llm = ChatAnthropic(**anthropic_kwargs)
1613 else:
1614 # Generic Anthropic completion model if needed, though mostly chat used now
1615 if AnthropicLLM:
1616 llm_kwargs = anthropic_kwargs.copy()
1617 llm_kwargs["model"] = llm_kwargs.pop("model_name")
1618 self.llm = AnthropicLLM(**llm_kwargs)
1619 else:
1620 raise ImportError("Anthropic completion model (AnthropicLLM) not available")
1622 elif provider_type == "bedrock":
1623 if not _BEDROCK_AVAILABLE:
1624 raise ImportError("AWS Bedrock provider requires langchain-aws. Install with: pip install langchain-aws boto3")
1626 # Map DB schema keys to boto3 kwargs.
1627 # DB stores: region, access_key_id, secret_access_key, session_token, profile_name
1628 # (see llm_provider_configs.AWSBedrockConfig)
1629 region_name = config.get("region", "us-east-1")
1630 credentials_kwargs = {}
1631 if config.get("access_key_id"):
1632 credentials_kwargs["aws_access_key_id"] = config["access_key_id"]
1633 if config.get("secret_access_key"):
1634 credentials_kwargs["aws_secret_access_key"] = config["secret_access_key"]
1635 if config.get("session_token"):
1636 credentials_kwargs["aws_session_token"] = config["session_token"]
1637 if config.get("profile_name"):
1638 credentials_kwargs["credentials_profile_name"] = config["profile_name"]
1640 model_kwargs = {
1641 "temperature": temperature,
1642 "max_tokens": max_tokens or 4096,
1643 }
1645 if model_type == "chat":
1646 self.llm = ChatBedrock(
1647 model_id=model.model_id,
1648 region_name=region_name,
1649 model_kwargs=model_kwargs,
1650 **credentials_kwargs,
1651 )
1652 else:
1653 self.llm = BedrockLLM(
1654 model_id=model.model_id,
1655 region_name=region_name,
1656 model_kwargs=model_kwargs,
1657 **credentials_kwargs,
1658 )
1660 elif provider_type == "ollama":
1661 base_url = provider.api_base or "http://localhost:11434"
1662 num_ctx = config.get("num_ctx")
1664 # Explicitly construct kwargs to avoid generic unpacking issues with Pydantic models
1665 ollama_kwargs = {
1666 "base_url": base_url,
1667 "model": model.model_id,
1668 "temperature": temperature,
1669 "timeout": self.config.timeout,
1670 }
1671 if num_ctx:
1672 ollama_kwargs["num_ctx"] = num_ctx
1674 if model_type == "chat":
1675 self.llm = ChatOllama(**ollama_kwargs)
1676 else:
1677 self.llm = OllamaLLM(**ollama_kwargs)
1679 elif provider_type == "watsonx":
1680 if not _WATSONX_AVAILABLE:
1681 raise ImportError("IBM watsonx.ai provider requires langchain-ibm. Install with: pip install langchain-ibm")
1683 project_id = config.get("project_id")
1684 if not project_id:
1685 raise ValueError("IBM watsonx.ai requires project_id in config")
1687 url = provider.api_base or "https://us-south.ml.cloud.ibm.com"
1689 params = {
1690 "temperature": temperature,
1691 "max_new_tokens": max_tokens or 1024,
1692 "min_new_tokens": config.get("min_new_tokens", 1),
1693 "decoding_method": config.get("decoding_method", "sample"),
1694 "top_k": config.get("top_k", 50),
1695 "top_p": config.get("top_p", 1.0),
1696 }
1698 if model_type == "chat":
1699 self.llm = ChatWatsonx(
1700 apikey=api_key,
1701 url=url,
1702 project_id=project_id,
1703 model_id=model.model_id,
1704 params=params,
1705 )
1706 else:
1707 self.llm = WatsonxLLM(
1708 apikey=api_key,
1709 url=url,
1710 project_id=project_id,
1711 model_id=model.model_id,
1712 params=params,
1713 )
1715 elif provider_type == "openai_compatible":
1716 if not provider.api_base:
1717 raise ValueError("OpenAI-compatible provider requires base_url to be configured")
1719 kwargs.update(
1720 {
1721 "api_key": api_key or "no-key-required",
1722 "model": model.model_id,
1723 "base_url": provider.api_base,
1724 "max_tokens": max_tokens,
1725 }
1726 )
1728 if model_type == "chat":
1729 self.llm = ChatOpenAI(**kwargs)
1730 else:
1731 self.llm = OpenAI(**kwargs)
1733 else:
1734 raise ValueError(f"Unsupported LLM provider: {provider_type}")
1736 logger.info(f"Gateway provider created LLM instance for model: {model.model_id} via {provider_type}")
1737 return self.llm
1739 def get_model_name(self) -> str:
1740 """
1741 Get the model name.
1743 Returns:
1744 str: The model name/ID.
1746 Examples:
1747 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP
1748 >>> provider = GatewayProvider(config) # doctest: +SKIP
1749 >>> provider.get_model_name() # doctest: +SKIP
1750 'gpt-4o'
1751 """
1752 return self._model_name or self.config.model
1755class LLMProviderFactory:
1756 """
1757 Factory for creating LLM providers.
1759 Implements the Factory pattern to instantiate the appropriate LLM provider
1760 based on configuration, abstracting away provider-specific initialization.
1762 Examples:
1763 >>> config = LLMConfig(
1764 ... provider="ollama",
1765 ... config=OllamaConfig(model="llama2")
1766 ... )
1767 >>> provider = LLMProviderFactory.create(config)
1768 >>> provider.get_model_name()
1769 'llama2'
1771 Note:
1772 This factory supports dynamic provider registration and ensures
1773 type safety through the LLMConfig discriminated union.
1774 """
1776 @staticmethod
1777 def create(llm_config: LLMConfig) -> Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]:
1778 """
1779 Create an LLM provider based on configuration.
1781 Args:
1782 llm_config: LLM configuration specifying provider type and settings.
1784 Returns:
1785 Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: Instantiated provider.
1787 Raises:
1788 ValueError: If provider type is not supported.
1789 ImportError: If required provider package is not installed.
1791 Examples:
1792 >>> # Create Azure OpenAI provider
1793 >>> config = LLMConfig(
1794 ... provider="azure_openai",
1795 ... config=AzureOpenAIConfig(
1796 ... api_key="key",
1797 ... azure_endpoint="https://example.com/",
1798 ... azure_deployment="gpt-4"
1799 ... )
1800 ... )
1801 >>> provider = LLMProviderFactory.create(config)
1802 >>> isinstance(provider, AzureOpenAIProvider)
1803 True
1805 >>> # Create OpenAI provider
1806 >>> config = LLMConfig(
1807 ... provider="openai",
1808 ... config=OpenAIConfig(
1809 ... api_key="sk-...",
1810 ... model="gpt-4"
1811 ... )
1812 ... )
1813 >>> provider = LLMProviderFactory.create(config)
1814 >>> isinstance(provider, OpenAIProvider)
1815 True
1817 >>> # Create Ollama provider
1818 >>> config = LLMConfig(
1819 ... provider="ollama",
1820 ... config=OllamaConfig(model="llama2")
1821 ... )
1822 >>> provider = LLMProviderFactory.create(config)
1823 >>> isinstance(provider, OllamaProvider)
1824 True
1825 """
1826 provider_map = {
1827 "azure_openai": AzureOpenAIProvider,
1828 "openai": OpenAIProvider,
1829 "anthropic": AnthropicProvider,
1830 "aws_bedrock": AWSBedrockProvider,
1831 "ollama": OllamaProvider,
1832 "watsonx": WatsonxProvider,
1833 "gateway": GatewayProvider,
1834 }
1836 provider_class = provider_map.get(llm_config.provider)
1838 if not provider_class:
1839 raise ValueError(f"Unsupported LLM provider: {llm_config.provider}. Supported providers: {list(provider_map.keys())}")
1841 logger.info(f"Creating LLM provider: {llm_config.provider}")
1842 return provider_class(llm_config.config)
1845# ==================== CHAT HISTORY MANAGER ====================
1848class ChatHistoryManager:
1849 """
1850 Centralized chat history management with Redis and in-memory fallback.
1852 Provides a unified interface for storing and retrieving chat histories across
1853 multiple workers using Redis, with automatic fallback to in-memory storage
1854 when Redis is not available.
1856 This class eliminates duplication between router and service layers by
1857 providing a single source of truth for all chat history operations.
1859 Attributes:
1860 redis_client: Optional Redis async client for distributed storage.
1861 max_messages: Maximum number of messages to retain per user.
1862 ttl: Time-to-live for Redis entries in seconds.
1863 _memory_store: In-memory dict fallback when Redis unavailable.
1865 Examples:
1866 >>> import asyncio
1867 >>> # Create manager without Redis (in-memory mode)
1868 >>> manager = ChatHistoryManager(redis_client=None, max_messages=50)
1869 >>> # asyncio.run(manager.save_history("user123", [{"role": "user", "content": "Hello"}]))
1870 >>> # history = asyncio.run(manager.get_history("user123"))
1871 >>> # len(history) >= 0
1872 True
1874 Note:
1875 Thread-safe for Redis operations. In-memory mode suitable for
1876 single-worker deployments only.
1877 """
1879 def __init__(self, redis_client: Optional[Any] = None, max_messages: int = 50, ttl: int = 3600):
1880 """
1881 Initialize chat history manager.
1883 Args:
1884 redis_client: Optional Redis async client. If None, uses in-memory storage.
1885 max_messages: Maximum messages to retain per user (default: 50).
1886 ttl: Time-to-live for Redis entries in seconds (default: 3600).
1888 Examples:
1889 >>> manager = ChatHistoryManager(redis_client=None, max_messages=100)
1890 >>> manager.max_messages
1891 100
1892 >>> manager.ttl
1893 3600
1894 """
1895 self.redis_client = redis_client
1896 self.max_messages = max_messages
1897 self.ttl = ttl
1898 self._memory_store: Dict[str, List[Dict[str, str]]] = {}
1900 if redis_client:
1901 logger.info("ChatHistoryManager initialized with Redis backend")
1902 else:
1903 logger.info("ChatHistoryManager initialized with in-memory backend")
1905 def _history_key(self, user_id: str) -> str:
1906 """
1907 Generate Redis key for user's chat history.
1909 Args:
1910 user_id: User identifier.
1912 Returns:
1913 str: Redis key string.
1915 Examples:
1916 >>> manager = ChatHistoryManager()
1917 >>> manager._history_key("user123")
1918 'chat_history:user123'
1919 """
1920 return f"chat_history:{user_id}"
1922 async def get_history(self, user_id: str) -> List[Dict[str, str]]:
1923 """
1924 Retrieve chat history for a user.
1926 Fetches history from Redis if available, otherwise from in-memory store.
1928 Args:
1929 user_id: User identifier.
1931 Returns:
1932 List[Dict[str, str]]: List of message dictionaries with 'role' and 'content' keys.
1933 Returns empty list if no history exists.
1935 Examples:
1936 >>> import asyncio
1937 >>> manager = ChatHistoryManager()
1938 >>> # history = asyncio.run(manager.get_history("user123"))
1939 >>> # isinstance(history, list)
1940 True
1942 Note:
1943 Automatically handles JSON deserialization errors by returning empty list.
1944 """
1945 if self.redis_client:
1946 try:
1947 data = await self.redis_client.get(self._history_key(user_id))
1948 if not data:
1949 return []
1950 return orjson.loads(data)
1951 except orjson.JSONDecodeError:
1952 logger.warning(f"Failed to decode chat history for user {SecurityValidator.sanitize_log_message(user_id)}")
1953 return []
1954 except Exception as e:
1955 logger.error(f"Error retrieving chat history from Redis for user {SecurityValidator.sanitize_log_message(user_id)}: {e}")
1956 return []
1957 else:
1958 return self._memory_store.get(user_id, [])
1960 async def save_history(self, user_id: str, history: List[Dict[str, str]]) -> None:
1961 """
1962 Save chat history for a user.
1964 Stores history in Redis (with TTL) if available, otherwise in memory.
1965 Automatically trims history to max_messages before saving.
1967 Args:
1968 user_id: User identifier.
1969 history: List of message dictionaries to save.
1971 Examples:
1972 >>> import asyncio
1973 >>> manager = ChatHistoryManager(max_messages=50)
1974 >>> messages = [{"role": "user", "content": "Hello"}]
1975 >>> # asyncio.run(manager.save_history("user123", messages))
1977 Note:
1978 History is automatically trimmed to max_messages limit before storage.
1979 """
1980 # Trim history before saving
1981 trimmed = self._trim_messages(history)
1983 if self.redis_client:
1984 try:
1985 await self.redis_client.set(self._history_key(user_id), orjson.dumps(trimmed), ex=self.ttl)
1986 except Exception as e:
1987 logger.error(f"Error saving chat history to Redis for user {SecurityValidator.sanitize_log_message(user_id)}: {e}")
1988 else:
1989 self._memory_store[user_id] = trimmed
1991 async def append_message(self, user_id: str, role: str, content: str) -> None:
1992 """
1993 Append a single message to user's chat history.
1995 Convenience method that fetches current history, appends the message,
1996 trims if needed, and saves back.
1998 Args:
1999 user_id: User identifier.
2000 role: Message role ('user' or 'assistant').
2001 content: Message content text.
2003 Examples:
2004 >>> import asyncio
2005 >>> manager = ChatHistoryManager()
2006 >>> # asyncio.run(manager.append_message("user123", "user", "Hello!"))
2008 Note:
2009 This method performs a read-modify-write operation which may
2010 not be atomic in distributed environments.
2011 """
2012 history = await self.get_history(user_id)
2013 history.append({"role": role, "content": content})
2014 await self.save_history(user_id, history)
2016 async def clear_history(self, user_id: str) -> None:
2017 """
2018 Clear all chat history for a user.
2020 Deletes history from Redis or memory store.
2022 Args:
2023 user_id: User identifier.
2025 Examples:
2026 >>> import asyncio
2027 >>> manager = ChatHistoryManager()
2028 >>> # asyncio.run(manager.clear_history("user123"))
2030 Note:
2031 This operation cannot be undone.
2032 """
2033 if self.redis_client:
2034 try:
2035 await self.redis_client.delete(self._history_key(user_id))
2036 except Exception as e:
2037 logger.error(f"Error clearing chat history from Redis for user {SecurityValidator.sanitize_log_message(user_id)}: {e}")
2038 else:
2039 self._memory_store.pop(user_id, None)
2041 def _trim_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
2042 """
2043 Trim message list to max_messages limit.
2045 Keeps the most recent messages up to max_messages count.
2047 Args:
2048 messages: List of message dictionaries.
2050 Returns:
2051 List[Dict[str, str]]: Trimmed message list.
2053 Examples:
2054 >>> manager = ChatHistoryManager(max_messages=2)
2055 >>> messages = [
2056 ... {"role": "user", "content": "1"},
2057 ... {"role": "assistant", "content": "2"},
2058 ... {"role": "user", "content": "3"}
2059 ... ]
2060 >>> trimmed = manager._trim_messages(messages)
2061 >>> len(trimmed)
2062 2
2063 >>> trimmed[0]["content"]
2064 '2'
2065 """
2066 if len(messages) > self.max_messages:
2067 return messages[-self.max_messages :]
2068 return messages
2070 async def get_langchain_messages(self, user_id: str) -> List[BaseMessage]:
2071 """
2072 Get chat history as LangChain message objects.
2074 Converts stored history dictionaries to LangChain HumanMessage and
2075 AIMessage objects for use with LangChain agents.
2077 Args:
2078 user_id: User identifier.
2080 Returns:
2081 List[BaseMessage]: List of LangChain message objects.
2083 Examples:
2084 >>> import asyncio
2085 >>> manager = ChatHistoryManager()
2086 >>> # messages = asyncio.run(manager.get_langchain_messages("user123"))
2087 >>> # isinstance(messages, list)
2088 True
2090 Note:
2091 Returns empty list if LangChain is not available or history is empty.
2092 """
2093 if not _LLMCHAT_AVAILABLE:
2094 return []
2096 history = await self.get_history(user_id)
2097 lc_messages = []
2099 for msg in history:
2100 role = msg.get("role")
2101 content = msg.get("content", "")
2103 if role == "user":
2104 lc_messages.append(HumanMessage(content=content))
2105 elif role == "assistant":
2106 lc_messages.append(AIMessage(content=content))
2108 return lc_messages
2111# ==================== MCP CLIENT ====================
2114class MCPClient:
2115 """
2116 Manages MCP server connections and tool loading.
2118 Provides a high-level interface for connecting to MCP servers, retrieving
2119 available tools, and managing connection health. Supports multiple transport
2120 protocols including HTTP, SSE, and stdio.
2122 Attributes:
2123 config: MCP server configuration.
2125 Examples:
2126 >>> import asyncio
2127 >>> config = MCPServerConfig(
2128 ... url="https://mcp-server.example.com/mcp",
2129 ... transport="streamable_http"
2130 ... )
2131 >>> client = MCPClient(config)
2132 >>> client.is_connected
2133 False
2134 >>> # asyncio.run(client.connect())
2135 >>> # tools = asyncio.run(client.get_tools())
2137 Note:
2138 All methods are async and should be called using asyncio or within
2139 an async context.
2140 """
2142 def __init__(self, config: MCPServerConfig):
2143 """
2144 Initialize MCP client.
2146 Args:
2147 config: MCP server configuration with connection parameters.
2149 Examples:
2150 >>> config = MCPServerConfig(
2151 ... url="https://example.com/mcp",
2152 ... transport="streamable_http"
2153 ... )
2154 >>> client = MCPClient(config)
2155 >>> client.config.transport
2156 'streamable_http'
2157 """
2158 self.config = config
2159 self._client: Optional[MultiServerMCPClient] = None
2160 self._tools: Optional[List[BaseTool]] = None
2161 self._connected = False
2162 logger.info(f"MCP client initialized with transport: {config.transport}")
2164 async def connect(self) -> None:
2165 """
2166 Connect to the MCP server.
2168 Establishes connection to the configured MCP server using the specified
2169 transport protocol. Subsequent calls are no-ops if already connected.
2171 Raises:
2172 ConnectionError: If connection to MCP server fails.
2174 Examples:
2175 >>> import asyncio
2176 >>> config = MCPServerConfig(
2177 ... url="https://example.com/mcp",
2178 ... transport="streamable_http"
2179 ... )
2180 >>> client = MCPClient(config)
2181 >>> # asyncio.run(client.connect())
2182 >>> # client.is_connected -> True
2184 Note:
2185 Connection is idempotent - calling multiple times is safe.
2186 """
2187 if self._connected:
2188 logger.warning("MCP client already connected")
2189 return
2191 try:
2192 logger.info(f"Connecting to MCP server via {self.config.transport}...")
2194 # Build server configuration for MultiServerMCPClient
2195 server_config = {
2196 "transport": self.config.transport,
2197 }
2199 if self.config.transport in ["streamable_http", "sse"]:
2200 server_config["url"] = self.config.url
2201 if self.config.headers:
2202 server_config["headers"] = self.config.headers
2203 elif self.config.transport == "stdio":
2204 server_config["command"] = self.config.command
2205 if self.config.args:
2206 server_config["args"] = self.config.args
2208 if not MultiServerMCPClient:
2209 logger.error("Some dependencies are missing. Install those with: pip install '.[llmchat]'")
2211 # Create MultiServerMCPClient with single server
2212 self._client = MultiServerMCPClient({"default": server_config})
2213 self._connected = True
2214 logger.info("Successfully connected to MCP server")
2216 except Exception as e:
2217 logger.error(f"Failed to connect to MCP server: {e}")
2218 self._connected = False
2219 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
2221 async def disconnect(self) -> None:
2222 """
2223 Disconnect from the MCP server.
2225 Cleanly closes the connection and releases resources. Safe to call
2226 even if not connected.
2228 Raises:
2229 Exception: If cleanup operations fail.
2231 Examples:
2232 >>> import asyncio
2233 >>> config = MCPServerConfig(
2234 ... url="https://example.com/mcp",
2235 ... transport="streamable_http"
2236 ... )
2237 >>> client = MCPClient(config)
2238 >>> # asyncio.run(client.connect())
2239 >>> # asyncio.run(client.disconnect())
2240 >>> # client.is_connected -> False
2242 Note:
2243 Clears cached tools upon disconnection.
2244 """
2245 if not self._connected:
2246 logger.warning("MCP client not connected")
2247 return
2249 try:
2250 if self._client:
2251 # MultiServerMCPClient manages connections internally
2252 self._client = None
2254 self._connected = False
2255 self._tools = None
2256 logger.info("Disconnected from MCP server")
2258 except Exception as e:
2259 logger.error(f"Error during disconnect: {e}")
2260 raise
2262 async def get_tools(self, force_reload: bool = False) -> List[BaseTool]:
2263 """
2264 Get tools from the MCP server.
2266 Retrieves available tools from the connected MCP server. Results are
2267 cached unless force_reload is True.
2269 Args:
2270 force_reload: Force reload tools even if cached (default: False).
2272 Returns:
2273 List[BaseTool]: List of available tools from the server.
2275 Raises:
2276 ConnectionError: If not connected to MCP server.
2277 Exception: If tool loading fails.
2279 Examples:
2280 >>> import asyncio
2281 >>> config = MCPServerConfig(
2282 ... url="https://example.com/mcp",
2283 ... transport="streamable_http"
2284 ... )
2285 >>> client = MCPClient(config)
2286 >>> # asyncio.run(client.connect())
2287 >>> # tools = asyncio.run(client.get_tools())
2288 >>> # len(tools) >= 0 -> True
2290 Note:
2291 Tools are cached after first successful load for performance.
2292 """
2293 if not self._connected or not self._client:
2294 raise ConnectionError("Not connected to MCP server. Call connect() first.")
2296 if self._tools and not force_reload:
2297 logger.debug(f"Returning {len(self._tools)} cached tools")
2298 return self._tools
2300 try:
2301 logger.info("Loading tools from MCP server...")
2302 self._tools = await self._client.get_tools()
2303 logger.info(f"Successfully loaded {len(self._tools)} tools")
2304 return self._tools
2306 except Exception as e:
2307 logger.error(f"Failed to load tools: {e}")
2308 raise
2310 @property
2311 def is_connected(self) -> bool:
2312 """
2313 Check if client is connected.
2315 Returns:
2316 bool: True if connected to MCP server, False otherwise.
2318 Examples:
2319 >>> config = MCPServerConfig(
2320 ... url="https://example.com/mcp",
2321 ... transport="streamable_http"
2322 ... )
2323 >>> client = MCPClient(config)
2324 >>> client.is_connected
2325 False
2326 """
2327 return self._connected
2330# ==================== MCP CHAT SERVICE ====================
2333class MCPChatService:
2334 """
2335 Main chat service for MCP client backend.
2336 Orchestrates chat sessions with LLM and MCP server integration.
2338 Provides a high-level interface for managing conversational AI sessions
2339 that combine LLM capabilities with MCP server tools. Handles conversation
2340 history management, tool execution, and streaming responses.
2342 This service integrates:
2343 - LLM providers (Azure OpenAI, OpenAI, Anthropic, AWS Bedrock, Ollama)
2344 - MCP server tools
2345 - Centralized chat history management (Redis or in-memory)
2346 - Streaming and non-streaming response modes
2348 Attributes:
2349 config: Complete MCP client configuration.
2350 user_id: Optional user identifier for history management.
2352 Examples:
2353 >>> import asyncio
2354 >>> config = MCPClientConfig(
2355 ... mcp_server=MCPServerConfig(
2356 ... url="https://example.com/mcp",
2357 ... transport="streamable_http"
2358 ... ),
2359 ... llm=LLMConfig(
2360 ... provider="ollama",
2361 ... config=OllamaConfig(model="llama2")
2362 ... )
2363 ... )
2364 >>> service = MCPChatService(config)
2365 >>> service.is_initialized
2366 False
2367 >>> # asyncio.run(service.initialize())
2369 Note:
2370 Must call initialize() before using chat methods.
2371 """
2373 def __init__(self, config: MCPClientConfig, user_id: Optional[str] = None, redis_client: Optional[Any] = None):
2374 """
2375 Initialize MCP chat service.
2377 Args:
2378 config: Complete MCP client configuration.
2379 user_id: Optional user identifier for chat history management.
2380 redis_client: Optional Redis client for distributed history storage.
2382 Examples:
2383 >>> config = MCPClientConfig(
2384 ... mcp_server=MCPServerConfig(
2385 ... url="https://example.com/mcp",
2386 ... transport="streamable_http"
2387 ... ),
2388 ... llm=LLMConfig(
2389 ... provider="ollama",
2390 ... config=OllamaConfig(model="llama2")
2391 ... )
2392 ... )
2393 >>> service = MCPChatService(config, user_id="user123")
2394 >>> service.user_id
2395 'user123'
2396 """
2397 self.config = config
2398 self.user_id = user_id
2399 self.mcp_client = MCPClient(config.mcp_server)
2400 self.llm_provider = LLMProviderFactory.create(config.llm)
2402 # Initialize centralized chat history manager
2403 self.history_manager = ChatHistoryManager(redis_client=redis_client, max_messages=config.chat_history_max_messages, ttl=settings.llmchat_chat_history_ttl)
2405 self._agent = None
2406 self._initialized = False
2407 self._tools: List[BaseTool] = []
2409 logger.info(f"MCPChatService initialized for user: {user_id or 'anonymous'}")
2411 async def initialize(self) -> None:
2412 """
2413 Initialize the chat service.
2415 Connects to MCP server, loads tools, initializes LLM, and creates the
2416 conversational agent. Must be called before using chat functionality.
2418 Raises:
2419 ImportError: If LLM chat dependencies are missing.
2420 ConnectionError: If MCP server connection fails.
2421 Exception: If initialization fails.
2423 Examples:
2424 >>> import asyncio
2425 >>> config = MCPClientConfig(
2426 ... mcp_server=MCPServerConfig(
2427 ... url="https://example.com/mcp",
2428 ... transport="streamable_http"
2429 ... ),
2430 ... llm=LLMConfig(
2431 ... provider="ollama",
2432 ... config=OllamaConfig(model="llama2")
2433 ... )
2434 ... )
2435 >>> service = MCPChatService(config)
2436 >>> # asyncio.run(service.initialize())
2437 >>> # service.is_initialized -> True
2439 Note:
2440 Automatically loads tools from MCP server and creates agent.
2441 """
2442 if self._initialized:
2443 logger.warning("Chat service already initialized")
2444 return
2446 if not _LLMCHAT_AVAILABLE:
2447 raise ImportError("LLM chat dependencies are missing. Install them with: pip install '.[llmchat]'")
2449 try:
2450 logger.info("Initializing chat service...")
2452 # Connect to MCP server and load tools
2453 await self.mcp_client.connect()
2454 self._tools = await self.mcp_client.get_tools()
2456 # Create LLM instance
2457 llm = self.llm_provider.get_llm()
2459 # Create ReAct agent with tools
2460 self._agent = create_react_agent(llm, self._tools)
2462 self._initialized = True
2463 logger.info(f"Chat service initialized successfully with {len(self._tools)} tools")
2465 except Exception as e:
2466 logger.error(f"Failed to initialize chat service: {e}")
2467 self._initialized = False
2468 raise
2470 async def chat(self, message: str) -> str:
2471 """
2472 Send a message and get a complete response.
2474 Processes the user's message through the LLM with tool access,
2475 manages conversation history, and returns the complete response.
2477 Args:
2478 message: User's message text.
2480 Returns:
2481 str: Complete AI response text.
2483 Raises:
2484 RuntimeError: If service not initialized.
2485 ValueError: If message is empty.
2486 Exception: If processing fails.
2488 Examples:
2489 >>> import asyncio
2490 >>> # Assuming service is initialized
2491 >>> # response = asyncio.run(service.chat("Hello!"))
2492 >>> # isinstance(response, str)
2493 True
2495 Note:
2496 Automatically saves conversation history after response.
2497 """
2498 if not self._initialized or not self._agent:
2499 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2501 if not message or not message.strip():
2502 raise ValueError("Message cannot be empty")
2504 span_attributes = {
2505 "langfuse.observation.type": "generation",
2506 "gen_ai.system": _llm_system_name(self),
2507 "gen_ai.request.model": self.llm_provider.get_model_name(),
2508 }
2509 if is_input_capture_enabled("llm.chat"):
2510 span_attributes["langfuse.observation.input"] = serialize_trace_payload({"message": message})
2512 with create_span("llm.chat", span_attributes) as span:
2513 try:
2514 logger.debug("Processing chat message...")
2516 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2517 user_message = HumanMessage(content=message)
2518 lc_messages.append(user_message)
2520 response = await self._agent.ainvoke({"messages": lc_messages})
2521 ai_message = response["messages"][-1]
2522 response_text = ai_message.content if hasattr(ai_message, "content") else str(ai_message)
2524 if span:
2525 _set_usage_attributes(span, ai_message)
2526 if is_output_capture_enabled("llm.chat"):
2527 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"response": response_text}))
2529 if self.user_id:
2530 await self.history_manager.append_message(self.user_id, "user", message)
2531 await self.history_manager.append_message(self.user_id, "assistant", response_text)
2533 logger.debug("Chat message processed successfully")
2534 return response_text
2536 except Exception as e:
2537 logger.error(f"Error processing chat message: {e}")
2538 raise
2540 async def chat_with_metadata(self, message: str) -> Dict[str, Any]:
2541 """
2542 Send a message and get response with metadata.
2544 Similar to chat() but collects all events and returns detailed
2545 information about tool usage and timing.
2547 Args:
2548 message: User's message text.
2550 Returns:
2551 Dict[str, Any]: Dictionary containing:
2552 - text (str): Complete response text
2553 - tool_used (bool): Whether any tools were invoked
2554 - tools (List[str]): Names of tools that were used
2555 - tool_invocations (List[dict]): Detailed tool invocation data
2556 - elapsed_ms (int): Processing time in milliseconds
2558 Raises:
2559 RuntimeError: If service not initialized.
2560 ValueError: If message is empty.
2562 Examples:
2563 >>> import asyncio
2564 >>> # Assuming service is initialized
2565 >>> # result = asyncio.run(service.chat_with_metadata("What's 2+2?"))
2566 >>> # 'text' in result and 'elapsed_ms' in result
2567 True
2569 Note:
2570 This method collects all events and returns them as a single response.
2571 """
2572 text = ""
2573 tool_invocations: list[dict[str, Any]] = []
2574 final: dict[str, Any] = {}
2576 async for ev in self.chat_events(message):
2577 t = ev.get("type")
2578 if t == "token":
2579 text += ev.get("content", "")
2580 elif t in ("tool_start", "tool_end", "tool_error"):
2581 tool_invocations.append(ev)
2582 elif t == "final":
2583 final = ev
2585 return {
2586 "text": text,
2587 "tool_used": final.get("tool_used", False),
2588 "tools": final.get("tools", []),
2589 "tool_invocations": tool_invocations,
2590 "elapsed_ms": final.get("elapsed_ms"),
2591 }
2593 async def chat_stream(self, message: str) -> AsyncGenerator[str, None]:
2594 """
2595 Send a message and stream the response.
2597 Yields response chunks as they're generated, enabling real-time display
2598 of the AI's response.
2600 Args:
2601 message: User's message text.
2603 Yields:
2604 str: Chunks of AI response text.
2606 Raises:
2607 RuntimeError: If service not initialized.
2608 Exception: If streaming fails.
2610 Examples:
2611 >>> import asyncio
2612 >>> async def stream_example():
2613 ... # Assuming service is initialized
2614 ... chunks = []
2615 ... async for chunk in service.chat_stream("Hello"):
2616 ... chunks.append(chunk)
2617 ... return ''.join(chunks)
2618 >>> # full_response = asyncio.run(stream_example())
2620 Note:
2621 Falls back to non-streaming if enable_streaming is False in config.
2622 """
2623 if not self._initialized or not self._agent:
2624 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2626 if not self.config.enable_streaming:
2627 # Fall back to non-streaming
2628 response = await self.chat(message)
2629 yield response
2630 return
2632 try:
2633 logger.debug("Processing streaming chat message...")
2635 # Get conversation history
2636 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2638 # Add user message
2639 user_message = HumanMessage(content=message)
2640 lc_messages.append(user_message)
2642 # Stream agent response
2643 full_response = ""
2644 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"):
2645 kind = event["event"]
2647 # Stream LLM tokens
2648 if kind == "on_chat_model_stream":
2649 chunk = event.get("data", {}).get("chunk")
2650 if chunk and hasattr(chunk, "content"):
2651 content = chunk.content
2652 if content:
2653 full_response += content
2654 yield content
2656 # Save history
2657 if self.user_id and full_response:
2658 await self.history_manager.append_message(self.user_id, "user", message)
2659 await self.history_manager.append_message(self.user_id, "assistant", full_response)
2661 logger.debug("Streaming chat message processed successfully")
2663 except Exception as e:
2664 logger.error(f"Error processing streaming chat message: {e}")
2665 raise
2667 async def chat_events(self, message: str) -> AsyncGenerator[Dict[str, Any], None]:
2668 """
2669 Stream structured events during chat processing.
2671 Provides granular visibility into the chat processing pipeline by yielding
2672 structured events for tokens, tool invocations, errors, and final results.
2674 Args:
2675 message: User's message text.
2677 Yields:
2678 dict: Event dictionaries with type-specific fields:
2679 - token: {"type": "token", "content": str}
2680 - tool_start: {"type": "tool_start", "id": str, "name": str,
2681 "input": Any, "start": str}
2682 - tool_end: {"type": "tool_end", "id": str, "name": str,
2683 "output": Any, "end": str}
2684 - tool_error: {"type": "tool_error", "id": str, "error": str,
2685 "time": str}
2686 - final: {"type": "final", "content": str, "tool_used": bool,
2687 "tools": List[str], "elapsed_ms": int}
2689 Raises:
2690 RuntimeError: If service not initialized.
2691 ValueError: If message is empty or whitespace only.
2692 ConnectionError: If the underlying MCP connection is lost.
2693 TimeoutError: If the LLM request times out.
2694 ChatProcessingError: If a tool, parsing, or model error occurs during streaming.
2696 Examples:
2697 >>> import asyncio
2698 >>> async def event_example():
2699 ... # Assuming service is initialized
2700 ... events = []
2701 ... async for event in service.chat_events("Hello"):
2702 ... events.append(event['type'])
2703 ... return events
2704 >>> # event_types = asyncio.run(event_example())
2705 >>> # 'final' in event_types -> True
2707 Note:
2708 This is the most detailed chat method, suitable for building
2709 interactive UIs or detailed logging systems.
2710 """
2711 if not self._initialized or not self._agent:
2712 raise RuntimeError("Chat service not initialized. Call initialize() first.")
2714 # Validate message
2715 if not message or not message.strip():
2716 raise ValueError("Message cannot be empty")
2718 # Get conversation history
2719 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else []
2721 # Append user message
2722 user_message = HumanMessage(content=message)
2723 lc_messages.append(user_message)
2725 full_response = ""
2726 start_ts = time.time()
2727 tool_runs: dict[str, dict[str, Any]] = {}
2728 # Buffer for out-of-order on_tool_end events (end arrives before start)
2729 pending_tool_ends: dict[str, dict[str, Any]] = {}
2730 pending_ttl_seconds = 30.0 # Max time to hold pending end events
2731 pending_max_size = 100 # Max number of pending end events to buffer
2732 # Track dropped run_ids for aggregated error (TTL-expired or buffer-full)
2733 dropped_tool_ends: set[str] = set()
2734 dropped_max_size = 200 # Max dropped IDs to track (prevents unbounded growth)
2735 dropped_overflow_count = 0 # Count of drops that couldn't be tracked due to full buffer
2737 def _extract_output(raw_output: Any) -> Any:
2738 """Extract output value from various LangChain output formats.
2740 Args:
2741 raw_output: The raw output from a tool execution.
2743 Returns:
2744 The extracted output value in a serializable format.
2745 """
2746 if hasattr(raw_output, "content"):
2747 return raw_output.content
2748 if hasattr(raw_output, "dict") and callable(raw_output.dict):
2749 return raw_output.dict()
2750 if not isinstance(raw_output, (str, int, float, bool, list, dict, type(None))):
2751 return str(raw_output)
2752 return raw_output
2754 def _cleanup_expired_pending(current_ts: float) -> None:
2755 """Remove expired entries from pending_tool_ends buffer and track them.
2757 Args:
2758 current_ts: Current timestamp in seconds since epoch.
2759 """
2760 nonlocal dropped_overflow_count
2761 expired = [rid for rid, data in pending_tool_ends.items() if current_ts - data.get("buffered_at", 0) > pending_ttl_seconds]
2762 for rid in expired:
2763 logger.warning(f"Pending on_tool_end for run_id {rid} expired after {pending_ttl_seconds}s (orphan event)")
2764 if len(dropped_tool_ends) < dropped_max_size:
2765 dropped_tool_ends.add(rid)
2766 else:
2767 dropped_overflow_count += 1
2768 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track expired run_id {rid} (overflow count: {dropped_overflow_count})")
2769 del pending_tool_ends[rid]
2771 span_attributes = {
2772 "langfuse.observation.type": "generation",
2773 "gen_ai.system": _llm_system_name(self),
2774 "gen_ai.request.model": self.llm_provider.get_model_name(),
2775 "llm.stream": True,
2776 }
2777 if is_input_capture_enabled("llm.chat"):
2778 span_attributes["langfuse.observation.input"] = serialize_trace_payload({"message": message})
2780 with create_span("llm.chat", span_attributes) as span:
2781 try:
2782 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"):
2783 kind = event.get("event")
2784 now_iso = datetime.now(timezone.utc).isoformat()
2785 now_ts = time.time()
2787 # Periodically cleanup expired pending ends
2788 _cleanup_expired_pending(now_ts)
2790 try:
2791 if kind == "on_tool_start":
2792 run_id = str(event.get("run_id") or uuid4())
2793 name = event.get("name") or event.get("data", {}).get("name") or event.get("data", {}).get("tool")
2794 input_data = event.get("data", {}).get("input")
2796 # Filter out common metadata keys injected by LangChain/LangGraph
2797 if isinstance(input_data, dict):
2798 input_data = {k: v for k, v in input_data.items() if k not in ["runtime", "config", "run_manager", "callbacks"]}
2800 tool_runs[run_id] = {"name": name, "start": now_iso, "input": input_data}
2802 # Register run for cancellation tracking with gateway-level Cancellation service
2803 async def _noop_cancel_cb(reason: Optional[str]) -> None:
2804 """
2805 No-op cancel callback used when a run is started.
2807 Args:
2808 reason: Optional textual reason for cancellation.
2810 Returns:
2811 None
2812 """
2813 # Default no-op; kept for potential future intra-process cancellation
2814 return None
2816 # Register with cancellation service only if feature is enabled
2817 if settings.mcpgateway_tool_cancellation_enabled:
2818 try:
2819 await cancellation_service.register_run(run_id, name=name, cancel_callback=_noop_cancel_cb)
2820 except Exception:
2821 logger.exception("Failed to register run %s with CancellationService", run_id)
2823 yield {"type": "tool_start", "id": run_id, "tool": name, "input": input_data, "start": now_iso}
2825 # NOTE: Do NOT clear from dropped_tool_ends here. If an end was dropped (TTL/buffer-full)
2826 # before this start arrived, that end is permanently lost. Since tools only end once,
2827 # we won't receive another end event, so this should still be reported as an orphan.
2829 # Check if we have a buffered end event for this run_id (out-of-order reconciliation)
2830 if run_id in pending_tool_ends:
2831 buffered = pending_tool_ends.pop(run_id)
2832 tool_runs[run_id]["end"] = buffered["end_time"]
2833 tool_runs[run_id]["output"] = buffered["output"]
2834 logger.info(f"Reconciled out-of-order on_tool_end for run_id {run_id}")
2836 if tool_runs[run_id].get("output") == "":
2837 error = "Tool execution failed: Please check if the tool is accessible"
2838 yield {"type": "tool_error", "id": run_id, "tool": name, "error": error, "time": buffered["end_time"]}
2840 yield {"type": "tool_end", "id": run_id, "tool": name, "output": tool_runs[run_id].get("output"), "end": buffered["end_time"]}
2842 elif kind == "on_tool_end":
2843 run_id = str(event.get("run_id") or uuid4())
2844 output = event.get("data", {}).get("output")
2845 extracted_output = _extract_output(output)
2847 if run_id in tool_runs:
2848 # Normal case: start already received
2849 tool_runs[run_id]["end"] = now_iso
2850 tool_runs[run_id]["output"] = extracted_output
2852 if tool_runs[run_id].get("output") == "":
2853 error = "Tool execution failed: Please check if the tool is accessible"
2854 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso}
2856 yield {"type": "tool_end", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "output": tool_runs[run_id].get("output"), "end": now_iso}
2857 else:
2858 # Out-of-order: buffer the end event for later reconciliation
2859 if len(pending_tool_ends) < pending_max_size:
2860 pending_tool_ends[run_id] = {"output": extracted_output, "end_time": now_iso, "buffered_at": now_ts}
2861 logger.debug(f"Buffered out-of-order on_tool_end for run_id {run_id}, awaiting on_tool_start")
2862 else:
2863 logger.warning(f"Pending tool ends buffer full ({pending_max_size}), dropping on_tool_end for run_id {run_id}")
2864 if len(dropped_tool_ends) < dropped_max_size:
2865 dropped_tool_ends.add(run_id)
2866 else:
2867 dropped_overflow_count += 1
2868 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track run_id {run_id} (overflow count: {dropped_overflow_count})")
2870 # Unregister run from cancellation service when finished (only if feature is enabled)
2871 if settings.mcpgateway_tool_cancellation_enabled:
2872 try:
2873 await cancellation_service.unregister_run(run_id)
2874 except Exception:
2875 logger.exception("Failed to unregister run %s", run_id)
2877 elif kind == "on_tool_error":
2878 run_id = str(event.get("run_id") or uuid4())
2879 error = str(event.get("data", {}).get("error", "Unknown error"))
2881 # Clear any buffered end for this run to avoid emitting both error and end
2882 if run_id in pending_tool_ends:
2883 del pending_tool_ends[run_id]
2884 logger.debug(f"Cleared buffered on_tool_end for run_id {run_id} due to tool error")
2886 # Clear from dropped set if this run was previously dropped (prevents false orphan)
2887 dropped_tool_ends.discard(run_id)
2889 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso}
2891 # Unregister run on error (only if feature is enabled)
2892 if settings.mcpgateway_tool_cancellation_enabled:
2893 try:
2894 await cancellation_service.unregister_run(run_id)
2895 except Exception:
2896 logger.exception("Failed to unregister run %s after error", run_id)
2898 elif kind == "on_chat_model_stream":
2899 chunk = event.get("data", {}).get("chunk")
2900 if chunk and hasattr(chunk, "content"):
2901 content = chunk.content
2902 if content:
2903 full_response += content
2904 yield {"type": "token", "content": content}
2906 except Exception as event_error:
2907 logger.warning(f"Error processing event {kind}: {event_error}")
2908 continue
2910 all_orphan_ids = sorted(set(pending_tool_ends.keys()) | dropped_tool_ends)
2911 if all_orphan_ids or dropped_overflow_count > 0:
2912 buffered_count = len(pending_tool_ends)
2913 dropped_count = len(dropped_tool_ends)
2914 total_unique = len(all_orphan_ids)
2915 total_affected = total_unique + dropped_overflow_count
2916 logger.warning(
2917 f"Stream completed with {total_affected} orphan tool end(s): {buffered_count} buffered, {dropped_count} dropped (tracked), {dropped_overflow_count} dropped (untracked overflow)"
2918 )
2919 if all_orphan_ids:
2920 logger.debug(f"Full orphan run_id list: {', '.join(all_orphan_ids)}")
2921 now_iso = datetime.now(timezone.utc).isoformat()
2922 error_parts = []
2923 if buffered_count > 0:
2924 error_parts.append(f"{buffered_count} buffered")
2925 if dropped_count > 0:
2926 error_parts.append(f"{dropped_count} dropped (TTL expired or buffer full)")
2927 if dropped_overflow_count > 0:
2928 error_parts.append(f"{dropped_overflow_count} additional dropped (tracking overflow)")
2929 error_msg = f"Tool execution incomplete: {total_affected} tool end(s) received without matching start ({', '.join(error_parts)})"
2930 if all_orphan_ids:
2931 max_display_ids = 10
2932 display_ids = all_orphan_ids[:max_display_ids]
2933 remaining = total_unique - len(display_ids)
2934 if remaining > 0:
2935 error_msg += f". Run IDs (first {max_display_ids} of {total_unique}): {', '.join(display_ids)} (+{remaining} more)"
2936 else:
2937 error_msg += f". Run IDs: {', '.join(display_ids)}"
2938 yield {
2939 "type": "tool_error",
2940 "id": str(uuid4()),
2941 "tool": None,
2942 "error": error_msg,
2943 "time": now_iso,
2944 }
2945 pending_tool_ends.clear()
2946 dropped_tool_ends.clear()
2948 elapsed_ms = int((time.time() - start_ts) * 1000)
2950 tools_used = list({tr["name"] for tr in tool_runs.values() if tr.get("name")})
2952 if span and is_output_capture_enabled("llm.chat"):
2953 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"response": full_response}))
2955 yield {"type": "final", "content": full_response, "tool_used": len(tools_used) > 0, "tools": tools_used, "elapsed_ms": elapsed_ms}
2957 if self.user_id and full_response:
2958 await self.history_manager.append_message(self.user_id, "user", message)
2959 await self.history_manager.append_message(self.user_id, "assistant", full_response)
2961 except (ConnectionError, TimeoutError) as e:
2962 logger.error(f"Error in chat_events: {e}")
2963 raise
2964 except Exception as e:
2965 logger.error(f"Error in chat_events: {e}")
2966 raise ChatProcessingError(f"Chat processing error: {e}") from e
2968 async def get_conversation_history(self) -> List[Dict[str, str]]:
2969 """
2970 Get conversation history for the current user.
2972 Returns:
2973 List[Dict[str, str]]: Conversation messages with keys:
2974 - role (str): "user" or "assistant"
2975 - content (str): Message text
2977 Examples:
2978 >>> import asyncio
2979 >>> # Assuming service is initialized with user_id
2980 >>> # history = asyncio.run(service.get_conversation_history())
2981 >>> # all('role' in msg and 'content' in msg for msg in history)
2982 True
2984 Note:
2985 Returns empty list if no user_id set or no history exists.
2986 """
2987 if not self.user_id:
2988 return []
2990 return await self.history_manager.get_history(self.user_id)
2992 async def clear_history(self) -> None:
2993 """
2994 Clear conversation history for the current user.
2996 Removes all messages from the conversation history. Useful for starting
2997 fresh conversations or managing memory usage.
2999 Examples:
3000 >>> import asyncio
3001 >>> # Assuming service is initialized with user_id
3002 >>> # asyncio.run(service.clear_history())
3003 >>> # history = asyncio.run(service.get_conversation_history())
3004 >>> # len(history) -> 0
3006 Note:
3007 This action cannot be undone. No-op if no user_id set.
3008 """
3009 if not self.user_id:
3010 return
3012 await self.history_manager.clear_history(self.user_id)
3013 logger.info(f"Conversation history cleared for user {self.user_id}")
3015 async def shutdown(self) -> None:
3016 """
3017 Shutdown the chat service and cleanup resources.
3019 Performs graceful shutdown by disconnecting from MCP server, clearing
3020 agent and history, and resetting initialization state.
3022 Raises:
3023 Exception: If cleanup operations fail.
3025 Examples:
3026 >>> import asyncio
3027 >>> config = MCPClientConfig(
3028 ... mcp_server=MCPServerConfig(
3029 ... url="https://example.com/mcp",
3030 ... transport="streamable_http"
3031 ... ),
3032 ... llm=LLMConfig(
3033 ... provider="ollama",
3034 ... config=OllamaConfig(model="llama2")
3035 ... )
3036 ... )
3037 >>> service = MCPChatService(config)
3038 >>> # asyncio.run(service.initialize())
3039 >>> # asyncio.run(service.shutdown())
3040 >>> # service.is_initialized -> False
3042 Note:
3043 Should be called when service is no longer needed to properly
3044 release resources and connections.
3045 """
3046 logger.info("Shutting down chat service...")
3048 try:
3049 # Disconnect from MCP server
3050 if self.mcp_client.is_connected:
3051 await self.mcp_client.disconnect()
3053 # Clear state
3054 self._agent = None
3055 self._initialized = False
3056 self._tools = []
3058 logger.info("Chat service shutdown complete")
3060 except Exception as e:
3061 logger.error(f"Error during shutdown: {e}")
3062 raise
3064 @property
3065 def is_initialized(self) -> bool:
3066 """
3067 Check if service is initialized.
3069 Returns:
3070 bool: True if service is initialized and ready, False otherwise.
3072 Examples:
3073 >>> config = MCPClientConfig(
3074 ... mcp_server=MCPServerConfig(
3075 ... url="https://example.com/mcp",
3076 ... transport="streamable_http"
3077 ... ),
3078 ... llm=LLMConfig(
3079 ... provider="ollama",
3080 ... config=OllamaConfig(model="llama2")
3081 ... )
3082 ... )
3083 >>> service = MCPChatService(config)
3084 >>> service.is_initialized
3085 False
3087 Note:
3088 Service must be initialized before calling chat methods.
3089 """
3090 return self._initialized
3092 async def reload_tools(self) -> int:
3093 """
3094 Reload tools from MCP server.
3096 Forces a reload of tools from the MCP server and recreates the agent
3097 with the updated tool set. Useful when MCP server tools have changed.
3099 Returns:
3100 int: Number of tools successfully loaded.
3102 Raises:
3103 RuntimeError: If service not initialized.
3104 ImportError: If LLM chat dependencies are missing.
3105 Exception: If tool reloading or agent recreation fails.
3107 Examples:
3108 >>> import asyncio
3109 >>> # Assuming service is initialized
3110 >>> # tool_count = asyncio.run(service.reload_tools())
3111 >>> # tool_count >= 0 -> True
3113 Note:
3114 This operation recreates the agent, so it may briefly interrupt
3115 ongoing conversations. Conversation history is preserved.
3116 """
3117 if not self._initialized:
3118 raise RuntimeError("Chat service not initialized")
3120 if not _LLMCHAT_AVAILABLE:
3121 raise ImportError("LLM chat dependencies are missing. Install them with: pip install '.[llmchat]'")
3123 try:
3124 logger.info("Reloading tools from MCP server...")
3125 tools = await self.mcp_client.get_tools(force_reload=True)
3127 # Recreate agent with new tools
3128 llm = self.llm_provider.get_llm()
3129 self._agent = create_react_agent(llm, tools)
3130 self._tools = tools
3132 logger.info(f"Reloaded {len(tools)} tools successfully")
3133 return len(tools)
3135 except Exception as e:
3136 logger.error(f"Failed to reload tools: {e}")
3137 raise