Coverage for mcpgateway / services / mcp_client_chat_service.py: 98%

888 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/mcp_client_chat_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Keval Mahajan 

6 

7MCP Client Service Module. 

8 

9This module provides a comprehensive client implementation for interacting with 

10MCP servers, managing LLM providers, and orchestrating conversational AI agents. 

11It supports multiple transport protocols and LLM providers. 

12 

13The module consists of several key components: 

14- Configuration classes for MCP servers and LLM providers 

15- LLM provider factory and implementations 

16- MCP client for tool management 

17- Chat history manager for Redis and in-memory storage 

18- Chat service for conversational interactions 

19""" 

20 

21# Standard 

22from datetime import datetime, timezone 

23import time 

24from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union 

25from uuid import uuid4 

26 

27# Third-Party 

28import orjson 

29 

30try: 

31 # Third-Party 

32 from langchain_core.language_models import BaseChatModel 

33 from langchain_core.messages import AIMessage, BaseMessage, HumanMessage 

34 from langchain_core.tools import BaseTool 

35 from langchain_mcp_adapters.client import MultiServerMCPClient 

36 from langchain_ollama import ChatOllama, OllamaLLM 

37 from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI, OpenAI 

38 from langgraph.prebuilt import create_react_agent 

39 

40 _LLMCHAT_AVAILABLE = True 

41except ImportError: 

42 # Optional dependencies for LLM chat feature not installed 

43 # These are only needed if LLMCHAT_ENABLED=true 

44 _LLMCHAT_AVAILABLE = False 

45 BaseChatModel = None # type: ignore 

46 AIMessage = None # type: ignore 

47 BaseMessage = None # type: ignore 

48 HumanMessage = None # type: ignore 

49 BaseTool = None # type: ignore 

50 MultiServerMCPClient = None # type: ignore 

51 ChatOllama = None # type: ignore 

52 OllamaLLM = None 

53 AzureChatOpenAI = None # type: ignore 

54 AzureOpenAI = None 

55 ChatOpenAI = None # type: ignore 

56 OpenAI = None 

57 create_react_agent = None # type: ignore 

58 

59# Try to import Anthropic and Bedrock providers (they may not be installed) 

60try: 

61 # Third-Party 

62 from langchain_anthropic import AnthropicLLM, ChatAnthropic 

63 

64 _ANTHROPIC_AVAILABLE = True 

65except ImportError: 

66 _ANTHROPIC_AVAILABLE = False 

67 ChatAnthropic = None # type: ignore 

68 AnthropicLLM = None 

69 

70try: 

71 # Third-Party 

72 from langchain_aws import BedrockLLM, ChatBedrock 

73 

74 _BEDROCK_AVAILABLE = True 

75except ImportError: 

76 _BEDROCK_AVAILABLE = False 

77 ChatBedrock = None # type: ignore 

78 BedrockLLM = None 

79 

80try: 

81 # Third-Party 

82 from langchain_ibm import ChatWatsonx, WatsonxLLM 

83 

84 _WATSONX_AVAILABLE = True 

85except ImportError: 

86 _WATSONX_AVAILABLE = False 

87 WatsonxLLM = None # type: ignore 

88 ChatWatsonx = None 

89 

90# Third-Party 

91from pydantic import BaseModel, Field, field_validator, model_validator 

92 

93# First-Party 

94from mcpgateway.common.validators import SecurityValidator 

95from mcpgateway.config import settings 

96from mcpgateway.observability import create_span, set_span_attribute 

97from mcpgateway.services.cancellation_service import cancellation_service 

98from mcpgateway.services.logging_service import LoggingService 

99from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload 

100 

101logging_service = LoggingService() 

102logger = logging_service.get_logger(__name__) 

103 

104 

105def _llm_system_name(service: "MCPChatService") -> str: 

106 """Return a stable provider/system label for trace attributes. 

107 

108 Args: 

109 service: Chat service instance holding the active LLM provider. 

110 

111 Returns: 

112 Lowercase provider name for GenAI trace attributes. 

113 """ 

114 provider_name = type(service.llm_provider).__name__.replace("Provider", "") 

115 return provider_name.lower() or "unknown" 

116 

117 

118def _set_usage_attributes(span: Any, ai_message: Any) -> None: 

119 """Attach token usage metadata to a span when available. 

120 

121 Args: 

122 span: Active span object to enrich. 

123 ai_message: Provider response object that may expose ``usage_metadata``. 

124 """ 

125 usage = getattr(ai_message, "usage_metadata", None) 

126 if not isinstance(usage, dict): 

127 return 

128 

129 input_tokens = usage.get("input_tokens") 

130 output_tokens = usage.get("output_tokens") 

131 total_tokens = usage.get("total_tokens") 

132 if input_tokens is not None: 

133 set_span_attribute(span, "gen_ai.usage.prompt_tokens", input_tokens) 

134 if output_tokens is not None: 

135 set_span_attribute(span, "gen_ai.usage.completion_tokens", output_tokens) 

136 if total_tokens is not None: 

137 set_span_attribute(span, "gen_ai.usage.total_tokens", total_tokens) 

138 

139 

140class ChatProcessingError(RuntimeError): 

141 """Recoverable error wrapping tool, parsing, or model failures during chat streaming.""" 

142 

143 

144class MCPServerConfig(BaseModel): 

145 """ 

146 Configuration for MCP server connection. 

147 

148 This class defines the configuration parameters required to connect to an 

149 MCP (Model Context Protocol) server using various transport mechanisms. 

150 

151 Attributes: 

152 url: MCP server URL for streamable_http/sse transports. 

153 command: Command to run for stdio transport. 

154 args: Command-line arguments for stdio command. 

155 transport: Transport type (streamable_http, sse, or stdio). 

156 auth_token: Authentication token for HTTP-based transports. 

157 headers: Additional HTTP headers for request customization. 

158 

159 Examples: 

160 >>> # HTTP-based transport 

161 >>> config = MCPServerConfig( 

162 ... url="https://mcp-server.example.com/mcp", 

163 ... transport="streamable_http", 

164 ... auth_token="secret-token" 

165 ... ) 

166 >>> config.transport 

167 'streamable_http' 

168 

169 >>> # Stdio transport (requires explicit feature flag) 

170 >>> settings.mcpgateway_stdio_transport_enabled = True 

171 >>> config = MCPServerConfig( 

172 ... command="python", 

173 ... args=["server.py"], 

174 ... transport="stdio" 

175 ... ) 

176 >>> config.command 

177 'python' 

178 

179 Note: 

180 The auth_token is automatically added to headers as a Bearer token 

181 for HTTP-based transports. 

182 """ 

183 

184 url: Optional[str] = Field(None, description="MCP server URL for streamable_http/sse transports") 

185 command: Optional[str] = Field(None, description="Command to run for stdio transport") 

186 args: Optional[list[str]] = Field(None, description="Arguments for stdio command") 

187 transport: Literal["streamable_http", "sse", "stdio"] = Field(default="streamable_http", description="Transport type for MCP connection") 

188 auth_token: Optional[str] = Field(None, description="Authentication token for the server") 

189 headers: Optional[Dict[str, str]] = Field(default=None, description="Additional headers for HTTP-based transports") 

190 

191 @model_validator(mode="before") 

192 @classmethod 

193 def add_auth_to_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

194 """ 

195 Automatically add authentication token to headers if provided. 

196 

197 This validator ensures that if an auth_token is provided for HTTP-based 

198 transports, it's automatically added to the headers as a Bearer token. 

199 

200 Args: 

201 values: Dictionary of field values before validation. 

202 

203 Returns: 

204 Dict[str, Any]: Updated values with auth token in headers. 

205 

206 Examples: 

207 >>> values = { 

208 ... "url": "https://api.example.com", 

209 ... "transport": "streamable_http", 

210 ... "auth_token": "token123" 

211 ... } 

212 >>> result = MCPServerConfig.add_auth_to_headers(values) 

213 >>> result['headers']['Authorization'] 

214 'Bearer token123' 

215 """ 

216 auth_token = values.get("auth_token") 

217 transport = values.get("transport") 

218 headers = values.get("headers") or {} 

219 

220 if auth_token and transport in ["streamable_http", "sse"]: 

221 if "Authorization" not in headers: 

222 headers["Authorization"] = f"Bearer {auth_token}" 

223 values["headers"] = headers 

224 

225 return values 

226 

227 @model_validator(mode="after") 

228 def validate_transport_requirements(self): 

229 """Validate transport-specific requirements and feature flags. 

230 

231 Returns: 

232 MCPServerConfig: Validated config instance. 

233 

234 Raises: 

235 ValueError: If transport requirements or feature flags are violated. 

236 """ 

237 if self.transport in ["streamable_http", "sse"] and not self.url: 

238 raise ValueError(f"URL is required for {self.transport} transport") 

239 

240 if self.transport == "stdio": 

241 if not settings.mcpgateway_stdio_transport_enabled: 

242 raise ValueError("stdio transport is disabled by default; set MCPGATEWAY_STDIO_TRANSPORT_ENABLED=true to enable it") 

243 if not self.command: 

244 raise ValueError("Command is required for stdio transport") 

245 

246 return self 

247 

248 model_config = { 

249 "json_schema_extra": { 

250 "examples": [ 

251 {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token-here"}, # nosec B105 - example placeholder 

252 {"command": "python", "args": ["server.py"], "transport": "stdio"}, 

253 ] 

254 } 

255 } 

256 

257 

258class AzureOpenAIConfig(BaseModel): 

259 """ 

260 Configuration for Azure OpenAI provider. 

261 

262 Defines all necessary parameters to connect to and use Azure OpenAI services, 

263 including API credentials, endpoints, model settings, and request parameters. 

264 

265 Attributes: 

266 api_key: Azure OpenAI API authentication key. 

267 azure_endpoint: Azure OpenAI service endpoint URL. 

268 api_version: API version to use for requests. 

269 azure_deployment: Name of the deployed model. 

270 model: Model identifier for logging and tracing. 

271 temperature: Sampling temperature for response generation (0.0-2.0). 

272 max_tokens: Maximum number of tokens to generate. 

273 timeout: Request timeout duration in seconds. 

274 max_retries: Maximum number of retry attempts for failed requests. 

275 

276 Examples: 

277 >>> config = AzureOpenAIConfig( 

278 ... api_key="your-api-key", 

279 ... azure_endpoint="https://your-resource.openai.azure.com/", 

280 ... azure_deployment="gpt-4", 

281 ... temperature=0.7 

282 ... ) 

283 >>> config.model 

284 'gpt-4' 

285 >>> config.temperature 

286 0.7 

287 """ 

288 

289 api_key: str = Field(..., description="Azure OpenAI API key") 

290 azure_endpoint: str = Field(..., description="Azure OpenAI endpoint URL") 

291 api_version: str = Field(default="2024-05-01-preview", description="Azure OpenAI API version") 

292 azure_deployment: str = Field(..., description="Azure OpenAI deployment name") 

293 model: str = Field(default="gpt-4", description="Model name for tracing") 

294 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

295 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

296 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

297 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries") 

298 

299 model_config = { 

300 "json_schema_extra": { 

301 "example": { 

302 "api_key": "your-api-key", 

303 "azure_endpoint": "https://your-resource.openai.azure.com/", 

304 "api_version": "2024-05-01-preview", 

305 "azure_deployment": "gpt-4", 

306 "model": "gpt-4", 

307 "temperature": 0.7, 

308 } 

309 } 

310 } 

311 

312 

313class OllamaConfig(BaseModel): 

314 """ 

315 Configuration for Ollama provider. 

316 

317 Defines parameters for connecting to a local or remote Ollama instance 

318 for running open-source language models. 

319 

320 Attributes: 

321 base_url: Ollama server base URL. 

322 model: Name of the Ollama model to use. 

323 temperature: Sampling temperature for response generation (0.0-2.0). 

324 timeout: Request timeout duration in seconds. 

325 num_ctx: Context window size for the model. 

326 

327 Examples: 

328 >>> config = OllamaConfig( 

329 ... base_url="http://localhost:11434", 

330 ... model="llama2", 

331 ... temperature=0.5 

332 ... ) 

333 >>> config.model 

334 'llama2' 

335 >>> config.base_url 

336 'http://localhost:11434' 

337 """ 

338 

339 base_url: str = Field(default="http://localhost:11434", description="Ollama base URL") 

340 model: str = Field(default="llama2", description="Model name to use") 

341 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

342 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

343 num_ctx: Optional[int] = Field(None, gt=0, description="Context window size") 

344 

345 model_config = {"json_schema_extra": {"example": {"base_url": "http://localhost:11434", "model": "llama2", "temperature": 0.7}}} 

346 

347 

348class OpenAIConfig(BaseModel): 

349 """ 

350 Configuration for OpenAI provider (non-Azure). 

351 

352 Defines parameters for connecting to OpenAI API (or OpenAI-compatible endpoints). 

353 

354 Attributes: 

355 api_key: OpenAI API authentication key. 

356 base_url: Optional base URL for OpenAI-compatible endpoints. 

357 model: Model identifier (e.g., gpt-4, gpt-3.5-turbo). 

358 temperature: Sampling temperature for response generation (0.0-2.0). 

359 max_tokens: Maximum number of tokens to generate. 

360 timeout: Request timeout duration in seconds. 

361 max_retries: Maximum number of retry attempts for failed requests. 

362 

363 Examples: 

364 >>> config = OpenAIConfig( 

365 ... api_key="sk-...", 

366 ... model="gpt-4", 

367 ... temperature=0.7 

368 ... ) 

369 >>> config.model 

370 'gpt-4' 

371 """ 

372 

373 api_key: str = Field(..., description="OpenAI API key") 

374 base_url: Optional[str] = Field(None, description="Base URL for OpenAI-compatible endpoints") 

375 model: str = Field(default="gpt-4o-mini", description="Model name (e.g., gpt-4, gpt-3.5-turbo)") 

376 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

377 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

378 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

379 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries") 

380 default_headers: Optional[dict] = Field(None, description="optional default headers required by the provider") 

381 

382 model_config = { 

383 "json_schema_extra": { 

384 "example": { 

385 "api_key": "sk-...", 

386 "model": "gpt-4o-mini", 

387 "temperature": 0.7, 

388 } 

389 } 

390 } 

391 

392 

393class AnthropicConfig(BaseModel): 

394 """ 

395 Configuration for Anthropic Claude provider. 

396 

397 Defines parameters for connecting to Anthropic's Claude API. 

398 

399 Attributes: 

400 api_key: Anthropic API authentication key. 

401 model: Claude model identifier (e.g., claude-3-5-sonnet-20241022, claude-3-opus). 

402 temperature: Sampling temperature for response generation (0.0-1.0). 

403 max_tokens: Maximum number of tokens to generate. 

404 timeout: Request timeout duration in seconds. 

405 max_retries: Maximum number of retry attempts for failed requests. 

406 

407 Examples: 

408 >>> config = AnthropicConfig( 

409 ... api_key="sk-ant-...", 

410 ... model="claude-3-5-sonnet-20241022", 

411 ... temperature=0.7 

412 ... ) 

413 >>> config.model 

414 'claude-3-5-sonnet-20241022' 

415 """ 

416 

417 api_key: str = Field(..., description="Anthropic API key") 

418 model: str = Field(default="claude-3-5-sonnet-20241022", description="Claude model name") 

419 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature") 

420 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate") 

421 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

422 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries") 

423 

424 model_config = { 

425 "json_schema_extra": { 

426 "example": { 

427 "api_key": "sk-ant-...", 

428 "model": "claude-3-5-sonnet-20241022", 

429 "temperature": 0.7, 

430 "max_tokens": 4096, 

431 } 

432 } 

433 } 

434 

435 

436class AWSBedrockConfig(BaseModel): 

437 """ 

438 Configuration for AWS Bedrock provider. 

439 

440 Defines parameters for connecting to AWS Bedrock LLM services. 

441 

442 Attributes: 

443 model_id: Bedrock model identifier (e.g., anthropic.claude-v2, amazon.titan-text-express-v1). 

444 region_name: AWS region name (e.g., us-east-1, us-west-2). 

445 aws_access_key_id: Optional AWS access key ID (uses default credential chain if not provided). 

446 aws_secret_access_key: Optional AWS secret access key. 

447 aws_session_token: Optional AWS session token for temporary credentials. 

448 temperature: Sampling temperature for response generation (0.0-1.0). 

449 max_tokens: Maximum number of tokens to generate. 

450 

451 Examples: 

452 >>> config = AWSBedrockConfig( 

453 ... model_id="anthropic.claude-v2", 

454 ... region_name="us-east-1", 

455 ... temperature=0.7 

456 ... ) 

457 >>> config.model_id 

458 'anthropic.claude-v2' 

459 """ 

460 

461 model_id: str = Field(..., description="Bedrock model ID") 

462 region_name: str = Field(default="us-east-1", description="AWS region name") 

463 aws_access_key_id: Optional[str] = Field(None, description="AWS access key ID") 

464 aws_secret_access_key: Optional[str] = Field(None, description="AWS secret access key") 

465 aws_session_token: Optional[str] = Field(None, description="AWS session token") 

466 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature") 

467 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate") 

468 

469 model_config = { 

470 "json_schema_extra": { 

471 "example": { 

472 "model_id": "anthropic.claude-v2", 

473 "region_name": "us-east-1", 

474 "temperature": 0.7, 

475 "max_tokens": 4096, 

476 } 

477 } 

478 } 

479 

480 

481class WatsonxConfig(BaseModel): 

482 """ 

483 Configuration for IBM watsonx.ai provider. 

484 

485 Defines parameters for connecting to IBM watsonx.ai services. 

486 

487 Attributes: 

488 api_key: IBM Cloud API key for authentication. 

489 url: IBM watsonx.ai service endpoint URL. 

490 project_id: IBM watsonx.ai project ID for context. 

491 model_id: Model identifier (e.g., ibm/granite-13b-chat-v2, meta-llama/llama-3-70b-instruct). 

492 temperature: Sampling temperature for response generation (0.0-2.0). 

493 max_new_tokens: Maximum number of tokens to generate. 

494 min_new_tokens: Minimum number of tokens to generate. 

495 decoding_method: Decoding method ('sample', 'greedy'). 

496 top_k: Top-K sampling parameter. 

497 top_p: Top-P (nucleus) sampling parameter. 

498 timeout: Request timeout duration in seconds. 

499 

500 Examples: 

501 >>> config = WatsonxConfig( 

502 ... api_key="your-api-key", 

503 ... url="https://us-south.ml.cloud.ibm.com", 

504 ... project_id="your-project-id", 

505 ... model_id="ibm/granite-13b-chat-v2" 

506 ... ) 

507 >>> config.model_id 

508 'ibm/granite-13b-chat-v2' 

509 """ 

510 

511 api_key: str = Field(..., description="IBM Cloud API key") 

512 url: str = Field(default="https://us-south.ml.cloud.ibm.com", description="watsonx.ai endpoint URL") 

513 project_id: str = Field(..., description="watsonx.ai project ID") 

514 model_id: str = Field(default="ibm/granite-13b-chat-v2", description="Model identifier") 

515 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

516 max_new_tokens: Optional[int] = Field(default=1024, gt=0, description="Maximum tokens to generate") 

517 min_new_tokens: Optional[int] = Field(default=1, gt=0, description="Minimum tokens to generate") 

518 decoding_method: str = Field(default="sample", description="Decoding method (sample or greedy)") 

519 top_k: Optional[int] = Field(default=50, gt=0, description="Top-K sampling") 

520 top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-P sampling") 

521 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

522 

523 model_config = { 

524 "json_schema_extra": { 

525 "example": { 

526 "api_key": "your-api-key", 

527 "url": "https://us-south.ml.cloud.ibm.com", 

528 "project_id": "your-project-id", 

529 "model_id": "ibm/granite-13b-chat-v2", 

530 "temperature": 0.7, 

531 "max_new_tokens": 1024, 

532 } 

533 } 

534 } 

535 

536 

537class GatewayConfig(BaseModel): 

538 """ 

539 Configuration for ContextForge internal LLM provider. 

540 

541 Allows LLM Chat to use models configured in the gateway's LLM Settings. 

542 The gateway routes requests to the appropriate configured provider. 

543 

544 Attributes: 

545 model: Model ID (gateway model ID or provider model ID). 

546 base_url: Gateway internal API URL (defaults to self). 

547 temperature: Sampling temperature for response generation. 

548 max_tokens: Maximum tokens to generate. 

549 timeout: Request timeout in seconds. 

550 

551 Examples: 

552 >>> config = GatewayConfig(model="gpt-4o") 

553 >>> config.model 

554 'gpt-4o' 

555 """ 

556 

557 model: str = Field(..., description="Gateway model ID to use") 

558 base_url: Optional[str] = Field(None, description="Gateway internal API URL (optional, defaults to self)") 

559 temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

560 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

561 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

562 

563 model_config = { 

564 "json_schema_extra": { 

565 "example": { 

566 "model": "gpt-4o", 

567 "temperature": 0.7, 

568 "max_tokens": 4096, 

569 } 

570 } 

571 } 

572 

573 

574class LLMConfig(BaseModel): 

575 """ 

576 Configuration for LLM provider. 

577 

578 Unified configuration class that supports multiple LLM providers through 

579 a discriminated union pattern. 

580 

581 Attributes: 

582 provider: Type of LLM provider (azure_openai, openai, anthropic, aws_bedrock, or ollama). 

583 config: Provider-specific configuration object. 

584 

585 Examples: 

586 >>> # Azure OpenAI configuration 

587 >>> config = LLMConfig( 

588 ... provider="azure_openai", 

589 ... config=AzureOpenAIConfig( 

590 ... api_key="key", 

591 ... azure_endpoint="https://example.com/", 

592 ... azure_deployment="gpt-4" 

593 ... ) 

594 ... ) 

595 >>> config.provider 

596 'azure_openai' 

597 

598 >>> # OpenAI configuration 

599 >>> config = LLMConfig( 

600 ... provider="openai", 

601 ... config=OpenAIConfig( 

602 ... api_key="sk-...", 

603 ... model="gpt-4" 

604 ... ) 

605 ... ) 

606 >>> config.provider 

607 'openai' 

608 

609 >>> # Ollama configuration 

610 >>> config = LLMConfig( 

611 ... provider="ollama", 

612 ... config=OllamaConfig(model="llama2") 

613 ... ) 

614 >>> config.provider 

615 'ollama' 

616 

617 >>> # Watsonx configuration 

618 >>> config = LLMConfig( 

619 ... provider="watsonx", 

620 ... config=WatsonxConfig( 

621 ... url="https://us-south.ml.cloud.ibm.com", 

622 ... model_id="ibm/granite-13b-instruct-v2", 

623 ... project_id="YOUR_PROJECT_ID", 

624 ... api_key="YOUR_API") 

625 ... ) 

626 >>> config.provider 

627 'watsonx' 

628 """ 

629 

630 provider: Literal["azure_openai", "openai", "anthropic", "aws_bedrock", "ollama", "watsonx", "gateway"] = Field(..., description="LLM provider type") 

631 config: Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig] = Field(..., description="Provider-specific configuration") 

632 

633 @field_validator("config", mode="before") 

634 @classmethod 

635 def validate_config_type(cls, v: Any, info) -> Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig]: 

636 """ 

637 Validate and convert config dictionary to appropriate provider type. 

638 

639 Args: 

640 v: Configuration value (dict or config object). 

641 info: Validation context containing provider information. 

642 

643 Returns: 

644 Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig]: Validated configuration object. 

645 

646 Examples: 

647 >>> # Automatically converts dict to appropriate config type 

648 >>> config_dict = { 

649 ... "api_key": "key", 

650 ... "azure_endpoint": "https://example.com/", 

651 ... "azure_deployment": "gpt-4" 

652 ... } 

653 >>> # Used internally by Pydantic during validation 

654 """ 

655 provider = info.data.get("provider") 

656 

657 if isinstance(v, dict): 

658 if provider == "azure_openai": 

659 return AzureOpenAIConfig(**v) 

660 if provider == "openai": 

661 return OpenAIConfig(**v) 

662 if provider == "anthropic": 

663 return AnthropicConfig(**v) 

664 if provider == "aws_bedrock": 

665 return AWSBedrockConfig(**v) 

666 if provider == "ollama": 

667 return OllamaConfig(**v) 

668 if provider == "watsonx": 

669 return WatsonxConfig(**v) 

670 if provider == "gateway": 

671 return GatewayConfig(**v) 

672 

673 return v 

674 

675 

676class MCPClientConfig(BaseModel): 

677 """ 

678 Main configuration for MCP client service. 

679 

680 Aggregates all configuration parameters required for the complete MCP client 

681 service, including server connection, LLM provider, and operational settings. 

682 

683 Attributes: 

684 mcp_server: MCP server connection configuration. 

685 llm: LLM provider configuration. 

686 chat_history_max_messages: Maximum messages to retain in chat history. 

687 enable_streaming: Whether to enable streaming responses. 

688 

689 Examples: 

690 >>> config = MCPClientConfig( 

691 ... mcp_server=MCPServerConfig( 

692 ... url="https://mcp-server.example.com/mcp", 

693 ... transport="streamable_http" 

694 ... ), 

695 ... llm=LLMConfig( 

696 ... provider="ollama", 

697 ... config=OllamaConfig(model="llama2") 

698 ... ), 

699 ... chat_history_max_messages=100, 

700 ... enable_streaming=True 

701 ... ) 

702 >>> config.chat_history_max_messages 

703 100 

704 >>> config.enable_streaming 

705 True 

706 """ 

707 

708 mcp_server: MCPServerConfig = Field(..., description="MCP server configuration") 

709 llm: LLMConfig = Field(..., description="LLM provider configuration") 

710 chat_history_max_messages: int = settings.llmchat_chat_history_max_messages 

711 enable_streaming: bool = Field(default=True, description="Enable streaming responses") 

712 

713 model_config = { 

714 "json_schema_extra": { 

715 "example": { 

716 "mcp_server": {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token"}, # nosec B105 - example placeholder 

717 "llm": { 

718 "provider": "azure_openai", 

719 "config": {"api_key": "your-key", "azure_endpoint": "https://your-resource.openai.azure.com/", "azure_deployment": "gpt-4", "api_version": "2024-05-01-preview"}, 

720 }, 

721 } 

722 } 

723 } 

724 

725 

726# ==================== LLM PROVIDER IMPLEMENTATIONS ==================== 

727 

728 

729class AzureOpenAIProvider: 

730 """ 

731 Azure OpenAI provider implementation. 

732 

733 Manages connection and interaction with Azure OpenAI services. 

734 

735 Attributes: 

736 config: Azure OpenAI configuration object. 

737 

738 Examples: 

739 >>> config = AzureOpenAIConfig( 

740 ... api_key="key", 

741 ... azure_endpoint="https://example.openai.azure.com/", 

742 ... azure_deployment="gpt-4" 

743 ... ) 

744 >>> provider = AzureOpenAIProvider(config) 

745 >>> provider.get_model_name() 

746 'gpt-4' 

747 

748 Note: 

749 The LLM instance is lazily initialized on first access for 

750 improved startup performance. 

751 """ 

752 

753 def __init__(self, config: AzureOpenAIConfig): 

754 """ 

755 Initialize Azure OpenAI provider. 

756 

757 Args: 

758 config: Azure OpenAI configuration with API credentials and settings. 

759 

760 Examples: 

761 >>> config = AzureOpenAIConfig( 

762 ... api_key="key", 

763 ... azure_endpoint="https://example.openai.azure.com/", 

764 ... azure_deployment="gpt-4" 

765 ... ) 

766 >>> provider = AzureOpenAIProvider(config) 

767 """ 

768 self.config = config 

769 self._llm = None 

770 logger.info(f"Initializing Azure OpenAI provider with deployment: {config.azure_deployment}") 

771 

772 def get_llm(self, model_type: str = "chat") -> Union[AzureChatOpenAI, AzureOpenAI]: 

773 """ 

774 Get Azure OpenAI LLM instance with lazy initialization. 

775 

776 Creates and caches the Azure OpenAI chat model instance on first call. 

777 Subsequent calls return the cached instance. 

778 

779 Args: 

780 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

781 

782 Returns: 

783 AzureChatOpenAI: Configured Azure OpenAI chat model. 

784 

785 Raises: 

786 Exception: If LLM initialization fails (e.g., invalid credentials). 

787 

788 Examples: 

789 >>> config = AzureOpenAIConfig( 

790 ... api_key="key", 

791 ... azure_endpoint="https://example.openai.azure.com/", 

792 ... azure_deployment="gpt-4" 

793 ... ) 

794 >>> provider = AzureOpenAIProvider(config) 

795 >>> # llm = provider.get_llm() # Returns AzureChatOpenAI instance 

796 """ 

797 if self._llm is None: 

798 try: 

799 if model_type == "chat": 

800 self._llm = AzureChatOpenAI( 

801 api_key=self.config.api_key, 

802 azure_endpoint=self.config.azure_endpoint, 

803 api_version=self.config.api_version, 

804 azure_deployment=self.config.azure_deployment, 

805 model=self.config.model, 

806 temperature=self.config.temperature, 

807 max_tokens=self.config.max_tokens, 

808 timeout=self.config.timeout, 

809 max_retries=self.config.max_retries, 

810 ) 

811 elif model_type == "completion": 

812 self._llm = AzureOpenAI( 

813 api_key=self.config.api_key, 

814 azure_endpoint=self.config.azure_endpoint, 

815 api_version=self.config.api_version, 

816 azure_deployment=self.config.azure_deployment, 

817 model=self.config.model, 

818 temperature=self.config.temperature, 

819 max_tokens=self.config.max_tokens, 

820 timeout=self.config.timeout, 

821 max_retries=self.config.max_retries, 

822 ) 

823 logger.info("Azure OpenAI LLM instance created successfully") 

824 except Exception as e: 

825 logger.error(f"Failed to create Azure OpenAI LLM: {e}") 

826 raise 

827 

828 return self._llm 

829 

830 def get_model_name(self) -> str: 

831 """ 

832 Get the Azure OpenAI model name. 

833 

834 Returns: 

835 str: The model name configured for this provider. 

836 

837 Examples: 

838 >>> config = AzureOpenAIConfig( 

839 ... api_key="key", 

840 ... azure_endpoint="https://example.openai.azure.com/", 

841 ... azure_deployment="gpt-4", 

842 ... model="gpt-4" 

843 ... ) 

844 >>> provider = AzureOpenAIProvider(config) 

845 >>> provider.get_model_name() 

846 'gpt-4' 

847 """ 

848 return self.config.model 

849 

850 

851class OllamaProvider: 

852 """ 

853 Ollama provider implementation. 

854 

855 Manages connection and interaction with Ollama instances for running 

856 open-source language models locally or remotely. 

857 

858 Attributes: 

859 config: Ollama configuration object. 

860 

861 Examples: 

862 >>> config = OllamaConfig( 

863 ... base_url="http://localhost:11434", 

864 ... model="llama2" 

865 ... ) 

866 >>> provider = OllamaProvider(config) 

867 >>> provider.get_model_name() 

868 'llama2' 

869 

870 Note: 

871 Requires Ollama to be running and accessible at the configured base_url. 

872 """ 

873 

874 def __init__(self, config: OllamaConfig): 

875 """ 

876 Initialize Ollama provider. 

877 

878 Args: 

879 config: Ollama configuration with server URL and model settings. 

880 

881 Examples: 

882 >>> config = OllamaConfig(model="llama2") 

883 >>> provider = OllamaProvider(config) 

884 """ 

885 self.config = config 

886 self._llm = None 

887 logger.info(f"Initializing Ollama provider with model: {config.model}") 

888 

889 def get_llm(self, model_type: str = "chat") -> Union[ChatOllama, OllamaLLM]: 

890 """ 

891 Get Ollama LLM instance with lazy initialization. 

892 

893 Creates and caches the Ollama chat model instance on first call. 

894 Subsequent calls return the cached instance. 

895 

896 Args: 

897 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

898 

899 Returns: 

900 ChatOllama: Configured Ollama chat model. 

901 

902 Raises: 

903 Exception: If LLM initialization fails (e.g., Ollama not running). 

904 

905 Examples: 

906 >>> config = OllamaConfig(model="llama2") 

907 >>> provider = OllamaProvider(config) 

908 >>> # llm = provider.get_llm() # Returns ChatOllama instance 

909 """ 

910 if self._llm is None: 

911 try: 

912 # Build model kwargs 

913 model_kwargs = {} 

914 if self.config.num_ctx is not None: 

915 model_kwargs["num_ctx"] = self.config.num_ctx 

916 

917 if model_type == "chat": 

918 self._llm = ChatOllama(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs) 

919 elif model_type == "completion": 

920 self._llm = OllamaLLM(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs) 

921 logger.info("Ollama LLM instance created successfully") 

922 except Exception as e: 

923 logger.error(f"Failed to create Ollama LLM: {e}") 

924 raise 

925 

926 return self._llm 

927 

928 def get_model_name(self) -> str: 

929 """Get the model name. 

930 

931 Returns: 

932 str: The model name 

933 """ 

934 return self.config.model 

935 

936 

937class OpenAIProvider: 

938 """ 

939 OpenAI provider implementation (non-Azure). 

940 

941 Manages connection and interaction with OpenAI API or OpenAI-compatible endpoints. 

942 

943 Attributes: 

944 config: OpenAI configuration object. 

945 

946 Examples: 

947 >>> config = OpenAIConfig( 

948 ... api_key="sk-...", 

949 ... model="gpt-4" 

950 ... ) 

951 >>> provider = OpenAIProvider(config) 

952 >>> provider.get_model_name() 

953 'gpt-4' 

954 

955 Note: 

956 The LLM instance is lazily initialized on first access for 

957 improved startup performance. 

958 """ 

959 

960 def __init__(self, config: OpenAIConfig): 

961 """ 

962 Initialize OpenAI provider. 

963 

964 Args: 

965 config: OpenAI configuration with API key and settings. 

966 

967 Examples: 

968 >>> config = OpenAIConfig( 

969 ... api_key="sk-...", 

970 ... model="gpt-4" 

971 ... ) 

972 >>> provider = OpenAIProvider(config) 

973 """ 

974 self.config = config 

975 self._llm = None 

976 logger.info(f"Initializing OpenAI provider with model: {config.model}") 

977 

978 def get_llm(self, model_type="chat") -> Union[ChatOpenAI, OpenAI]: 

979 """ 

980 Get OpenAI LLM instance with lazy initialization. 

981 

982 Creates and caches the OpenAI chat model instance on first call. 

983 Subsequent calls return the cached instance. 

984 

985 Args: 

986 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

987 

988 Returns: 

989 ChatOpenAI: Configured OpenAI chat model. 

990 

991 Raises: 

992 Exception: If LLM initialization fails (e.g., invalid credentials). 

993 

994 Examples: 

995 >>> config = OpenAIConfig( 

996 ... api_key="sk-...", 

997 ... model="gpt-4" 

998 ... ) 

999 >>> provider = OpenAIProvider(config) 

1000 >>> # llm = provider.get_llm() # Returns ChatOpenAI instance 

1001 """ 

1002 if self._llm is None: 

1003 try: 

1004 kwargs = { 

1005 "api_key": self.config.api_key, 

1006 "model": self.config.model, 

1007 "temperature": self.config.temperature, 

1008 "max_tokens": self.config.max_tokens, 

1009 "timeout": self.config.timeout, 

1010 "max_retries": self.config.max_retries, 

1011 } 

1012 

1013 if self.config.base_url: 

1014 kwargs["base_url"] = self.config.base_url 

1015 

1016 # add default headers if present 

1017 if self.config.default_headers is not None: 

1018 kwargs["default_headers"] = self.config.default_headers 

1019 

1020 if model_type == "chat": 

1021 self._llm = ChatOpenAI(**kwargs) 

1022 elif model_type == "completion": 

1023 self._llm = OpenAI(**kwargs) 

1024 

1025 logger.info("OpenAI LLM instance created successfully") 

1026 except Exception as e: 

1027 logger.error(f"Failed to create OpenAI LLM: {e}") 

1028 raise 

1029 

1030 return self._llm 

1031 

1032 def get_model_name(self) -> str: 

1033 """ 

1034 Get the OpenAI model name. 

1035 

1036 Returns: 

1037 str: The model name configured for this provider. 

1038 

1039 Examples: 

1040 >>> config = OpenAIConfig( 

1041 ... api_key="sk-...", 

1042 ... model="gpt-4" 

1043 ... ) 

1044 >>> provider = OpenAIProvider(config) 

1045 >>> provider.get_model_name() 

1046 'gpt-4' 

1047 """ 

1048 return self.config.model 

1049 

1050 

1051class AnthropicProvider: 

1052 """ 

1053 Anthropic Claude provider implementation. 

1054 

1055 Manages connection and interaction with Anthropic's Claude API. 

1056 

1057 Attributes: 

1058 config: Anthropic configuration object. 

1059 

1060 Examples: 

1061 >>> config = AnthropicConfig( # doctest: +SKIP 

1062 ... api_key="sk-ant-...", 

1063 ... model="claude-3-5-sonnet-20241022" 

1064 ... ) 

1065 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1066 >>> provider.get_model_name() # doctest: +SKIP 

1067 'claude-3-5-sonnet-20241022' 

1068 

1069 Note: 

1070 Requires langchain-anthropic package to be installed. 

1071 """ 

1072 

1073 def __init__(self, config: AnthropicConfig): 

1074 """ 

1075 Initialize Anthropic provider. 

1076 

1077 Args: 

1078 config: Anthropic configuration with API key and settings. 

1079 

1080 Raises: 

1081 ImportError: If langchain-anthropic is not installed. 

1082 

1083 Examples: 

1084 >>> config = AnthropicConfig( # doctest: +SKIP 

1085 ... api_key="sk-ant-...", 

1086 ... model="claude-3-5-sonnet-20241022" 

1087 ... ) 

1088 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1089 """ 

1090 if not _ANTHROPIC_AVAILABLE: 

1091 raise ImportError("Anthropic provider requires langchain-anthropic package. Install it with: pip install langchain-anthropic") 

1092 

1093 self.config = config 

1094 self._llm = None 

1095 logger.info(f"Initializing Anthropic provider with model: {config.model}") 

1096 

1097 def get_llm(self, model_type: str = "chat") -> Union[ChatAnthropic, AnthropicLLM]: 

1098 """ 

1099 Get Anthropic LLM instance with lazy initialization. 

1100 

1101 Creates and caches the Anthropic chat model instance on first call. 

1102 Subsequent calls return the cached instance. 

1103 

1104 Args: 

1105 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

1106 

1107 Returns: 

1108 ChatAnthropic: Configured Anthropic chat model. 

1109 

1110 Raises: 

1111 Exception: If LLM initialization fails (e.g., invalid credentials). 

1112 

1113 Examples: 

1114 >>> config = AnthropicConfig( # doctest: +SKIP 

1115 ... api_key="sk-ant-...", 

1116 ... model="claude-3-5-sonnet-20241022" 

1117 ... ) 

1118 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1119 >>> # llm = provider.get_llm() # Returns ChatAnthropic instance 

1120 """ 

1121 if self._llm is None: 

1122 try: 

1123 if model_type == "chat": 

1124 self._llm = ChatAnthropic( 

1125 api_key=self.config.api_key, 

1126 model=self.config.model, 

1127 temperature=self.config.temperature, 

1128 max_tokens=self.config.max_tokens, 

1129 timeout=self.config.timeout, 

1130 max_retries=self.config.max_retries, 

1131 ) 

1132 elif model_type == "completion": 

1133 self._llm = AnthropicLLM( 

1134 api_key=self.config.api_key, 

1135 model=self.config.model, 

1136 temperature=self.config.temperature, 

1137 max_tokens=self.config.max_tokens, 

1138 timeout=self.config.timeout, 

1139 max_retries=self.config.max_retries, 

1140 ) 

1141 logger.info("Anthropic LLM instance created successfully") 

1142 except Exception as e: 

1143 logger.error(f"Failed to create Anthropic LLM: {e}") 

1144 raise 

1145 

1146 return self._llm 

1147 

1148 def get_model_name(self) -> str: 

1149 """ 

1150 Get the Anthropic model name. 

1151 

1152 Returns: 

1153 str: The model name configured for this provider. 

1154 

1155 Examples: 

1156 >>> config = AnthropicConfig( # doctest: +SKIP 

1157 ... api_key="sk-ant-...", 

1158 ... model="claude-3-5-sonnet-20241022" 

1159 ... ) 

1160 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1161 >>> provider.get_model_name() # doctest: +SKIP 

1162 'claude-3-5-sonnet-20241022' 

1163 """ 

1164 return self.config.model 

1165 

1166 

1167class AWSBedrockProvider: 

1168 """ 

1169 AWS Bedrock provider implementation. 

1170 

1171 Manages connection and interaction with AWS Bedrock LLM services. 

1172 

1173 Attributes: 

1174 config: AWS Bedrock configuration object. 

1175 

1176 Examples: 

1177 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1178 ... model_id="anthropic.claude-v2", 

1179 ... region_name="us-east-1" 

1180 ... ) 

1181 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1182 >>> provider.get_model_name() # doctest: +SKIP 

1183 'anthropic.claude-v2' 

1184 

1185 Note: 

1186 Requires langchain-aws package and boto3 to be installed. 

1187 Uses AWS default credential chain if credentials not explicitly provided. 

1188 """ 

1189 

1190 def __init__(self, config: AWSBedrockConfig): 

1191 """ 

1192 Initialize AWS Bedrock provider. 

1193 

1194 Args: 

1195 config: AWS Bedrock configuration with model ID and settings. 

1196 

1197 Raises: 

1198 ImportError: If langchain-aws is not installed. 

1199 

1200 Examples: 

1201 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1202 ... model_id="anthropic.claude-v2", 

1203 ... region_name="us-east-1" 

1204 ... ) 

1205 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1206 """ 

1207 if not _BEDROCK_AVAILABLE: 

1208 raise ImportError("AWS Bedrock provider requires langchain-aws package. Install it with: pip install langchain-aws boto3") 

1209 

1210 self.config = config 

1211 self._llm = None 

1212 logger.info(f"Initializing AWS Bedrock provider with model: {config.model_id}") 

1213 

1214 def get_llm(self, model_type: str = "chat") -> Union[ChatBedrock, BedrockLLM]: 

1215 """ 

1216 Get AWS Bedrock LLM instance with lazy initialization. 

1217 

1218 Creates and caches the Bedrock chat model instance on first call. 

1219 Subsequent calls return the cached instance. 

1220 

1221 Args: 

1222 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

1223 

1224 Returns: 

1225 ChatBedrock: Configured AWS Bedrock chat model. 

1226 

1227 Raises: 

1228 Exception: If LLM initialization fails (e.g., invalid credentials, permissions). 

1229 

1230 Examples: 

1231 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1232 ... model_id="anthropic.claude-v2", 

1233 ... region_name="us-east-1" 

1234 ... ) 

1235 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1236 >>> # llm = provider.get_llm() # Returns ChatBedrock instance 

1237 """ 

1238 if self._llm is None: 

1239 try: 

1240 # Build credentials dict if provided 

1241 credentials_kwargs = {} 

1242 if self.config.aws_access_key_id: 

1243 credentials_kwargs["aws_access_key_id"] = self.config.aws_access_key_id 

1244 if self.config.aws_secret_access_key: 

1245 credentials_kwargs["aws_secret_access_key"] = self.config.aws_secret_access_key 

1246 if self.config.aws_session_token: 

1247 credentials_kwargs["aws_session_token"] = self.config.aws_session_token 

1248 

1249 if model_type == "chat": 

1250 self._llm = ChatBedrock( 

1251 model_id=self.config.model_id, 

1252 region_name=self.config.region_name, 

1253 model_kwargs={ 

1254 "temperature": self.config.temperature, 

1255 "max_tokens": self.config.max_tokens, 

1256 }, 

1257 **credentials_kwargs, 

1258 ) 

1259 elif model_type == "completion": 

1260 self._llm = BedrockLLM( 

1261 model_id=self.config.model_id, 

1262 region_name=self.config.region_name, 

1263 model_kwargs={ 

1264 "temperature": self.config.temperature, 

1265 "max_tokens": self.config.max_tokens, 

1266 }, 

1267 **credentials_kwargs, 

1268 ) 

1269 logger.info("AWS Bedrock LLM instance created successfully") 

1270 except Exception as e: 

1271 logger.error(f"Failed to create AWS Bedrock LLM: {e}") 

1272 raise 

1273 

1274 return self._llm 

1275 

1276 def get_model_name(self) -> str: 

1277 """ 

1278 Get the AWS Bedrock model ID. 

1279 

1280 Returns: 

1281 str: The model ID configured for this provider. 

1282 

1283 Examples: 

1284 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1285 ... model_id="anthropic.claude-v2", 

1286 ... region_name="us-east-1" 

1287 ... ) 

1288 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1289 >>> provider.get_model_name() # doctest: +SKIP 

1290 'anthropic.claude-v2' 

1291 """ 

1292 return self.config.model_id 

1293 

1294 

1295class WatsonxProvider: 

1296 """ 

1297 IBM watsonx.ai provider implementation. 

1298 

1299 Manages connection and interaction with IBM watsonx.ai services. 

1300 

1301 Attributes: 

1302 config: IBM watsonx.ai configuration object. 

1303 

1304 Examples: 

1305 >>> config = WatsonxConfig( # doctest: +SKIP 

1306 ... api_key="key", 

1307 ... url="https://us-south.ml.cloud.ibm.com", 

1308 ... project_id="project-id", 

1309 ... model_id="ibm/granite-13b-chat-v2" 

1310 ... ) 

1311 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1312 >>> provider.get_model_name() # doctest: +SKIP 

1313 'ibm/granite-13b-chat-v2' 

1314 

1315 Note: 

1316 Requires langchain-ibm package to be installed. 

1317 """ 

1318 

1319 def __init__(self, config: WatsonxConfig): 

1320 """ 

1321 Initialize IBM watsonx.ai provider. 

1322 

1323 Args: 

1324 config: IBM watsonx.ai configuration with credentials and settings. 

1325 

1326 Raises: 

1327 ImportError: If langchain-ibm is not installed. 

1328 

1329 Examples: 

1330 >>> config = WatsonxConfig( # doctest: +SKIP 

1331 ... api_key="key", 

1332 ... url="https://us-south.ml.cloud.ibm.com", 

1333 ... project_id="project-id", 

1334 ... model_id="ibm/granite-13b-chat-v2" 

1335 ... ) 

1336 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1337 """ 

1338 if not _WATSONX_AVAILABLE: 

1339 raise ImportError("IBM watsonx.ai provider requires langchain-ibm package. Install it with: pip install langchain-ibm") 

1340 self.config = config 

1341 self.llm = None 

1342 logger.info(f"Initializing IBM watsonx.ai provider with model {config.model_id}") 

1343 

1344 def get_llm(self, model_type="chat") -> Union[WatsonxLLM, ChatWatsonx]: 

1345 """ 

1346 Get IBM watsonx.ai LLM instance with lazy initialization. 

1347 

1348 Creates and caches the watsonx LLM instance on first call. 

1349 Subsequent calls return the cached instance. 

1350 

1351 Args: 

1352 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

1353 

1354 Returns: 

1355 WatsonxLLM: Configured IBM watsonx.ai LLM model. 

1356 

1357 Raises: 

1358 Exception: If LLM initialization fails (e.g., invalid credentials). 

1359 

1360 Examples: 

1361 >>> config = WatsonxConfig( # doctest: +SKIP 

1362 ... api_key="key", 

1363 ... url="https://us-south.ml.cloud.ibm.com", 

1364 ... project_id="project-id", 

1365 ... model_id="ibm/granite-13b-chat-v2" 

1366 ... ) 

1367 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1368 >>> #llm = provider.get_llm() # Returns WatsonxLLM instance 

1369 """ 

1370 if self.llm is None: 

1371 try: 

1372 # Build parameters dict 

1373 params = { 

1374 "decoding_method": self.config.decoding_method, 

1375 "temperature": self.config.temperature, 

1376 "max_new_tokens": self.config.max_new_tokens, 

1377 "min_new_tokens": self.config.min_new_tokens, 

1378 } 

1379 

1380 if self.config.top_k is not None: 

1381 params["top_k"] = self.config.top_k 

1382 if self.config.top_p is not None: 

1383 params["top_p"] = self.config.top_p 

1384 if model_type == "completion": 

1385 # Initialize WatsonxLLM 

1386 self.llm = WatsonxLLM( 

1387 apikey=self.config.api_key, 

1388 url=self.config.url, 

1389 project_id=self.config.project_id, 

1390 model_id=self.config.model_id, 

1391 params=params, 

1392 ) 

1393 elif model_type == "chat": 

1394 # Initialize Chat WatsonxLLM 

1395 self.llm = ChatWatsonx( 

1396 apikey=self.config.api_key, 

1397 url=self.config.url, 

1398 project_id=self.config.project_id, 

1399 model_id=self.config.model_id, 

1400 params=params, 

1401 ) 

1402 logger.info("IBM watsonx.ai LLM instance created successfully") 

1403 except Exception as e: 

1404 logger.error(f"Failed to create IBM watsonx.ai LLM: {e}") 

1405 raise 

1406 return self.llm 

1407 

1408 def get_model_name(self) -> str: 

1409 """ 

1410 Get the IBM watsonx.ai model ID. 

1411 

1412 Returns: 

1413 str: The model ID configured for this provider. 

1414 

1415 Examples: 

1416 >>> config = WatsonxConfig( # doctest: +SKIP 

1417 ... api_key="key", 

1418 ... url="https://us-south.ml.cloud.ibm.com", 

1419 ... project_id="project-id", 

1420 ... model_id="ibm/granite-13b-chat-v2" 

1421 ... ) 

1422 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1423 >>> provider.get_model_name() # doctest: +SKIP 

1424 'ibm/granite-13b-chat-v2' 

1425 """ 

1426 return self.config.model_id 

1427 

1428 

1429class GatewayProvider: 

1430 """ 

1431 Gateway provider implementation for using models configured in LLM Settings. 

1432 

1433 Routes LLM requests through the gateway's configured providers, allowing 

1434 users to use models set up via the Admin UI's LLM Settings without needing 

1435 to configure credentials in environment variables or API requests. 

1436 

1437 Attributes: 

1438 config: Gateway configuration with model ID. 

1439 llm: Lazily initialized LLM instance. 

1440 

1441 Examples: 

1442 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1443 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1444 >>> provider.get_model_name() # doctest: +SKIP 

1445 'gpt-4o' 

1446 

1447 Note: 

1448 Requires models to be configured via Admin UI -> Settings -> LLM Settings. 

1449 """ 

1450 

1451 def __init__(self, config: GatewayConfig): 

1452 """ 

1453 Initialize Gateway provider. 

1454 

1455 Args: 

1456 config: Gateway configuration with model ID and optional settings. 

1457 

1458 Examples: 

1459 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1460 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1461 """ 

1462 self.config = config 

1463 self.llm = None 

1464 self._model_name: Optional[str] = None 

1465 self._underlying_provider = None 

1466 logger.info(f"Initializing Gateway provider with model: {config.model}") 

1467 

1468 def get_llm(self, model_type: str = "chat") -> Union[BaseChatModel, Any]: 

1469 """ 

1470 Get LLM instance by looking up model from gateway's LLM Settings. 

1471 

1472 Fetches the model configuration from the database, decrypts API keys, 

1473 and creates the appropriate LangChain LLM instance based on provider type. 

1474 

1475 Args: 

1476 model_type: Type of model to return ('chat' or 'completion'). Defaults to 'chat'. 

1477 

1478 Returns: 

1479 Union[BaseChatModel, Any]: Configured LangChain chat or completion model instance. 

1480 

1481 Raises: 

1482 ValueError: If model not found or provider not enabled. 

1483 ImportError: If required provider package not installed. 

1484 

1485 Examples: 

1486 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1487 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1488 >>> llm = provider.get_llm() # doctest: +SKIP 

1489 

1490 Note: 

1491 The LLM instance is lazily initialized and cached by model_type. 

1492 """ 

1493 if self.llm is not None: 

1494 return self.llm 

1495 

1496 # Import here to avoid circular imports 

1497 # First-Party 

1498 from mcpgateway.db import LLMModel, LLMProvider, SessionLocal # pylint: disable=import-outside-toplevel 

1499 from mcpgateway.services.llm_provider_service import decrypt_provider_config_for_runtime # pylint: disable=import-outside-toplevel 

1500 from mcpgateway.utils.services_auth import decode_auth # pylint: disable=import-outside-toplevel 

1501 

1502 model_id = self.config.model 

1503 

1504 with SessionLocal() as db: 

1505 # Try to find model by UUID first, then by model_id 

1506 model = db.query(LLMModel).filter(LLMModel.id == model_id).first() 

1507 if not model: 

1508 model = db.query(LLMModel).filter(LLMModel.model_id == model_id).first() 

1509 

1510 if not model: 

1511 raise ValueError(f"Model '{model_id}' not found in LLM Settings. Configure it via Admin UI -> Settings -> LLM Settings.") 

1512 

1513 if not model.enabled: 

1514 raise ValueError(f"Model '{model.model_id}' is disabled. Enable it in LLM Settings.") 

1515 

1516 # Get the provider 

1517 provider = db.query(LLMProvider).filter(LLMProvider.id == model.provider_id).first() 

1518 if not provider: 

1519 raise ValueError(f"Provider not found for model '{model.model_id}'") 

1520 

1521 if not provider.enabled: 

1522 raise ValueError(f"Provider '{provider.name}' is disabled. Enable it in LLM Settings.") 

1523 

1524 # Get decrypted API key 

1525 api_key = None 

1526 if provider.api_key: 

1527 auth_data = decode_auth(provider.api_key) 

1528 if isinstance(auth_data, dict): 

1529 api_key = auth_data.get("api_key") 

1530 else: 

1531 api_key = auth_data 

1532 

1533 # Store model name for get_model_name() 

1534 self._model_name = model.model_id 

1535 

1536 # Get temperature - use config override or provider default 

1537 temperature = self.config.temperature if self.config.temperature is not None else (provider.default_temperature or 0.7) 

1538 max_tokens = self.config.max_tokens or model.max_output_tokens 

1539 

1540 # Create appropriate LLM based on provider type 

1541 provider_type = provider.provider_type.lower() 

1542 config = decrypt_provider_config_for_runtime(provider.config) 

1543 

1544 # Common kwargs 

1545 kwargs: Dict[str, Any] = { 

1546 "temperature": temperature, 

1547 "timeout": self.config.timeout, 

1548 } 

1549 

1550 if provider_type == "openai": 

1551 kwargs.update( 

1552 { 

1553 "api_key": api_key, 

1554 "model": model.model_id, 

1555 "max_tokens": max_tokens, 

1556 } 

1557 ) 

1558 if provider.api_base: 

1559 kwargs["base_url"] = provider.api_base 

1560 

1561 # Handle default headers 

1562 if config.get("default_headers"): 

1563 kwargs["default_headers"] = config["default_headers"] 

1564 elif hasattr(self.config, "default_headers") and self.config.default_headers: # type: ignore 

1565 kwargs["default_headers"] = self.config.default_headers 

1566 

1567 if model_type == "chat": 

1568 self.llm = ChatOpenAI(**kwargs) 

1569 else: 

1570 self.llm = OpenAI(**kwargs) 

1571 

1572 elif provider_type == "azure_openai": 

1573 if not provider.api_base: 

1574 raise ValueError("Azure OpenAI requires base_url (azure_endpoint) to be configured") 

1575 

1576 azure_deployment = config.get("azure_deployment", model.model_id) 

1577 api_version = config.get("api_version", "2024-05-01-preview") 

1578 max_retries = config.get("max_retries", 2) 

1579 

1580 kwargs.update( 

1581 { 

1582 "api_key": api_key, 

1583 "azure_endpoint": provider.api_base, 

1584 "azure_deployment": azure_deployment, 

1585 "api_version": api_version, 

1586 "model": model.model_id, 

1587 "max_tokens": int(max_tokens) if max_tokens is not None else None, 

1588 "max_retries": max_retries, 

1589 } 

1590 ) 

1591 

1592 if model_type == "chat": 

1593 self.llm = AzureChatOpenAI(**kwargs) 

1594 else: 

1595 self.llm = AzureOpenAI(**kwargs) 

1596 

1597 elif provider_type == "anthropic": 

1598 if not _ANTHROPIC_AVAILABLE: 

1599 raise ImportError("Anthropic provider requires langchain-anthropic. Install with: pip install langchain-anthropic") 

1600 

1601 # Anthropic uses 'model_name' instead of 'model' 

1602 anthropic_kwargs = { 

1603 "api_key": api_key, 

1604 "model_name": model.model_id, 

1605 "max_tokens": max_tokens or 4096, 

1606 "temperature": temperature, 

1607 "timeout": self.config.timeout, 

1608 "default_request_timeout": self.config.timeout, 

1609 } 

1610 

1611 if model_type == "chat": 

1612 self.llm = ChatAnthropic(**anthropic_kwargs) 

1613 else: 

1614 # Generic Anthropic completion model if needed, though mostly chat used now 

1615 if AnthropicLLM: 

1616 llm_kwargs = anthropic_kwargs.copy() 

1617 llm_kwargs["model"] = llm_kwargs.pop("model_name") 

1618 self.llm = AnthropicLLM(**llm_kwargs) 

1619 else: 

1620 raise ImportError("Anthropic completion model (AnthropicLLM) not available") 

1621 

1622 elif provider_type == "bedrock": 

1623 if not _BEDROCK_AVAILABLE: 

1624 raise ImportError("AWS Bedrock provider requires langchain-aws. Install with: pip install langchain-aws boto3") 

1625 

1626 # Map DB schema keys to boto3 kwargs. 

1627 # DB stores: region, access_key_id, secret_access_key, session_token, profile_name 

1628 # (see llm_provider_configs.AWSBedrockConfig) 

1629 region_name = config.get("region", "us-east-1") 

1630 credentials_kwargs = {} 

1631 if config.get("access_key_id"): 

1632 credentials_kwargs["aws_access_key_id"] = config["access_key_id"] 

1633 if config.get("secret_access_key"): 

1634 credentials_kwargs["aws_secret_access_key"] = config["secret_access_key"] 

1635 if config.get("session_token"): 

1636 credentials_kwargs["aws_session_token"] = config["session_token"] 

1637 if config.get("profile_name"): 

1638 credentials_kwargs["credentials_profile_name"] = config["profile_name"] 

1639 

1640 model_kwargs = { 

1641 "temperature": temperature, 

1642 "max_tokens": max_tokens or 4096, 

1643 } 

1644 

1645 if model_type == "chat": 

1646 self.llm = ChatBedrock( 

1647 model_id=model.model_id, 

1648 region_name=region_name, 

1649 model_kwargs=model_kwargs, 

1650 **credentials_kwargs, 

1651 ) 

1652 else: 

1653 self.llm = BedrockLLM( 

1654 model_id=model.model_id, 

1655 region_name=region_name, 

1656 model_kwargs=model_kwargs, 

1657 **credentials_kwargs, 

1658 ) 

1659 

1660 elif provider_type == "ollama": 

1661 base_url = provider.api_base or "http://localhost:11434" 

1662 num_ctx = config.get("num_ctx") 

1663 

1664 # Explicitly construct kwargs to avoid generic unpacking issues with Pydantic models 

1665 ollama_kwargs = { 

1666 "base_url": base_url, 

1667 "model": model.model_id, 

1668 "temperature": temperature, 

1669 "timeout": self.config.timeout, 

1670 } 

1671 if num_ctx: 

1672 ollama_kwargs["num_ctx"] = num_ctx 

1673 

1674 if model_type == "chat": 

1675 self.llm = ChatOllama(**ollama_kwargs) 

1676 else: 

1677 self.llm = OllamaLLM(**ollama_kwargs) 

1678 

1679 elif provider_type == "watsonx": 

1680 if not _WATSONX_AVAILABLE: 

1681 raise ImportError("IBM watsonx.ai provider requires langchain-ibm. Install with: pip install langchain-ibm") 

1682 

1683 project_id = config.get("project_id") 

1684 if not project_id: 

1685 raise ValueError("IBM watsonx.ai requires project_id in config") 

1686 

1687 url = provider.api_base or "https://us-south.ml.cloud.ibm.com" 

1688 

1689 params = { 

1690 "temperature": temperature, 

1691 "max_new_tokens": max_tokens or 1024, 

1692 "min_new_tokens": config.get("min_new_tokens", 1), 

1693 "decoding_method": config.get("decoding_method", "sample"), 

1694 "top_k": config.get("top_k", 50), 

1695 "top_p": config.get("top_p", 1.0), 

1696 } 

1697 

1698 if model_type == "chat": 

1699 self.llm = ChatWatsonx( 

1700 apikey=api_key, 

1701 url=url, 

1702 project_id=project_id, 

1703 model_id=model.model_id, 

1704 params=params, 

1705 ) 

1706 else: 

1707 self.llm = WatsonxLLM( 

1708 apikey=api_key, 

1709 url=url, 

1710 project_id=project_id, 

1711 model_id=model.model_id, 

1712 params=params, 

1713 ) 

1714 

1715 elif provider_type == "openai_compatible": 

1716 if not provider.api_base: 

1717 raise ValueError("OpenAI-compatible provider requires base_url to be configured") 

1718 

1719 kwargs.update( 

1720 { 

1721 "api_key": api_key or "no-key-required", 

1722 "model": model.model_id, 

1723 "base_url": provider.api_base, 

1724 "max_tokens": max_tokens, 

1725 } 

1726 ) 

1727 

1728 if model_type == "chat": 

1729 self.llm = ChatOpenAI(**kwargs) 

1730 else: 

1731 self.llm = OpenAI(**kwargs) 

1732 

1733 else: 

1734 raise ValueError(f"Unsupported LLM provider: {provider_type}") 

1735 

1736 logger.info(f"Gateway provider created LLM instance for model: {model.model_id} via {provider_type}") 

1737 return self.llm 

1738 

1739 def get_model_name(self) -> str: 

1740 """ 

1741 Get the model name. 

1742 

1743 Returns: 

1744 str: The model name/ID. 

1745 

1746 Examples: 

1747 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1748 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1749 >>> provider.get_model_name() # doctest: +SKIP 

1750 'gpt-4o' 

1751 """ 

1752 return self._model_name or self.config.model 

1753 

1754 

1755class LLMProviderFactory: 

1756 """ 

1757 Factory for creating LLM providers. 

1758 

1759 Implements the Factory pattern to instantiate the appropriate LLM provider 

1760 based on configuration, abstracting away provider-specific initialization. 

1761 

1762 Examples: 

1763 >>> config = LLMConfig( 

1764 ... provider="ollama", 

1765 ... config=OllamaConfig(model="llama2") 

1766 ... ) 

1767 >>> provider = LLMProviderFactory.create(config) 

1768 >>> provider.get_model_name() 

1769 'llama2' 

1770 

1771 Note: 

1772 This factory supports dynamic provider registration and ensures 

1773 type safety through the LLMConfig discriminated union. 

1774 """ 

1775 

1776 @staticmethod 

1777 def create(llm_config: LLMConfig) -> Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: 

1778 """ 

1779 Create an LLM provider based on configuration. 

1780 

1781 Args: 

1782 llm_config: LLM configuration specifying provider type and settings. 

1783 

1784 Returns: 

1785 Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: Instantiated provider. 

1786 

1787 Raises: 

1788 ValueError: If provider type is not supported. 

1789 ImportError: If required provider package is not installed. 

1790 

1791 Examples: 

1792 >>> # Create Azure OpenAI provider 

1793 >>> config = LLMConfig( 

1794 ... provider="azure_openai", 

1795 ... config=AzureOpenAIConfig( 

1796 ... api_key="key", 

1797 ... azure_endpoint="https://example.com/", 

1798 ... azure_deployment="gpt-4" 

1799 ... ) 

1800 ... ) 

1801 >>> provider = LLMProviderFactory.create(config) 

1802 >>> isinstance(provider, AzureOpenAIProvider) 

1803 True 

1804 

1805 >>> # Create OpenAI provider 

1806 >>> config = LLMConfig( 

1807 ... provider="openai", 

1808 ... config=OpenAIConfig( 

1809 ... api_key="sk-...", 

1810 ... model="gpt-4" 

1811 ... ) 

1812 ... ) 

1813 >>> provider = LLMProviderFactory.create(config) 

1814 >>> isinstance(provider, OpenAIProvider) 

1815 True 

1816 

1817 >>> # Create Ollama provider 

1818 >>> config = LLMConfig( 

1819 ... provider="ollama", 

1820 ... config=OllamaConfig(model="llama2") 

1821 ... ) 

1822 >>> provider = LLMProviderFactory.create(config) 

1823 >>> isinstance(provider, OllamaProvider) 

1824 True 

1825 """ 

1826 provider_map = { 

1827 "azure_openai": AzureOpenAIProvider, 

1828 "openai": OpenAIProvider, 

1829 "anthropic": AnthropicProvider, 

1830 "aws_bedrock": AWSBedrockProvider, 

1831 "ollama": OllamaProvider, 

1832 "watsonx": WatsonxProvider, 

1833 "gateway": GatewayProvider, 

1834 } 

1835 

1836 provider_class = provider_map.get(llm_config.provider) 

1837 

1838 if not provider_class: 

1839 raise ValueError(f"Unsupported LLM provider: {llm_config.provider}. Supported providers: {list(provider_map.keys())}") 

1840 

1841 logger.info(f"Creating LLM provider: {llm_config.provider}") 

1842 return provider_class(llm_config.config) 

1843 

1844 

1845# ==================== CHAT HISTORY MANAGER ==================== 

1846 

1847 

1848class ChatHistoryManager: 

1849 """ 

1850 Centralized chat history management with Redis and in-memory fallback. 

1851 

1852 Provides a unified interface for storing and retrieving chat histories across 

1853 multiple workers using Redis, with automatic fallback to in-memory storage 

1854 when Redis is not available. 

1855 

1856 This class eliminates duplication between router and service layers by 

1857 providing a single source of truth for all chat history operations. 

1858 

1859 Attributes: 

1860 redis_client: Optional Redis async client for distributed storage. 

1861 max_messages: Maximum number of messages to retain per user. 

1862 ttl: Time-to-live for Redis entries in seconds. 

1863 _memory_store: In-memory dict fallback when Redis unavailable. 

1864 

1865 Examples: 

1866 >>> import asyncio 

1867 >>> # Create manager without Redis (in-memory mode) 

1868 >>> manager = ChatHistoryManager(redis_client=None, max_messages=50) 

1869 >>> # asyncio.run(manager.save_history("user123", [{"role": "user", "content": "Hello"}])) 

1870 >>> # history = asyncio.run(manager.get_history("user123")) 

1871 >>> # len(history) >= 0 

1872 True 

1873 

1874 Note: 

1875 Thread-safe for Redis operations. In-memory mode suitable for 

1876 single-worker deployments only. 

1877 """ 

1878 

1879 def __init__(self, redis_client: Optional[Any] = None, max_messages: int = 50, ttl: int = 3600): 

1880 """ 

1881 Initialize chat history manager. 

1882 

1883 Args: 

1884 redis_client: Optional Redis async client. If None, uses in-memory storage. 

1885 max_messages: Maximum messages to retain per user (default: 50). 

1886 ttl: Time-to-live for Redis entries in seconds (default: 3600). 

1887 

1888 Examples: 

1889 >>> manager = ChatHistoryManager(redis_client=None, max_messages=100) 

1890 >>> manager.max_messages 

1891 100 

1892 >>> manager.ttl 

1893 3600 

1894 """ 

1895 self.redis_client = redis_client 

1896 self.max_messages = max_messages 

1897 self.ttl = ttl 

1898 self._memory_store: Dict[str, List[Dict[str, str]]] = {} 

1899 

1900 if redis_client: 

1901 logger.info("ChatHistoryManager initialized with Redis backend") 

1902 else: 

1903 logger.info("ChatHistoryManager initialized with in-memory backend") 

1904 

1905 def _history_key(self, user_id: str) -> str: 

1906 """ 

1907 Generate Redis key for user's chat history. 

1908 

1909 Args: 

1910 user_id: User identifier. 

1911 

1912 Returns: 

1913 str: Redis key string. 

1914 

1915 Examples: 

1916 >>> manager = ChatHistoryManager() 

1917 >>> manager._history_key("user123") 

1918 'chat_history:user123' 

1919 """ 

1920 return f"chat_history:{user_id}" 

1921 

1922 async def get_history(self, user_id: str) -> List[Dict[str, str]]: 

1923 """ 

1924 Retrieve chat history for a user. 

1925 

1926 Fetches history from Redis if available, otherwise from in-memory store. 

1927 

1928 Args: 

1929 user_id: User identifier. 

1930 

1931 Returns: 

1932 List[Dict[str, str]]: List of message dictionaries with 'role' and 'content' keys. 

1933 Returns empty list if no history exists. 

1934 

1935 Examples: 

1936 >>> import asyncio 

1937 >>> manager = ChatHistoryManager() 

1938 >>> # history = asyncio.run(manager.get_history("user123")) 

1939 >>> # isinstance(history, list) 

1940 True 

1941 

1942 Note: 

1943 Automatically handles JSON deserialization errors by returning empty list. 

1944 """ 

1945 if self.redis_client: 

1946 try: 

1947 data = await self.redis_client.get(self._history_key(user_id)) 

1948 if not data: 

1949 return [] 

1950 return orjson.loads(data) 

1951 except orjson.JSONDecodeError: 

1952 logger.warning(f"Failed to decode chat history for user {SecurityValidator.sanitize_log_message(user_id)}") 

1953 return [] 

1954 except Exception as e: 

1955 logger.error(f"Error retrieving chat history from Redis for user {SecurityValidator.sanitize_log_message(user_id)}: {e}") 

1956 return [] 

1957 else: 

1958 return self._memory_store.get(user_id, []) 

1959 

1960 async def save_history(self, user_id: str, history: List[Dict[str, str]]) -> None: 

1961 """ 

1962 Save chat history for a user. 

1963 

1964 Stores history in Redis (with TTL) if available, otherwise in memory. 

1965 Automatically trims history to max_messages before saving. 

1966 

1967 Args: 

1968 user_id: User identifier. 

1969 history: List of message dictionaries to save. 

1970 

1971 Examples: 

1972 >>> import asyncio 

1973 >>> manager = ChatHistoryManager(max_messages=50) 

1974 >>> messages = [{"role": "user", "content": "Hello"}] 

1975 >>> # asyncio.run(manager.save_history("user123", messages)) 

1976 

1977 Note: 

1978 History is automatically trimmed to max_messages limit before storage. 

1979 """ 

1980 # Trim history before saving 

1981 trimmed = self._trim_messages(history) 

1982 

1983 if self.redis_client: 

1984 try: 

1985 await self.redis_client.set(self._history_key(user_id), orjson.dumps(trimmed), ex=self.ttl) 

1986 except Exception as e: 

1987 logger.error(f"Error saving chat history to Redis for user {SecurityValidator.sanitize_log_message(user_id)}: {e}") 

1988 else: 

1989 self._memory_store[user_id] = trimmed 

1990 

1991 async def append_message(self, user_id: str, role: str, content: str) -> None: 

1992 """ 

1993 Append a single message to user's chat history. 

1994 

1995 Convenience method that fetches current history, appends the message, 

1996 trims if needed, and saves back. 

1997 

1998 Args: 

1999 user_id: User identifier. 

2000 role: Message role ('user' or 'assistant'). 

2001 content: Message content text. 

2002 

2003 Examples: 

2004 >>> import asyncio 

2005 >>> manager = ChatHistoryManager() 

2006 >>> # asyncio.run(manager.append_message("user123", "user", "Hello!")) 

2007 

2008 Note: 

2009 This method performs a read-modify-write operation which may 

2010 not be atomic in distributed environments. 

2011 """ 

2012 history = await self.get_history(user_id) 

2013 history.append({"role": role, "content": content}) 

2014 await self.save_history(user_id, history) 

2015 

2016 async def clear_history(self, user_id: str) -> None: 

2017 """ 

2018 Clear all chat history for a user. 

2019 

2020 Deletes history from Redis or memory store. 

2021 

2022 Args: 

2023 user_id: User identifier. 

2024 

2025 Examples: 

2026 >>> import asyncio 

2027 >>> manager = ChatHistoryManager() 

2028 >>> # asyncio.run(manager.clear_history("user123")) 

2029 

2030 Note: 

2031 This operation cannot be undone. 

2032 """ 

2033 if self.redis_client: 

2034 try: 

2035 await self.redis_client.delete(self._history_key(user_id)) 

2036 except Exception as e: 

2037 logger.error(f"Error clearing chat history from Redis for user {SecurityValidator.sanitize_log_message(user_id)}: {e}") 

2038 else: 

2039 self._memory_store.pop(user_id, None) 

2040 

2041 def _trim_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 

2042 """ 

2043 Trim message list to max_messages limit. 

2044 

2045 Keeps the most recent messages up to max_messages count. 

2046 

2047 Args: 

2048 messages: List of message dictionaries. 

2049 

2050 Returns: 

2051 List[Dict[str, str]]: Trimmed message list. 

2052 

2053 Examples: 

2054 >>> manager = ChatHistoryManager(max_messages=2) 

2055 >>> messages = [ 

2056 ... {"role": "user", "content": "1"}, 

2057 ... {"role": "assistant", "content": "2"}, 

2058 ... {"role": "user", "content": "3"} 

2059 ... ] 

2060 >>> trimmed = manager._trim_messages(messages) 

2061 >>> len(trimmed) 

2062 2 

2063 >>> trimmed[0]["content"] 

2064 '2' 

2065 """ 

2066 if len(messages) > self.max_messages: 

2067 return messages[-self.max_messages :] 

2068 return messages 

2069 

2070 async def get_langchain_messages(self, user_id: str) -> List[BaseMessage]: 

2071 """ 

2072 Get chat history as LangChain message objects. 

2073 

2074 Converts stored history dictionaries to LangChain HumanMessage and 

2075 AIMessage objects for use with LangChain agents. 

2076 

2077 Args: 

2078 user_id: User identifier. 

2079 

2080 Returns: 

2081 List[BaseMessage]: List of LangChain message objects. 

2082 

2083 Examples: 

2084 >>> import asyncio 

2085 >>> manager = ChatHistoryManager() 

2086 >>> # messages = asyncio.run(manager.get_langchain_messages("user123")) 

2087 >>> # isinstance(messages, list) 

2088 True 

2089 

2090 Note: 

2091 Returns empty list if LangChain is not available or history is empty. 

2092 """ 

2093 if not _LLMCHAT_AVAILABLE: 

2094 return [] 

2095 

2096 history = await self.get_history(user_id) 

2097 lc_messages = [] 

2098 

2099 for msg in history: 

2100 role = msg.get("role") 

2101 content = msg.get("content", "") 

2102 

2103 if role == "user": 

2104 lc_messages.append(HumanMessage(content=content)) 

2105 elif role == "assistant": 

2106 lc_messages.append(AIMessage(content=content)) 

2107 

2108 return lc_messages 

2109 

2110 

2111# ==================== MCP CLIENT ==================== 

2112 

2113 

2114class MCPClient: 

2115 """ 

2116 Manages MCP server connections and tool loading. 

2117 

2118 Provides a high-level interface for connecting to MCP servers, retrieving 

2119 available tools, and managing connection health. Supports multiple transport 

2120 protocols including HTTP, SSE, and stdio. 

2121 

2122 Attributes: 

2123 config: MCP server configuration. 

2124 

2125 Examples: 

2126 >>> import asyncio 

2127 >>> config = MCPServerConfig( 

2128 ... url="https://mcp-server.example.com/mcp", 

2129 ... transport="streamable_http" 

2130 ... ) 

2131 >>> client = MCPClient(config) 

2132 >>> client.is_connected 

2133 False 

2134 >>> # asyncio.run(client.connect()) 

2135 >>> # tools = asyncio.run(client.get_tools()) 

2136 

2137 Note: 

2138 All methods are async and should be called using asyncio or within 

2139 an async context. 

2140 """ 

2141 

2142 def __init__(self, config: MCPServerConfig): 

2143 """ 

2144 Initialize MCP client. 

2145 

2146 Args: 

2147 config: MCP server configuration with connection parameters. 

2148 

2149 Examples: 

2150 >>> config = MCPServerConfig( 

2151 ... url="https://example.com/mcp", 

2152 ... transport="streamable_http" 

2153 ... ) 

2154 >>> client = MCPClient(config) 

2155 >>> client.config.transport 

2156 'streamable_http' 

2157 """ 

2158 self.config = config 

2159 self._client: Optional[MultiServerMCPClient] = None 

2160 self._tools: Optional[List[BaseTool]] = None 

2161 self._connected = False 

2162 logger.info(f"MCP client initialized with transport: {config.transport}") 

2163 

2164 async def connect(self) -> None: 

2165 """ 

2166 Connect to the MCP server. 

2167 

2168 Establishes connection to the configured MCP server using the specified 

2169 transport protocol. Subsequent calls are no-ops if already connected. 

2170 

2171 Raises: 

2172 ConnectionError: If connection to MCP server fails. 

2173 

2174 Examples: 

2175 >>> import asyncio 

2176 >>> config = MCPServerConfig( 

2177 ... url="https://example.com/mcp", 

2178 ... transport="streamable_http" 

2179 ... ) 

2180 >>> client = MCPClient(config) 

2181 >>> # asyncio.run(client.connect()) 

2182 >>> # client.is_connected -> True 

2183 

2184 Note: 

2185 Connection is idempotent - calling multiple times is safe. 

2186 """ 

2187 if self._connected: 

2188 logger.warning("MCP client already connected") 

2189 return 

2190 

2191 try: 

2192 logger.info(f"Connecting to MCP server via {self.config.transport}...") 

2193 

2194 # Build server configuration for MultiServerMCPClient 

2195 server_config = { 

2196 "transport": self.config.transport, 

2197 } 

2198 

2199 if self.config.transport in ["streamable_http", "sse"]: 

2200 server_config["url"] = self.config.url 

2201 if self.config.headers: 

2202 server_config["headers"] = self.config.headers 

2203 elif self.config.transport == "stdio": 

2204 server_config["command"] = self.config.command 

2205 if self.config.args: 

2206 server_config["args"] = self.config.args 

2207 

2208 if not MultiServerMCPClient: 

2209 logger.error("Some dependencies are missing. Install those with: pip install '.[llmchat]'") 

2210 

2211 # Create MultiServerMCPClient with single server 

2212 self._client = MultiServerMCPClient({"default": server_config}) 

2213 self._connected = True 

2214 logger.info("Successfully connected to MCP server") 

2215 

2216 except Exception as e: 

2217 logger.error(f"Failed to connect to MCP server: {e}") 

2218 self._connected = False 

2219 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e 

2220 

2221 async def disconnect(self) -> None: 

2222 """ 

2223 Disconnect from the MCP server. 

2224 

2225 Cleanly closes the connection and releases resources. Safe to call 

2226 even if not connected. 

2227 

2228 Raises: 

2229 Exception: If cleanup operations fail. 

2230 

2231 Examples: 

2232 >>> import asyncio 

2233 >>> config = MCPServerConfig( 

2234 ... url="https://example.com/mcp", 

2235 ... transport="streamable_http" 

2236 ... ) 

2237 >>> client = MCPClient(config) 

2238 >>> # asyncio.run(client.connect()) 

2239 >>> # asyncio.run(client.disconnect()) 

2240 >>> # client.is_connected -> False 

2241 

2242 Note: 

2243 Clears cached tools upon disconnection. 

2244 """ 

2245 if not self._connected: 

2246 logger.warning("MCP client not connected") 

2247 return 

2248 

2249 try: 

2250 if self._client: 

2251 # MultiServerMCPClient manages connections internally 

2252 self._client = None 

2253 

2254 self._connected = False 

2255 self._tools = None 

2256 logger.info("Disconnected from MCP server") 

2257 

2258 except Exception as e: 

2259 logger.error(f"Error during disconnect: {e}") 

2260 raise 

2261 

2262 async def get_tools(self, force_reload: bool = False) -> List[BaseTool]: 

2263 """ 

2264 Get tools from the MCP server. 

2265 

2266 Retrieves available tools from the connected MCP server. Results are 

2267 cached unless force_reload is True. 

2268 

2269 Args: 

2270 force_reload: Force reload tools even if cached (default: False). 

2271 

2272 Returns: 

2273 List[BaseTool]: List of available tools from the server. 

2274 

2275 Raises: 

2276 ConnectionError: If not connected to MCP server. 

2277 Exception: If tool loading fails. 

2278 

2279 Examples: 

2280 >>> import asyncio 

2281 >>> config = MCPServerConfig( 

2282 ... url="https://example.com/mcp", 

2283 ... transport="streamable_http" 

2284 ... ) 

2285 >>> client = MCPClient(config) 

2286 >>> # asyncio.run(client.connect()) 

2287 >>> # tools = asyncio.run(client.get_tools()) 

2288 >>> # len(tools) >= 0 -> True 

2289 

2290 Note: 

2291 Tools are cached after first successful load for performance. 

2292 """ 

2293 if not self._connected or not self._client: 

2294 raise ConnectionError("Not connected to MCP server. Call connect() first.") 

2295 

2296 if self._tools and not force_reload: 

2297 logger.debug(f"Returning {len(self._tools)} cached tools") 

2298 return self._tools 

2299 

2300 try: 

2301 logger.info("Loading tools from MCP server...") 

2302 self._tools = await self._client.get_tools() 

2303 logger.info(f"Successfully loaded {len(self._tools)} tools") 

2304 return self._tools 

2305 

2306 except Exception as e: 

2307 logger.error(f"Failed to load tools: {e}") 

2308 raise 

2309 

2310 @property 

2311 def is_connected(self) -> bool: 

2312 """ 

2313 Check if client is connected. 

2314 

2315 Returns: 

2316 bool: True if connected to MCP server, False otherwise. 

2317 

2318 Examples: 

2319 >>> config = MCPServerConfig( 

2320 ... url="https://example.com/mcp", 

2321 ... transport="streamable_http" 

2322 ... ) 

2323 >>> client = MCPClient(config) 

2324 >>> client.is_connected 

2325 False 

2326 """ 

2327 return self._connected 

2328 

2329 

2330# ==================== MCP CHAT SERVICE ==================== 

2331 

2332 

2333class MCPChatService: 

2334 """ 

2335 Main chat service for MCP client backend. 

2336 Orchestrates chat sessions with LLM and MCP server integration. 

2337 

2338 Provides a high-level interface for managing conversational AI sessions 

2339 that combine LLM capabilities with MCP server tools. Handles conversation 

2340 history management, tool execution, and streaming responses. 

2341 

2342 This service integrates: 

2343 - LLM providers (Azure OpenAI, OpenAI, Anthropic, AWS Bedrock, Ollama) 

2344 - MCP server tools 

2345 - Centralized chat history management (Redis or in-memory) 

2346 - Streaming and non-streaming response modes 

2347 

2348 Attributes: 

2349 config: Complete MCP client configuration. 

2350 user_id: Optional user identifier for history management. 

2351 

2352 Examples: 

2353 >>> import asyncio 

2354 >>> config = MCPClientConfig( 

2355 ... mcp_server=MCPServerConfig( 

2356 ... url="https://example.com/mcp", 

2357 ... transport="streamable_http" 

2358 ... ), 

2359 ... llm=LLMConfig( 

2360 ... provider="ollama", 

2361 ... config=OllamaConfig(model="llama2") 

2362 ... ) 

2363 ... ) 

2364 >>> service = MCPChatService(config) 

2365 >>> service.is_initialized 

2366 False 

2367 >>> # asyncio.run(service.initialize()) 

2368 

2369 Note: 

2370 Must call initialize() before using chat methods. 

2371 """ 

2372 

2373 def __init__(self, config: MCPClientConfig, user_id: Optional[str] = None, redis_client: Optional[Any] = None): 

2374 """ 

2375 Initialize MCP chat service. 

2376 

2377 Args: 

2378 config: Complete MCP client configuration. 

2379 user_id: Optional user identifier for chat history management. 

2380 redis_client: Optional Redis client for distributed history storage. 

2381 

2382 Examples: 

2383 >>> config = MCPClientConfig( 

2384 ... mcp_server=MCPServerConfig( 

2385 ... url="https://example.com/mcp", 

2386 ... transport="streamable_http" 

2387 ... ), 

2388 ... llm=LLMConfig( 

2389 ... provider="ollama", 

2390 ... config=OllamaConfig(model="llama2") 

2391 ... ) 

2392 ... ) 

2393 >>> service = MCPChatService(config, user_id="user123") 

2394 >>> service.user_id 

2395 'user123' 

2396 """ 

2397 self.config = config 

2398 self.user_id = user_id 

2399 self.mcp_client = MCPClient(config.mcp_server) 

2400 self.llm_provider = LLMProviderFactory.create(config.llm) 

2401 

2402 # Initialize centralized chat history manager 

2403 self.history_manager = ChatHistoryManager(redis_client=redis_client, max_messages=config.chat_history_max_messages, ttl=settings.llmchat_chat_history_ttl) 

2404 

2405 self._agent = None 

2406 self._initialized = False 

2407 self._tools: List[BaseTool] = [] 

2408 

2409 logger.info(f"MCPChatService initialized for user: {user_id or 'anonymous'}") 

2410 

2411 async def initialize(self) -> None: 

2412 """ 

2413 Initialize the chat service. 

2414 

2415 Connects to MCP server, loads tools, initializes LLM, and creates the 

2416 conversational agent. Must be called before using chat functionality. 

2417 

2418 Raises: 

2419 ImportError: If LLM chat dependencies are missing. 

2420 ConnectionError: If MCP server connection fails. 

2421 Exception: If initialization fails. 

2422 

2423 Examples: 

2424 >>> import asyncio 

2425 >>> config = MCPClientConfig( 

2426 ... mcp_server=MCPServerConfig( 

2427 ... url="https://example.com/mcp", 

2428 ... transport="streamable_http" 

2429 ... ), 

2430 ... llm=LLMConfig( 

2431 ... provider="ollama", 

2432 ... config=OllamaConfig(model="llama2") 

2433 ... ) 

2434 ... ) 

2435 >>> service = MCPChatService(config) 

2436 >>> # asyncio.run(service.initialize()) 

2437 >>> # service.is_initialized -> True 

2438 

2439 Note: 

2440 Automatically loads tools from MCP server and creates agent. 

2441 """ 

2442 if self._initialized: 

2443 logger.warning("Chat service already initialized") 

2444 return 

2445 

2446 if not _LLMCHAT_AVAILABLE: 

2447 raise ImportError("LLM chat dependencies are missing. Install them with: pip install '.[llmchat]'") 

2448 

2449 try: 

2450 logger.info("Initializing chat service...") 

2451 

2452 # Connect to MCP server and load tools 

2453 await self.mcp_client.connect() 

2454 self._tools = await self.mcp_client.get_tools() 

2455 

2456 # Create LLM instance 

2457 llm = self.llm_provider.get_llm() 

2458 

2459 # Create ReAct agent with tools 

2460 self._agent = create_react_agent(llm, self._tools) 

2461 

2462 self._initialized = True 

2463 logger.info(f"Chat service initialized successfully with {len(self._tools)} tools") 

2464 

2465 except Exception as e: 

2466 logger.error(f"Failed to initialize chat service: {e}") 

2467 self._initialized = False 

2468 raise 

2469 

2470 async def chat(self, message: str) -> str: 

2471 """ 

2472 Send a message and get a complete response. 

2473 

2474 Processes the user's message through the LLM with tool access, 

2475 manages conversation history, and returns the complete response. 

2476 

2477 Args: 

2478 message: User's message text. 

2479 

2480 Returns: 

2481 str: Complete AI response text. 

2482 

2483 Raises: 

2484 RuntimeError: If service not initialized. 

2485 ValueError: If message is empty. 

2486 Exception: If processing fails. 

2487 

2488 Examples: 

2489 >>> import asyncio 

2490 >>> # Assuming service is initialized 

2491 >>> # response = asyncio.run(service.chat("Hello!")) 

2492 >>> # isinstance(response, str) 

2493 True 

2494 

2495 Note: 

2496 Automatically saves conversation history after response. 

2497 """ 

2498 if not self._initialized or not self._agent: 

2499 raise RuntimeError("Chat service not initialized. Call initialize() first.") 

2500 

2501 if not message or not message.strip(): 

2502 raise ValueError("Message cannot be empty") 

2503 

2504 span_attributes = { 

2505 "langfuse.observation.type": "generation", 

2506 "gen_ai.system": _llm_system_name(self), 

2507 "gen_ai.request.model": self.llm_provider.get_model_name(), 

2508 } 

2509 if is_input_capture_enabled("llm.chat"): 

2510 span_attributes["langfuse.observation.input"] = serialize_trace_payload({"message": message}) 

2511 

2512 with create_span("llm.chat", span_attributes) as span: 

2513 try: 

2514 logger.debug("Processing chat message...") 

2515 

2516 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else [] 

2517 user_message = HumanMessage(content=message) 

2518 lc_messages.append(user_message) 

2519 

2520 response = await self._agent.ainvoke({"messages": lc_messages}) 

2521 ai_message = response["messages"][-1] 

2522 response_text = ai_message.content if hasattr(ai_message, "content") else str(ai_message) 

2523 

2524 if span: 

2525 _set_usage_attributes(span, ai_message) 

2526 if is_output_capture_enabled("llm.chat"): 

2527 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"response": response_text})) 

2528 

2529 if self.user_id: 

2530 await self.history_manager.append_message(self.user_id, "user", message) 

2531 await self.history_manager.append_message(self.user_id, "assistant", response_text) 

2532 

2533 logger.debug("Chat message processed successfully") 

2534 return response_text 

2535 

2536 except Exception as e: 

2537 logger.error(f"Error processing chat message: {e}") 

2538 raise 

2539 

2540 async def chat_with_metadata(self, message: str) -> Dict[str, Any]: 

2541 """ 

2542 Send a message and get response with metadata. 

2543 

2544 Similar to chat() but collects all events and returns detailed 

2545 information about tool usage and timing. 

2546 

2547 Args: 

2548 message: User's message text. 

2549 

2550 Returns: 

2551 Dict[str, Any]: Dictionary containing: 

2552 - text (str): Complete response text 

2553 - tool_used (bool): Whether any tools were invoked 

2554 - tools (List[str]): Names of tools that were used 

2555 - tool_invocations (List[dict]): Detailed tool invocation data 

2556 - elapsed_ms (int): Processing time in milliseconds 

2557 

2558 Raises: 

2559 RuntimeError: If service not initialized. 

2560 ValueError: If message is empty. 

2561 

2562 Examples: 

2563 >>> import asyncio 

2564 >>> # Assuming service is initialized 

2565 >>> # result = asyncio.run(service.chat_with_metadata("What's 2+2?")) 

2566 >>> # 'text' in result and 'elapsed_ms' in result 

2567 True 

2568 

2569 Note: 

2570 This method collects all events and returns them as a single response. 

2571 """ 

2572 text = "" 

2573 tool_invocations: list[dict[str, Any]] = [] 

2574 final: dict[str, Any] = {} 

2575 

2576 async for ev in self.chat_events(message): 

2577 t = ev.get("type") 

2578 if t == "token": 

2579 text += ev.get("content", "") 

2580 elif t in ("tool_start", "tool_end", "tool_error"): 

2581 tool_invocations.append(ev) 

2582 elif t == "final": 

2583 final = ev 

2584 

2585 return { 

2586 "text": text, 

2587 "tool_used": final.get("tool_used", False), 

2588 "tools": final.get("tools", []), 

2589 "tool_invocations": tool_invocations, 

2590 "elapsed_ms": final.get("elapsed_ms"), 

2591 } 

2592 

2593 async def chat_stream(self, message: str) -> AsyncGenerator[str, None]: 

2594 """ 

2595 Send a message and stream the response. 

2596 

2597 Yields response chunks as they're generated, enabling real-time display 

2598 of the AI's response. 

2599 

2600 Args: 

2601 message: User's message text. 

2602 

2603 Yields: 

2604 str: Chunks of AI response text. 

2605 

2606 Raises: 

2607 RuntimeError: If service not initialized. 

2608 Exception: If streaming fails. 

2609 

2610 Examples: 

2611 >>> import asyncio 

2612 >>> async def stream_example(): 

2613 ... # Assuming service is initialized 

2614 ... chunks = [] 

2615 ... async for chunk in service.chat_stream("Hello"): 

2616 ... chunks.append(chunk) 

2617 ... return ''.join(chunks) 

2618 >>> # full_response = asyncio.run(stream_example()) 

2619 

2620 Note: 

2621 Falls back to non-streaming if enable_streaming is False in config. 

2622 """ 

2623 if not self._initialized or not self._agent: 

2624 raise RuntimeError("Chat service not initialized. Call initialize() first.") 

2625 

2626 if not self.config.enable_streaming: 

2627 # Fall back to non-streaming 

2628 response = await self.chat(message) 

2629 yield response 

2630 return 

2631 

2632 try: 

2633 logger.debug("Processing streaming chat message...") 

2634 

2635 # Get conversation history 

2636 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else [] 

2637 

2638 # Add user message 

2639 user_message = HumanMessage(content=message) 

2640 lc_messages.append(user_message) 

2641 

2642 # Stream agent response 

2643 full_response = "" 

2644 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"): 

2645 kind = event["event"] 

2646 

2647 # Stream LLM tokens 

2648 if kind == "on_chat_model_stream": 

2649 chunk = event.get("data", {}).get("chunk") 

2650 if chunk and hasattr(chunk, "content"): 

2651 content = chunk.content 

2652 if content: 

2653 full_response += content 

2654 yield content 

2655 

2656 # Save history 

2657 if self.user_id and full_response: 

2658 await self.history_manager.append_message(self.user_id, "user", message) 

2659 await self.history_manager.append_message(self.user_id, "assistant", full_response) 

2660 

2661 logger.debug("Streaming chat message processed successfully") 

2662 

2663 except Exception as e: 

2664 logger.error(f"Error processing streaming chat message: {e}") 

2665 raise 

2666 

2667 async def chat_events(self, message: str) -> AsyncGenerator[Dict[str, Any], None]: 

2668 """ 

2669 Stream structured events during chat processing. 

2670 

2671 Provides granular visibility into the chat processing pipeline by yielding 

2672 structured events for tokens, tool invocations, errors, and final results. 

2673 

2674 Args: 

2675 message: User's message text. 

2676 

2677 Yields: 

2678 dict: Event dictionaries with type-specific fields: 

2679 - token: {"type": "token", "content": str} 

2680 - tool_start: {"type": "tool_start", "id": str, "name": str, 

2681 "input": Any, "start": str} 

2682 - tool_end: {"type": "tool_end", "id": str, "name": str, 

2683 "output": Any, "end": str} 

2684 - tool_error: {"type": "tool_error", "id": str, "error": str, 

2685 "time": str} 

2686 - final: {"type": "final", "content": str, "tool_used": bool, 

2687 "tools": List[str], "elapsed_ms": int} 

2688 

2689 Raises: 

2690 RuntimeError: If service not initialized. 

2691 ValueError: If message is empty or whitespace only. 

2692 ConnectionError: If the underlying MCP connection is lost. 

2693 TimeoutError: If the LLM request times out. 

2694 ChatProcessingError: If a tool, parsing, or model error occurs during streaming. 

2695 

2696 Examples: 

2697 >>> import asyncio 

2698 >>> async def event_example(): 

2699 ... # Assuming service is initialized 

2700 ... events = [] 

2701 ... async for event in service.chat_events("Hello"): 

2702 ... events.append(event['type']) 

2703 ... return events 

2704 >>> # event_types = asyncio.run(event_example()) 

2705 >>> # 'final' in event_types -> True 

2706 

2707 Note: 

2708 This is the most detailed chat method, suitable for building 

2709 interactive UIs or detailed logging systems. 

2710 """ 

2711 if not self._initialized or not self._agent: 

2712 raise RuntimeError("Chat service not initialized. Call initialize() first.") 

2713 

2714 # Validate message 

2715 if not message or not message.strip(): 

2716 raise ValueError("Message cannot be empty") 

2717 

2718 # Get conversation history 

2719 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else [] 

2720 

2721 # Append user message 

2722 user_message = HumanMessage(content=message) 

2723 lc_messages.append(user_message) 

2724 

2725 full_response = "" 

2726 start_ts = time.time() 

2727 tool_runs: dict[str, dict[str, Any]] = {} 

2728 # Buffer for out-of-order on_tool_end events (end arrives before start) 

2729 pending_tool_ends: dict[str, dict[str, Any]] = {} 

2730 pending_ttl_seconds = 30.0 # Max time to hold pending end events 

2731 pending_max_size = 100 # Max number of pending end events to buffer 

2732 # Track dropped run_ids for aggregated error (TTL-expired or buffer-full) 

2733 dropped_tool_ends: set[str] = set() 

2734 dropped_max_size = 200 # Max dropped IDs to track (prevents unbounded growth) 

2735 dropped_overflow_count = 0 # Count of drops that couldn't be tracked due to full buffer 

2736 

2737 def _extract_output(raw_output: Any) -> Any: 

2738 """Extract output value from various LangChain output formats. 

2739 

2740 Args: 

2741 raw_output: The raw output from a tool execution. 

2742 

2743 Returns: 

2744 The extracted output value in a serializable format. 

2745 """ 

2746 if hasattr(raw_output, "content"): 

2747 return raw_output.content 

2748 if hasattr(raw_output, "dict") and callable(raw_output.dict): 

2749 return raw_output.dict() 

2750 if not isinstance(raw_output, (str, int, float, bool, list, dict, type(None))): 

2751 return str(raw_output) 

2752 return raw_output 

2753 

2754 def _cleanup_expired_pending(current_ts: float) -> None: 

2755 """Remove expired entries from pending_tool_ends buffer and track them. 

2756 

2757 Args: 

2758 current_ts: Current timestamp in seconds since epoch. 

2759 """ 

2760 nonlocal dropped_overflow_count 

2761 expired = [rid for rid, data in pending_tool_ends.items() if current_ts - data.get("buffered_at", 0) > pending_ttl_seconds] 

2762 for rid in expired: 

2763 logger.warning(f"Pending on_tool_end for run_id {rid} expired after {pending_ttl_seconds}s (orphan event)") 

2764 if len(dropped_tool_ends) < dropped_max_size: 

2765 dropped_tool_ends.add(rid) 

2766 else: 

2767 dropped_overflow_count += 1 

2768 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track expired run_id {rid} (overflow count: {dropped_overflow_count})") 

2769 del pending_tool_ends[rid] 

2770 

2771 span_attributes = { 

2772 "langfuse.observation.type": "generation", 

2773 "gen_ai.system": _llm_system_name(self), 

2774 "gen_ai.request.model": self.llm_provider.get_model_name(), 

2775 "llm.stream": True, 

2776 } 

2777 if is_input_capture_enabled("llm.chat"): 

2778 span_attributes["langfuse.observation.input"] = serialize_trace_payload({"message": message}) 

2779 

2780 with create_span("llm.chat", span_attributes) as span: 

2781 try: 

2782 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"): 

2783 kind = event.get("event") 

2784 now_iso = datetime.now(timezone.utc).isoformat() 

2785 now_ts = time.time() 

2786 

2787 # Periodically cleanup expired pending ends 

2788 _cleanup_expired_pending(now_ts) 

2789 

2790 try: 

2791 if kind == "on_tool_start": 

2792 run_id = str(event.get("run_id") or uuid4()) 

2793 name = event.get("name") or event.get("data", {}).get("name") or event.get("data", {}).get("tool") 

2794 input_data = event.get("data", {}).get("input") 

2795 

2796 # Filter out common metadata keys injected by LangChain/LangGraph 

2797 if isinstance(input_data, dict): 

2798 input_data = {k: v for k, v in input_data.items() if k not in ["runtime", "config", "run_manager", "callbacks"]} 

2799 

2800 tool_runs[run_id] = {"name": name, "start": now_iso, "input": input_data} 

2801 

2802 # Register run for cancellation tracking with gateway-level Cancellation service 

2803 async def _noop_cancel_cb(reason: Optional[str]) -> None: 

2804 """ 

2805 No-op cancel callback used when a run is started. 

2806 

2807 Args: 

2808 reason: Optional textual reason for cancellation. 

2809 

2810 Returns: 

2811 None 

2812 """ 

2813 # Default no-op; kept for potential future intra-process cancellation 

2814 return None 

2815 

2816 # Register with cancellation service only if feature is enabled 

2817 if settings.mcpgateway_tool_cancellation_enabled: 

2818 try: 

2819 await cancellation_service.register_run(run_id, name=name, cancel_callback=_noop_cancel_cb) 

2820 except Exception: 

2821 logger.exception("Failed to register run %s with CancellationService", run_id) 

2822 

2823 yield {"type": "tool_start", "id": run_id, "tool": name, "input": input_data, "start": now_iso} 

2824 

2825 # NOTE: Do NOT clear from dropped_tool_ends here. If an end was dropped (TTL/buffer-full) 

2826 # before this start arrived, that end is permanently lost. Since tools only end once, 

2827 # we won't receive another end event, so this should still be reported as an orphan. 

2828 

2829 # Check if we have a buffered end event for this run_id (out-of-order reconciliation) 

2830 if run_id in pending_tool_ends: 

2831 buffered = pending_tool_ends.pop(run_id) 

2832 tool_runs[run_id]["end"] = buffered["end_time"] 

2833 tool_runs[run_id]["output"] = buffered["output"] 

2834 logger.info(f"Reconciled out-of-order on_tool_end for run_id {run_id}") 

2835 

2836 if tool_runs[run_id].get("output") == "": 

2837 error = "Tool execution failed: Please check if the tool is accessible" 

2838 yield {"type": "tool_error", "id": run_id, "tool": name, "error": error, "time": buffered["end_time"]} 

2839 

2840 yield {"type": "tool_end", "id": run_id, "tool": name, "output": tool_runs[run_id].get("output"), "end": buffered["end_time"]} 

2841 

2842 elif kind == "on_tool_end": 

2843 run_id = str(event.get("run_id") or uuid4()) 

2844 output = event.get("data", {}).get("output") 

2845 extracted_output = _extract_output(output) 

2846 

2847 if run_id in tool_runs: 

2848 # Normal case: start already received 

2849 tool_runs[run_id]["end"] = now_iso 

2850 tool_runs[run_id]["output"] = extracted_output 

2851 

2852 if tool_runs[run_id].get("output") == "": 

2853 error = "Tool execution failed: Please check if the tool is accessible" 

2854 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso} 

2855 

2856 yield {"type": "tool_end", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "output": tool_runs[run_id].get("output"), "end": now_iso} 

2857 else: 

2858 # Out-of-order: buffer the end event for later reconciliation 

2859 if len(pending_tool_ends) < pending_max_size: 

2860 pending_tool_ends[run_id] = {"output": extracted_output, "end_time": now_iso, "buffered_at": now_ts} 

2861 logger.debug(f"Buffered out-of-order on_tool_end for run_id {run_id}, awaiting on_tool_start") 

2862 else: 

2863 logger.warning(f"Pending tool ends buffer full ({pending_max_size}), dropping on_tool_end for run_id {run_id}") 

2864 if len(dropped_tool_ends) < dropped_max_size: 

2865 dropped_tool_ends.add(run_id) 

2866 else: 

2867 dropped_overflow_count += 1 

2868 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track run_id {run_id} (overflow count: {dropped_overflow_count})") 

2869 

2870 # Unregister run from cancellation service when finished (only if feature is enabled) 

2871 if settings.mcpgateway_tool_cancellation_enabled: 

2872 try: 

2873 await cancellation_service.unregister_run(run_id) 

2874 except Exception: 

2875 logger.exception("Failed to unregister run %s", run_id) 

2876 

2877 elif kind == "on_tool_error": 

2878 run_id = str(event.get("run_id") or uuid4()) 

2879 error = str(event.get("data", {}).get("error", "Unknown error")) 

2880 

2881 # Clear any buffered end for this run to avoid emitting both error and end 

2882 if run_id in pending_tool_ends: 

2883 del pending_tool_ends[run_id] 

2884 logger.debug(f"Cleared buffered on_tool_end for run_id {run_id} due to tool error") 

2885 

2886 # Clear from dropped set if this run was previously dropped (prevents false orphan) 

2887 dropped_tool_ends.discard(run_id) 

2888 

2889 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso} 

2890 

2891 # Unregister run on error (only if feature is enabled) 

2892 if settings.mcpgateway_tool_cancellation_enabled: 

2893 try: 

2894 await cancellation_service.unregister_run(run_id) 

2895 except Exception: 

2896 logger.exception("Failed to unregister run %s after error", run_id) 

2897 

2898 elif kind == "on_chat_model_stream": 

2899 chunk = event.get("data", {}).get("chunk") 

2900 if chunk and hasattr(chunk, "content"): 

2901 content = chunk.content 

2902 if content: 

2903 full_response += content 

2904 yield {"type": "token", "content": content} 

2905 

2906 except Exception as event_error: 

2907 logger.warning(f"Error processing event {kind}: {event_error}") 

2908 continue 

2909 

2910 all_orphan_ids = sorted(set(pending_tool_ends.keys()) | dropped_tool_ends) 

2911 if all_orphan_ids or dropped_overflow_count > 0: 

2912 buffered_count = len(pending_tool_ends) 

2913 dropped_count = len(dropped_tool_ends) 

2914 total_unique = len(all_orphan_ids) 

2915 total_affected = total_unique + dropped_overflow_count 

2916 logger.warning( 

2917 f"Stream completed with {total_affected} orphan tool end(s): {buffered_count} buffered, {dropped_count} dropped (tracked), {dropped_overflow_count} dropped (untracked overflow)" 

2918 ) 

2919 if all_orphan_ids: 

2920 logger.debug(f"Full orphan run_id list: {', '.join(all_orphan_ids)}") 

2921 now_iso = datetime.now(timezone.utc).isoformat() 

2922 error_parts = [] 

2923 if buffered_count > 0: 

2924 error_parts.append(f"{buffered_count} buffered") 

2925 if dropped_count > 0: 

2926 error_parts.append(f"{dropped_count} dropped (TTL expired or buffer full)") 

2927 if dropped_overflow_count > 0: 

2928 error_parts.append(f"{dropped_overflow_count} additional dropped (tracking overflow)") 

2929 error_msg = f"Tool execution incomplete: {total_affected} tool end(s) received without matching start ({', '.join(error_parts)})" 

2930 if all_orphan_ids: 

2931 max_display_ids = 10 

2932 display_ids = all_orphan_ids[:max_display_ids] 

2933 remaining = total_unique - len(display_ids) 

2934 if remaining > 0: 

2935 error_msg += f". Run IDs (first {max_display_ids} of {total_unique}): {', '.join(display_ids)} (+{remaining} more)" 

2936 else: 

2937 error_msg += f". Run IDs: {', '.join(display_ids)}" 

2938 yield { 

2939 "type": "tool_error", 

2940 "id": str(uuid4()), 

2941 "tool": None, 

2942 "error": error_msg, 

2943 "time": now_iso, 

2944 } 

2945 pending_tool_ends.clear() 

2946 dropped_tool_ends.clear() 

2947 

2948 elapsed_ms = int((time.time() - start_ts) * 1000) 

2949 

2950 tools_used = list({tr["name"] for tr in tool_runs.values() if tr.get("name")}) 

2951 

2952 if span and is_output_capture_enabled("llm.chat"): 

2953 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"response": full_response})) 

2954 

2955 yield {"type": "final", "content": full_response, "tool_used": len(tools_used) > 0, "tools": tools_used, "elapsed_ms": elapsed_ms} 

2956 

2957 if self.user_id and full_response: 

2958 await self.history_manager.append_message(self.user_id, "user", message) 

2959 await self.history_manager.append_message(self.user_id, "assistant", full_response) 

2960 

2961 except (ConnectionError, TimeoutError) as e: 

2962 logger.error(f"Error in chat_events: {e}") 

2963 raise 

2964 except Exception as e: 

2965 logger.error(f"Error in chat_events: {e}") 

2966 raise ChatProcessingError(f"Chat processing error: {e}") from e 

2967 

2968 async def get_conversation_history(self) -> List[Dict[str, str]]: 

2969 """ 

2970 Get conversation history for the current user. 

2971 

2972 Returns: 

2973 List[Dict[str, str]]: Conversation messages with keys: 

2974 - role (str): "user" or "assistant" 

2975 - content (str): Message text 

2976 

2977 Examples: 

2978 >>> import asyncio 

2979 >>> # Assuming service is initialized with user_id 

2980 >>> # history = asyncio.run(service.get_conversation_history()) 

2981 >>> # all('role' in msg and 'content' in msg for msg in history) 

2982 True 

2983 

2984 Note: 

2985 Returns empty list if no user_id set or no history exists. 

2986 """ 

2987 if not self.user_id: 

2988 return [] 

2989 

2990 return await self.history_manager.get_history(self.user_id) 

2991 

2992 async def clear_history(self) -> None: 

2993 """ 

2994 Clear conversation history for the current user. 

2995 

2996 Removes all messages from the conversation history. Useful for starting 

2997 fresh conversations or managing memory usage. 

2998 

2999 Examples: 

3000 >>> import asyncio 

3001 >>> # Assuming service is initialized with user_id 

3002 >>> # asyncio.run(service.clear_history()) 

3003 >>> # history = asyncio.run(service.get_conversation_history()) 

3004 >>> # len(history) -> 0 

3005 

3006 Note: 

3007 This action cannot be undone. No-op if no user_id set. 

3008 """ 

3009 if not self.user_id: 

3010 return 

3011 

3012 await self.history_manager.clear_history(self.user_id) 

3013 logger.info(f"Conversation history cleared for user {self.user_id}") 

3014 

3015 async def shutdown(self) -> None: 

3016 """ 

3017 Shutdown the chat service and cleanup resources. 

3018 

3019 Performs graceful shutdown by disconnecting from MCP server, clearing 

3020 agent and history, and resetting initialization state. 

3021 

3022 Raises: 

3023 Exception: If cleanup operations fail. 

3024 

3025 Examples: 

3026 >>> import asyncio 

3027 >>> config = MCPClientConfig( 

3028 ... mcp_server=MCPServerConfig( 

3029 ... url="https://example.com/mcp", 

3030 ... transport="streamable_http" 

3031 ... ), 

3032 ... llm=LLMConfig( 

3033 ... provider="ollama", 

3034 ... config=OllamaConfig(model="llama2") 

3035 ... ) 

3036 ... ) 

3037 >>> service = MCPChatService(config) 

3038 >>> # asyncio.run(service.initialize()) 

3039 >>> # asyncio.run(service.shutdown()) 

3040 >>> # service.is_initialized -> False 

3041 

3042 Note: 

3043 Should be called when service is no longer needed to properly 

3044 release resources and connections. 

3045 """ 

3046 logger.info("Shutting down chat service...") 

3047 

3048 try: 

3049 # Disconnect from MCP server 

3050 if self.mcp_client.is_connected: 

3051 await self.mcp_client.disconnect() 

3052 

3053 # Clear state 

3054 self._agent = None 

3055 self._initialized = False 

3056 self._tools = [] 

3057 

3058 logger.info("Chat service shutdown complete") 

3059 

3060 except Exception as e: 

3061 logger.error(f"Error during shutdown: {e}") 

3062 raise 

3063 

3064 @property 

3065 def is_initialized(self) -> bool: 

3066 """ 

3067 Check if service is initialized. 

3068 

3069 Returns: 

3070 bool: True if service is initialized and ready, False otherwise. 

3071 

3072 Examples: 

3073 >>> config = MCPClientConfig( 

3074 ... mcp_server=MCPServerConfig( 

3075 ... url="https://example.com/mcp", 

3076 ... transport="streamable_http" 

3077 ... ), 

3078 ... llm=LLMConfig( 

3079 ... provider="ollama", 

3080 ... config=OllamaConfig(model="llama2") 

3081 ... ) 

3082 ... ) 

3083 >>> service = MCPChatService(config) 

3084 >>> service.is_initialized 

3085 False 

3086 

3087 Note: 

3088 Service must be initialized before calling chat methods. 

3089 """ 

3090 return self._initialized 

3091 

3092 async def reload_tools(self) -> int: 

3093 """ 

3094 Reload tools from MCP server. 

3095 

3096 Forces a reload of tools from the MCP server and recreates the agent 

3097 with the updated tool set. Useful when MCP server tools have changed. 

3098 

3099 Returns: 

3100 int: Number of tools successfully loaded. 

3101 

3102 Raises: 

3103 RuntimeError: If service not initialized. 

3104 ImportError: If LLM chat dependencies are missing. 

3105 Exception: If tool reloading or agent recreation fails. 

3106 

3107 Examples: 

3108 >>> import asyncio 

3109 >>> # Assuming service is initialized 

3110 >>> # tool_count = asyncio.run(service.reload_tools()) 

3111 >>> # tool_count >= 0 -> True 

3112 

3113 Note: 

3114 This operation recreates the agent, so it may briefly interrupt 

3115 ongoing conversations. Conversation history is preserved. 

3116 """ 

3117 if not self._initialized: 

3118 raise RuntimeError("Chat service not initialized") 

3119 

3120 if not _LLMCHAT_AVAILABLE: 

3121 raise ImportError("LLM chat dependencies are missing. Install them with: pip install '.[llmchat]'") 

3122 

3123 try: 

3124 logger.info("Reloading tools from MCP server...") 

3125 tools = await self.mcp_client.get_tools(force_reload=True) 

3126 

3127 # Recreate agent with new tools 

3128 llm = self.llm_provider.get_llm() 

3129 self._agent = create_react_agent(llm, tools) 

3130 self._tools = tools 

3131 

3132 logger.info(f"Reloaded {len(tools)} tools successfully") 

3133 return len(tools) 

3134 

3135 except Exception as e: 

3136 logger.error(f"Failed to reload tools: {e}") 

3137 raise