Coverage for mcpgateway / services / mcp_client_chat_service.py: 99%

845 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/mcp_client_chat_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Keval Mahajan 

6 

7MCP Client Service Module. 

8 

9This module provides a comprehensive client implementation for interacting with 

10MCP servers, managing LLM providers, and orchestrating conversational AI agents. 

11It supports multiple transport protocols and LLM providers. 

12 

13The module consists of several key components: 

14- Configuration classes for MCP servers and LLM providers 

15- LLM provider factory and implementations 

16- MCP client for tool management 

17- Chat history manager for Redis and in-memory storage 

18- Chat service for conversational interactions 

19""" 

20 

21# Standard 

22from datetime import datetime, timezone 

23import time 

24from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union 

25from uuid import uuid4 

26 

27# Third-Party 

28import orjson 

29 

30try: 

31 # Third-Party 

32 from langchain_core.language_models import BaseChatModel 

33 from langchain_core.messages import AIMessage, BaseMessage, HumanMessage 

34 from langchain_core.tools import BaseTool 

35 from langchain_mcp_adapters.client import MultiServerMCPClient 

36 from langchain_ollama import ChatOllama, OllamaLLM 

37 from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI, OpenAI 

38 from langgraph.prebuilt import create_react_agent 

39 

40 _LLMCHAT_AVAILABLE = True 

41except ImportError: 

42 # Optional dependencies for LLM chat feature not installed 

43 # These are only needed if LLMCHAT_ENABLED=true 

44 _LLMCHAT_AVAILABLE = False 

45 BaseChatModel = None # type: ignore 

46 AIMessage = None # type: ignore 

47 BaseMessage = None # type: ignore 

48 HumanMessage = None # type: ignore 

49 BaseTool = None # type: ignore 

50 MultiServerMCPClient = None # type: ignore 

51 ChatOllama = None # type: ignore 

52 OllamaLLM = None 

53 AzureChatOpenAI = None # type: ignore 

54 AzureOpenAI = None 

55 ChatOpenAI = None # type: ignore 

56 OpenAI = None 

57 create_react_agent = None # type: ignore 

58 

59# Try to import Anthropic and Bedrock providers (they may not be installed) 

60try: 

61 # Third-Party 

62 from langchain_anthropic import AnthropicLLM, ChatAnthropic 

63 

64 _ANTHROPIC_AVAILABLE = True 

65except ImportError: 

66 _ANTHROPIC_AVAILABLE = False 

67 ChatAnthropic = None # type: ignore 

68 AnthropicLLM = None 

69 

70try: 

71 # Third-Party 

72 from langchain_aws import BedrockLLM, ChatBedrock 

73 

74 _BEDROCK_AVAILABLE = True 

75except ImportError: 

76 _BEDROCK_AVAILABLE = False 

77 ChatBedrock = None # type: ignore 

78 BedrockLLM = None 

79 

80try: 

81 # Third-Party 

82 from langchain_ibm import ChatWatsonx, WatsonxLLM 

83 

84 _WATSONX_AVAILABLE = True 

85except ImportError: 

86 _WATSONX_AVAILABLE = False 

87 WatsonxLLM = None # type: ignore 

88 ChatWatsonx = None 

89 

90# Third-Party 

91from pydantic import BaseModel, Field, field_validator, model_validator 

92 

93# First-Party 

94from mcpgateway.config import settings 

95from mcpgateway.services.cancellation_service import cancellation_service 

96from mcpgateway.services.logging_service import LoggingService 

97 

98logging_service = LoggingService() 

99logger = logging_service.get_logger(__name__) 

100 

101 

102class MCPServerConfig(BaseModel): 

103 """ 

104 Configuration for MCP server connection. 

105 

106 This class defines the configuration parameters required to connect to an 

107 MCP (Model Context Protocol) server using various transport mechanisms. 

108 

109 Attributes: 

110 url: MCP server URL for streamable_http/sse transports. 

111 command: Command to run for stdio transport. 

112 args: Command-line arguments for stdio command. 

113 transport: Transport type (streamable_http, sse, or stdio). 

114 auth_token: Authentication token for HTTP-based transports. 

115 headers: Additional HTTP headers for request customization. 

116 

117 Examples: 

118 >>> # HTTP-based transport 

119 >>> config = MCPServerConfig( 

120 ... url="https://mcp-server.example.com/mcp", 

121 ... transport="streamable_http", 

122 ... auth_token="secret-token" 

123 ... ) 

124 >>> config.transport 

125 'streamable_http' 

126 

127 >>> # Stdio transport (requires explicit feature flag) 

128 >>> settings.mcpgateway_stdio_transport_enabled = True 

129 >>> config = MCPServerConfig( 

130 ... command="python", 

131 ... args=["server.py"], 

132 ... transport="stdio" 

133 ... ) 

134 >>> config.command 

135 'python' 

136 

137 Note: 

138 The auth_token is automatically added to headers as a Bearer token 

139 for HTTP-based transports. 

140 """ 

141 

142 url: Optional[str] = Field(None, description="MCP server URL for streamable_http/sse transports") 

143 command: Optional[str] = Field(None, description="Command to run for stdio transport") 

144 args: Optional[list[str]] = Field(None, description="Arguments for stdio command") 

145 transport: Literal["streamable_http", "sse", "stdio"] = Field(default="streamable_http", description="Transport type for MCP connection") 

146 auth_token: Optional[str] = Field(None, description="Authentication token for the server") 

147 headers: Optional[Dict[str, str]] = Field(default=None, description="Additional headers for HTTP-based transports") 

148 

149 @model_validator(mode="before") 

150 @classmethod 

151 def add_auth_to_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

152 """ 

153 Automatically add authentication token to headers if provided. 

154 

155 This validator ensures that if an auth_token is provided for HTTP-based 

156 transports, it's automatically added to the headers as a Bearer token. 

157 

158 Args: 

159 values: Dictionary of field values before validation. 

160 

161 Returns: 

162 Dict[str, Any]: Updated values with auth token in headers. 

163 

164 Examples: 

165 >>> values = { 

166 ... "url": "https://api.example.com", 

167 ... "transport": "streamable_http", 

168 ... "auth_token": "token123" 

169 ... } 

170 >>> result = MCPServerConfig.add_auth_to_headers(values) 

171 >>> result['headers']['Authorization'] 

172 'Bearer token123' 

173 """ 

174 auth_token = values.get("auth_token") 

175 transport = values.get("transport") 

176 headers = values.get("headers") or {} 

177 

178 if auth_token and transport in ["streamable_http", "sse"]: 

179 if "Authorization" not in headers: 

180 headers["Authorization"] = f"Bearer {auth_token}" 

181 values["headers"] = headers 

182 

183 return values 

184 

185 @model_validator(mode="after") 

186 def validate_transport_requirements(self): 

187 """Validate transport-specific requirements and feature flags. 

188 

189 Returns: 

190 MCPServerConfig: Validated config instance. 

191 

192 Raises: 

193 ValueError: If transport requirements or feature flags are violated. 

194 """ 

195 if self.transport in ["streamable_http", "sse"] and not self.url: 

196 raise ValueError(f"URL is required for {self.transport} transport") 

197 

198 if self.transport == "stdio": 

199 if not settings.mcpgateway_stdio_transport_enabled: 

200 raise ValueError("stdio transport is disabled by default; set MCPGATEWAY_STDIO_TRANSPORT_ENABLED=true to enable it") 

201 if not self.command: 

202 raise ValueError("Command is required for stdio transport") 

203 

204 return self 

205 

206 model_config = { 

207 "json_schema_extra": { 

208 "examples": [ 

209 {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token-here"}, # nosec B105 - example placeholder 

210 {"command": "python", "args": ["server.py"], "transport": "stdio"}, 

211 ] 

212 } 

213 } 

214 

215 

216class AzureOpenAIConfig(BaseModel): 

217 """ 

218 Configuration for Azure OpenAI provider. 

219 

220 Defines all necessary parameters to connect to and use Azure OpenAI services, 

221 including API credentials, endpoints, model settings, and request parameters. 

222 

223 Attributes: 

224 api_key: Azure OpenAI API authentication key. 

225 azure_endpoint: Azure OpenAI service endpoint URL. 

226 api_version: API version to use for requests. 

227 azure_deployment: Name of the deployed model. 

228 model: Model identifier for logging and tracing. 

229 temperature: Sampling temperature for response generation (0.0-2.0). 

230 max_tokens: Maximum number of tokens to generate. 

231 timeout: Request timeout duration in seconds. 

232 max_retries: Maximum number of retry attempts for failed requests. 

233 

234 Examples: 

235 >>> config = AzureOpenAIConfig( 

236 ... api_key="your-api-key", 

237 ... azure_endpoint="https://your-resource.openai.azure.com/", 

238 ... azure_deployment="gpt-4", 

239 ... temperature=0.7 

240 ... ) 

241 >>> config.model 

242 'gpt-4' 

243 >>> config.temperature 

244 0.7 

245 """ 

246 

247 api_key: str = Field(..., description="Azure OpenAI API key") 

248 azure_endpoint: str = Field(..., description="Azure OpenAI endpoint URL") 

249 api_version: str = Field(default="2024-05-01-preview", description="Azure OpenAI API version") 

250 azure_deployment: str = Field(..., description="Azure OpenAI deployment name") 

251 model: str = Field(default="gpt-4", description="Model name for tracing") 

252 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

253 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

254 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

255 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries") 

256 

257 model_config = { 

258 "json_schema_extra": { 

259 "example": { 

260 "api_key": "your-api-key", 

261 "azure_endpoint": "https://your-resource.openai.azure.com/", 

262 "api_version": "2024-05-01-preview", 

263 "azure_deployment": "gpt-4", 

264 "model": "gpt-4", 

265 "temperature": 0.7, 

266 } 

267 } 

268 } 

269 

270 

271class OllamaConfig(BaseModel): 

272 """ 

273 Configuration for Ollama provider. 

274 

275 Defines parameters for connecting to a local or remote Ollama instance 

276 for running open-source language models. 

277 

278 Attributes: 

279 base_url: Ollama server base URL. 

280 model: Name of the Ollama model to use. 

281 temperature: Sampling temperature for response generation (0.0-2.0). 

282 timeout: Request timeout duration in seconds. 

283 num_ctx: Context window size for the model. 

284 

285 Examples: 

286 >>> config = OllamaConfig( 

287 ... base_url="http://localhost:11434", 

288 ... model="llama2", 

289 ... temperature=0.5 

290 ... ) 

291 >>> config.model 

292 'llama2' 

293 >>> config.base_url 

294 'http://localhost:11434' 

295 """ 

296 

297 base_url: str = Field(default="http://localhost:11434", description="Ollama base URL") 

298 model: str = Field(default="llama2", description="Model name to use") 

299 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

300 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

301 num_ctx: Optional[int] = Field(None, gt=0, description="Context window size") 

302 

303 model_config = {"json_schema_extra": {"example": {"base_url": "http://localhost:11434", "model": "llama2", "temperature": 0.7}}} 

304 

305 

306class OpenAIConfig(BaseModel): 

307 """ 

308 Configuration for OpenAI provider (non-Azure). 

309 

310 Defines parameters for connecting to OpenAI API (or OpenAI-compatible endpoints). 

311 

312 Attributes: 

313 api_key: OpenAI API authentication key. 

314 base_url: Optional base URL for OpenAI-compatible endpoints. 

315 model: Model identifier (e.g., gpt-4, gpt-3.5-turbo). 

316 temperature: Sampling temperature for response generation (0.0-2.0). 

317 max_tokens: Maximum number of tokens to generate. 

318 timeout: Request timeout duration in seconds. 

319 max_retries: Maximum number of retry attempts for failed requests. 

320 

321 Examples: 

322 >>> config = OpenAIConfig( 

323 ... api_key="sk-...", 

324 ... model="gpt-4", 

325 ... temperature=0.7 

326 ... ) 

327 >>> config.model 

328 'gpt-4' 

329 """ 

330 

331 api_key: str = Field(..., description="OpenAI API key") 

332 base_url: Optional[str] = Field(None, description="Base URL for OpenAI-compatible endpoints") 

333 model: str = Field(default="gpt-4o-mini", description="Model name (e.g., gpt-4, gpt-3.5-turbo)") 

334 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

335 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

336 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

337 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries") 

338 default_headers: Optional[dict] = Field(None, description="optional default headers required by the provider") 

339 

340 model_config = { 

341 "json_schema_extra": { 

342 "example": { 

343 "api_key": "sk-...", 

344 "model": "gpt-4o-mini", 

345 "temperature": 0.7, 

346 } 

347 } 

348 } 

349 

350 

351class AnthropicConfig(BaseModel): 

352 """ 

353 Configuration for Anthropic Claude provider. 

354 

355 Defines parameters for connecting to Anthropic's Claude API. 

356 

357 Attributes: 

358 api_key: Anthropic API authentication key. 

359 model: Claude model identifier (e.g., claude-3-5-sonnet-20241022, claude-3-opus). 

360 temperature: Sampling temperature for response generation (0.0-1.0). 

361 max_tokens: Maximum number of tokens to generate. 

362 timeout: Request timeout duration in seconds. 

363 max_retries: Maximum number of retry attempts for failed requests. 

364 

365 Examples: 

366 >>> config = AnthropicConfig( 

367 ... api_key="sk-ant-...", 

368 ... model="claude-3-5-sonnet-20241022", 

369 ... temperature=0.7 

370 ... ) 

371 >>> config.model 

372 'claude-3-5-sonnet-20241022' 

373 """ 

374 

375 api_key: str = Field(..., description="Anthropic API key") 

376 model: str = Field(default="claude-3-5-sonnet-20241022", description="Claude model name") 

377 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature") 

378 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate") 

379 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

380 max_retries: int = Field(default=2, ge=0, description="Maximum number of retries") 

381 

382 model_config = { 

383 "json_schema_extra": { 

384 "example": { 

385 "api_key": "sk-ant-...", 

386 "model": "claude-3-5-sonnet-20241022", 

387 "temperature": 0.7, 

388 "max_tokens": 4096, 

389 } 

390 } 

391 } 

392 

393 

394class AWSBedrockConfig(BaseModel): 

395 """ 

396 Configuration for AWS Bedrock provider. 

397 

398 Defines parameters for connecting to AWS Bedrock LLM services. 

399 

400 Attributes: 

401 model_id: Bedrock model identifier (e.g., anthropic.claude-v2, amazon.titan-text-express-v1). 

402 region_name: AWS region name (e.g., us-east-1, us-west-2). 

403 aws_access_key_id: Optional AWS access key ID (uses default credential chain if not provided). 

404 aws_secret_access_key: Optional AWS secret access key. 

405 aws_session_token: Optional AWS session token for temporary credentials. 

406 temperature: Sampling temperature for response generation (0.0-1.0). 

407 max_tokens: Maximum number of tokens to generate. 

408 

409 Examples: 

410 >>> config = AWSBedrockConfig( 

411 ... model_id="anthropic.claude-v2", 

412 ... region_name="us-east-1", 

413 ... temperature=0.7 

414 ... ) 

415 >>> config.model_id 

416 'anthropic.claude-v2' 

417 """ 

418 

419 model_id: str = Field(..., description="Bedrock model ID") 

420 region_name: str = Field(default="us-east-1", description="AWS region name") 

421 aws_access_key_id: Optional[str] = Field(None, description="AWS access key ID") 

422 aws_secret_access_key: Optional[str] = Field(None, description="AWS secret access key") 

423 aws_session_token: Optional[str] = Field(None, description="AWS session token") 

424 temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Sampling temperature") 

425 max_tokens: int = Field(default=4096, gt=0, description="Maximum tokens to generate") 

426 

427 model_config = { 

428 "json_schema_extra": { 

429 "example": { 

430 "model_id": "anthropic.claude-v2", 

431 "region_name": "us-east-1", 

432 "temperature": 0.7, 

433 "max_tokens": 4096, 

434 } 

435 } 

436 } 

437 

438 

439class WatsonxConfig(BaseModel): 

440 """ 

441 Configuration for IBM watsonx.ai provider. 

442 

443 Defines parameters for connecting to IBM watsonx.ai services. 

444 

445 Attributes: 

446 api_key: IBM Cloud API key for authentication. 

447 url: IBM watsonx.ai service endpoint URL. 

448 project_id: IBM watsonx.ai project ID for context. 

449 model_id: Model identifier (e.g., ibm/granite-13b-chat-v2, meta-llama/llama-3-70b-instruct). 

450 temperature: Sampling temperature for response generation (0.0-2.0). 

451 max_new_tokens: Maximum number of tokens to generate. 

452 min_new_tokens: Minimum number of tokens to generate. 

453 decoding_method: Decoding method ('sample', 'greedy'). 

454 top_k: Top-K sampling parameter. 

455 top_p: Top-P (nucleus) sampling parameter. 

456 timeout: Request timeout duration in seconds. 

457 

458 Examples: 

459 >>> config = WatsonxConfig( 

460 ... api_key="your-api-key", 

461 ... url="https://us-south.ml.cloud.ibm.com", 

462 ... project_id="your-project-id", 

463 ... model_id="ibm/granite-13b-chat-v2" 

464 ... ) 

465 >>> config.model_id 

466 'ibm/granite-13b-chat-v2' 

467 """ 

468 

469 api_key: str = Field(..., description="IBM Cloud API key") 

470 url: str = Field(default="https://us-south.ml.cloud.ibm.com", description="watsonx.ai endpoint URL") 

471 project_id: str = Field(..., description="watsonx.ai project ID") 

472 model_id: str = Field(default="ibm/granite-13b-chat-v2", description="Model identifier") 

473 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

474 max_new_tokens: Optional[int] = Field(default=1024, gt=0, description="Maximum tokens to generate") 

475 min_new_tokens: Optional[int] = Field(default=1, gt=0, description="Minimum tokens to generate") 

476 decoding_method: str = Field(default="sample", description="Decoding method (sample or greedy)") 

477 top_k: Optional[int] = Field(default=50, gt=0, description="Top-K sampling") 

478 top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-P sampling") 

479 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

480 

481 model_config = { 

482 "json_schema_extra": { 

483 "example": { 

484 "api_key": "your-api-key", 

485 "url": "https://us-south.ml.cloud.ibm.com", 

486 "project_id": "your-project-id", 

487 "model_id": "ibm/granite-13b-chat-v2", 

488 "temperature": 0.7, 

489 "max_new_tokens": 1024, 

490 } 

491 } 

492 } 

493 

494 

495class GatewayConfig(BaseModel): 

496 """ 

497 Configuration for ContextForge internal LLM provider. 

498 

499 Allows LLM Chat to use models configured in the gateway's LLM Settings. 

500 The gateway routes requests to the appropriate configured provider. 

501 

502 Attributes: 

503 model: Model ID (gateway model ID or provider model ID). 

504 base_url: Gateway internal API URL (defaults to self). 

505 temperature: Sampling temperature for response generation. 

506 max_tokens: Maximum tokens to generate. 

507 timeout: Request timeout in seconds. 

508 

509 Examples: 

510 >>> config = GatewayConfig(model="gpt-4o") 

511 >>> config.model 

512 'gpt-4o' 

513 """ 

514 

515 model: str = Field(..., description="Gateway model ID to use") 

516 base_url: Optional[str] = Field(None, description="Gateway internal API URL (optional, defaults to self)") 

517 temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

518 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

519 timeout: Optional[float] = Field(None, gt=0, description="Request timeout in seconds") 

520 

521 model_config = { 

522 "json_schema_extra": { 

523 "example": { 

524 "model": "gpt-4o", 

525 "temperature": 0.7, 

526 "max_tokens": 4096, 

527 } 

528 } 

529 } 

530 

531 

532class LLMConfig(BaseModel): 

533 """ 

534 Configuration for LLM provider. 

535 

536 Unified configuration class that supports multiple LLM providers through 

537 a discriminated union pattern. 

538 

539 Attributes: 

540 provider: Type of LLM provider (azure_openai, openai, anthropic, aws_bedrock, or ollama). 

541 config: Provider-specific configuration object. 

542 

543 Examples: 

544 >>> # Azure OpenAI configuration 

545 >>> config = LLMConfig( 

546 ... provider="azure_openai", 

547 ... config=AzureOpenAIConfig( 

548 ... api_key="key", 

549 ... azure_endpoint="https://example.com/", 

550 ... azure_deployment="gpt-4" 

551 ... ) 

552 ... ) 

553 >>> config.provider 

554 'azure_openai' 

555 

556 >>> # OpenAI configuration 

557 >>> config = LLMConfig( 

558 ... provider="openai", 

559 ... config=OpenAIConfig( 

560 ... api_key="sk-...", 

561 ... model="gpt-4" 

562 ... ) 

563 ... ) 

564 >>> config.provider 

565 'openai' 

566 

567 >>> # Ollama configuration 

568 >>> config = LLMConfig( 

569 ... provider="ollama", 

570 ... config=OllamaConfig(model="llama2") 

571 ... ) 

572 >>> config.provider 

573 'ollama' 

574 

575 >>> # Watsonx configuration 

576 >>> config = LLMConfig( 

577 ... provider="watsonx", 

578 ... config=WatsonxConfig( 

579 ... url="https://us-south.ml.cloud.ibm.com", 

580 ... model_id="ibm/granite-13b-instruct-v2", 

581 ... project_id="YOUR_PROJECT_ID", 

582 ... api_key="YOUR_API") 

583 ... ) 

584 >>> config.provider 

585 'watsonx' 

586 """ 

587 

588 provider: Literal["azure_openai", "openai", "anthropic", "aws_bedrock", "ollama", "watsonx", "gateway"] = Field(..., description="LLM provider type") 

589 config: Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig] = Field(..., description="Provider-specific configuration") 

590 

591 @field_validator("config", mode="before") 

592 @classmethod 

593 def validate_config_type(cls, v: Any, info) -> Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig, WatsonxConfig, GatewayConfig]: 

594 """ 

595 Validate and convert config dictionary to appropriate provider type. 

596 

597 Args: 

598 v: Configuration value (dict or config object). 

599 info: Validation context containing provider information. 

600 

601 Returns: 

602 Union[AzureOpenAIConfig, OpenAIConfig, AnthropicConfig, AWSBedrockConfig, OllamaConfig]: Validated configuration object. 

603 

604 Examples: 

605 >>> # Automatically converts dict to appropriate config type 

606 >>> config_dict = { 

607 ... "api_key": "key", 

608 ... "azure_endpoint": "https://example.com/", 

609 ... "azure_deployment": "gpt-4" 

610 ... } 

611 >>> # Used internally by Pydantic during validation 

612 """ 

613 provider = info.data.get("provider") 

614 

615 if isinstance(v, dict): 

616 if provider == "azure_openai": 

617 return AzureOpenAIConfig(**v) 

618 if provider == "openai": 

619 return OpenAIConfig(**v) 

620 if provider == "anthropic": 

621 return AnthropicConfig(**v) 

622 if provider == "aws_bedrock": 

623 return AWSBedrockConfig(**v) 

624 if provider == "ollama": 

625 return OllamaConfig(**v) 

626 if provider == "watsonx": 

627 return WatsonxConfig(**v) 

628 if provider == "gateway": 

629 return GatewayConfig(**v) 

630 

631 return v 

632 

633 

634class MCPClientConfig(BaseModel): 

635 """ 

636 Main configuration for MCP client service. 

637 

638 Aggregates all configuration parameters required for the complete MCP client 

639 service, including server connection, LLM provider, and operational settings. 

640 

641 Attributes: 

642 mcp_server: MCP server connection configuration. 

643 llm: LLM provider configuration. 

644 chat_history_max_messages: Maximum messages to retain in chat history. 

645 enable_streaming: Whether to enable streaming responses. 

646 

647 Examples: 

648 >>> config = MCPClientConfig( 

649 ... mcp_server=MCPServerConfig( 

650 ... url="https://mcp-server.example.com/mcp", 

651 ... transport="streamable_http" 

652 ... ), 

653 ... llm=LLMConfig( 

654 ... provider="ollama", 

655 ... config=OllamaConfig(model="llama2") 

656 ... ), 

657 ... chat_history_max_messages=100, 

658 ... enable_streaming=True 

659 ... ) 

660 >>> config.chat_history_max_messages 

661 100 

662 >>> config.enable_streaming 

663 True 

664 """ 

665 

666 mcp_server: MCPServerConfig = Field(..., description="MCP server configuration") 

667 llm: LLMConfig = Field(..., description="LLM provider configuration") 

668 chat_history_max_messages: int = settings.llmchat_chat_history_max_messages 

669 enable_streaming: bool = Field(default=True, description="Enable streaming responses") 

670 

671 model_config = { 

672 "json_schema_extra": { 

673 "example": { 

674 "mcp_server": {"url": "https://mcp-server.example.com/mcp", "transport": "streamable_http", "auth_token": "your-token"}, # nosec B105 - example placeholder 

675 "llm": { 

676 "provider": "azure_openai", 

677 "config": {"api_key": "your-key", "azure_endpoint": "https://your-resource.openai.azure.com/", "azure_deployment": "gpt-4", "api_version": "2024-05-01-preview"}, 

678 }, 

679 } 

680 } 

681 } 

682 

683 

684# ==================== LLM PROVIDER IMPLEMENTATIONS ==================== 

685 

686 

687class AzureOpenAIProvider: 

688 """ 

689 Azure OpenAI provider implementation. 

690 

691 Manages connection and interaction with Azure OpenAI services. 

692 

693 Attributes: 

694 config: Azure OpenAI configuration object. 

695 

696 Examples: 

697 >>> config = AzureOpenAIConfig( 

698 ... api_key="key", 

699 ... azure_endpoint="https://example.openai.azure.com/", 

700 ... azure_deployment="gpt-4" 

701 ... ) 

702 >>> provider = AzureOpenAIProvider(config) 

703 >>> provider.get_model_name() 

704 'gpt-4' 

705 

706 Note: 

707 The LLM instance is lazily initialized on first access for 

708 improved startup performance. 

709 """ 

710 

711 def __init__(self, config: AzureOpenAIConfig): 

712 """ 

713 Initialize Azure OpenAI provider. 

714 

715 Args: 

716 config: Azure OpenAI configuration with API credentials and settings. 

717 

718 Examples: 

719 >>> config = AzureOpenAIConfig( 

720 ... api_key="key", 

721 ... azure_endpoint="https://example.openai.azure.com/", 

722 ... azure_deployment="gpt-4" 

723 ... ) 

724 >>> provider = AzureOpenAIProvider(config) 

725 """ 

726 self.config = config 

727 self._llm = None 

728 logger.info(f"Initializing Azure OpenAI provider with deployment: {config.azure_deployment}") 

729 

730 def get_llm(self, model_type: str = "chat") -> Union[AzureChatOpenAI, AzureOpenAI]: 

731 """ 

732 Get Azure OpenAI LLM instance with lazy initialization. 

733 

734 Creates and caches the Azure OpenAI chat model instance on first call. 

735 Subsequent calls return the cached instance. 

736 

737 Args: 

738 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

739 

740 Returns: 

741 AzureChatOpenAI: Configured Azure OpenAI chat model. 

742 

743 Raises: 

744 Exception: If LLM initialization fails (e.g., invalid credentials). 

745 

746 Examples: 

747 >>> config = AzureOpenAIConfig( 

748 ... api_key="key", 

749 ... azure_endpoint="https://example.openai.azure.com/", 

750 ... azure_deployment="gpt-4" 

751 ... ) 

752 >>> provider = AzureOpenAIProvider(config) 

753 >>> # llm = provider.get_llm() # Returns AzureChatOpenAI instance 

754 """ 

755 if self._llm is None: 

756 try: 

757 if model_type == "chat": 

758 self._llm = AzureChatOpenAI( 

759 api_key=self.config.api_key, 

760 azure_endpoint=self.config.azure_endpoint, 

761 api_version=self.config.api_version, 

762 azure_deployment=self.config.azure_deployment, 

763 model=self.config.model, 

764 temperature=self.config.temperature, 

765 max_tokens=self.config.max_tokens, 

766 timeout=self.config.timeout, 

767 max_retries=self.config.max_retries, 

768 ) 

769 elif model_type == "completion": 

770 self._llm = AzureOpenAI( 

771 api_key=self.config.api_key, 

772 azure_endpoint=self.config.azure_endpoint, 

773 api_version=self.config.api_version, 

774 azure_deployment=self.config.azure_deployment, 

775 model=self.config.model, 

776 temperature=self.config.temperature, 

777 max_tokens=self.config.max_tokens, 

778 timeout=self.config.timeout, 

779 max_retries=self.config.max_retries, 

780 ) 

781 logger.info("Azure OpenAI LLM instance created successfully") 

782 except Exception as e: 

783 logger.error(f"Failed to create Azure OpenAI LLM: {e}") 

784 raise 

785 

786 return self._llm 

787 

788 def get_model_name(self) -> str: 

789 """ 

790 Get the Azure OpenAI model name. 

791 

792 Returns: 

793 str: The model name configured for this provider. 

794 

795 Examples: 

796 >>> config = AzureOpenAIConfig( 

797 ... api_key="key", 

798 ... azure_endpoint="https://example.openai.azure.com/", 

799 ... azure_deployment="gpt-4", 

800 ... model="gpt-4" 

801 ... ) 

802 >>> provider = AzureOpenAIProvider(config) 

803 >>> provider.get_model_name() 

804 'gpt-4' 

805 """ 

806 return self.config.model 

807 

808 

809class OllamaProvider: 

810 """ 

811 Ollama provider implementation. 

812 

813 Manages connection and interaction with Ollama instances for running 

814 open-source language models locally or remotely. 

815 

816 Attributes: 

817 config: Ollama configuration object. 

818 

819 Examples: 

820 >>> config = OllamaConfig( 

821 ... base_url="http://localhost:11434", 

822 ... model="llama2" 

823 ... ) 

824 >>> provider = OllamaProvider(config) 

825 >>> provider.get_model_name() 

826 'llama2' 

827 

828 Note: 

829 Requires Ollama to be running and accessible at the configured base_url. 

830 """ 

831 

832 def __init__(self, config: OllamaConfig): 

833 """ 

834 Initialize Ollama provider. 

835 

836 Args: 

837 config: Ollama configuration with server URL and model settings. 

838 

839 Examples: 

840 >>> config = OllamaConfig(model="llama2") 

841 >>> provider = OllamaProvider(config) 

842 """ 

843 self.config = config 

844 self._llm = None 

845 logger.info(f"Initializing Ollama provider with model: {config.model}") 

846 

847 def get_llm(self, model_type: str = "chat") -> Union[ChatOllama, OllamaLLM]: 

848 """ 

849 Get Ollama LLM instance with lazy initialization. 

850 

851 Creates and caches the Ollama chat model instance on first call. 

852 Subsequent calls return the cached instance. 

853 

854 Args: 

855 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

856 

857 Returns: 

858 ChatOllama: Configured Ollama chat model. 

859 

860 Raises: 

861 Exception: If LLM initialization fails (e.g., Ollama not running). 

862 

863 Examples: 

864 >>> config = OllamaConfig(model="llama2") 

865 >>> provider = OllamaProvider(config) 

866 >>> # llm = provider.get_llm() # Returns ChatOllama instance 

867 """ 

868 if self._llm is None: 

869 try: 

870 # Build model kwargs 

871 model_kwargs = {} 

872 if self.config.num_ctx is not None: 

873 model_kwargs["num_ctx"] = self.config.num_ctx 

874 

875 if model_type == "chat": 

876 self._llm = ChatOllama(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs) 

877 elif model_type == "completion": 

878 self._llm = OllamaLLM(base_url=self.config.base_url, model=self.config.model, temperature=self.config.temperature, timeout=self.config.timeout, **model_kwargs) 

879 logger.info("Ollama LLM instance created successfully") 

880 except Exception as e: 

881 logger.error(f"Failed to create Ollama LLM: {e}") 

882 raise 

883 

884 return self._llm 

885 

886 def get_model_name(self) -> str: 

887 """Get the model name. 

888 

889 Returns: 

890 str: The model name 

891 """ 

892 return self.config.model 

893 

894 

895class OpenAIProvider: 

896 """ 

897 OpenAI provider implementation (non-Azure). 

898 

899 Manages connection and interaction with OpenAI API or OpenAI-compatible endpoints. 

900 

901 Attributes: 

902 config: OpenAI configuration object. 

903 

904 Examples: 

905 >>> config = OpenAIConfig( 

906 ... api_key="sk-...", 

907 ... model="gpt-4" 

908 ... ) 

909 >>> provider = OpenAIProvider(config) 

910 >>> provider.get_model_name() 

911 'gpt-4' 

912 

913 Note: 

914 The LLM instance is lazily initialized on first access for 

915 improved startup performance. 

916 """ 

917 

918 def __init__(self, config: OpenAIConfig): 

919 """ 

920 Initialize OpenAI provider. 

921 

922 Args: 

923 config: OpenAI configuration with API key and settings. 

924 

925 Examples: 

926 >>> config = OpenAIConfig( 

927 ... api_key="sk-...", 

928 ... model="gpt-4" 

929 ... ) 

930 >>> provider = OpenAIProvider(config) 

931 """ 

932 self.config = config 

933 self._llm = None 

934 logger.info(f"Initializing OpenAI provider with model: {config.model}") 

935 

936 def get_llm(self, model_type="chat") -> Union[ChatOpenAI, OpenAI]: 

937 """ 

938 Get OpenAI LLM instance with lazy initialization. 

939 

940 Creates and caches the OpenAI chat model instance on first call. 

941 Subsequent calls return the cached instance. 

942 

943 Args: 

944 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

945 

946 Returns: 

947 ChatOpenAI: Configured OpenAI chat model. 

948 

949 Raises: 

950 Exception: If LLM initialization fails (e.g., invalid credentials). 

951 

952 Examples: 

953 >>> config = OpenAIConfig( 

954 ... api_key="sk-...", 

955 ... model="gpt-4" 

956 ... ) 

957 >>> provider = OpenAIProvider(config) 

958 >>> # llm = provider.get_llm() # Returns ChatOpenAI instance 

959 """ 

960 if self._llm is None: 

961 try: 

962 kwargs = { 

963 "api_key": self.config.api_key, 

964 "model": self.config.model, 

965 "temperature": self.config.temperature, 

966 "max_tokens": self.config.max_tokens, 

967 "timeout": self.config.timeout, 

968 "max_retries": self.config.max_retries, 

969 } 

970 

971 if self.config.base_url: 

972 kwargs["base_url"] = self.config.base_url 

973 

974 # add default headers if present 

975 if self.config.default_headers is not None: 

976 kwargs["default_headers"] = self.config.default_headers 

977 

978 if model_type == "chat": 

979 self._llm = ChatOpenAI(**kwargs) 

980 elif model_type == "completion": 

981 self._llm = OpenAI(**kwargs) 

982 

983 logger.info("OpenAI LLM instance created successfully") 

984 except Exception as e: 

985 logger.error(f"Failed to create OpenAI LLM: {e}") 

986 raise 

987 

988 return self._llm 

989 

990 def get_model_name(self) -> str: 

991 """ 

992 Get the OpenAI model name. 

993 

994 Returns: 

995 str: The model name configured for this provider. 

996 

997 Examples: 

998 >>> config = OpenAIConfig( 

999 ... api_key="sk-...", 

1000 ... model="gpt-4" 

1001 ... ) 

1002 >>> provider = OpenAIProvider(config) 

1003 >>> provider.get_model_name() 

1004 'gpt-4' 

1005 """ 

1006 return self.config.model 

1007 

1008 

1009class AnthropicProvider: 

1010 """ 

1011 Anthropic Claude provider implementation. 

1012 

1013 Manages connection and interaction with Anthropic's Claude API. 

1014 

1015 Attributes: 

1016 config: Anthropic configuration object. 

1017 

1018 Examples: 

1019 >>> config = AnthropicConfig( # doctest: +SKIP 

1020 ... api_key="sk-ant-...", 

1021 ... model="claude-3-5-sonnet-20241022" 

1022 ... ) 

1023 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1024 >>> provider.get_model_name() # doctest: +SKIP 

1025 'claude-3-5-sonnet-20241022' 

1026 

1027 Note: 

1028 Requires langchain-anthropic package to be installed. 

1029 """ 

1030 

1031 def __init__(self, config: AnthropicConfig): 

1032 """ 

1033 Initialize Anthropic provider. 

1034 

1035 Args: 

1036 config: Anthropic configuration with API key and settings. 

1037 

1038 Raises: 

1039 ImportError: If langchain-anthropic is not installed. 

1040 

1041 Examples: 

1042 >>> config = AnthropicConfig( # doctest: +SKIP 

1043 ... api_key="sk-ant-...", 

1044 ... model="claude-3-5-sonnet-20241022" 

1045 ... ) 

1046 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1047 """ 

1048 if not _ANTHROPIC_AVAILABLE: 

1049 raise ImportError("Anthropic provider requires langchain-anthropic package. Install it with: pip install langchain-anthropic") 

1050 

1051 self.config = config 

1052 self._llm = None 

1053 logger.info(f"Initializing Anthropic provider with model: {config.model}") 

1054 

1055 def get_llm(self, model_type: str = "chat") -> Union[ChatAnthropic, AnthropicLLM]: 

1056 """ 

1057 Get Anthropic LLM instance with lazy initialization. 

1058 

1059 Creates and caches the Anthropic chat model instance on first call. 

1060 Subsequent calls return the cached instance. 

1061 

1062 Args: 

1063 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

1064 

1065 Returns: 

1066 ChatAnthropic: Configured Anthropic chat model. 

1067 

1068 Raises: 

1069 Exception: If LLM initialization fails (e.g., invalid credentials). 

1070 

1071 Examples: 

1072 >>> config = AnthropicConfig( # doctest: +SKIP 

1073 ... api_key="sk-ant-...", 

1074 ... model="claude-3-5-sonnet-20241022" 

1075 ... ) 

1076 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1077 >>> # llm = provider.get_llm() # Returns ChatAnthropic instance 

1078 """ 

1079 if self._llm is None: 

1080 try: 

1081 if model_type == "chat": 

1082 self._llm = ChatAnthropic( 

1083 api_key=self.config.api_key, 

1084 model=self.config.model, 

1085 temperature=self.config.temperature, 

1086 max_tokens=self.config.max_tokens, 

1087 timeout=self.config.timeout, 

1088 max_retries=self.config.max_retries, 

1089 ) 

1090 elif model_type == "completion": 

1091 self._llm = AnthropicLLM( 

1092 api_key=self.config.api_key, 

1093 model=self.config.model, 

1094 temperature=self.config.temperature, 

1095 max_tokens=self.config.max_tokens, 

1096 timeout=self.config.timeout, 

1097 max_retries=self.config.max_retries, 

1098 ) 

1099 logger.info("Anthropic LLM instance created successfully") 

1100 except Exception as e: 

1101 logger.error(f"Failed to create Anthropic LLM: {e}") 

1102 raise 

1103 

1104 return self._llm 

1105 

1106 def get_model_name(self) -> str: 

1107 """ 

1108 Get the Anthropic model name. 

1109 

1110 Returns: 

1111 str: The model name configured for this provider. 

1112 

1113 Examples: 

1114 >>> config = AnthropicConfig( # doctest: +SKIP 

1115 ... api_key="sk-ant-...", 

1116 ... model="claude-3-5-sonnet-20241022" 

1117 ... ) 

1118 >>> provider = AnthropicProvider(config) # doctest: +SKIP 

1119 >>> provider.get_model_name() # doctest: +SKIP 

1120 'claude-3-5-sonnet-20241022' 

1121 """ 

1122 return self.config.model 

1123 

1124 

1125class AWSBedrockProvider: 

1126 """ 

1127 AWS Bedrock provider implementation. 

1128 

1129 Manages connection and interaction with AWS Bedrock LLM services. 

1130 

1131 Attributes: 

1132 config: AWS Bedrock configuration object. 

1133 

1134 Examples: 

1135 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1136 ... model_id="anthropic.claude-v2", 

1137 ... region_name="us-east-1" 

1138 ... ) 

1139 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1140 >>> provider.get_model_name() # doctest: +SKIP 

1141 'anthropic.claude-v2' 

1142 

1143 Note: 

1144 Requires langchain-aws package and boto3 to be installed. 

1145 Uses AWS default credential chain if credentials not explicitly provided. 

1146 """ 

1147 

1148 def __init__(self, config: AWSBedrockConfig): 

1149 """ 

1150 Initialize AWS Bedrock provider. 

1151 

1152 Args: 

1153 config: AWS Bedrock configuration with model ID and settings. 

1154 

1155 Raises: 

1156 ImportError: If langchain-aws is not installed. 

1157 

1158 Examples: 

1159 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1160 ... model_id="anthropic.claude-v2", 

1161 ... region_name="us-east-1" 

1162 ... ) 

1163 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1164 """ 

1165 if not _BEDROCK_AVAILABLE: 

1166 raise ImportError("AWS Bedrock provider requires langchain-aws package. Install it with: pip install langchain-aws boto3") 

1167 

1168 self.config = config 

1169 self._llm = None 

1170 logger.info(f"Initializing AWS Bedrock provider with model: {config.model_id}") 

1171 

1172 def get_llm(self, model_type: str = "chat") -> Union[ChatBedrock, BedrockLLM]: 

1173 """ 

1174 Get AWS Bedrock LLM instance with lazy initialization. 

1175 

1176 Creates and caches the Bedrock chat model instance on first call. 

1177 Subsequent calls return the cached instance. 

1178 

1179 Args: 

1180 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

1181 

1182 Returns: 

1183 ChatBedrock: Configured AWS Bedrock chat model. 

1184 

1185 Raises: 

1186 Exception: If LLM initialization fails (e.g., invalid credentials, permissions). 

1187 

1188 Examples: 

1189 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1190 ... model_id="anthropic.claude-v2", 

1191 ... region_name="us-east-1" 

1192 ... ) 

1193 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1194 >>> # llm = provider.get_llm() # Returns ChatBedrock instance 

1195 """ 

1196 if self._llm is None: 

1197 try: 

1198 # Build credentials dict if provided 

1199 credentials_kwargs = {} 

1200 if self.config.aws_access_key_id: 

1201 credentials_kwargs["aws_access_key_id"] = self.config.aws_access_key_id 

1202 if self.config.aws_secret_access_key: 

1203 credentials_kwargs["aws_secret_access_key"] = self.config.aws_secret_access_key 

1204 if self.config.aws_session_token: 

1205 credentials_kwargs["aws_session_token"] = self.config.aws_session_token 

1206 

1207 if model_type == "chat": 

1208 self._llm = ChatBedrock( 

1209 model_id=self.config.model_id, 

1210 region_name=self.config.region_name, 

1211 model_kwargs={ 

1212 "temperature": self.config.temperature, 

1213 "max_tokens": self.config.max_tokens, 

1214 }, 

1215 **credentials_kwargs, 

1216 ) 

1217 elif model_type == "completion": 

1218 self._llm = BedrockLLM( 

1219 model_id=self.config.model_id, 

1220 region_name=self.config.region_name, 

1221 model_kwargs={ 

1222 "temperature": self.config.temperature, 

1223 "max_tokens": self.config.max_tokens, 

1224 }, 

1225 **credentials_kwargs, 

1226 ) 

1227 logger.info("AWS Bedrock LLM instance created successfully") 

1228 except Exception as e: 

1229 logger.error(f"Failed to create AWS Bedrock LLM: {e}") 

1230 raise 

1231 

1232 return self._llm 

1233 

1234 def get_model_name(self) -> str: 

1235 """ 

1236 Get the AWS Bedrock model ID. 

1237 

1238 Returns: 

1239 str: The model ID configured for this provider. 

1240 

1241 Examples: 

1242 >>> config = AWSBedrockConfig( # doctest: +SKIP 

1243 ... model_id="anthropic.claude-v2", 

1244 ... region_name="us-east-1" 

1245 ... ) 

1246 >>> provider = AWSBedrockProvider(config) # doctest: +SKIP 

1247 >>> provider.get_model_name() # doctest: +SKIP 

1248 'anthropic.claude-v2' 

1249 """ 

1250 return self.config.model_id 

1251 

1252 

1253class WatsonxProvider: 

1254 """ 

1255 IBM watsonx.ai provider implementation. 

1256 

1257 Manages connection and interaction with IBM watsonx.ai services. 

1258 

1259 Attributes: 

1260 config: IBM watsonx.ai configuration object. 

1261 

1262 Examples: 

1263 >>> config = WatsonxConfig( # doctest: +SKIP 

1264 ... api_key="key", 

1265 ... url="https://us-south.ml.cloud.ibm.com", 

1266 ... project_id="project-id", 

1267 ... model_id="ibm/granite-13b-chat-v2" 

1268 ... ) 

1269 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1270 >>> provider.get_model_name() # doctest: +SKIP 

1271 'ibm/granite-13b-chat-v2' 

1272 

1273 Note: 

1274 Requires langchain-ibm package to be installed. 

1275 """ 

1276 

1277 def __init__(self, config: WatsonxConfig): 

1278 """ 

1279 Initialize IBM watsonx.ai provider. 

1280 

1281 Args: 

1282 config: IBM watsonx.ai configuration with credentials and settings. 

1283 

1284 Raises: 

1285 ImportError: If langchain-ibm is not installed. 

1286 

1287 Examples: 

1288 >>> config = WatsonxConfig( # doctest: +SKIP 

1289 ... api_key="key", 

1290 ... url="https://us-south.ml.cloud.ibm.com", 

1291 ... project_id="project-id", 

1292 ... model_id="ibm/granite-13b-chat-v2" 

1293 ... ) 

1294 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1295 """ 

1296 if not _WATSONX_AVAILABLE: 

1297 raise ImportError("IBM watsonx.ai provider requires langchain-ibm package. Install it with: pip install langchain-ibm") 

1298 self.config = config 

1299 self.llm = None 

1300 logger.info(f"Initializing IBM watsonx.ai provider with model {config.model_id}") 

1301 

1302 def get_llm(self, model_type="chat") -> Union[WatsonxLLM, ChatWatsonx]: 

1303 """ 

1304 Get IBM watsonx.ai LLM instance with lazy initialization. 

1305 

1306 Creates and caches the watsonx LLM instance on first call. 

1307 Subsequent calls return the cached instance. 

1308 

1309 Args: 

1310 model_type: LLM inference model type such as 'chat' model , text 'completion' model 

1311 

1312 Returns: 

1313 WatsonxLLM: Configured IBM watsonx.ai LLM model. 

1314 

1315 Raises: 

1316 Exception: If LLM initialization fails (e.g., invalid credentials). 

1317 

1318 Examples: 

1319 >>> config = WatsonxConfig( # doctest: +SKIP 

1320 ... api_key="key", 

1321 ... url="https://us-south.ml.cloud.ibm.com", 

1322 ... project_id="project-id", 

1323 ... model_id="ibm/granite-13b-chat-v2" 

1324 ... ) 

1325 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1326 >>> #llm = provider.get_llm() # Returns WatsonxLLM instance 

1327 """ 

1328 if self.llm is None: 

1329 try: 

1330 # Build parameters dict 

1331 params = { 

1332 "decoding_method": self.config.decoding_method, 

1333 "temperature": self.config.temperature, 

1334 "max_new_tokens": self.config.max_new_tokens, 

1335 "min_new_tokens": self.config.min_new_tokens, 

1336 } 

1337 

1338 if self.config.top_k is not None: 

1339 params["top_k"] = self.config.top_k 

1340 if self.config.top_p is not None: 

1341 params["top_p"] = self.config.top_p 

1342 if model_type == "completion": 

1343 # Initialize WatsonxLLM 

1344 self.llm = WatsonxLLM( 

1345 apikey=self.config.api_key, 

1346 url=self.config.url, 

1347 project_id=self.config.project_id, 

1348 model_id=self.config.model_id, 

1349 params=params, 

1350 ) 

1351 elif model_type == "chat": 

1352 # Initialize Chat WatsonxLLM 

1353 self.llm = ChatWatsonx( 

1354 apikey=self.config.api_key, 

1355 url=self.config.url, 

1356 project_id=self.config.project_id, 

1357 model_id=self.config.model_id, 

1358 params=params, 

1359 ) 

1360 logger.info("IBM watsonx.ai LLM instance created successfully") 

1361 except Exception as e: 

1362 logger.error(f"Failed to create IBM watsonx.ai LLM: {e}") 

1363 raise 

1364 return self.llm 

1365 

1366 def get_model_name(self) -> str: 

1367 """ 

1368 Get the IBM watsonx.ai model ID. 

1369 

1370 Returns: 

1371 str: The model ID configured for this provider. 

1372 

1373 Examples: 

1374 >>> config = WatsonxConfig( # doctest: +SKIP 

1375 ... api_key="key", 

1376 ... url="https://us-south.ml.cloud.ibm.com", 

1377 ... project_id="project-id", 

1378 ... model_id="ibm/granite-13b-chat-v2" 

1379 ... ) 

1380 >>> provider = WatsonxProvider(config) # doctest: +SKIP 

1381 >>> provider.get_model_name() # doctest: +SKIP 

1382 'ibm/granite-13b-chat-v2' 

1383 """ 

1384 return self.config.model_id 

1385 

1386 

1387class GatewayProvider: 

1388 """ 

1389 Gateway provider implementation for using models configured in LLM Settings. 

1390 

1391 Routes LLM requests through the gateway's configured providers, allowing 

1392 users to use models set up via the Admin UI's LLM Settings without needing 

1393 to configure credentials in environment variables or API requests. 

1394 

1395 Attributes: 

1396 config: Gateway configuration with model ID. 

1397 llm: Lazily initialized LLM instance. 

1398 

1399 Examples: 

1400 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1401 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1402 >>> provider.get_model_name() # doctest: +SKIP 

1403 'gpt-4o' 

1404 

1405 Note: 

1406 Requires models to be configured via Admin UI -> Settings -> LLM Settings. 

1407 """ 

1408 

1409 def __init__(self, config: GatewayConfig): 

1410 """ 

1411 Initialize Gateway provider. 

1412 

1413 Args: 

1414 config: Gateway configuration with model ID and optional settings. 

1415 

1416 Examples: 

1417 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1418 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1419 """ 

1420 self.config = config 

1421 self.llm = None 

1422 self._model_name: Optional[str] = None 

1423 self._underlying_provider = None 

1424 logger.info(f"Initializing Gateway provider with model: {config.model}") 

1425 

1426 def get_llm(self, model_type: str = "chat") -> Union[BaseChatModel, Any]: 

1427 """ 

1428 Get LLM instance by looking up model from gateway's LLM Settings. 

1429 

1430 Fetches the model configuration from the database, decrypts API keys, 

1431 and creates the appropriate LangChain LLM instance based on provider type. 

1432 

1433 Args: 

1434 model_type: Type of model to return ('chat' or 'completion'). Defaults to 'chat'. 

1435 

1436 Returns: 

1437 Union[BaseChatModel, Any]: Configured LangChain chat or completion model instance. 

1438 

1439 Raises: 

1440 ValueError: If model not found or provider not enabled. 

1441 ImportError: If required provider package not installed. 

1442 

1443 Examples: 

1444 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1445 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1446 >>> llm = provider.get_llm() # doctest: +SKIP 

1447 

1448 Note: 

1449 The LLM instance is lazily initialized and cached by model_type. 

1450 """ 

1451 if self.llm is not None: 

1452 return self.llm 

1453 

1454 # Import here to avoid circular imports 

1455 # First-Party 

1456 from mcpgateway.db import LLMModel, LLMProvider, SessionLocal # pylint: disable=import-outside-toplevel 

1457 from mcpgateway.services.llm_provider_service import decrypt_provider_config_for_runtime # pylint: disable=import-outside-toplevel 

1458 from mcpgateway.utils.services_auth import decode_auth # pylint: disable=import-outside-toplevel 

1459 

1460 model_id = self.config.model 

1461 

1462 with SessionLocal() as db: 

1463 # Try to find model by UUID first, then by model_id 

1464 model = db.query(LLMModel).filter(LLMModel.id == model_id).first() 

1465 if not model: 

1466 model = db.query(LLMModel).filter(LLMModel.model_id == model_id).first() 

1467 

1468 if not model: 

1469 raise ValueError(f"Model '{model_id}' not found in LLM Settings. Configure it via Admin UI -> Settings -> LLM Settings.") 

1470 

1471 if not model.enabled: 

1472 raise ValueError(f"Model '{model.model_id}' is disabled. Enable it in LLM Settings.") 

1473 

1474 # Get the provider 

1475 provider = db.query(LLMProvider).filter(LLMProvider.id == model.provider_id).first() 

1476 if not provider: 

1477 raise ValueError(f"Provider not found for model '{model.model_id}'") 

1478 

1479 if not provider.enabled: 

1480 raise ValueError(f"Provider '{provider.name}' is disabled. Enable it in LLM Settings.") 

1481 

1482 # Get decrypted API key 

1483 api_key = None 

1484 if provider.api_key: 

1485 auth_data = decode_auth(provider.api_key) 

1486 if isinstance(auth_data, dict): 

1487 api_key = auth_data.get("api_key") 

1488 else: 

1489 api_key = auth_data 

1490 

1491 # Store model name for get_model_name() 

1492 self._model_name = model.model_id 

1493 

1494 # Get temperature - use config override or provider default 

1495 temperature = self.config.temperature if self.config.temperature is not None else (provider.default_temperature or 0.7) 

1496 max_tokens = self.config.max_tokens or model.max_output_tokens 

1497 

1498 # Create appropriate LLM based on provider type 

1499 provider_type = provider.provider_type.lower() 

1500 config = decrypt_provider_config_for_runtime(provider.config) 

1501 

1502 # Common kwargs 

1503 kwargs: Dict[str, Any] = { 

1504 "temperature": temperature, 

1505 "timeout": self.config.timeout, 

1506 } 

1507 

1508 if provider_type == "openai": 

1509 kwargs.update( 

1510 { 

1511 "api_key": api_key, 

1512 "model": model.model_id, 

1513 "max_tokens": max_tokens, 

1514 } 

1515 ) 

1516 if provider.api_base: 

1517 kwargs["base_url"] = provider.api_base 

1518 

1519 # Handle default headers 

1520 if config.get("default_headers"): 

1521 kwargs["default_headers"] = config["default_headers"] 

1522 elif hasattr(self.config, "default_headers") and self.config.default_headers: # type: ignore 

1523 kwargs["default_headers"] = self.config.default_headers 

1524 

1525 if model_type == "chat": 

1526 self.llm = ChatOpenAI(**kwargs) 

1527 else: 

1528 self.llm = OpenAI(**kwargs) 

1529 

1530 elif provider_type == "azure_openai": 

1531 if not provider.api_base: 

1532 raise ValueError("Azure OpenAI requires base_url (azure_endpoint) to be configured") 

1533 

1534 azure_deployment = config.get("azure_deployment", model.model_id) 

1535 api_version = config.get("api_version", "2024-05-01-preview") 

1536 max_retries = config.get("max_retries", 2) 

1537 

1538 kwargs.update( 

1539 { 

1540 "api_key": api_key, 

1541 "azure_endpoint": provider.api_base, 

1542 "azure_deployment": azure_deployment, 

1543 "api_version": api_version, 

1544 "model": model.model_id, 

1545 "max_tokens": int(max_tokens) if max_tokens is not None else None, 

1546 "max_retries": max_retries, 

1547 } 

1548 ) 

1549 

1550 if model_type == "chat": 

1551 self.llm = AzureChatOpenAI(**kwargs) 

1552 else: 

1553 self.llm = AzureOpenAI(**kwargs) 

1554 

1555 elif provider_type == "anthropic": 

1556 if not _ANTHROPIC_AVAILABLE: 

1557 raise ImportError("Anthropic provider requires langchain-anthropic. Install with: pip install langchain-anthropic") 

1558 

1559 # Anthropic uses 'model_name' instead of 'model' 

1560 anthropic_kwargs = { 

1561 "api_key": api_key, 

1562 "model_name": model.model_id, 

1563 "max_tokens": max_tokens or 4096, 

1564 "temperature": temperature, 

1565 "timeout": self.config.timeout, 

1566 "default_request_timeout": self.config.timeout, 

1567 } 

1568 

1569 if model_type == "chat": 

1570 self.llm = ChatAnthropic(**anthropic_kwargs) 

1571 else: 

1572 # Generic Anthropic completion model if needed, though mostly chat used now 

1573 if AnthropicLLM: 

1574 llm_kwargs = anthropic_kwargs.copy() 

1575 llm_kwargs["model"] = llm_kwargs.pop("model_name") 

1576 self.llm = AnthropicLLM(**llm_kwargs) 

1577 else: 

1578 raise ImportError("Anthropic completion model (AnthropicLLM) not available") 

1579 

1580 elif provider_type == "bedrock": 

1581 if not _BEDROCK_AVAILABLE: 

1582 raise ImportError("AWS Bedrock provider requires langchain-aws. Install with: pip install langchain-aws boto3") 

1583 

1584 region_name = config.get("region_name", "us-east-1") 

1585 credentials_kwargs = {} 

1586 if config.get("aws_access_key_id"): 

1587 credentials_kwargs["aws_access_key_id"] = config["aws_access_key_id"] 

1588 if config.get("aws_secret_access_key"): 

1589 credentials_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"] 

1590 if config.get("aws_session_token"): 

1591 credentials_kwargs["aws_session_token"] = config["aws_session_token"] 

1592 

1593 model_kwargs = { 

1594 "temperature": temperature, 

1595 "max_tokens": max_tokens or 4096, 

1596 } 

1597 

1598 if model_type == "chat": 

1599 self.llm = ChatBedrock( 

1600 model_id=model.model_id, 

1601 region_name=region_name, 

1602 model_kwargs=model_kwargs, 

1603 **credentials_kwargs, 

1604 ) 

1605 else: 

1606 self.llm = BedrockLLM( 

1607 model_id=model.model_id, 

1608 region_name=region_name, 

1609 model_kwargs=model_kwargs, 

1610 **credentials_kwargs, 

1611 ) 

1612 

1613 elif provider_type == "ollama": 

1614 base_url = provider.api_base or "http://localhost:11434" 

1615 num_ctx = config.get("num_ctx") 

1616 

1617 # Explicitly construct kwargs to avoid generic unpacking issues with Pydantic models 

1618 ollama_kwargs = { 

1619 "base_url": base_url, 

1620 "model": model.model_id, 

1621 "temperature": temperature, 

1622 "timeout": self.config.timeout, 

1623 } 

1624 if num_ctx: 

1625 ollama_kwargs["num_ctx"] = num_ctx 

1626 

1627 if model_type == "chat": 

1628 self.llm = ChatOllama(**ollama_kwargs) 

1629 else: 

1630 self.llm = OllamaLLM(**ollama_kwargs) 

1631 

1632 elif provider_type == "watsonx": 

1633 if not _WATSONX_AVAILABLE: 

1634 raise ImportError("IBM watsonx.ai provider requires langchain-ibm. Install with: pip install langchain-ibm") 

1635 

1636 project_id = config.get("project_id") 

1637 if not project_id: 

1638 raise ValueError("IBM watsonx.ai requires project_id in config") 

1639 

1640 url = provider.api_base or "https://us-south.ml.cloud.ibm.com" 

1641 

1642 params = { 

1643 "temperature": temperature, 

1644 "max_new_tokens": max_tokens or 1024, 

1645 "min_new_tokens": config.get("min_new_tokens", 1), 

1646 "decoding_method": config.get("decoding_method", "sample"), 

1647 "top_k": config.get("top_k", 50), 

1648 "top_p": config.get("top_p", 1.0), 

1649 } 

1650 

1651 if model_type == "chat": 

1652 self.llm = ChatWatsonx( 

1653 apikey=api_key, 

1654 url=url, 

1655 project_id=project_id, 

1656 model_id=model.model_id, 

1657 params=params, 

1658 ) 

1659 else: 

1660 self.llm = WatsonxLLM( 

1661 apikey=api_key, 

1662 url=url, 

1663 project_id=project_id, 

1664 model_id=model.model_id, 

1665 params=params, 

1666 ) 

1667 

1668 elif provider_type == "openai_compatible": 

1669 if not provider.api_base: 

1670 raise ValueError("OpenAI-compatible provider requires base_url to be configured") 

1671 

1672 kwargs.update( 

1673 { 

1674 "api_key": api_key or "no-key-required", 

1675 "model": model.model_id, 

1676 "base_url": provider.api_base, 

1677 "max_tokens": max_tokens, 

1678 } 

1679 ) 

1680 

1681 if model_type == "chat": 

1682 self.llm = ChatOpenAI(**kwargs) 

1683 else: 

1684 self.llm = OpenAI(**kwargs) 

1685 

1686 else: 

1687 raise ValueError(f"Unsupported LLM provider: {provider_type}") 

1688 

1689 logger.info(f"Gateway provider created LLM instance for model: {model.model_id} via {provider_type}") 

1690 return self.llm 

1691 

1692 def get_model_name(self) -> str: 

1693 """ 

1694 Get the model name. 

1695 

1696 Returns: 

1697 str: The model name/ID. 

1698 

1699 Examples: 

1700 >>> config = GatewayConfig(model="gpt-4o") # doctest: +SKIP 

1701 >>> provider = GatewayProvider(config) # doctest: +SKIP 

1702 >>> provider.get_model_name() # doctest: +SKIP 

1703 'gpt-4o' 

1704 """ 

1705 return self._model_name or self.config.model 

1706 

1707 

1708class LLMProviderFactory: 

1709 """ 

1710 Factory for creating LLM providers. 

1711 

1712 Implements the Factory pattern to instantiate the appropriate LLM provider 

1713 based on configuration, abstracting away provider-specific initialization. 

1714 

1715 Examples: 

1716 >>> config = LLMConfig( 

1717 ... provider="ollama", 

1718 ... config=OllamaConfig(model="llama2") 

1719 ... ) 

1720 >>> provider = LLMProviderFactory.create(config) 

1721 >>> provider.get_model_name() 

1722 'llama2' 

1723 

1724 Note: 

1725 This factory supports dynamic provider registration and ensures 

1726 type safety through the LLMConfig discriminated union. 

1727 """ 

1728 

1729 @staticmethod 

1730 def create(llm_config: LLMConfig) -> Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: 

1731 """ 

1732 Create an LLM provider based on configuration. 

1733 

1734 Args: 

1735 llm_config: LLM configuration specifying provider type and settings. 

1736 

1737 Returns: 

1738 Union[AzureOpenAIProvider, OpenAIProvider, AnthropicProvider, AWSBedrockProvider, OllamaProvider, WatsonxProvider, GatewayProvider]: Instantiated provider. 

1739 

1740 Raises: 

1741 ValueError: If provider type is not supported. 

1742 ImportError: If required provider package is not installed. 

1743 

1744 Examples: 

1745 >>> # Create Azure OpenAI provider 

1746 >>> config = LLMConfig( 

1747 ... provider="azure_openai", 

1748 ... config=AzureOpenAIConfig( 

1749 ... api_key="key", 

1750 ... azure_endpoint="https://example.com/", 

1751 ... azure_deployment="gpt-4" 

1752 ... ) 

1753 ... ) 

1754 >>> provider = LLMProviderFactory.create(config) 

1755 >>> isinstance(provider, AzureOpenAIProvider) 

1756 True 

1757 

1758 >>> # Create OpenAI provider 

1759 >>> config = LLMConfig( 

1760 ... provider="openai", 

1761 ... config=OpenAIConfig( 

1762 ... api_key="sk-...", 

1763 ... model="gpt-4" 

1764 ... ) 

1765 ... ) 

1766 >>> provider = LLMProviderFactory.create(config) 

1767 >>> isinstance(provider, OpenAIProvider) 

1768 True 

1769 

1770 >>> # Create Ollama provider 

1771 >>> config = LLMConfig( 

1772 ... provider="ollama", 

1773 ... config=OllamaConfig(model="llama2") 

1774 ... ) 

1775 >>> provider = LLMProviderFactory.create(config) 

1776 >>> isinstance(provider, OllamaProvider) 

1777 True 

1778 """ 

1779 provider_map = { 

1780 "azure_openai": AzureOpenAIProvider, 

1781 "openai": OpenAIProvider, 

1782 "anthropic": AnthropicProvider, 

1783 "aws_bedrock": AWSBedrockProvider, 

1784 "ollama": OllamaProvider, 

1785 "watsonx": WatsonxProvider, 

1786 "gateway": GatewayProvider, 

1787 } 

1788 

1789 provider_class = provider_map.get(llm_config.provider) 

1790 

1791 if not provider_class: 

1792 raise ValueError(f"Unsupported LLM provider: {llm_config.provider}. Supported providers: {list(provider_map.keys())}") 

1793 

1794 logger.info(f"Creating LLM provider: {llm_config.provider}") 

1795 return provider_class(llm_config.config) 

1796 

1797 

1798# ==================== CHAT HISTORY MANAGER ==================== 

1799 

1800 

1801class ChatHistoryManager: 

1802 """ 

1803 Centralized chat history management with Redis and in-memory fallback. 

1804 

1805 Provides a unified interface for storing and retrieving chat histories across 

1806 multiple workers using Redis, with automatic fallback to in-memory storage 

1807 when Redis is not available. 

1808 

1809 This class eliminates duplication between router and service layers by 

1810 providing a single source of truth for all chat history operations. 

1811 

1812 Attributes: 

1813 redis_client: Optional Redis async client for distributed storage. 

1814 max_messages: Maximum number of messages to retain per user. 

1815 ttl: Time-to-live for Redis entries in seconds. 

1816 _memory_store: In-memory dict fallback when Redis unavailable. 

1817 

1818 Examples: 

1819 >>> import asyncio 

1820 >>> # Create manager without Redis (in-memory mode) 

1821 >>> manager = ChatHistoryManager(redis_client=None, max_messages=50) 

1822 >>> # asyncio.run(manager.save_history("user123", [{"role": "user", "content": "Hello"}])) 

1823 >>> # history = asyncio.run(manager.get_history("user123")) 

1824 >>> # len(history) >= 0 

1825 True 

1826 

1827 Note: 

1828 Thread-safe for Redis operations. In-memory mode suitable for 

1829 single-worker deployments only. 

1830 """ 

1831 

1832 def __init__(self, redis_client: Optional[Any] = None, max_messages: int = 50, ttl: int = 3600): 

1833 """ 

1834 Initialize chat history manager. 

1835 

1836 Args: 

1837 redis_client: Optional Redis async client. If None, uses in-memory storage. 

1838 max_messages: Maximum messages to retain per user (default: 50). 

1839 ttl: Time-to-live for Redis entries in seconds (default: 3600). 

1840 

1841 Examples: 

1842 >>> manager = ChatHistoryManager(redis_client=None, max_messages=100) 

1843 >>> manager.max_messages 

1844 100 

1845 >>> manager.ttl 

1846 3600 

1847 """ 

1848 self.redis_client = redis_client 

1849 self.max_messages = max_messages 

1850 self.ttl = ttl 

1851 self._memory_store: Dict[str, List[Dict[str, str]]] = {} 

1852 

1853 if redis_client: 

1854 logger.info("ChatHistoryManager initialized with Redis backend") 

1855 else: 

1856 logger.info("ChatHistoryManager initialized with in-memory backend") 

1857 

1858 def _history_key(self, user_id: str) -> str: 

1859 """ 

1860 Generate Redis key for user's chat history. 

1861 

1862 Args: 

1863 user_id: User identifier. 

1864 

1865 Returns: 

1866 str: Redis key string. 

1867 

1868 Examples: 

1869 >>> manager = ChatHistoryManager() 

1870 >>> manager._history_key("user123") 

1871 'chat_history:user123' 

1872 """ 

1873 return f"chat_history:{user_id}" 

1874 

1875 async def get_history(self, user_id: str) -> List[Dict[str, str]]: 

1876 """ 

1877 Retrieve chat history for a user. 

1878 

1879 Fetches history from Redis if available, otherwise from in-memory store. 

1880 

1881 Args: 

1882 user_id: User identifier. 

1883 

1884 Returns: 

1885 List[Dict[str, str]]: List of message dictionaries with 'role' and 'content' keys. 

1886 Returns empty list if no history exists. 

1887 

1888 Examples: 

1889 >>> import asyncio 

1890 >>> manager = ChatHistoryManager() 

1891 >>> # history = asyncio.run(manager.get_history("user123")) 

1892 >>> # isinstance(history, list) 

1893 True 

1894 

1895 Note: 

1896 Automatically handles JSON deserialization errors by returning empty list. 

1897 """ 

1898 if self.redis_client: 

1899 try: 

1900 data = await self.redis_client.get(self._history_key(user_id)) 

1901 if not data: 

1902 return [] 

1903 return orjson.loads(data) 

1904 except orjson.JSONDecodeError: 

1905 logger.warning(f"Failed to decode chat history for user {user_id}") 

1906 return [] 

1907 except Exception as e: 

1908 logger.error(f"Error retrieving chat history from Redis for user {user_id}: {e}") 

1909 return [] 

1910 else: 

1911 return self._memory_store.get(user_id, []) 

1912 

1913 async def save_history(self, user_id: str, history: List[Dict[str, str]]) -> None: 

1914 """ 

1915 Save chat history for a user. 

1916 

1917 Stores history in Redis (with TTL) if available, otherwise in memory. 

1918 Automatically trims history to max_messages before saving. 

1919 

1920 Args: 

1921 user_id: User identifier. 

1922 history: List of message dictionaries to save. 

1923 

1924 Examples: 

1925 >>> import asyncio 

1926 >>> manager = ChatHistoryManager(max_messages=50) 

1927 >>> messages = [{"role": "user", "content": "Hello"}] 

1928 >>> # asyncio.run(manager.save_history("user123", messages)) 

1929 

1930 Note: 

1931 History is automatically trimmed to max_messages limit before storage. 

1932 """ 

1933 # Trim history before saving 

1934 trimmed = self._trim_messages(history) 

1935 

1936 if self.redis_client: 

1937 try: 

1938 await self.redis_client.set(self._history_key(user_id), orjson.dumps(trimmed), ex=self.ttl) 

1939 except Exception as e: 

1940 logger.error(f"Error saving chat history to Redis for user {user_id}: {e}") 

1941 else: 

1942 self._memory_store[user_id] = trimmed 

1943 

1944 async def append_message(self, user_id: str, role: str, content: str) -> None: 

1945 """ 

1946 Append a single message to user's chat history. 

1947 

1948 Convenience method that fetches current history, appends the message, 

1949 trims if needed, and saves back. 

1950 

1951 Args: 

1952 user_id: User identifier. 

1953 role: Message role ('user' or 'assistant'). 

1954 content: Message content text. 

1955 

1956 Examples: 

1957 >>> import asyncio 

1958 >>> manager = ChatHistoryManager() 

1959 >>> # asyncio.run(manager.append_message("user123", "user", "Hello!")) 

1960 

1961 Note: 

1962 This method performs a read-modify-write operation which may 

1963 not be atomic in distributed environments. 

1964 """ 

1965 history = await self.get_history(user_id) 

1966 history.append({"role": role, "content": content}) 

1967 await self.save_history(user_id, history) 

1968 

1969 async def clear_history(self, user_id: str) -> None: 

1970 """ 

1971 Clear all chat history for a user. 

1972 

1973 Deletes history from Redis or memory store. 

1974 

1975 Args: 

1976 user_id: User identifier. 

1977 

1978 Examples: 

1979 >>> import asyncio 

1980 >>> manager = ChatHistoryManager() 

1981 >>> # asyncio.run(manager.clear_history("user123")) 

1982 

1983 Note: 

1984 This operation cannot be undone. 

1985 """ 

1986 if self.redis_client: 

1987 try: 

1988 await self.redis_client.delete(self._history_key(user_id)) 

1989 except Exception as e: 

1990 logger.error(f"Error clearing chat history from Redis for user {user_id}: {e}") 

1991 else: 

1992 self._memory_store.pop(user_id, None) 

1993 

1994 def _trim_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: 

1995 """ 

1996 Trim message list to max_messages limit. 

1997 

1998 Keeps the most recent messages up to max_messages count. 

1999 

2000 Args: 

2001 messages: List of message dictionaries. 

2002 

2003 Returns: 

2004 List[Dict[str, str]]: Trimmed message list. 

2005 

2006 Examples: 

2007 >>> manager = ChatHistoryManager(max_messages=2) 

2008 >>> messages = [ 

2009 ... {"role": "user", "content": "1"}, 

2010 ... {"role": "assistant", "content": "2"}, 

2011 ... {"role": "user", "content": "3"} 

2012 ... ] 

2013 >>> trimmed = manager._trim_messages(messages) 

2014 >>> len(trimmed) 

2015 2 

2016 >>> trimmed[0]["content"] 

2017 '2' 

2018 """ 

2019 if len(messages) > self.max_messages: 

2020 return messages[-self.max_messages :] 

2021 return messages 

2022 

2023 async def get_langchain_messages(self, user_id: str) -> List[BaseMessage]: 

2024 """ 

2025 Get chat history as LangChain message objects. 

2026 

2027 Converts stored history dictionaries to LangChain HumanMessage and 

2028 AIMessage objects for use with LangChain agents. 

2029 

2030 Args: 

2031 user_id: User identifier. 

2032 

2033 Returns: 

2034 List[BaseMessage]: List of LangChain message objects. 

2035 

2036 Examples: 

2037 >>> import asyncio 

2038 >>> manager = ChatHistoryManager() 

2039 >>> # messages = asyncio.run(manager.get_langchain_messages("user123")) 

2040 >>> # isinstance(messages, list) 

2041 True 

2042 

2043 Note: 

2044 Returns empty list if LangChain is not available or history is empty. 

2045 """ 

2046 if not _LLMCHAT_AVAILABLE: 

2047 return [] 

2048 

2049 history = await self.get_history(user_id) 

2050 lc_messages = [] 

2051 

2052 for msg in history: 

2053 role = msg.get("role") 

2054 content = msg.get("content", "") 

2055 

2056 if role == "user": 

2057 lc_messages.append(HumanMessage(content=content)) 

2058 elif role == "assistant": 

2059 lc_messages.append(AIMessage(content=content)) 

2060 

2061 return lc_messages 

2062 

2063 

2064# ==================== MCP CLIENT ==================== 

2065 

2066 

2067class MCPClient: 

2068 """ 

2069 Manages MCP server connections and tool loading. 

2070 

2071 Provides a high-level interface for connecting to MCP servers, retrieving 

2072 available tools, and managing connection health. Supports multiple transport 

2073 protocols including HTTP, SSE, and stdio. 

2074 

2075 Attributes: 

2076 config: MCP server configuration. 

2077 

2078 Examples: 

2079 >>> import asyncio 

2080 >>> config = MCPServerConfig( 

2081 ... url="https://mcp-server.example.com/mcp", 

2082 ... transport="streamable_http" 

2083 ... ) 

2084 >>> client = MCPClient(config) 

2085 >>> client.is_connected 

2086 False 

2087 >>> # asyncio.run(client.connect()) 

2088 >>> # tools = asyncio.run(client.get_tools()) 

2089 

2090 Note: 

2091 All methods are async and should be called using asyncio or within 

2092 an async context. 

2093 """ 

2094 

2095 def __init__(self, config: MCPServerConfig): 

2096 """ 

2097 Initialize MCP client. 

2098 

2099 Args: 

2100 config: MCP server configuration with connection parameters. 

2101 

2102 Examples: 

2103 >>> config = MCPServerConfig( 

2104 ... url="https://example.com/mcp", 

2105 ... transport="streamable_http" 

2106 ... ) 

2107 >>> client = MCPClient(config) 

2108 >>> client.config.transport 

2109 'streamable_http' 

2110 """ 

2111 self.config = config 

2112 self._client: Optional[MultiServerMCPClient] = None 

2113 self._tools: Optional[List[BaseTool]] = None 

2114 self._connected = False 

2115 logger.info(f"MCP client initialized with transport: {config.transport}") 

2116 

2117 async def connect(self) -> None: 

2118 """ 

2119 Connect to the MCP server. 

2120 

2121 Establishes connection to the configured MCP server using the specified 

2122 transport protocol. Subsequent calls are no-ops if already connected. 

2123 

2124 Raises: 

2125 ConnectionError: If connection to MCP server fails. 

2126 

2127 Examples: 

2128 >>> import asyncio 

2129 >>> config = MCPServerConfig( 

2130 ... url="https://example.com/mcp", 

2131 ... transport="streamable_http" 

2132 ... ) 

2133 >>> client = MCPClient(config) 

2134 >>> # asyncio.run(client.connect()) 

2135 >>> # client.is_connected -> True 

2136 

2137 Note: 

2138 Connection is idempotent - calling multiple times is safe. 

2139 """ 

2140 if self._connected: 

2141 logger.warning("MCP client already connected") 

2142 return 

2143 

2144 try: 

2145 logger.info(f"Connecting to MCP server via {self.config.transport}...") 

2146 

2147 # Build server configuration for MultiServerMCPClient 

2148 server_config = { 

2149 "transport": self.config.transport, 

2150 } 

2151 

2152 if self.config.transport in ["streamable_http", "sse"]: 

2153 server_config["url"] = self.config.url 

2154 if self.config.headers: 

2155 server_config["headers"] = self.config.headers 

2156 elif self.config.transport == "stdio": 

2157 server_config["command"] = self.config.command 

2158 if self.config.args: 

2159 server_config["args"] = self.config.args 

2160 

2161 if not MultiServerMCPClient: 

2162 logger.error("Some dependencies are missing. Install those with: pip install '.[llmchat]'") 

2163 

2164 # Create MultiServerMCPClient with single server 

2165 self._client = MultiServerMCPClient({"default": server_config}) 

2166 self._connected = True 

2167 logger.info("Successfully connected to MCP server") 

2168 

2169 except Exception as e: 

2170 logger.error(f"Failed to connect to MCP server: {e}") 

2171 self._connected = False 

2172 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e 

2173 

2174 async def disconnect(self) -> None: 

2175 """ 

2176 Disconnect from the MCP server. 

2177 

2178 Cleanly closes the connection and releases resources. Safe to call 

2179 even if not connected. 

2180 

2181 Raises: 

2182 Exception: If cleanup operations fail. 

2183 

2184 Examples: 

2185 >>> import asyncio 

2186 >>> config = MCPServerConfig( 

2187 ... url="https://example.com/mcp", 

2188 ... transport="streamable_http" 

2189 ... ) 

2190 >>> client = MCPClient(config) 

2191 >>> # asyncio.run(client.connect()) 

2192 >>> # asyncio.run(client.disconnect()) 

2193 >>> # client.is_connected -> False 

2194 

2195 Note: 

2196 Clears cached tools upon disconnection. 

2197 """ 

2198 if not self._connected: 

2199 logger.warning("MCP client not connected") 

2200 return 

2201 

2202 try: 

2203 if self._client: 

2204 # MultiServerMCPClient manages connections internally 

2205 self._client = None 

2206 

2207 self._connected = False 

2208 self._tools = None 

2209 logger.info("Disconnected from MCP server") 

2210 

2211 except Exception as e: 

2212 logger.error(f"Error during disconnect: {e}") 

2213 raise 

2214 

2215 async def get_tools(self, force_reload: bool = False) -> List[BaseTool]: 

2216 """ 

2217 Get tools from the MCP server. 

2218 

2219 Retrieves available tools from the connected MCP server. Results are 

2220 cached unless force_reload is True. 

2221 

2222 Args: 

2223 force_reload: Force reload tools even if cached (default: False). 

2224 

2225 Returns: 

2226 List[BaseTool]: List of available tools from the server. 

2227 

2228 Raises: 

2229 ConnectionError: If not connected to MCP server. 

2230 Exception: If tool loading fails. 

2231 

2232 Examples: 

2233 >>> import asyncio 

2234 >>> config = MCPServerConfig( 

2235 ... url="https://example.com/mcp", 

2236 ... transport="streamable_http" 

2237 ... ) 

2238 >>> client = MCPClient(config) 

2239 >>> # asyncio.run(client.connect()) 

2240 >>> # tools = asyncio.run(client.get_tools()) 

2241 >>> # len(tools) >= 0 -> True 

2242 

2243 Note: 

2244 Tools are cached after first successful load for performance. 

2245 """ 

2246 if not self._connected or not self._client: 

2247 raise ConnectionError("Not connected to MCP server. Call connect() first.") 

2248 

2249 if self._tools and not force_reload: 

2250 logger.debug(f"Returning {len(self._tools)} cached tools") 

2251 return self._tools 

2252 

2253 try: 

2254 logger.info("Loading tools from MCP server...") 

2255 self._tools = await self._client.get_tools() 

2256 logger.info(f"Successfully loaded {len(self._tools)} tools") 

2257 return self._tools 

2258 

2259 except Exception as e: 

2260 logger.error(f"Failed to load tools: {e}") 

2261 raise 

2262 

2263 @property 

2264 def is_connected(self) -> bool: 

2265 """ 

2266 Check if client is connected. 

2267 

2268 Returns: 

2269 bool: True if connected to MCP server, False otherwise. 

2270 

2271 Examples: 

2272 >>> config = MCPServerConfig( 

2273 ... url="https://example.com/mcp", 

2274 ... transport="streamable_http" 

2275 ... ) 

2276 >>> client = MCPClient(config) 

2277 >>> client.is_connected 

2278 False 

2279 """ 

2280 return self._connected 

2281 

2282 

2283# ==================== MCP CHAT SERVICE ==================== 

2284 

2285 

2286class MCPChatService: 

2287 """ 

2288 Main chat service for MCP client backend. 

2289 Orchestrates chat sessions with LLM and MCP server integration. 

2290 

2291 Provides a high-level interface for managing conversational AI sessions 

2292 that combine LLM capabilities with MCP server tools. Handles conversation 

2293 history management, tool execution, and streaming responses. 

2294 

2295 This service integrates: 

2296 - LLM providers (Azure OpenAI, OpenAI, Anthropic, AWS Bedrock, Ollama) 

2297 - MCP server tools 

2298 - Centralized chat history management (Redis or in-memory) 

2299 - Streaming and non-streaming response modes 

2300 

2301 Attributes: 

2302 config: Complete MCP client configuration. 

2303 user_id: Optional user identifier for history management. 

2304 

2305 Examples: 

2306 >>> import asyncio 

2307 >>> config = MCPClientConfig( 

2308 ... mcp_server=MCPServerConfig( 

2309 ... url="https://example.com/mcp", 

2310 ... transport="streamable_http" 

2311 ... ), 

2312 ... llm=LLMConfig( 

2313 ... provider="ollama", 

2314 ... config=OllamaConfig(model="llama2") 

2315 ... ) 

2316 ... ) 

2317 >>> service = MCPChatService(config) 

2318 >>> service.is_initialized 

2319 False 

2320 >>> # asyncio.run(service.initialize()) 

2321 

2322 Note: 

2323 Must call initialize() before using chat methods. 

2324 """ 

2325 

2326 def __init__(self, config: MCPClientConfig, user_id: Optional[str] = None, redis_client: Optional[Any] = None): 

2327 """ 

2328 Initialize MCP chat service. 

2329 

2330 Args: 

2331 config: Complete MCP client configuration. 

2332 user_id: Optional user identifier for chat history management. 

2333 redis_client: Optional Redis client for distributed history storage. 

2334 

2335 Examples: 

2336 >>> config = MCPClientConfig( 

2337 ... mcp_server=MCPServerConfig( 

2338 ... url="https://example.com/mcp", 

2339 ... transport="streamable_http" 

2340 ... ), 

2341 ... llm=LLMConfig( 

2342 ... provider="ollama", 

2343 ... config=OllamaConfig(model="llama2") 

2344 ... ) 

2345 ... ) 

2346 >>> service = MCPChatService(config, user_id="user123") 

2347 >>> service.user_id 

2348 'user123' 

2349 """ 

2350 self.config = config 

2351 self.user_id = user_id 

2352 self.mcp_client = MCPClient(config.mcp_server) 

2353 self.llm_provider = LLMProviderFactory.create(config.llm) 

2354 

2355 # Initialize centralized chat history manager 

2356 self.history_manager = ChatHistoryManager(redis_client=redis_client, max_messages=config.chat_history_max_messages, ttl=settings.llmchat_chat_history_ttl) 

2357 

2358 self._agent = None 

2359 self._initialized = False 

2360 self._tools: List[BaseTool] = [] 

2361 

2362 logger.info(f"MCPChatService initialized for user: {user_id or 'anonymous'}") 

2363 

2364 async def initialize(self) -> None: 

2365 """ 

2366 Initialize the chat service. 

2367 

2368 Connects to MCP server, loads tools, initializes LLM, and creates the 

2369 conversational agent. Must be called before using chat functionality. 

2370 

2371 Raises: 

2372 ConnectionError: If MCP server connection fails. 

2373 Exception: If initialization fails. 

2374 

2375 Examples: 

2376 >>> import asyncio 

2377 >>> config = MCPClientConfig( 

2378 ... mcp_server=MCPServerConfig( 

2379 ... url="https://example.com/mcp", 

2380 ... transport="streamable_http" 

2381 ... ), 

2382 ... llm=LLMConfig( 

2383 ... provider="ollama", 

2384 ... config=OllamaConfig(model="llama2") 

2385 ... ) 

2386 ... ) 

2387 >>> service = MCPChatService(config) 

2388 >>> # asyncio.run(service.initialize()) 

2389 >>> # service.is_initialized -> True 

2390 

2391 Note: 

2392 Automatically loads tools from MCP server and creates agent. 

2393 """ 

2394 if self._initialized: 

2395 logger.warning("Chat service already initialized") 

2396 return 

2397 

2398 try: 

2399 logger.info("Initializing chat service...") 

2400 

2401 # Connect to MCP server and load tools 

2402 await self.mcp_client.connect() 

2403 self._tools = await self.mcp_client.get_tools() 

2404 

2405 # Create LLM instance 

2406 llm = self.llm_provider.get_llm() 

2407 

2408 # Create ReAct agent with tools 

2409 self._agent = create_react_agent(llm, self._tools) 

2410 

2411 self._initialized = True 

2412 logger.info(f"Chat service initialized successfully with {len(self._tools)} tools") 

2413 

2414 except Exception as e: 

2415 logger.error(f"Failed to initialize chat service: {e}") 

2416 self._initialized = False 

2417 raise 

2418 

2419 async def chat(self, message: str) -> str: 

2420 """ 

2421 Send a message and get a complete response. 

2422 

2423 Processes the user's message through the LLM with tool access, 

2424 manages conversation history, and returns the complete response. 

2425 

2426 Args: 

2427 message: User's message text. 

2428 

2429 Returns: 

2430 str: Complete AI response text. 

2431 

2432 Raises: 

2433 RuntimeError: If service not initialized. 

2434 ValueError: If message is empty. 

2435 Exception: If processing fails. 

2436 

2437 Examples: 

2438 >>> import asyncio 

2439 >>> # Assuming service is initialized 

2440 >>> # response = asyncio.run(service.chat("Hello!")) 

2441 >>> # isinstance(response, str) 

2442 True 

2443 

2444 Note: 

2445 Automatically saves conversation history after response. 

2446 """ 

2447 if not self._initialized or not self._agent: 

2448 raise RuntimeError("Chat service not initialized. Call initialize() first.") 

2449 

2450 if not message or not message.strip(): 

2451 raise ValueError("Message cannot be empty") 

2452 

2453 try: 

2454 logger.debug("Processing chat message...") 

2455 

2456 # Get conversation history from manager 

2457 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else [] 

2458 

2459 # Add user message 

2460 user_message = HumanMessage(content=message) 

2461 lc_messages.append(user_message) 

2462 

2463 # Invoke agent 

2464 response = await self._agent.ainvoke({"messages": lc_messages}) 

2465 

2466 # Extract AI response 

2467 ai_message = response["messages"][-1] 

2468 response_text = ai_message.content if hasattr(ai_message, "content") else str(ai_message) 

2469 

2470 # Save history if user_id provided 

2471 if self.user_id: 

2472 await self.history_manager.append_message(self.user_id, "user", message) 

2473 await self.history_manager.append_message(self.user_id, "assistant", response_text) 

2474 

2475 logger.debug("Chat message processed successfully") 

2476 return response_text 

2477 

2478 except Exception as e: 

2479 logger.error(f"Error processing chat message: {e}") 

2480 raise 

2481 

2482 async def chat_with_metadata(self, message: str) -> Dict[str, Any]: 

2483 """ 

2484 Send a message and get response with metadata. 

2485 

2486 Similar to chat() but collects all events and returns detailed 

2487 information about tool usage and timing. 

2488 

2489 Args: 

2490 message: User's message text. 

2491 

2492 Returns: 

2493 Dict[str, Any]: Dictionary containing: 

2494 - text (str): Complete response text 

2495 - tool_used (bool): Whether any tools were invoked 

2496 - tools (List[str]): Names of tools that were used 

2497 - tool_invocations (List[dict]): Detailed tool invocation data 

2498 - elapsed_ms (int): Processing time in milliseconds 

2499 

2500 Raises: 

2501 RuntimeError: If service not initialized. 

2502 ValueError: If message is empty. 

2503 

2504 Examples: 

2505 >>> import asyncio 

2506 >>> # Assuming service is initialized 

2507 >>> # result = asyncio.run(service.chat_with_metadata("What's 2+2?")) 

2508 >>> # 'text' in result and 'elapsed_ms' in result 

2509 True 

2510 

2511 Note: 

2512 This method collects all events and returns them as a single response. 

2513 """ 

2514 text = "" 

2515 tool_invocations: list[dict[str, Any]] = [] 

2516 final: dict[str, Any] = {} 

2517 

2518 async for ev in self.chat_events(message): 

2519 t = ev.get("type") 

2520 if t == "token": 

2521 text += ev.get("content", "") 

2522 elif t in ("tool_start", "tool_end", "tool_error"): 

2523 tool_invocations.append(ev) 

2524 elif t == "final": 

2525 final = ev 

2526 

2527 return { 

2528 "text": text, 

2529 "tool_used": final.get("tool_used", False), 

2530 "tools": final.get("tools", []), 

2531 "tool_invocations": tool_invocations, 

2532 "elapsed_ms": final.get("elapsed_ms"), 

2533 } 

2534 

2535 async def chat_stream(self, message: str) -> AsyncGenerator[str, None]: 

2536 """ 

2537 Send a message and stream the response. 

2538 

2539 Yields response chunks as they're generated, enabling real-time display 

2540 of the AI's response. 

2541 

2542 Args: 

2543 message: User's message text. 

2544 

2545 Yields: 

2546 str: Chunks of AI response text. 

2547 

2548 Raises: 

2549 RuntimeError: If service not initialized. 

2550 Exception: If streaming fails. 

2551 

2552 Examples: 

2553 >>> import asyncio 

2554 >>> async def stream_example(): 

2555 ... # Assuming service is initialized 

2556 ... chunks = [] 

2557 ... async for chunk in service.chat_stream("Hello"): 

2558 ... chunks.append(chunk) 

2559 ... return ''.join(chunks) 

2560 >>> # full_response = asyncio.run(stream_example()) 

2561 

2562 Note: 

2563 Falls back to non-streaming if enable_streaming is False in config. 

2564 """ 

2565 if not self._initialized or not self._agent: 

2566 raise RuntimeError("Chat service not initialized. Call initialize() first.") 

2567 

2568 if not self.config.enable_streaming: 

2569 # Fall back to non-streaming 

2570 response = await self.chat(message) 

2571 yield response 

2572 return 

2573 

2574 try: 

2575 logger.debug("Processing streaming chat message...") 

2576 

2577 # Get conversation history 

2578 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else [] 

2579 

2580 # Add user message 

2581 user_message = HumanMessage(content=message) 

2582 lc_messages.append(user_message) 

2583 

2584 # Stream agent response 

2585 full_response = "" 

2586 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"): 

2587 kind = event["event"] 

2588 

2589 # Stream LLM tokens 

2590 if kind == "on_chat_model_stream": 

2591 chunk = event.get("data", {}).get("chunk") 

2592 if chunk and hasattr(chunk, "content"): 

2593 content = chunk.content 

2594 if content: 

2595 full_response += content 

2596 yield content 

2597 

2598 # Save history 

2599 if self.user_id and full_response: 

2600 await self.history_manager.append_message(self.user_id, "user", message) 

2601 await self.history_manager.append_message(self.user_id, "assistant", full_response) 

2602 

2603 logger.debug("Streaming chat message processed successfully") 

2604 

2605 except Exception as e: 

2606 logger.error(f"Error processing streaming chat message: {e}") 

2607 raise 

2608 

2609 async def chat_events(self, message: str) -> AsyncGenerator[Dict[str, Any], None]: 

2610 """ 

2611 Stream structured events during chat processing. 

2612 

2613 Provides granular visibility into the chat processing pipeline by yielding 

2614 structured events for tokens, tool invocations, errors, and final results. 

2615 

2616 Args: 

2617 message: User's message text. 

2618 

2619 Yields: 

2620 dict: Event dictionaries with type-specific fields: 

2621 - token: {"type": "token", "content": str} 

2622 - tool_start: {"type": "tool_start", "id": str, "name": str, 

2623 "input": Any, "start": str} 

2624 - tool_end: {"type": "tool_end", "id": str, "name": str, 

2625 "output": Any, "end": str} 

2626 - tool_error: {"type": "tool_error", "id": str, "error": str, 

2627 "time": str} 

2628 - final: {"type": "final", "content": str, "tool_used": bool, 

2629 "tools": List[str], "elapsed_ms": int} 

2630 

2631 Raises: 

2632 RuntimeError: If service not initialized. 

2633 ValueError: If message is empty or whitespace only. 

2634 

2635 Examples: 

2636 >>> import asyncio 

2637 >>> async def event_example(): 

2638 ... # Assuming service is initialized 

2639 ... events = [] 

2640 ... async for event in service.chat_events("Hello"): 

2641 ... events.append(event['type']) 

2642 ... return events 

2643 >>> # event_types = asyncio.run(event_example()) 

2644 >>> # 'final' in event_types -> True 

2645 

2646 Note: 

2647 This is the most detailed chat method, suitable for building 

2648 interactive UIs or detailed logging systems. 

2649 """ 

2650 if not self._initialized or not self._agent: 

2651 raise RuntimeError("Chat service not initialized. Call initialize() first.") 

2652 

2653 # Validate message 

2654 if not message or not message.strip(): 

2655 raise ValueError("Message cannot be empty") 

2656 

2657 # Get conversation history 

2658 lc_messages = await self.history_manager.get_langchain_messages(self.user_id) if self.user_id else [] 

2659 

2660 # Append user message 

2661 user_message = HumanMessage(content=message) 

2662 lc_messages.append(user_message) 

2663 

2664 full_response = "" 

2665 start_ts = time.time() 

2666 tool_runs: dict[str, dict[str, Any]] = {} 

2667 # Buffer for out-of-order on_tool_end events (end arrives before start) 

2668 pending_tool_ends: dict[str, dict[str, Any]] = {} 

2669 pending_ttl_seconds = 30.0 # Max time to hold pending end events 

2670 pending_max_size = 100 # Max number of pending end events to buffer 

2671 # Track dropped run_ids for aggregated error (TTL-expired or buffer-full) 

2672 dropped_tool_ends: set[str] = set() 

2673 dropped_max_size = 200 # Max dropped IDs to track (prevents unbounded growth) 

2674 dropped_overflow_count = 0 # Count of drops that couldn't be tracked due to full buffer 

2675 

2676 def _extract_output(raw_output: Any) -> Any: 

2677 """Extract output value from various LangChain output formats. 

2678 

2679 Args: 

2680 raw_output: The raw output from a tool execution. 

2681 

2682 Returns: 

2683 The extracted output value in a serializable format. 

2684 """ 

2685 if hasattr(raw_output, "content"): 

2686 return raw_output.content 

2687 if hasattr(raw_output, "dict") and callable(raw_output.dict): 

2688 return raw_output.dict() 

2689 if not isinstance(raw_output, (str, int, float, bool, list, dict, type(None))): 

2690 return str(raw_output) 

2691 return raw_output 

2692 

2693 def _cleanup_expired_pending(current_ts: float) -> None: 

2694 """Remove expired entries from pending_tool_ends buffer and track them. 

2695 

2696 Args: 

2697 current_ts: Current timestamp in seconds since epoch. 

2698 """ 

2699 nonlocal dropped_overflow_count 

2700 expired = [rid for rid, data in pending_tool_ends.items() if current_ts - data.get("buffered_at", 0) > pending_ttl_seconds] 

2701 for rid in expired: 

2702 logger.warning(f"Pending on_tool_end for run_id {rid} expired after {pending_ttl_seconds}s (orphan event)") 

2703 if len(dropped_tool_ends) < dropped_max_size: 

2704 dropped_tool_ends.add(rid) 

2705 else: 

2706 dropped_overflow_count += 1 

2707 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track expired run_id {rid} (overflow count: {dropped_overflow_count})") 

2708 del pending_tool_ends[rid] 

2709 

2710 try: 

2711 async for event in self._agent.astream_events({"messages": lc_messages}, version="v2"): 

2712 kind = event.get("event") 

2713 now_iso = datetime.now(timezone.utc).isoformat() 

2714 now_ts = time.time() 

2715 

2716 # Periodically cleanup expired pending ends 

2717 _cleanup_expired_pending(now_ts) 

2718 

2719 try: 

2720 if kind == "on_tool_start": 

2721 run_id = str(event.get("run_id") or uuid4()) 

2722 name = event.get("name") or event.get("data", {}).get("name") or event.get("data", {}).get("tool") 

2723 input_data = event.get("data", {}).get("input") 

2724 

2725 # Filter out common metadata keys injected by LangChain/LangGraph 

2726 if isinstance(input_data, dict): 

2727 input_data = {k: v for k, v in input_data.items() if k not in ["runtime", "config", "run_manager", "callbacks"]} 

2728 

2729 tool_runs[run_id] = {"name": name, "start": now_iso, "input": input_data} 

2730 

2731 # Register run for cancellation tracking with gateway-level Cancellation service 

2732 async def _noop_cancel_cb(reason: Optional[str]) -> None: 

2733 """ 

2734 No-op cancel callback used when a run is started. 

2735 

2736 Args: 

2737 reason: Optional textual reason for cancellation. 

2738 

2739 Returns: 

2740 None 

2741 """ 

2742 # Default no-op; kept for potential future intra-process cancellation 

2743 return None 

2744 

2745 # Register with cancellation service only if feature is enabled 

2746 if settings.mcpgateway_tool_cancellation_enabled: 

2747 try: 

2748 await cancellation_service.register_run(run_id, name=name, cancel_callback=_noop_cancel_cb) 

2749 except Exception: 

2750 logger.exception("Failed to register run %s with CancellationService", run_id) 

2751 

2752 yield {"type": "tool_start", "id": run_id, "tool": name, "input": input_data, "start": now_iso} 

2753 

2754 # NOTE: Do NOT clear from dropped_tool_ends here. If an end was dropped (TTL/buffer-full) 

2755 # before this start arrived, that end is permanently lost. Since tools only end once, 

2756 # we won't receive another end event, so this should still be reported as an orphan. 

2757 

2758 # Check if we have a buffered end event for this run_id (out-of-order reconciliation) 

2759 if run_id in pending_tool_ends: 

2760 buffered = pending_tool_ends.pop(run_id) 

2761 tool_runs[run_id]["end"] = buffered["end_time"] 

2762 tool_runs[run_id]["output"] = buffered["output"] 

2763 logger.info(f"Reconciled out-of-order on_tool_end for run_id {run_id}") 

2764 

2765 if tool_runs[run_id].get("output") == "": 

2766 error = "Tool execution failed: Please check if the tool is accessible" 

2767 yield {"type": "tool_error", "id": run_id, "tool": name, "error": error, "time": buffered["end_time"]} 

2768 

2769 yield {"type": "tool_end", "id": run_id, "tool": name, "output": tool_runs[run_id].get("output"), "end": buffered["end_time"]} 

2770 

2771 elif kind == "on_tool_end": 

2772 run_id = str(event.get("run_id") or uuid4()) 

2773 output = event.get("data", {}).get("output") 

2774 extracted_output = _extract_output(output) 

2775 

2776 if run_id in tool_runs: 

2777 # Normal case: start already received 

2778 tool_runs[run_id]["end"] = now_iso 

2779 tool_runs[run_id]["output"] = extracted_output 

2780 

2781 if tool_runs[run_id].get("output") == "": 

2782 error = "Tool execution failed: Please check if the tool is accessible" 

2783 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso} 

2784 

2785 yield {"type": "tool_end", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "output": tool_runs[run_id].get("output"), "end": now_iso} 

2786 else: 

2787 # Out-of-order: buffer the end event for later reconciliation 

2788 if len(pending_tool_ends) < pending_max_size: 

2789 pending_tool_ends[run_id] = {"output": extracted_output, "end_time": now_iso, "buffered_at": now_ts} 

2790 logger.debug(f"Buffered out-of-order on_tool_end for run_id {run_id}, awaiting on_tool_start") 

2791 else: 

2792 logger.warning(f"Pending tool ends buffer full ({pending_max_size}), dropping on_tool_end for run_id {run_id}") 

2793 if len(dropped_tool_ends) < dropped_max_size: 

2794 dropped_tool_ends.add(run_id) 

2795 else: 

2796 dropped_overflow_count += 1 

2797 logger.warning(f"Dropped tool ends tracking full ({dropped_max_size}), cannot track run_id {run_id} (overflow count: {dropped_overflow_count})") 

2798 

2799 # Unregister run from cancellation service when finished (only if feature is enabled) 

2800 if settings.mcpgateway_tool_cancellation_enabled: 

2801 try: 

2802 await cancellation_service.unregister_run(run_id) 

2803 except Exception: 

2804 logger.exception("Failed to unregister run %s", run_id) 

2805 

2806 elif kind == "on_tool_error": 

2807 run_id = str(event.get("run_id") or uuid4()) 

2808 error = str(event.get("data", {}).get("error", "Unknown error")) 

2809 

2810 # Clear any buffered end for this run to avoid emitting both error and end 

2811 if run_id in pending_tool_ends: 

2812 del pending_tool_ends[run_id] 

2813 logger.debug(f"Cleared buffered on_tool_end for run_id {run_id} due to tool error") 

2814 

2815 # Clear from dropped set if this run was previously dropped (prevents false orphan) 

2816 dropped_tool_ends.discard(run_id) 

2817 

2818 yield {"type": "tool_error", "id": run_id, "tool": tool_runs.get(run_id, {}).get("name"), "error": error, "time": now_iso} 

2819 

2820 # Unregister run on error (only if feature is enabled) 

2821 if settings.mcpgateway_tool_cancellation_enabled: 

2822 try: 

2823 await cancellation_service.unregister_run(run_id) 

2824 except Exception: 

2825 logger.exception("Failed to unregister run %s after error", run_id) 

2826 

2827 elif kind == "on_chat_model_stream": 

2828 chunk = event.get("data", {}).get("chunk") 

2829 if chunk and hasattr(chunk, "content"): 

2830 content = chunk.content 

2831 if content: 

2832 full_response += content 

2833 yield {"type": "token", "content": content} 

2834 

2835 except Exception as event_error: 

2836 logger.warning(f"Error processing event {kind}: {event_error}") 

2837 continue 

2838 

2839 # Emit aggregated error for any orphan/dropped tool ends 

2840 # De-duplicate IDs (in case same ID was buffered and dropped in edge cases) 

2841 all_orphan_ids = sorted(set(pending_tool_ends.keys()) | dropped_tool_ends) 

2842 if all_orphan_ids or dropped_overflow_count > 0: 

2843 buffered_count = len(pending_tool_ends) 

2844 dropped_count = len(dropped_tool_ends) 

2845 total_unique = len(all_orphan_ids) 

2846 total_affected = total_unique + dropped_overflow_count 

2847 logger.warning( 

2848 f"Stream completed with {total_affected} orphan tool end(s): {buffered_count} buffered, {dropped_count} dropped (tracked), {dropped_overflow_count} dropped (untracked overflow)" 

2849 ) 

2850 # Log full list at debug level for observability 

2851 if all_orphan_ids: 

2852 logger.debug(f"Full orphan run_id list: {', '.join(all_orphan_ids)}") 

2853 now_iso = datetime.now(timezone.utc).isoformat() 

2854 error_parts = [] 

2855 if buffered_count > 0: 

2856 error_parts.append(f"{buffered_count} buffered") 

2857 if dropped_count > 0: 

2858 error_parts.append(f"{dropped_count} dropped (TTL expired or buffer full)") 

2859 if dropped_overflow_count > 0: 

2860 error_parts.append(f"{dropped_overflow_count} additional dropped (tracking overflow)") 

2861 error_msg = f"Tool execution incomplete: {total_affected} tool end(s) received without matching start ({', '.join(error_parts)})" 

2862 # Truncate to first 10 IDs in error message to avoid excessive payload 

2863 if all_orphan_ids: 

2864 max_display_ids = 10 

2865 display_ids = all_orphan_ids[:max_display_ids] 

2866 remaining = total_unique - len(display_ids) 

2867 if remaining > 0: 

2868 error_msg += f". Run IDs (first {max_display_ids} of {total_unique}): {', '.join(display_ids)} (+{remaining} more)" 

2869 else: 

2870 error_msg += f". Run IDs: {', '.join(display_ids)}" 

2871 yield { 

2872 "type": "tool_error", 

2873 "id": str(uuid4()), 

2874 "tool": None, 

2875 "error": error_msg, 

2876 "time": now_iso, 

2877 } 

2878 pending_tool_ends.clear() 

2879 dropped_tool_ends.clear() 

2880 

2881 # Calculate elapsed time 

2882 elapsed_ms = int((time.time() - start_ts) * 1000) 

2883 

2884 # Determine tool usage 

2885 tools_used = list({tr["name"] for tr in tool_runs.values() if tr.get("name")}) 

2886 

2887 # Yield final event 

2888 yield {"type": "final", "content": full_response, "tool_used": len(tools_used) > 0, "tools": tools_used, "elapsed_ms": elapsed_ms} 

2889 

2890 # Save history 

2891 if self.user_id and full_response: 

2892 await self.history_manager.append_message(self.user_id, "user", message) 

2893 await self.history_manager.append_message(self.user_id, "assistant", full_response) 

2894 

2895 except Exception as e: 

2896 logger.error(f"Error in chat_events: {e}") 

2897 raise RuntimeError(f"Chat processing error: {e}") from e 

2898 

2899 async def get_conversation_history(self) -> List[Dict[str, str]]: 

2900 """ 

2901 Get conversation history for the current user. 

2902 

2903 Returns: 

2904 List[Dict[str, str]]: Conversation messages with keys: 

2905 - role (str): "user" or "assistant" 

2906 - content (str): Message text 

2907 

2908 Examples: 

2909 >>> import asyncio 

2910 >>> # Assuming service is initialized with user_id 

2911 >>> # history = asyncio.run(service.get_conversation_history()) 

2912 >>> # all('role' in msg and 'content' in msg for msg in history) 

2913 True 

2914 

2915 Note: 

2916 Returns empty list if no user_id set or no history exists. 

2917 """ 

2918 if not self.user_id: 

2919 return [] 

2920 

2921 return await self.history_manager.get_history(self.user_id) 

2922 

2923 async def clear_history(self) -> None: 

2924 """ 

2925 Clear conversation history for the current user. 

2926 

2927 Removes all messages from the conversation history. Useful for starting 

2928 fresh conversations or managing memory usage. 

2929 

2930 Examples: 

2931 >>> import asyncio 

2932 >>> # Assuming service is initialized with user_id 

2933 >>> # asyncio.run(service.clear_history()) 

2934 >>> # history = asyncio.run(service.get_conversation_history()) 

2935 >>> # len(history) -> 0 

2936 

2937 Note: 

2938 This action cannot be undone. No-op if no user_id set. 

2939 """ 

2940 if not self.user_id: 

2941 return 

2942 

2943 await self.history_manager.clear_history(self.user_id) 

2944 logger.info(f"Conversation history cleared for user {self.user_id}") 

2945 

2946 async def shutdown(self) -> None: 

2947 """ 

2948 Shutdown the chat service and cleanup resources. 

2949 

2950 Performs graceful shutdown by disconnecting from MCP server, clearing 

2951 agent and history, and resetting initialization state. 

2952 

2953 Raises: 

2954 Exception: If cleanup operations fail. 

2955 

2956 Examples: 

2957 >>> import asyncio 

2958 >>> config = MCPClientConfig( 

2959 ... mcp_server=MCPServerConfig( 

2960 ... url="https://example.com/mcp", 

2961 ... transport="streamable_http" 

2962 ... ), 

2963 ... llm=LLMConfig( 

2964 ... provider="ollama", 

2965 ... config=OllamaConfig(model="llama2") 

2966 ... ) 

2967 ... ) 

2968 >>> service = MCPChatService(config) 

2969 >>> # asyncio.run(service.initialize()) 

2970 >>> # asyncio.run(service.shutdown()) 

2971 >>> # service.is_initialized -> False 

2972 

2973 Note: 

2974 Should be called when service is no longer needed to properly 

2975 release resources and connections. 

2976 """ 

2977 logger.info("Shutting down chat service...") 

2978 

2979 try: 

2980 # Disconnect from MCP server 

2981 if self.mcp_client.is_connected: 

2982 await self.mcp_client.disconnect() 

2983 

2984 # Clear state 

2985 self._agent = None 

2986 self._initialized = False 

2987 self._tools = [] 

2988 

2989 logger.info("Chat service shutdown complete") 

2990 

2991 except Exception as e: 

2992 logger.error(f"Error during shutdown: {e}") 

2993 raise 

2994 

2995 @property 

2996 def is_initialized(self) -> bool: 

2997 """ 

2998 Check if service is initialized. 

2999 

3000 Returns: 

3001 bool: True if service is initialized and ready, False otherwise. 

3002 

3003 Examples: 

3004 >>> config = MCPClientConfig( 

3005 ... mcp_server=MCPServerConfig( 

3006 ... url="https://example.com/mcp", 

3007 ... transport="streamable_http" 

3008 ... ), 

3009 ... llm=LLMConfig( 

3010 ... provider="ollama", 

3011 ... config=OllamaConfig(model="llama2") 

3012 ... ) 

3013 ... ) 

3014 >>> service = MCPChatService(config) 

3015 >>> service.is_initialized 

3016 False 

3017 

3018 Note: 

3019 Service must be initialized before calling chat methods. 

3020 """ 

3021 return self._initialized 

3022 

3023 async def reload_tools(self) -> int: 

3024 """ 

3025 Reload tools from MCP server. 

3026 

3027 Forces a reload of tools from the MCP server and recreates the agent 

3028 with the updated tool set. Useful when MCP server tools have changed. 

3029 

3030 Returns: 

3031 int: Number of tools successfully loaded. 

3032 

3033 Raises: 

3034 RuntimeError: If service not initialized. 

3035 Exception: If tool reloading or agent recreation fails. 

3036 

3037 Examples: 

3038 >>> import asyncio 

3039 >>> # Assuming service is initialized 

3040 >>> # tool_count = asyncio.run(service.reload_tools()) 

3041 >>> # tool_count >= 0 -> True 

3042 

3043 Note: 

3044 This operation recreates the agent, so it may briefly interrupt 

3045 ongoing conversations. Conversation history is preserved. 

3046 """ 

3047 if not self._initialized: 

3048 raise RuntimeError("Chat service not initialized") 

3049 

3050 try: 

3051 logger.info("Reloading tools from MCP server...") 

3052 tools = await self.mcp_client.get_tools(force_reload=True) 

3053 

3054 # Recreate agent with new tools 

3055 llm = self.llm_provider.get_llm() 

3056 self._agent = create_react_agent(llm, tools) 

3057 self._tools = tools 

3058 

3059 logger.info(f"Reloaded {len(tools)} tools successfully") 

3060 return len(tools) 

3061 

3062 except Exception as e: 

3063 logger.error(f"Failed to reload tools: {e}") 

3064 raise