Source code for ibm_watsonx_ai.foundation_models_manager

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2024.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

from __future__ import annotations
from enum import Enum
from functools import cached_property

from typing import TYPE_CHECKING, Generator, Literal, overload, Any
import warnings

from ibm_watsonx_ai.wml_resource import WMLResource
from ibm_watsonx_ai.messages.messages import Messages
from ibm_watsonx_ai.wml_client_error import WMLClientError
from ibm_watsonx_ai.utils.utils import StrEnum

if TYPE_CHECKING:
    from ibm_watsonx_ai import APIClient


[docs] class FoundationModelsManager(WMLResource): def __init__(self, client: APIClient): WMLResource.__init__(self, __name__, client) self._client = client @cached_property def TextModels(self): return StrEnum("TextModels", self._get_model_dict("base")) @cached_property def ChatModels(self): return StrEnum("ChatModels", self._get_model_dict("text_chat")) @cached_property def EmbeddingModels(self): return StrEnum("EmbeddingModels", self._get_model_dict("embedding")) @cached_property def PromptTunableModels(self): return StrEnum("PromptTunableModels", self._get_model_dict("prompt_tuning")) @cached_property def RerankModels(self): return StrEnum("RerankModels", self._get_model_dict("rerank")) @cached_property def TimeSeriesModels(self): return StrEnum("TimeSeriesModels", self._get_model_dict("time_series_forecast")) def _get_spec( self, url: str, operation_name: str, error_msg_id: str, model_id: str | None = None, limit: int | None = 50, filters: str | None = None, asynchronous: bool = False, get_all: bool = False, tech_preview: bool = False, ) -> dict | Generator | None: params = self._client._params(skip_userfs=True, skip_space_project_chk=True) if filters: params.update({"filters": filters}) if tech_preview: params.update({"tech_preview": True}) try: if model_id: result = self._get_with_or_without_limit( url, limit=None, op_name=operation_name, query_params=params, _all=True, _async=False, skip_space_project_chk=True, ) if isinstance(model_id, Enum): model_id = model_id.value model_res = [ res for res in result["resources"] if res["model_id"] == model_id ] if len(model_res) > 0: return model_res[0] else: return None else: return self._get_with_or_without_limit( url=url, limit=limit, op_name=operation_name, query_params=params, _async=asynchronous, _all=get_all, skip_space_project_chk=True, ) except WMLClientError as e: raise WMLClientError( Messages.get_message( self._client.credentials.url, message_id=error_msg_id, ), e, )
[docs] def get_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, **kwargs: Any, ) -> dict | Generator | None: """ Retrieves a list of specifications for a deployed foundation model. :param model_id: ID of the model, defaults to None (all models specifications are returned) :type model_id: str or ModelTypes, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, will work as a generator :type asynchronous: bool, optional :param get_all: if True, will get all entries in 'limited' chunks :type get_all: bool, optional :return: list of specifications for the deployed foundation model :rtype: dict or generator **Example:** .. code-block:: python # GET ALL MODEL SPECS client.foundation_models.get_model_specs() # GET MODEL SPECS BY MODEL_ID client.foundation_models.get_model_specs(model_id="google/flan-ul2") """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available foundation models", error_msg_id="fm_prompt_tuning_no_model_specs", filters=( None if self._client.CPD_version < 5.0 else "function_text_generation,!lifecycle_withdrawn:and" ), model_id=model_id, limit=limit, asynchronous=asynchronous, get_all=get_all, tech_preview=kwargs.get("tech_preview", False), )
[docs] def get_chat_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Operations to retrieve the list of chat foundation models specifications. :param model_id: Id of the model, defaults to None (all models specs are returned). :type model_id: str or ModelTypes, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, it will work as a generator :type asynchronous: bool, optional :param get_all: if True, it will get all entries in 'limited' chunks :type get_all: bool, optional :return: list of deployed foundation model specs :rtype: dict or generator **Example** .. code-block:: python # GET CHAT MODEL SPECS client.foundation_models.get_chat_model_specs() # GET CHAT MODEL SPECS BY MODEL_ID client.foundation_models.get_chat_model_specs(model_id="ibm/granite-13b-chat-v2") """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available chat models", error_msg_id="fm_prompt_tuning_no_model_specs", model_id=model_id, filters="function_text_chat", limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_chat_function_calling_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Operations to retrieve the list of chat foundation models specifications with function calling support . :param model_id: Id of the model, defaults to None (all models specs are returned). :type model_id: str or ModelTypes, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, it will work as a generator :type asynchronous: bool, optional :param get_all: if True, it will get all entries in 'limited' chunks :type get_all: bool, optional :return: list of deployed foundation model specs :rtype: dict or generator **Example** .. code-block:: python # GET CHAT FUNCTION CALLING MODEL SPECS client.foundation_models.get_chat_function_calling_model_specs() # GET CHAT FUNCTION CALLING MODEL SPECS BY MODEL_ID client.foundation_models.get_chat_function_calling_model_specs(model_id="meta-llama/llama-3-1-70b-instruct") """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available chat function calling models", error_msg_id="fm_prompt_tuning_no_model_specs", model_id=model_id, filters="task_function_calling", limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_custom_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """Get details on available custom model(s) as a dictionary or as a generator (``asynchronous``). If ``asynchronous`` or ``get_all`` is set, then ``model_id`` is ignored. :param model_id: ID of the model, defaults to None (all models specifications are returned) :type model_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, will work as a generator :type asynchronous: bool, optional :param get_all: if True, will get all entries in 'limited' chunks :type get_all: bool, optional :return: details of supported custom models, None if no supported custom models are found for the given model_id :rtype: dict or generator **Example:** .. code-block:: python client.foundation_models.get_custom_models_spec() client.foundation_models.get_custom_models_spec() client.foundation_models.get_custom_models_spec(model_id='mistralai/Mistral-7B-Instruct-v0.2') client.foundation_models.get_custom_models_spec(limit=20) client.foundation_models.get_custom_models_spec(limit=20, get_all=True) for spec in client.foundation_models.get_custom_model_specs(limit=20, asynchronous=True, get_all=True): print(spec, end="") """ if self._client.CLOUD_PLATFORM_SPACES: raise WMLClientError( Messages.get_message(message_id="custom_models_cloud_scenario") ) warnings.warn( "Model needs to be first stored via client.repository.store_model(model_id, meta_props=metadata)" " and deployed via client.deployments.create(asset_id, metadata) to be used." ) return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_custom_foundation_models_href(), operation_name="Get custom model specs", error_msg_id="custom_models_no_model_specs", model_id=model_id, limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_embeddings_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Retrieves the specifications of an embeddings model. :param model_id: ID of the model, defaults to None (all models specifications are returned) :type model_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, will work as a generator :type asynchronous: bool, optional :param get_all: if True, will get all entries in 'limited' chunks :type get_all: bool, optional :return: specifications of the embeddings model :rtype: dict or generator **Example:** .. code-block:: python client.foundation_models.get_embeddings_model_specs() client.foundation_models.get_embeddings_model_specs('ibm/slate-125m-english-rtrvr') """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available embedding models", error_msg_id="fm_prompt_tuning_no_model_specs", model_id=model_id, filters="function_embedding", limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_time_series_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Retrieves the specifications of an time series model. :param model_id: ID of the model, defaults to None (all models specifications are returned) :type model_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, will work as a generator :type asynchronous: bool, optional :param get_all: if True, will get all entries in 'limited' chunks :type get_all: bool, optional :return: specifications of the time series model :rtype: dict or generator **Example:** .. code-block:: python client.foundation_models.get_time_series_model_specs() client.foundation_models.get_time_series_model_specs('ibm/granite-ttm-1536-96-r2') """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available time series models", error_msg_id="fm_prompt_tuning_no_model_specs", model_id=model_id, filters="function_time_series_forecast", limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_rerank_model_specs( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Retrieves the specifications of a rerank model. :param model_id: ID of the model, defaults to None (all models specifications are returned) :type model_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, will work as a generator :type asynchronous: bool, optional :param get_all: if True, will get all entries in 'limited' chunks :type get_all: bool, optional :return: specifications of the rerank model :rtype: dict or generator **Example:** .. code-block:: python client.foundation_models.get_rerank_model_specs() client.foundation_models.get_rerank_model_specs('ibm/slate-125m-english-rtrvr-v2') """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available rerank models", error_msg_id="fm_prompt_tuning_no_model_specs", model_id=model_id, filters="function_rerank", limit=limit, asynchronous=asynchronous, get_all=get_all, )
@overload def get_model_specs_with_prompt_tuning_support( self, model_id: str | None = ..., limit: int | None = ..., asynchronous: Literal[False] = False, get_all: bool = ..., ) -> dict | None: ... @overload def get_model_specs_with_prompt_tuning_support( self, model_id: str | None, limit: int | None, asynchronous: Literal[True], get_all: bool, ) -> Generator: ...
[docs] def get_model_specs_with_prompt_tuning_support( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Queries the details of deployed foundation models with prompt tuning support. :param model_id: ID of the model, defaults to None (all models specifications are returned) :type model_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, will work as a generator :type asynchronous: bool, optional :param get_all: if True, will get all entries in 'limited' chunks :type get_all: bool, optional :return: list of specifications of a deployed foundation model with prompt tuning support :rtype: dict or generator **Example:** .. code-block:: python client.foundation_models.get_model_specs_with_prompt_tuning_support() client.foundation_models.get_model_specs_with_prompt_tuning_support('google/flan-t5-xl') """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available foundation models", error_msg_id="fm_prompt_tuning_no_model_specs", model_id=model_id, filters="function_prompt_tune_trainable", limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_model_specs_with_fine_tuning_support( self, model_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict | Generator | None: """ Operations to query the details of the deployed foundation models with fine-tuning support. :param model_id: Id of the model, defaults to None (all models specs are returned). :type model_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if True, it will work as a generator :type asynchronous: bool, optional :param get_all: if True, it will get all entries in 'limited' chunks :type get_all: bool, optional :return: list of deployed foundation model specs with prompt tuning support :rtype: dict or generator **Example** .. code-block:: python client.foundation_models.get_model_specs_with_fine_tuning_support() client.foundation_models.get_model_specs_with_fine_tuning_support('bigscience/bloom') """ return self._get_spec( url=self._client.service_instance._href_definitions.get_fm_specifications_href(), operation_name="Get available foundation models", error_msg_id="fm_fine_tuning_no_model_specs", model_id=model_id, filters="function_fine_tune_trainable", limit=limit, asynchronous=asynchronous, get_all=get_all, )
[docs] def get_model_lifecycle(self, model_id: str, **kwargs: Any) -> list | None: """ Retrieves a list of lifecycle data of a foundation model. :param model_id: ID of the model :type model_id: str :return: list of lifecycle data of a foundation model :rtype: list **Example:** .. code-block:: python client.foundation_models.get_model_lifecycle( model_id="ibm/granite-13b-instruct-v2" ) """ model_spec = self.get_model_specs( model_id, tech_preview=kwargs.get("tech_preview", False) ) return model_spec.get("lifecycle") if model_spec is not None else None
def _get_model_dict( self, model_type: Literal[ "base", "embedding", "prompt_tuning", "text_chat", "rerank", "time_series_forecast", ], ) -> dict: """ Retrieves the dictionary of models to Enum. :param model_type: type of model function :type model_type: Literal["base", "embedding", "prompt_tuning", "text_chat", "rerank", "time_series_forecast"] :return: dictionary of models to Enum :rtype: dict """ function_dict = { "base": self.get_model_specs, "embedding": self.get_embeddings_model_specs, "prompt_tuning": self.get_model_specs_with_prompt_tuning_support, "text_chat": self.get_chat_model_specs, "rerank": self.get_rerank_model_specs, "time_series_forecast": self.get_time_series_model_specs, } model_specs_dict = {} for model_spec in function_dict[model_type]()["resources"]: if "model_id" in model_spec: model_specs_dict[ model_spec["model_id"].split("/")[-1].replace("-", "_").upper() ] = model_spec["model_id"] return model_specs_dict