# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import annotations
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Generator, Literal, overload, Any
import warnings
from ibm_watsonx_ai.wml_resource import WMLResource
from ibm_watsonx_ai.messages.messages import Messages
from ibm_watsonx_ai.wml_client_error import WMLClientError
from ibm_watsonx_ai.foundation_models.utils.utils import StrEnum
if TYPE_CHECKING:
from ibm_watsonx_ai import APIClient
[docs]
class FoundationModelsManager(WMLResource):
def __init__(self, client: APIClient):
WMLResource.__init__(self, __name__, client)
self._client = client
@cached_property
def TextModels(self):
return StrEnum("TextModels", self._get_model_dict("base"))
@cached_property
def EmbeddingModels(self):
return StrEnum("EmbeddingModels", self._get_model_dict("embedding"))
@cached_property
def PromptTunableModels(self):
return StrEnum("PromptTunableModels", self._get_model_dict("prompt_tuning"))
def _get_spec(
self,
url: str,
operation_name: str,
error_msg_id: str,
model_id: str | None = None,
limit: int | None = 50,
filters: str | None = None,
asynchronous: bool = False,
get_all: bool = False,
tech_preview: bool = False,
) -> dict | Generator | None:
params = self._client._params(skip_userfs=True, skip_space_project_chk=True)
if filters:
params.update({"filters": filters})
if tech_preview:
params.update({"tech_preview": True})
try:
if model_id:
result = self._get_with_or_without_limit(
url,
limit=None,
op_name=operation_name,
query_params=params,
_all=True,
_async=False,
skip_space_project_chk=True,
)
if isinstance(model_id, Enum):
model_id = model_id.value
model_res = [
res for res in result["resources"] if res["model_id"] == model_id
]
if len(model_res) > 0:
return model_res[0]
else:
return None
else:
return self._get_with_or_without_limit(
url=url,
limit=limit,
op_name=operation_name,
query_params=params,
_async=asynchronous,
_all=get_all,
skip_space_project_chk=True,
)
except WMLClientError as e:
raise WMLClientError(
Messages.get_message(
self._client.credentials.url,
message_id=error_msg_id,
),
e,
)
[docs]
def get_model_specs(
self,
model_id: str | None = None,
limit: int | None = None,
asynchronous: bool = False,
get_all: bool = False,
**kwargs: Any,
) -> dict | Generator | None:
"""
Operations to retrieve the list of deployed foundation models specifications.
:param model_id: Id of the model, defaults to None (all models specs are returned).
:type model_id: str or ModelTypes, optional
:param limit: limit number of fetched records
:type limit: int, optional
:param asynchronous: if True, it will work as a generator
:type asynchronous: bool, optional
:param get_all: if True, it will get all entries in 'limited' chunks
:type get_all: bool, optional
:return: list of deployed foundation model specs
:rtype: dict or generator
**Example**
.. code-block:: python
# GET ALL MODEL SPECS
client.foundation_models.get_model_specs()
# GET MODEL SPECS BY MODEL_ID
client.foundation_models.get_model_specs(model_id="google/flan-ul2")
"""
return self._get_spec(
url=self._client.service_instance._href_definitions.get_fm_specifications_href(),
operation_name="Get available foundation models",
error_msg_id="fm_prompt_tuning_no_model_specs",
filters="function_text_generation,!lifecycle_withdrawn:and",
model_id=model_id,
limit=limit,
asynchronous=asynchronous,
get_all=get_all,
tech_preview=kwargs.get("tech_preview", False),
)
[docs]
def get_custom_model_specs(
self,
model_id: str | None = None,
limit: int | None = None,
asynchronous: bool = False,
get_all: bool = False,
) -> dict | Generator | None:
"""Get details on available custom model(s) as dict or as generator (``asynchronous``).
If ``asynchronous`` or ``get_all`` is set, then ``model_id`` is ignored.
:param model_id: Id of the model, defaults to None (all models specs are returned).
:type model_id: str, optional
:param limit: limit number of fetched records
:type limit: int, optional
:param asynchronous: if True, it will work as a generator
:type asynchronous: bool, optional
:param get_all: if True, it will get all entries in 'limited' chunks
:type get_all: bool, optional
:return: details of supported custom models, None if for given model_id non is found
:rtype: dict or generator
**Example**
.. code-block:: python
client.foundation_models.get_custom_models_spec()
client.foundation_models.get_custom_models_spec()
client.foundation_models.get_custom_models_spec(model_id='mistralai/Mistral-7B-Instruct-v0.2')
client.foundation_models.get_custom_models_spec(limit=20)
client.foundation_models.get_custom_models_spec(limit=20, get_all=True)
for spec in client.foundation_models.get_custom_model_specs(limit=20, asynchronous=True, get_all=True):
print(spec, end="")
"""
warnings.warn(
"Model needs to be first stored via client.repository.store_model(model_id, meta_props=metadata)"
" and deployed via client.deployments.create(asset_id, metadata) to be used."
)
return self._get_spec(
url=self._client.service_instance._href_definitions.get_fm_custom_foundation_models_href(),
operation_name="Get custom model specs",
error_msg_id="custom_models_no_model_specs",
model_id=model_id,
limit=limit,
asynchronous=asynchronous,
get_all=get_all,
)
[docs]
def get_embeddings_model_specs(
self,
model_id: str | None = None,
limit: int | None = None,
asynchronous: bool = False,
get_all: bool = False,
) -> dict | Generator | None:
"""
Operation to retrieve the embeddings model specs.
:param model_id: Id of the model, defaults to None (all models specs are returned).
:type model_id: str, optional
:param limit: limit number of fetched records
:type limit: int, optional
:param asynchronous: if True, it will work as a generator
:type asynchronous: bool, optional
:param get_all: if True, it will get all entries in 'limited' chunks
:type get_all: bool, optional
:return: embeddings model specs
:rtype: dict or generator
**Example**
.. code-block:: python
client.foundation_models.get_embeddings_model_specs()
client.foundation_models.get_embeddings_model_specs('ibm/slate-125m-english-rtrvr')
"""
return self._get_spec(
url=self._client.service_instance._href_definitions.get_fm_specifications_href(),
operation_name="Get available embedding models",
error_msg_id="fm_prompt_tuning_no_model_specs",
model_id=model_id,
filters="function_embedding",
limit=limit,
asynchronous=asynchronous,
get_all=get_all,
)
@overload
def get_model_specs_with_prompt_tuning_support(
self,
model_id: str | None = ...,
limit: int | None = ...,
asynchronous: Literal[False] = False,
get_all: bool = ...,
) -> dict | None: ...
@overload
def get_model_specs_with_prompt_tuning_support(
self,
model_id: str | None,
limit: int | None,
asynchronous: Literal[True],
get_all: bool,
) -> Generator: ...
[docs]
def get_model_specs_with_prompt_tuning_support(
self,
model_id: str | None = None,
limit: int | None = None,
asynchronous: bool = False,
get_all: bool = False,
) -> dict | Generator | None:
"""
Operations to query the details of the deployed foundation models with prompt tuning support.
:param model_id: Id of the model, defaults to None (all models specs are returned).
:type model_id: str, optional
:param limit: limit number of fetched records
:type limit: int, optional
:param asynchronous: if True, it will work as a generator
:type asynchronous: bool, optional
:param get_all: if True, it will get all entries in 'limited' chunks
:type get_all: bool, optional
:return: list of deployed foundation model specs with prompt tuning support
:rtype: dict or generator
**Example**
.. code-block:: python
client.foundation_models.get_model_specs_with_prompt_tuning_support()
client.foundation_models.get_model_specs_with_prompt_tuning_support('google/flan-t5-xl')
"""
return self._get_spec(
url=self._client.service_instance._href_definitions.get_fm_specifications_href(),
operation_name="Get available foundation models",
error_msg_id="fm_prompt_tuning_no_model_specs",
model_id=model_id,
filters="function_prompt_tune_trainable",
limit=limit,
asynchronous=asynchronous,
get_all=get_all,
)
[docs]
def get_model_lifecycle(self, model_id: str, **kwargs: Any) -> list | None:
"""
Operation to retrieve the list of model lifecycle data.
:param model_id: the type of model to use
:type model_id: str
:return: list of deployed foundation model lifecycle data
:rtype: list
**Example**
.. code-block:: python
client.foundation_models.get_model_lifecycle(
model_id="ibm/granite-13b-instruct-v2"
)
"""
model_spec = self.get_model_specs(
model_id, tech_preview=kwargs.get("tech_preview", False)
)
return model_spec.get("lifecycle") if model_spec is not None else None
def _get_model_dict(
self, model_type: Literal["base", "embedding", "prompt_tuning"]
) -> dict:
"""
Operation to retrieve the dictionary of models to Enum
:param model_type: Type of model function
:type model_type: Literal["base", "embedding", "prompt_tuning"]
:return: dict of models to Enum
:rtype: dict
"""
function_dict = {
"base": self.get_model_specs,
"embedding": self.get_embeddings_model_specs,
"prompt_tuning": self.get_model_specs_with_prompt_tuning_support,
}
model_specs_dict = {}
for model_spec in function_dict[model_type]()["resources"]:
if "model_id" in model_spec:
model_specs_dict[
model_spec["model_id"].split("/")[-1].replace("-", "_").upper()
] = model_spec["model_id"]
return model_specs_dict