# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2023-2025.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from ibm_watsonx_ai.foundation_models.base_tuner import BaseTuner
from ibm_watsonx_ai.messages.messages import Messages
from ibm_watsonx_ai.wml_resource import WMLResource
from ibm_watsonx_ai.wml_client_error import WMLClientError
from ibm_watsonx_ai.helpers.connections import (
DataConnection,
)
from ibm_watsonx_ai.utils.autoai.utils import is_ipython
from ibm_watsonx_ai.foundation_models.utils import PromptTuningParams
import datetime
if TYPE_CHECKING:
from ibm_watsonx_ai import APIClient
from pandas import DataFrame
[docs]
class PromptTuner(BaseTuner):
id: str | None = None
_client: APIClient = None # type: ignore[assignment]
_training_metadata: dict | None = None
def __init__(
self,
name: str,
task_id: str,
*,
description: str | None = None,
base_model: str | None = None,
accumulate_steps: int | None = None,
batch_size: int | None = None,
init_method: str | None = None,
init_text: str | None = None,
learning_rate: float | None = None,
max_input_tokens: int | None = None,
max_output_tokens: int | None = None,
num_epochs: int | None = None,
verbalizer: str | None = None,
tuning_type: str | None = None,
auto_update_model: bool = True,
group_by_name: bool | None = None,
):
BaseTuner.__init__(self, "prompt")
self.name = name
self.description = description if description else "Prompt tuning with SDK"
self.auto_update_model = auto_update_model
self.group_by_name = group_by_name
base_model_red: dict = {"model_id": base_model}
self.prompt_tuning_params = PromptTuningParams(
base_model=base_model_red,
accumulate_steps=accumulate_steps,
batch_size=batch_size,
init_method=init_method,
init_text=init_text,
learning_rate=learning_rate,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_epochs=num_epochs,
task_id=task_id,
tuning_type=tuning_type,
verbalizer=verbalizer,
)
if not isinstance(self.name, str):
raise WMLClientError(
f"'name' param expected string, but got {type(self.name)}: {self.name}"
)
if self.description and (not isinstance(self.description, str)):
raise WMLClientError(
f"'description' param expected string, but got {type(self.description)}: "
f"{self.description}"
)
if self.auto_update_model and (not isinstance(self.auto_update_model, bool)):
raise WMLClientError(
f"'auto_update_model' param expected bool, but got {type(self.auto_update_model)}: "
f"{self.auto_update_model}"
)
if self.group_by_name and (not isinstance(self.group_by_name, bool)):
raise WMLClientError(
f"'group_by_name' param expected bool, but got {type(self.group_by_name)}: "
f"{self.group_by_name}"
)
[docs]
def run(
self,
training_data_references: list[DataConnection],
training_results_reference: DataConnection | None = None,
background_mode: bool = False,
) -> dict:
"""Run a prompt tuning process of a foundation model on top of the training data referenced by DataConnection.
:param training_data_references: data storage connection details to inform where the training data is stored
:type training_data_references: list[DataConnection]
:param training_results_reference: data storage connection details to store pipeline training results
:type training_results_reference: DataConnection, optional
:param background_mode: indicator if the fit() method will run in the background, async or sync
:type background_mode: bool, optional
:return: run details
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
from ibm_watsonx_ai.helpers import DataConnection, S3Location
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(
training_data_references=[DataConnection(
connection_asset_id=connection_id,
location=S3Location(
bucket='prompt_tuning_data',
path='pt_train_data.json')
)
)]
background_mode=False)
"""
WMLResource._validate_type(
training_data_references, "training_data_references", list, mandatory=True
)
WMLResource._validate_type(
training_results_reference,
"training_results_reference",
object,
mandatory=False,
)
for source_data_connection in [training_data_references]:
if source_data_connection:
self._validate_source_data_connections(source_data_connection)
training_results_reference = self._determine_result_reference(
results_reference=training_results_reference,
data_references=training_data_references,
)
self._initialize_training_metadata(
training_data_references,
test_data_references=None,
training_results_reference=training_results_reference,
)
self._training_metadata = cast(dict, self._training_metadata)
tuning_details = self._client.training.run(
meta_props=self._training_metadata, asynchronous=background_mode
)
self.id = self._client.training.get_id(tuning_details)
return self._client.training.get_details(
self.id
) # TODO improve the background_mode = False option
def _initialize_training_metadata(
self,
training_data_references: list[DataConnection],
test_data_references: list[DataConnection] | None = None,
training_results_reference: DataConnection | None = None,
) -> None:
self._training_metadata = {
self._client.training.ConfigurationMetaNames.TAGS: self._get_tags(),
self._client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [
connection._to_dict() for connection in training_data_references
],
self._client.training.ConfigurationMetaNames.NAME: f"{self.name[:100]}",
self._client.training.ConfigurationMetaNames.PROMPT_TUNING: self.prompt_tuning_params.to_dict(),
}
if test_data_references:
self._training_metadata[
self._client.training.ConfigurationMetaNames.TEST_DATA_REFERENCES
] = [connection._to_dict() for connection in test_data_references]
if training_results_reference:
self._training_metadata[
self._client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE
] = training_results_reference._to_dict()
if self.description:
self._training_metadata[
self._client.training.ConfigurationMetaNames.DESCRIPTION
] = f"{self.description}"
if self.auto_update_model is not None:
self._training_metadata[
self._client.training.ConfigurationMetaNames.AUTO_UPDATE_MODEL
] = self.auto_update_model
def _get_tags(self) -> list:
tags = ["prompt_tuning"]
if self.group_by_name is not None and self.group_by_name:
for training in self._client.training.get_details(
tag_value="prompt_tuning"
)["resources"]:
if training["metadata"].get("name") == self.name:
# Find recent tags related to 'name'
tags = list(set(tags) | set(training["metadata"].get("tags")))
break
if tags != ["prompt_tuning"]:
self._client.generate_ux_tag = False
return tags
[docs]
def get_params(self) -> dict:
"""Get configuration parameters of PromptTuner.
:return: PromptTuner parameters
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.get_params()
# Result:
#
# {'base_model': {'name': 'google/flan-t5-xl'},
# 'task_id': 'summarization',
# 'name': 'Prompt Tuning of Flan T5 model',
# 'auto_update_model': False,
# 'group_by_name': False}
"""
params = self.prompt_tuning_params.to_dict()
params["name"] = self.name
params["description"] = self.description
params["auto_update_model"] = self.auto_update_model
params["group_by_name"] = self.group_by_name
return params
#####################
# Run operations #
#####################
[docs]
def get_run_status(self) -> str:
"""Check the status/state of an initialized prompt tuning run if it was run in background mode.
:return: status of the Prompt Tuning run
:rtype: str
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.get_run_details()
# Result:
# 'completed'
"""
if self.id is None:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_not_scheduled")
)
return self._client.training.get_status(training_id=self.id).get("state") # type: ignore[return-value]
[docs]
def get_run_details(self, include_metrics: bool = False) -> dict:
"""Get details of a prompt tuning run.
:param include_metrics: indicates to include metrics in the training details output
:type include_metrics: bool, optional
:return: details of the prompt tuning
:rtype: dict
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.get_run_details()
"""
if self.id is None:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_not_scheduled")
)
details = self._client.training.get_details(training_id=self.id)
if include_metrics:
try:
details["entity"]["status"]["metrics"] = (
self._get_metrics_data_from_property_or_file(details)
)
except KeyError:
pass
finally:
return details
if details["entity"]["status"].get("metrics", False):
del details["entity"]["status"]["metrics"]
return details
[docs]
def plot_learning_curve(self) -> None:
"""Plot learning curves.
.. note ::
Available only for Jupyter notebooks.
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.plot_learning_curve()
"""
if not is_ipython():
raise WMLClientError(
"Function `plot_learning_curve` is available only for Jupyter notebooks."
)
from ibm_watsonx_ai.utils.autoai.incremental import plot_learning_curve
import matplotlib.pyplot as plt
tuning_details = self.get_run_details(include_metrics=True)
if "metrics" in tuning_details["entity"]["status"]:
# average loss score for each epoch
scores = self._get_average_loss_score_for_each_epoch(
tuning_details=tuning_details, epoch=0
)
# date_time from the first and last iteration on each epoch
if "data" in tuning_details["entity"]["status"]["metrics"][0]:
date_times = [
datetime.datetime.strptime(
m_obj["data"]["timestamp"], "%Y-%m-%dT%H:%M:%S.%f"
)
for m_obj in self._get_first_and_last_iteration_metrics_for_each_epoch(
tuning_details=tuning_details
)
]
else:
date_times = [
datetime.datetime.strptime(
m_obj["timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z"
)
for m_obj in self._get_first_and_last_iteration_metrics_for_each_epoch(
tuning_details=tuning_details
)
]
elapsed_time = []
for i in range(1, len(date_times), 2):
elapsed_time.append((date_times[i] - date_times[i - 1]).total_seconds())
fig, axes = plt.subplots(1, 3, figsize=(18, 4))
if scores:
plot_learning_curve(
fig=fig,
axes=axes,
scores=scores,
fit_times=elapsed_time,
xlabels={"first_xlabel": "Epochs", "second_xlabel": "Epochs"},
titles={"first_plot": "Loss function"},
)
else:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_no_metrics")
)
[docs]
def summary(self, scoring: str = "loss") -> DataFrame:
"""Print the details of PromptTuner models (prompt-tuned models).
:param scoring: scoring metric for sorting pipelines,
when not provided, uses loss one
:type scoring: string, optional
:return: computed models and metrics
:rtype: pandas.DataFrame
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.summary()
# Result:
# Enhancements Base model ... loss
# Model Name
# Prompt_tuned_M_1 [prompt_tuning] google/flan-t5-xl ... 0.449197
"""
if self.id is None:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_not_scheduled")
)
from pandas import DataFrame
details = self.get_run_details(include_metrics=True)
metrics = details["entity"]["status"].get("metrics", [{}])[0]
is_ml_metrics = "data" in metrics or "ml_metrics" in metrics
if not is_ml_metrics:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_no_metrics")
)
columns = [
"Model Name",
"Enhancements",
"Base model",
"Auto store",
"Epochs",
scoring,
]
values = []
model_name = "model_" + self.id
base_model_name = None
epochs = None
enhancements = []
if scoring == "loss":
model_metrics = [
self._get_average_loss_score_for_each_epoch(
tuning_details=details, epoch=0
)[-1]
]
else:
if "data" in details["entity"]["status"]["metrics"][0]:
model_metrics = [
details["entity"]["status"]
.get("metrics", [{}])[-1]
.get("data", {})[scoring]
]
else:
model_metrics = [
details["entity"]["status"]
.get("metrics", [{}])[-1]
.get("ml_metrics", {})[scoring]
]
if "prompt_tuning" in details["entity"]:
enhancements = [details["entity"]["prompt_tuning"]["tuning_type"]]
base_model_name = details["entity"]["prompt_tuning"]["base_model"][
"model_id"
]
epochs = details["entity"]["prompt_tuning"]["num_epochs"]
values.append(
(
[model_name]
+ [enhancements]
+ [base_model_name]
+ [details["entity"]["auto_update_model"]]
+ [epochs]
+ model_metrics
)
)
summary = DataFrame(data=values, columns=columns)
summary.set_index("Model Name", inplace=True)
return summary
[docs]
def get_model_id(self) -> str:
"""Get the model ID.
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.get_model_id()
"""
run_details = self.get_run_details()
if run_details["entity"]["auto_update_model"]:
return run_details["entity"]["model_id"]
else:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_no_model_id")
)
[docs]
def cancel_run(self, hard_delete: bool = False) -> None:
"""Cancel or delete a Prompt Tuning run.
:param hard_delete: if True, the completed or cancelled prompt tuning run is deleted,
if False, the current run is canceled. Default: False
:type hard_delete: bool, optional
"""
if self.id is None:
raise WMLClientError(
Messages.get_message(message_id="fm_prompt_tuning_not_scheduled")
)
self._client.training.cancel(training_id=self.id, hard_delete=hard_delete)
[docs]
def get_data_connections(self) -> list[DataConnection]:
"""Create DataConnection objects for further usage
(eg. to handle data storage connection).
:return: list of DataConnections
:rtype: list['DataConnection']
**Example:**
.. code-block:: python
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
data_connections = prompt_tuner.get_data_connections()
"""
training_data_references = self.get_run_details()["entity"][
"training_data_references"
]
data_connections = [
DataConnection._from_dict(_dict=data_connection)
for data_connection in training_data_references
]
for data_connection in data_connections:
data_connection.set_client(self._client)
data_connection._run_id = self.id
return data_connections