Coverage for mcpgateway / toolops / utils / llm_util.py: 100%
85 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/toolops/utils/llm_util.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Jay Bandlamudi
7MCP Gateway - Main module for using and supporting MCP-CF LLM providers in toolops modules.
9This module defines the utility funtions to use MCP-CF supported LLM providers in toolops.
10"""
11# Standard
12import os
14# Third-Party
15from dotenv import load_dotenv
16import orjson
18# First-Party
19from mcpgateway.services.logging_service import LoggingService
20from mcpgateway.services.mcp_client_chat_service import (
21 AnthropicConfig,
22 AnthropicProvider,
23 AWSBedrockConfig,
24 AWSBedrockProvider,
25 AzureOpenAIConfig,
26 AzureOpenAIProvider,
27 OllamaConfig,
28 OllamaProvider,
29 OpenAIConfig,
30 OpenAIProvider,
31 WatsonxConfig,
32 WatsonxProvider,
33)
35logging_service = LoggingService()
36logger = logging_service.get_logger(__name__)
38load_dotenv()
40# set LLM temperature for toolops modules as low to produce minimally variable model outputs.
41TOOLOPS_TEMPERATURE = 0.1
44def get_llm_instance(model_type="completion"):
45 """
46 Method to get MCP-CF provider type llm instance based on model type
48 Args:
49 model_type : LLM instance type such as chat model or token completion model, accepted values: 'completion', 'chat'
51 Returns:
52 llm_instance : LLM model instance used for inferencing the prompts/user inputs
53 llm_config: LLM provider configuration provided in the environment variables
55 Examples:
56 >>> import os
57 >>> from unittest.mock import patch, MagicMock
58 >>> # Setup: Define the global variable used in the function for the test context
59 >>> global TOOLOPS_TEMPERATURE
60 >>> TOOLOPS_TEMPERATURE = 0.7
62 >>> # Case 1: OpenAI Provider Configuration
63 >>> # We patch os.environ to simulate specific provider settings
64 >>> env_vars = {
65 ... "LLM_PROVIDER": "openai",
66 ... "OPENAI_API_KEY": "sk-mock-key",
67 ... "OPENAI_BASE_URL": "https://api.openai.com",
68 ... "OPENAI_MODEL": "gpt-4"
69 ... }
70 >>> with patch.dict(os.environ, env_vars):
71 ... # Assuming OpenAIProvider and OpenAIConfig are available in the module scope
72 ... # We simulate the function call. Note: This tests the Config creation logic.
73 ... llm_instance, llm_config = get_llm_instance("completion")
74 ... llm_config.__class__.__name__
75 'OpenAIConfig'
77 >>> # Case 2: Azure OpenAI Provider Configuration
78 >>> env_vars = {
79 ... "LLM_PROVIDER": "azure_openai",
80 ... "AZURE_OPENAI_API_KEY": "az-mock-key",
81 ... "AZURE_OPENAI_ENDPOINT": "https://mock.azure.com",
82 ... "AZURE_OPENAI_MODEL": "gpt-35-turbo"
83 ... }
84 >>> with patch.dict(os.environ, env_vars):
85 ... llm_instance, llm_config = get_llm_instance("chat")
86 ... llm_config.__class__.__name__
87 'AzureOpenAIConfig'
89 >>> # Case 3: AWS Bedrock Provider Configuration
90 >>> env_vars = {
91 ... "LLM_PROVIDER": "aws_bedrock",
92 ... "AWS_BEDROCK_MODEL_ID": "anthropic.claude-v2",
93 ... "AWS_BEDROCK_REGION": "us-east-1",
94 ... "AWS_ACCESS_KEY_ID": "mock-access",
95 ... "AWS_SECRET_ACCESS_KEY": "mock-secret"
96 ... }
97 >>> with patch.dict(os.environ, env_vars):
98 ... llm_instance, llm_config = get_llm_instance("chat")
99 ... llm_config.__class__.__name__
100 'AWSBedrockConfig'
102 >>> # Case 4: WatsonX Provider Configuration
103 >>> env_vars = {
104 ... "LLM_PROVIDER": "watsonx",
105 ... "WATSONX_APIKEY": "wx-mock-key",
106 ... "WATSONX_URL": "https://us-south.ml.cloud.ibm.com",
107 ... "WATSONX_PROJECT_ID": "mock-project-id",
108 ... "WATSONX_MODEL_ID": "ibm/granite-13b"
109 ... }
110 >>> with patch.dict(os.environ, env_vars):
111 ... llm_instance, llm_config = get_llm_instance("completion")
112 ... llm_config.__class__.__name__
113 'WatsonxConfig'
114 """
115 llm_provider = os.getenv("LLM_PROVIDER", "")
116 llm_instance, llm_config = None, None
117 logger.info("Configuring LLM instance for ToolOps , and LLM provider - " + llm_provider)
118 try:
119 provider_map = {
120 "azure_openai": AzureOpenAIProvider,
121 "openai": OpenAIProvider,
122 "anthropic": AnthropicProvider,
123 "aws_bedrock": AWSBedrockProvider,
124 "ollama": OllamaProvider,
125 "watsonx": WatsonxProvider,
126 }
127 provider_class = provider_map.get(llm_provider)
129 # getting LLM configs from environment variables
130 llm_config = None
131 if llm_provider == "openai":
132 oai_api_key = os.getenv("OPENAI_API_KEY", "")
133 oai_base_url = os.getenv("OPENAI_BASE_URL", "")
134 oai_model = os.getenv("OPENAI_MODEL", "")
135 # oai_temperature= float(os.getenv("OPENAI_TEMPERATURE","0.7"))
136 oai_temperature = TOOLOPS_TEMPERATURE
137 oai_max_retries = int(os.getenv("OPENAI_MAX_RETRIES", "2"))
138 oai_max_tokens = int(os.getenv("OPENAI_MAX_TOEKNS", "600"))
139 # adding default headers for RITS LLM platform as required
140 if isinstance(oai_base_url, str) and "rits.fmaas.res.ibm.com" in oai_base_url:
141 default_headers = {"RITS_API_KEY": oai_api_key}
142 else:
143 default_headers = None
144 llm_config = OpenAIConfig(
145 api_key=oai_api_key,
146 base_url=oai_base_url,
147 temperature=oai_temperature,
148 model=oai_model,
149 max_tokens=oai_max_tokens,
150 max_retries=oai_max_retries,
151 default_headers=default_headers,
152 timeout=None,
153 )
154 elif llm_provider == "azure_openai":
155 az_api_key = os.getenv("AZURE_OPENAI_API_KEY", "")
156 az_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
157 az_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "")
158 az_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT", "")
159 az_model = os.getenv("AZURE_OPENAI_MODEL", "")
160 # az_temperature= float(os.getenv("AZURE_OPENAI_TEMPERATURE",0.7))
161 az_temperature = TOOLOPS_TEMPERATURE
162 az_max_retries = int(os.getenv("AZURE_OPENAI_MAX_RETRIES", "2"))
163 az_max_tokens = int(os.getenv("AZURE_OPENAI_MAX_TOEKNS", "600"))
164 llm_config = AzureOpenAIConfig(
165 api_key=az_api_key,
166 azure_endpoint=az_url,
167 api_version=az_api_version,
168 azure_deployment=az_deployment,
169 model=az_model,
170 temperature=az_temperature,
171 max_retries=az_max_retries,
172 max_tokens=az_max_tokens,
173 timeout=None,
174 )
175 elif llm_provider == "anthropic":
176 ant_api_key = os.getenv("ANTHROPIC_API_KEY", "")
177 ant_model = os.getenv("ANTHROPIC_MODEL", "")
178 # ant_temperature= float(os.getenv("ANTHROPIC_TEMPERATURE",0.7))
179 ant_temperature = TOOLOPS_TEMPERATURE
180 ant_max_tokens = int(os.getenv("ANTHROPIC_MAX_TOKENS", "4096"))
181 ant_max_retries = int(os.getenv("ANTHROPIC_MAX_RETRIES", "2"))
182 llm_config = AnthropicConfig(api_key=ant_api_key, model=ant_model, temperature=ant_temperature, max_tokens=ant_max_tokens, max_retries=ant_max_retries, timeout=None)
184 elif llm_provider == "aws_bedrock":
185 aws_model = os.getenv("AWS_BEDROCK_MODEL_ID", "")
186 aws_region = os.getenv("AWS_BEDROCK_REGION", "")
187 # aws_temperatute=float(os.getenv("AWS_BEDROCK_TEMPERATURE",0.7))
188 aws_temperatute = TOOLOPS_TEMPERATURE
189 aws_max_tokens = int(os.getenv("AWS_BEDROCK_MAX_TOKENS", "4096"))
190 aws_key_id = os.getenv("AWS_ACCESS_KEY_ID", "")
191 aws_secret = os.getenv("AWS_SECRET_ACCESS_KEY", "")
192 aws_session_token = os.getenv("AWS_SESSION_TOKEN", "")
193 llm_config = AWSBedrockConfig(
194 model_id=aws_model,
195 region_name=aws_region,
196 temperature=aws_temperatute,
197 max_tokens=aws_max_tokens,
198 aws_access_key_id=aws_key_id,
199 aws_secret_access_key=aws_secret,
200 aws_session_token=aws_session_token,
201 )
202 elif llm_provider == "ollama":
203 ollama_model = os.getenv("OLLAMA_MODEL", "")
204 ollama_url = os.getenv("OLLAMA_BASE_URL", "")
205 # ollama_temeperature=float(os.getenv("OLLAMA_TEMPERATURE",0.7))
206 ollama_temeperature = TOOLOPS_TEMPERATURE
207 llm_config = OllamaConfig(base_url=ollama_url, model=ollama_model, temperature=ollama_temeperature, timeout=None, num_ctx=None)
208 elif llm_provider == "watsonx":
209 wx_api_key = os.getenv("WATSONX_APIKEY", "")
210 wx_base_url = os.getenv("WATSONX_URL", "")
211 wx_model = os.getenv("WATSONX_MODEL_ID", "")
212 wx_project_id = os.getenv("WATSONX_PROJECT_ID", "")
213 wx_temperature = TOOLOPS_TEMPERATURE
214 wx_max_tokens = int(os.getenv("WATSONX_MAX_NEW_TOKENS", "1000"))
215 wx_decoding_method = os.getenv("WATSONX_DECODING_METHOD", "greedy")
216 llm_config = WatsonxConfig(
217 api_key=wx_api_key,
218 url=wx_base_url,
219 project_id=wx_project_id,
220 model_id=wx_model,
221 temperature=wx_temperature,
222 max_new_tokens=wx_max_tokens,
223 decoding_method=wx_decoding_method,
224 )
225 else:
226 return None, None
228 llm_service = provider_class(llm_config)
229 llm_instance = llm_service.get_llm(model_type=model_type)
230 logger.info("Successfully configured LLM instance for ToolOps , and LLM provider - " + llm_provider)
231 except Exception as e:
232 logger.info("Error in configuring LLM instance for ToolOps -" + str(e))
233 return llm_instance, llm_config
236def execute_prompt(prompt):
237 """
238 Method for LLM inferencing using a prompt/user input
240 Args:
241 prompt: used specified prompt or inputs for LLM inferecning
243 Returns:
244 response: LLM output response for the given prompt
245 """
246 try:
247 logger.info("Inferencing OpenAI provider LLM with the given prompt")
249 completion_llm_instance, _ = get_llm_instance(model_type="completion")
250 llm_response = completion_llm_instance.invoke(prompt, stop=["\n\n", "<|endoftext|>", "###STOP###"])
251 response = llm_response.replace("<|eom_id|>", "").strip()
252 # logger.info("Successful - Inferencing OpenAI provider LLM")
253 return response
254 except Exception as e:
255 logger.error("Error in configuring LLM using OpenAI service provider - " + orjson.dumps({"Error": str(e)}).decode())
256 return ""
259# if __name__ == "__main__":
260# chat_llm_instance, _ = get_llm_instance(model_type="chat")
261# completion_llm_instance, _ = get_llm_instance(model_type="completion")
262# prompt = "what is India capital city?"
263# print("Prompt : ", prompt)
264# print("Text completion output : ")
265# print(execute_prompt(prompt))
266# response = chat_llm_instance.invoke(prompt)
267# print("Chat completion output : ")
268# print(response.content)