# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2023-2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
import logging
from typing import Any, List, Mapping, Optional
try:
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
except ImportError:
raise ImportError("Could not import langchain: Please install langchain extension.")
from ibm_watson_machine_learning.foundation_models import Model, ModelInference
from ibm_watson_machine_learning.foundation_models.utils.utils import _raise_watsonxllm_deprecation_warning
logger = logging.getLogger(__name__)
[docs]
class WatsonxLLM(LLM):
"""
`LangChain CustomLLM <https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm>`_ wrapper for watsonx foundation models.
:param model: foundation model inference object instance
:type model: Model
**Supported chain types:**
* `LLMChain`,
* `TransformChain`,
* `SequentialChain`,
* `SimpleSequentialChain`
* `ConversationChain` (including `ConversationBufferMemory`)
* `LLMMathChain` (``bigscience/mt0-xxl``, ``eleutherai/gpt-neox-20b``, ``ibm/mpt-7b-instruct2``, ``bigcode/starcoder``, ``meta-llama/llama-2-70b-chat``, ``ibm/granite-13b-instruct-v1`` models only)
**Instantiate the WatsonxLLM interface**
.. code-block:: python
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
generate_params = {
GenParams.MAX_NEW_TOKENS: 25
}
model = Model(
model_id="google/flan-ul2",
credentials={
"apikey": "***",
"url": "https://us-south.ml.cloud.ibm.com"
},
params=generate_params,
project_id="*****"
)
custom_llm = WatsonxLLM(model=model)
"""
_raise_watsonxllm_deprecation_warning()
model: Model | ModelInference = None
llm_type: str = "IBM watsonx.ai"
def __init__(self, model: Model | ModelInference) -> None:
super(WatsonxLLM, self).__init__()
self.model = model
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return self.model.get_identifying_params()
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return self.llm_type
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call the IBM watsonx.ai inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
model = Model(
model_id="google/flan-ul2",
credentials={
"apikey": "***",
"url": "https://us-south.ml.cloud.ibm.com"
},
project_id="*****"
)
llm = WatsonxLLM(model=model)
response = llm("What is a molecule")
"""
text = self.model.generate_text(prompt=prompt)
logger.info("Output of watsonx.ai call: {}".format(text))
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text