Source code for ibm_watson_machine_learning.foundation_models.extensions.langchain.llm

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2023-2024.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------
import logging
from typing import Any, List, Mapping, Optional

try:
    from langchain.llms.base import LLM
    from langchain.llms.utils import enforce_stop_tokens
except ImportError:
    raise ImportError("Could not import langchain: Please install langchain extension.")
from ibm_watson_machine_learning.foundation_models import Model, ModelInference
from ibm_watson_machine_learning.foundation_models.utils.utils import _raise_watsonxllm_deprecation_warning

logger = logging.getLogger(__name__)


[docs] class WatsonxLLM(LLM): """ `LangChain CustomLLM <https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm>`_ wrapper for watsonx foundation models. :param model: foundation model inference object instance :type model: Model **Supported chain types:** * `LLMChain`, * `TransformChain`, * `SequentialChain`, * `SimpleSequentialChain` * `ConversationChain` (including `ConversationBufferMemory`) * `LLMMathChain` (``bigscience/mt0-xxl``, ``eleutherai/gpt-neox-20b``, ``ibm/mpt-7b-instruct2``, ``bigcode/starcoder``, ``meta-llama/llama-2-70b-chat``, ``ibm/granite-13b-instruct-v1`` models only) **Instantiate the WatsonxLLM interface** .. code-block:: python from ibm_watson_machine_learning.foundation_models import Model from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM generate_params = { GenParams.MAX_NEW_TOKENS: 25 } model = Model( model_id="google/flan-ul2", credentials={ "apikey": "***", "url": "https://us-south.ml.cloud.ibm.com" }, params=generate_params, project_id="*****" ) custom_llm = WatsonxLLM(model=model) """ _raise_watsonxllm_deprecation_warning() model: Model | ModelInference = None llm_type: str = "IBM watsonx.ai" def __init__(self, model: Model | ModelInference) -> None: super(WatsonxLLM, self).__init__() self.model = model @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return self.model.get_identifying_params() @property def _llm_type(self) -> str: """Return type of llm.""" return self.llm_type def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call the IBM watsonx.ai inference endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python model = Model( model_id="google/flan-ul2", credentials={ "apikey": "***", "url": "https://us-south.ml.cloud.ibm.com" }, project_id="*****" ) llm = WatsonxLLM(model=model) response = llm("What is a molecule") """ text = self.model.generate_text(prompt=prompt) logger.info("Output of watsonx.ai call: {}".format(text)) if stop is not None: text = enforce_stop_tokens(text, stop) return text