# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2023-2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import annotations
import warnings
import asyncio
from enum import Enum
import types
import copy
import functools
from string import Formatter
from typing import (
TYPE_CHECKING,
Any,
Sequence,
cast,
Mapping,
Type,
Generator,
Callable,
)
from dataclasses import dataclass, KW_ONLY, asdict
from json import loads as json_loads
from warnings import warn, simplefilter, catch_warnings
from pprint import pprint
import threading
from ibm_watsonx_ai._wrappers import requests
from ibm_watsonx_ai.helpers import DataConnection
from ibm_watsonx_ai.messages.messages import Messages
from ibm_watsonx_ai.wml_client_error import (
WMLClientError,
InvalidMultipleArguments,
InvalidValue,
)
from ibm_watsonx_ai.utils import next_resource_generator, get_user_agent_header
from ibm_watsonx_ai.utils.autoai.utils import load_file_from_file_system_nonautoai
from ibm_watsonx_ai.utils.autoai.enums import DataConnectionTypes
from ibm_watsonx_ai.lifecycle import SpecStates
if TYPE_CHECKING:
from ibm_watsonx_ai import APIClient
from types import TracebackType, FrameType
from asyncio import AbstractEventLoop
from concurrent.futures import Future
@dataclass
class PromptTuningParams:
base_model: dict
_: KW_ONLY
accumulate_steps: int | None = None
batch_size: int | None = None
init_method: str | None = None
init_text: str | None = None
learning_rate: float | None = None
max_input_tokens: int | None = None
max_output_tokens: int | None = None
num_epochs: int | None = None
task_id: str | None = None
tuning_type: str | None = None
verbalizer: str | None = None
def to_dict(self) -> dict:
return {key: value for key, value in asdict(self).items() if value is not None}
@dataclass
class FineTuningParams:
base_model: dict
task_id: str
_: KW_ONLY
num_epochs: int | None = None
learning_rate: float | None = None
batch_size: int | None = None
max_seq_length: int | None = None
accumulate_steps: int | None = None
verbalizer: str | None = None
response_template: str | None = None
gpu: dict | None = None
def to_dict(self) -> dict:
return {key: value for key, value in asdict(self).items() if value is not None}
def _get_foundation_models_spec(
url: str, operation_name: str, additional_params: dict | None = None
) -> dict:
params = {"version": "2023-09-30"}
if additional_params:
params.update(additional_params)
response = requests.get(
url, params=params, headers={"User-Agent": get_user_agent_header()}
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_no_foundation_models"),
logg_messages=False,
)
else:
msg = f"{operation_name} failed. Reason: {response.text}"
raise WMLClientError(msg)
[docs]
def get_model_specs(url: str, model_id: str | None = None) -> dict:
"""
Retrieve the list of deployed foundation models specifications.
**Decrecated:** From ``ibm_watsonx_ai`` 1.0, the `get_model_specs()` function is deprecated, use the `client.foundation_models.get_model_specs()` function instead.
:param url: URL of the environment
:type url: str
:param model_id: ID of the model, defaults to None (all models specs are returned).
:type model_id: Optional[str, ModelTypes], optional
:return: list of deployed foundation model specs
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.foundation_models import get_model_specs
# GET ALL MODEL SPECS
get_model_specs(
url="https://us-south.ml.cloud.ibm.com"
)
# GET MODEL SPECS BY MODEL_ID
get_model_specs(
url="https://us-south.ml.cloud.ibm.com",
model_id="google/flan-ul2"
)
"""
warn(
"`get_model_specs()` function is deprecated from 1.0, please use `client.foundation_models.get_model_specs()` function instead.",
category=DeprecationWarning,
)
try:
if model_id:
if isinstance(model_id, Enum):
model_id = model_id.value
try:
return [
res
for res in _get_foundation_models_spec(
f"{url}/ml/v1/foundation_model_specs",
"Get available foundation models",
)["resources"]
if res["model_id"] == model_id
][0]
except WMLClientError: # Remove on CPD 5.0 release
return [
res
for res in _get_foundation_models_spec(
f"{url}/ml/v1-beta/foundation_model_specs",
"Get available foundation models",
)["resources"]
if res["model_id"] == model_id
][0]
else:
try:
return _get_foundation_models_spec(
f"{url}/ml/v1/foundation_model_specs",
"Get available foundation models",
)
except WMLClientError: # Remove on CPD 5.0 release
return _get_foundation_models_spec(
f"{url}/ml/v1-beta/foundation_model_specs",
"Get available foundation models",
)
except WMLClientError as e:
raise WMLClientError(
Messages.get_message(url, message_id="fm_prompt_tuning_no_model_specs"),
e.reason,
)
def get_custom_model_specs(
credentials: dict | None = None,
api_client: APIClient | None = None,
model_id: str | None = None,
limit: int = 100,
asynchronous: bool = False,
get_all: bool = False,
verify: str | bool | None = None,
) -> dict | Generator:
"""Get details on available custom model(s) as dict or as generator (``asynchronous``).
If ``asynchronous`` or ``get_all`` is set, then ``model_id`` is ignored.
**Decrecated:** From ``ibm_watsonx_ai`` 1.0, the `get_custom_model_specs()` function is deprecated, use `client.foundation_models.get_custom_model_specs()` function instead.
:param credentials: credentials to watsonx.ai instance
:type credentials: dict, optional
:param api_client: API client to connect to service
:type api_client: APIClient, optional
:param model_id: Id of the model, defaults to None (all models specs are returned).
:type model_id: str, optional
:param limit: limit number of fetched records. Possible values: 1 ≤ value ≤ 200, default value: 100
:type limit: int, optional
:param asynchronous: if True, it will work as a generator
:type asynchronous: bool, optional
:param get_all: if True, it will get all entries in 'limited' chunks
:type get_all: bool, optional
:param verify: You can pass one of following as verify:
* the path to a CA_BUNDLE file
* the path of directory with certificates of trusted CAs
* `True` - default path to truststore will be taken
* `False` - no verification will be made
:type verify: bool or str, optional
:return: details of supported custom models
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.foundation_models import get_custom_model_specs
get_custom_models_spec(api_client=client)
get_custom_models_spec(credentials=credentials)
get_custom_models_spec(api_client=client, model_id='mistralai/Mistral-7B-Instruct-v0.2')
get_custom_models_spec(api_client=client, limit=20)
get_custom_models_spec(api_client=client, limit=20, get_all=True)
for spec in get_custom_model_specs(api_client=client, limit=20, asynchronous=True, get_all=True):
print(spec, end="")
"""
warn(
"`get_custom_model_specs()` function is deprecated from 1.0, please use `client.foundation_models.get_custom_model_specs()` function instead.",
category=DeprecationWarning,
)
warnings.warn(
"Model needs to be first stored via client.repository.store_model(model_id, meta_props=metadata)"
" and deployed via client.deployments.create(asset_id, metadata) to be used."
)
if credentials is None and api_client is None:
raise InvalidMultipleArguments(
params_names_list=["credentials", "api_client"],
reason="None of the arguments were provided.",
)
elif credentials:
from ibm_watsonx_ai import APIClient
client = APIClient(credentials, verify=verify)
else:
client = api_client # type: ignore[assignment]
url = client.credentials.url
params = client._params(skip_for_create=True, skip_userfs=True)
if limit < 1 or limit > 200:
raise InvalidValue(
value_name="limit",
reason=f"The given value {limit} is not in the range <1, 200>",
)
else:
params.update({"limit": limit})
url = cast(str, url)
if asynchronous or get_all:
resource_generator = next_resource_generator(
client,
url=url,
href="ml/v4/custom_foundation_models",
params=params,
_all=get_all,
)
if asynchronous:
return resource_generator
resources = []
for entry in resource_generator:
resources.extend(entry["resources"])
return {"resources": resources}
response = requests.get(
f"{url}/ml/v4/custom_foundation_models",
params=params,
headers=client._get_headers(),
)
if response.status_code == 200:
if model_id:
resources = [
res
for res in response.json()["resources"]
if res["model_id"] == model_id
]
return resources[0] if resources else {}
else:
return response.json()
elif response.status_code == 404:
raise WMLClientError(
Messages.get_message(url, message_id="custom_models_no_model_specs"), url
)
else:
msg = f"Getting failed. Reason: {response.text}"
raise WMLClientError(msg)
[docs]
def get_model_lifecycle(url: str, model_id: str) -> list | None:
"""
Retrieve the list of model lifecycle data.
**Decrecated:** From ``ibm_watsonx_ai`` 1.0, the `get_model_lifecycle()` function is deprecated, use `client.foundation_models.get_model_lifecycle()` function instead.
:param url: URL of environment
:type url: str
:param model_id: the type of model to use
:type model_id: str
:return: list of deployed foundation model lifecycle data
:rtype: list
**Example:**
.. code-block:: python
from ibm_watsonx_ai.foundation_models import get_model_lifecycle
get_model_lifecycle(
url="https://us-south.ml.cloud.ibm.com",
model_id="ibm/granite-13b-instruct-v2"
)
"""
warn(
"`get_model_lifecycle()` function is deprecated from 1.0, please use `client.foundation_models.get_model_lifecycle()` function instead.",
category=DeprecationWarning,
)
model_specs = get_model_specs(url=url)
model_spec = next(
(
model_metadata
for model_metadata in model_specs.get("resources", [])
if model_metadata.get("model_id") == model_id
),
None,
)
return model_spec.get("lifecycle") if model_spec is not None else None
def _check_model_state(
client: APIClient,
model_id: str,
tech_preview: bool = False,
model_specs: dict | None = None,
) -> None:
default_warning_template = (
"Model '{model_id}' is in {state} state from {start_date} until {withdrawn_start_date}. "
"IDs of alternative models: {alternative_model_ids}. "
"Further details: https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-lifecycle.html?context=wx&audience=wdp"
)
if model_specs is not None:
lifecycle = None
for model_spec in model_specs["resources"]:
if model_spec["model_id"] == model_id:
lifecycle = model_spec.get("lifecycle")
break
elif client._use_fm_ga_api:
lifecycle = client.foundation_models.get_model_lifecycle(
model_id, tech_preview=tech_preview
)
else:
with catch_warnings():
simplefilter("ignore", category=DeprecationWarning)
lifecycle = get_model_lifecycle(client.credentials.url, model_id) # type: ignore[arg-type]
modes_list = [ids.get("id") for ids in (lifecycle or [])]
deprecated_or_constricted_warning_template_cpd = (
"Model '{model_id}' is in {state} state. "
"IDs of alternative models: {alternative_model_ids}. "
"Further details: https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-lifecycle.html?context=wx&audience=wdp"
)
if lifecycle and SpecStates.DEPRECATED.value in modes_list:
model_lifecycle = next(
(el for el in lifecycle if el.get("id") == SpecStates.DEPRECATED.value),
None,
)
model_lifecycle = cast(dict, model_lifecycle)
if model_lifecycle.get("since_version"):
warnings.warn(
deprecated_or_constricted_warning_template_cpd.format(
model_id=model_id,
state=(model_lifecycle.get("label") or SpecStates.DEPRECATED.value),
alternative_model_ids=", ".join(
model_lifecycle.get("alternative_model_ids", ["None"])
),
),
category=LifecycleWarning,
)
else:
warnings.warn(
default_warning_template.format(
model_id=model_id,
state=(model_lifecycle.get("label") or SpecStates.DEPRECATED.value),
start_date=model_lifecycle.get("start_date"),
withdrawn_start_date=next(
(
el.get("start_date")
for el in lifecycle
if el.get("id") == SpecStates.WITHDRAWN.value
),
None,
),
alternative_model_ids=", ".join(
model_lifecycle.get("alternative_model_ids", ["None"])
),
),
category=LifecycleWarning,
)
elif lifecycle and SpecStates.CONSTRICTED.value in modes_list:
model_lifecycle = next(
(el for el in lifecycle if el.get("id") == SpecStates.CONSTRICTED.value),
None,
)
model_lifecycle = cast(dict, model_lifecycle)
if model_lifecycle.get("since_version"):
warnings.warn(
deprecated_or_constricted_warning_template_cpd.format(
model_id=model_id,
state=(
model_lifecycle.get("label") or SpecStates.CONSTRICTED.value
),
alternative_model_ids=", ".join(
model_lifecycle.get("alternative_model_ids", ["None"])
),
),
category=LifecycleWarning,
)
else:
warnings.warn(
default_warning_template.format(
model_id=model_id,
state=(
model_lifecycle.get("label") or SpecStates.CONSTRICTED.value
),
start_date=model_lifecycle.get("start_date"),
withdrawn_start_date=next(
(
el.get("start_date")
for el in lifecycle
if el.get("id") == SpecStates.WITHDRAWN.value
),
None,
),
alternative_model_ids=", ".join(
model_lifecycle.get("alternative_model_ids", ["None"])
),
),
category=LifecycleWarning,
)
[docs]
def get_model_specs_with_prompt_tuning_support(url: str) -> dict:
"""
Query the details of the deployed foundation models with prompt tuning support.
**Decrecated:** From ``ibm_watsonx_ai`` 1.0, the `get_model_specs_with_prompt_tuning_support()` function is deprecated, use the `client.foundation_models.get_model_specs_with_prompt_tuning_support()` function instead.
:param url: URL of environment
:type url: str
:return: list of deployed foundation model specs with prompt tuning support
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.foundation_models import get_model_specs_with_prompt_tuning_support
get_model_specs_with_prompt_tuning_support(
url="https://us-south.ml.cloud.ibm.com"
)
"""
warn(
"`get_model_specs_with_prompt_tuning_support()` function is deprecated from 1.0, please use `client.foundation_models.get_model_specs_with_prompt_tuning_support()` function instead.",
category=DeprecationWarning,
)
try:
try:
return _get_foundation_models_spec(
url=f"{url}/ml/v1/foundation_model_specs",
operation_name="Get available foundation models",
additional_params={"filters": "function_prompt_tune_trainable"},
)
except WMLClientError: # Remove on CPD 5.0 release
return _get_foundation_models_spec(
url=f"{url}/ml/v1-beta/foundation_model_specs",
operation_name="Get available foundation models",
additional_params={"filters": "function_prompt_tune_trainable"},
)
except WMLClientError as e:
raise WMLClientError(
Messages.get_message(url, message_id="fm_prompt_tuning_no_model_specs"),
e.reason,
)
[docs]
def get_supported_tasks(url: str) -> dict:
"""
Retrieves a list of tasks that are supported by the foundation models.
:param url: URL of the environment
:type url: str
:return: list of tasks that are supported by the foundation models
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.foundation_models import get_supported_tasks
get_supported_tasks(
url="https://us-south.ml.cloud.ibm.com"
)
"""
try:
try:
return _get_foundation_models_spec(
f"{url}/ml/v1/foundation_model_tasks",
"Get tasks that are supported by the foundation models.",
)
except WMLClientError: # Remove on CPD 5.0 release
return _get_foundation_models_spec(
f"{url}/ml/v1-beta/foundation_model_tasks",
"Get tasks that are supported by the foundation models.",
)
except WMLClientError as e:
raise WMLClientError(
Messages.get_message(url, message_id="fm_prompt_tuning_no_supported_tasks"),
e.reason,
)
def get_all_supported_tasks_dict(
url: str = "https://us-south.ml.cloud.ibm.com",
) -> dict:
tasks_dict = dict()
for task_spec in get_supported_tasks(url).get("resources", []):
tasks_dict[task_spec["label"].replace("-", "_").replace(" ", "_").upper()] = (
task_spec["task_id"]
)
return tasks_dict
def load_request_json(
run_id: str,
api_client: APIClient,
run_params: dict[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
# note: backward compatibility
if (wml_client := kwargs.get("wml_client")) is None and api_client is None:
raise WMLClientError("No `api_client` provided")
elif wml_client is not None:
if api_client is None:
api_client = wml_client
warnings.warn(
(
"`wml_client` parameter is deprecated and will be removed in future. "
"Instead, please use `api_client`."
),
category=DeprecationWarning,
)
# --- end note
if run_params is None:
run_params = api_client.training.get_details(run_id)
model_request_path = (
run_params["entity"]
.get("results_reference", {})
.get("location", {})
.get("model_request_path")
)
if model_request_path is None:
raise WMLClientError(
"Missing model_request_path in run_params. Verify if the training run has been completed."
)
if api_client.CLOUD_PLATFORM_SPACES:
results_reference = DataConnection._from_dict(
run_params["entity"]["results_reference"]
)
if run_params["entity"]["results_reference"]["type"] == DataConnectionTypes.CA:
results_reference.location.file_name = model_request_path # type: ignore
else:
results_reference.location.path = model_request_path # type: ignore[union-attr]
results_reference.set_client(api_client)
request_json_bytes = results_reference.read(raw=True, binary=True)
# download from cos
elif api_client.CPD_version >= 4.8:
asset_parts = model_request_path.split("/")
model_request_asset_url = "/".join(
asset_parts[asset_parts.index("assets") + 1 :]
)
request_json_bytes = load_file_from_file_system_nonautoai(
api_client=api_client, file_path=model_request_asset_url
).read()
else:
raise WMLClientError("Unsupported environment for this action")
request_json_bytes = cast(bytes, request_json_bytes)
return json_loads(request_json_bytes.decode())
def is_training_prompt_tuning(
training_id: str | None = None, api_client: APIClient | None = None, **kwargs: Any
) -> bool:
"""Returns True if training_id is connected to prompt tuning"""
# note: backward compatibility
if (wml_client := kwargs.get("wml_client")) is None and api_client is None:
raise WMLClientError("No `api_client` provided")
elif wml_client is not None:
if api_client is None:
api_client = wml_client
warnings.warn(
(
"`wml_client` parameter is deprecated and will be removed in future. "
"Instead, please use `api_client`."
),
category=DeprecationWarning,
)
# --- end note
if training_id is None:
return False
run_params = api_client.training.get_details(training_id=training_id) # type: ignore[union-attr]
return bool(run_params["entity"].get("prompt_tuning"))
class TemplateFormatter(Formatter):
def check_unused_args(
self, used_args: set[(int | str)], args: Sequence, kwargs: Mapping[str, Any]
) -> None:
"""Check for unused args."""
extra_args = set(kwargs).difference(used_args)
if extra_args:
raise KeyError(extra_args)
class HAPDetectionWarning(UserWarning): ...
class PIIDetectionWarning(UserWarning): ...
class LifecycleWarning(UserWarning): ...
class WatsonxLLMDeprecationWarning(UserWarning): ...
def _raise_watsonxllm_deprecation_warning() -> None:
warnings.warn(
"ibm_watsonx_ai.foundation_models.extensions.langchain.WatsonxLLM"
" is deprecated and will not be supported in the future. "
"Please import from langchain-ibm instead.\n"
"To install langchain-ibm run `pip install -U langchain-ibm`.",
category=WatsonxLLMDeprecationWarning,
stacklevel=2,
)
def get_embedding_model_specs(url: str) -> dict:
"""
**Decrecated:** From ``ibm_watsonx_ai`` 1.0, the `get_embedding_model_specs()` function is deprecated, use the `client.foundation_models.get_embeddings_model_specs()` function instead.
"""
warn(
"`get_embedding_model_specs()` function is deprecated from 1.0, please use `client.foundation_models.get_embeddings_model_specs()` function instead.",
category=DeprecationWarning,
)
return _get_foundation_models_spec(
url=f"{url}/ml/v1/foundation_model_specs",
operation_name="Get available embedding models",
additional_params={"filters": "function_embedding"},
)
def _copy_function(func: Callable) -> Callable:
"""Custom copy of function"""
new_func = types.FunctionType(
func.__code__,
func.__globals__,
name=func.__name__,
argdefs=func.__defaults__,
closure=func.__closure__,
)
new_func = functools.update_wrapper(new_func, func) # type: ignore[assignment]
new_func.__kwdefaults__ = copy.deepcopy(func.__kwdefaults__)
return new_func
def _is_fine_tuning_endpoint_available(api_client: APIClient) -> bool:
try:
url = api_client.service_instance._href_definitions.get_fine_tunings_href()
response_fine_tuning_api = api_client._session.get(
url=f"{url}?limit=1",
params=api_client._params(),
headers=api_client._get_headers(),
)
return response_fine_tuning_api.status_code == 200
except:
return False