Coverage for mcpgateway / toolops / utils / llm_util.py: 100%
85 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/toolops/utils/llm_util.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Jay Bandlamudi
7ContextForge - Main module for using and supporting MCP-CF LLM providers in toolops modules.
9This module defines the utility funtions to use MCP-CF supported LLM providers in toolops.
10"""
12# Standard
13import os
15# Third-Party
16from dotenv import load_dotenv
17import orjson
19# First-Party
20from mcpgateway.services.logging_service import LoggingService
21from mcpgateway.services.mcp_client_chat_service import (
22 AnthropicConfig,
23 AnthropicProvider,
24 AWSBedrockConfig,
25 AWSBedrockProvider,
26 AzureOpenAIConfig,
27 AzureOpenAIProvider,
28 OllamaConfig,
29 OllamaProvider,
30 OpenAIConfig,
31 OpenAIProvider,
32 WatsonxConfig,
33 WatsonxProvider,
34)
36logging_service = LoggingService()
37logger = logging_service.get_logger(__name__)
39load_dotenv()
41# set LLM temperature for toolops modules as low to produce minimally variable model outputs.
42TOOLOPS_TEMPERATURE = 0.1
45def get_llm_instance(model_type="completion"):
46 """
47 Method to get MCP-CF provider type llm instance based on model type
49 Args:
50 model_type : LLM instance type such as chat model or token completion model, accepted values: 'completion', 'chat'
52 Returns:
53 llm_instance : LLM model instance used for inferencing the prompts/user inputs
54 llm_config: LLM provider configuration provided in the environment variables
56 Examples:
57 >>> import os
58 >>> from unittest.mock import patch, MagicMock
59 >>> # Setup: Define the global variable used in the function for the test context
60 >>> global TOOLOPS_TEMPERATURE
61 >>> TOOLOPS_TEMPERATURE = 0.7
63 >>> # Case 1: OpenAI Provider Configuration
64 >>> # We patch os.environ to simulate specific provider settings
65 >>> env_vars = {
66 ... "LLM_PROVIDER": "openai",
67 ... "OPENAI_API_KEY": "sk-mock-key",
68 ... "OPENAI_BASE_URL": "https://api.openai.com",
69 ... "OPENAI_MODEL": "gpt-4"
70 ... }
71 >>> with patch.dict(os.environ, env_vars):
72 ... # Assuming OpenAIProvider and OpenAIConfig are available in the module scope
73 ... # We simulate the function call. Note: This tests the Config creation logic.
74 ... llm_instance, llm_config = get_llm_instance("completion")
75 ... llm_config.__class__.__name__
76 'OpenAIConfig'
78 >>> # Case 2: Azure OpenAI Provider Configuration
79 >>> env_vars = {
80 ... "LLM_PROVIDER": "azure_openai",
81 ... "AZURE_OPENAI_API_KEY": "az-mock-key",
82 ... "AZURE_OPENAI_ENDPOINT": "https://mock.azure.com",
83 ... "AZURE_OPENAI_MODEL": "gpt-35-turbo"
84 ... }
85 >>> with patch.dict(os.environ, env_vars):
86 ... llm_instance, llm_config = get_llm_instance("chat")
87 ... llm_config.__class__.__name__
88 'AzureOpenAIConfig'
90 >>> # Case 3: AWS Bedrock Provider Configuration
91 >>> env_vars = {
92 ... "LLM_PROVIDER": "aws_bedrock",
93 ... "AWS_BEDROCK_MODEL_ID": "anthropic.claude-v2",
94 ... "AWS_BEDROCK_REGION": "us-east-1",
95 ... "AWS_ACCESS_KEY_ID": "mock-access",
96 ... "AWS_SECRET_ACCESS_KEY": "mock-secret"
97 ... }
98 >>> with patch.dict(os.environ, env_vars):
99 ... llm_instance, llm_config = get_llm_instance("chat")
100 ... llm_config.__class__.__name__
101 'AWSBedrockConfig'
103 >>> # Case 4: WatsonX Provider Configuration
104 >>> env_vars = {
105 ... "LLM_PROVIDER": "watsonx",
106 ... "WATSONX_APIKEY": "wx-mock-key",
107 ... "WATSONX_URL": "https://us-south.ml.cloud.ibm.com",
108 ... "WATSONX_PROJECT_ID": "mock-project-id",
109 ... "WATSONX_MODEL_ID": "ibm/granite-13b"
110 ... }
111 >>> with patch.dict(os.environ, env_vars):
112 ... llm_instance, llm_config = get_llm_instance("completion")
113 ... llm_config.__class__.__name__
114 'WatsonxConfig'
115 """
116 llm_provider = os.getenv("LLM_PROVIDER", "")
117 llm_instance, llm_config = None, None
118 logger.info("Configuring LLM instance for ToolOps , and LLM provider - " + llm_provider)
119 try:
120 provider_map = {
121 "azure_openai": AzureOpenAIProvider,
122 "openai": OpenAIProvider,
123 "anthropic": AnthropicProvider,
124 "aws_bedrock": AWSBedrockProvider,
125 "ollama": OllamaProvider,
126 "watsonx": WatsonxProvider,
127 }
128 provider_class = provider_map.get(llm_provider)
130 # getting LLM configs from environment variables
131 llm_config = None
132 if llm_provider == "openai":
133 oai_api_key = os.getenv("OPENAI_API_KEY", "")
134 oai_base_url = os.getenv("OPENAI_BASE_URL", "")
135 oai_model = os.getenv("OPENAI_MODEL", "")
136 # oai_temperature= float(os.getenv("OPENAI_TEMPERATURE","0.7"))
137 oai_temperature = TOOLOPS_TEMPERATURE
138 oai_max_retries = int(os.getenv("OPENAI_MAX_RETRIES", "2"))
139 oai_max_tokens = int(os.getenv("OPENAI_MAX_TOEKNS", "600"))
140 # adding default headers for RITS LLM platform as required
141 if isinstance(oai_base_url, str) and "rits.fmaas.res.ibm.com" in oai_base_url:
142 default_headers = {"RITS_API_KEY": oai_api_key}
143 else:
144 default_headers = None
145 llm_config = OpenAIConfig(
146 api_key=oai_api_key,
147 base_url=oai_base_url,
148 temperature=oai_temperature,
149 model=oai_model,
150 max_tokens=oai_max_tokens,
151 max_retries=oai_max_retries,
152 default_headers=default_headers,
153 timeout=None,
154 )
155 elif llm_provider == "azure_openai":
156 az_api_key = os.getenv("AZURE_OPENAI_API_KEY", "")
157 az_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
158 az_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "")
159 az_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT", "")
160 az_model = os.getenv("AZURE_OPENAI_MODEL", "")
161 # az_temperature= float(os.getenv("AZURE_OPENAI_TEMPERATURE",0.7))
162 az_temperature = TOOLOPS_TEMPERATURE
163 az_max_retries = int(os.getenv("AZURE_OPENAI_MAX_RETRIES", "2"))
164 az_max_tokens = int(os.getenv("AZURE_OPENAI_MAX_TOEKNS", "600"))
165 llm_config = AzureOpenAIConfig(
166 api_key=az_api_key,
167 azure_endpoint=az_url,
168 api_version=az_api_version,
169 azure_deployment=az_deployment,
170 model=az_model,
171 temperature=az_temperature,
172 max_retries=az_max_retries,
173 max_tokens=az_max_tokens,
174 timeout=None,
175 )
176 elif llm_provider == "anthropic":
177 ant_api_key = os.getenv("ANTHROPIC_API_KEY", "")
178 ant_model = os.getenv("ANTHROPIC_MODEL", "")
179 # ant_temperature= float(os.getenv("ANTHROPIC_TEMPERATURE",0.7))
180 ant_temperature = TOOLOPS_TEMPERATURE
181 ant_max_tokens = int(os.getenv("ANTHROPIC_MAX_TOKENS", "4096"))
182 ant_max_retries = int(os.getenv("ANTHROPIC_MAX_RETRIES", "2"))
183 llm_config = AnthropicConfig(api_key=ant_api_key, model=ant_model, temperature=ant_temperature, max_tokens=ant_max_tokens, max_retries=ant_max_retries, timeout=None)
185 elif llm_provider == "aws_bedrock":
186 aws_model = os.getenv("AWS_BEDROCK_MODEL_ID", "")
187 aws_region = os.getenv("AWS_BEDROCK_REGION", "")
188 # aws_temperatute=float(os.getenv("AWS_BEDROCK_TEMPERATURE",0.7))
189 aws_temperatute = TOOLOPS_TEMPERATURE
190 aws_max_tokens = int(os.getenv("AWS_BEDROCK_MAX_TOKENS", "4096"))
191 aws_key_id = os.getenv("AWS_ACCESS_KEY_ID", "")
192 aws_secret = os.getenv("AWS_SECRET_ACCESS_KEY", "")
193 aws_session_token = os.getenv("AWS_SESSION_TOKEN", "")
194 llm_config = AWSBedrockConfig(
195 model_id=aws_model,
196 region_name=aws_region,
197 temperature=aws_temperatute,
198 max_tokens=aws_max_tokens,
199 aws_access_key_id=aws_key_id,
200 aws_secret_access_key=aws_secret,
201 aws_session_token=aws_session_token,
202 )
203 elif llm_provider == "ollama":
204 ollama_model = os.getenv("OLLAMA_MODEL", "")
205 ollama_url = os.getenv("OLLAMA_BASE_URL", "")
206 # ollama_temeperature=float(os.getenv("OLLAMA_TEMPERATURE",0.7))
207 ollama_temeperature = TOOLOPS_TEMPERATURE
208 llm_config = OllamaConfig(base_url=ollama_url, model=ollama_model, temperature=ollama_temeperature, timeout=None, num_ctx=None)
209 elif llm_provider == "watsonx":
210 wx_api_key = os.getenv("WATSONX_APIKEY", "")
211 wx_base_url = os.getenv("WATSONX_URL", "")
212 wx_model = os.getenv("WATSONX_MODEL_ID", "")
213 wx_project_id = os.getenv("WATSONX_PROJECT_ID", "")
214 wx_temperature = TOOLOPS_TEMPERATURE
215 wx_max_tokens = int(os.getenv("WATSONX_MAX_NEW_TOKENS", "1000"))
216 wx_decoding_method = os.getenv("WATSONX_DECODING_METHOD", "greedy")
217 llm_config = WatsonxConfig(
218 api_key=wx_api_key,
219 url=wx_base_url,
220 project_id=wx_project_id,
221 model_id=wx_model,
222 temperature=wx_temperature,
223 max_new_tokens=wx_max_tokens,
224 decoding_method=wx_decoding_method,
225 )
226 else:
227 return None, None
229 llm_service = provider_class(llm_config)
230 llm_instance = llm_service.get_llm(model_type=model_type)
231 logger.info("Successfully configured LLM instance for ToolOps , and LLM provider - " + llm_provider)
232 except Exception as e:
233 logger.info("Error in configuring LLM instance for ToolOps -" + str(e))
234 return llm_instance, llm_config
237def execute_prompt(prompt):
238 """
239 Method for LLM inferencing using a prompt/user input
241 Args:
242 prompt: used specified prompt or inputs for LLM inferecning
244 Returns:
245 response: LLM output response for the given prompt
246 """
247 try:
248 logger.info("Inferencing OpenAI provider LLM with the given prompt")
250 completion_llm_instance, _ = get_llm_instance(model_type="completion")
251 llm_response = completion_llm_instance.invoke(prompt, stop=["\n\n", "<|endoftext|>", "###STOP###"])
252 response = llm_response.replace("<|eom_id|>", "").strip()
253 # logger.info("Successful - Inferencing OpenAI provider LLM")
254 return response
255 except Exception as e:
256 logger.error("Error in configuring LLM using OpenAI service provider - " + orjson.dumps({"Error": str(e)}).decode())
257 return ""
260# if __name__ == "__main__":
261# chat_llm_instance, _ = get_llm_instance(model_type="chat")
262# completion_llm_instance, _ = get_llm_instance(model_type="completion")
263# prompt = "what is India capital city?"
264# print("Prompt : ", prompt)
265# print("Text completion output : ")
266# print(execute_prompt(prompt))
267# response = chat_llm_instance.invoke(prompt)
268# print("Chat completion output : ")
269# print(response.content)