# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2023-2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from typing import List, Dict
from ibm_watson_machine_learning import APIClient
from ibm_watson_machine_learning.messages.messages import Messages
from ibm_watson_machine_learning.wml_resource import WMLResource
from ibm_watson_machine_learning.wml_client_error import WMLClientError
from ibm_watson_machine_learning.utils.autoai.errors import ContainerTypeNotSupported
from ibm_watson_machine_learning.helpers.connections import (DataConnection, ContainerLocation, S3Connection,
S3Location, FSLocation, AssetLocation)
from ibm_watson_machine_learning.utils.autoai.utils import is_ipython
from ibm_watson_machine_learning.foundation_models.utils import PromptTuningParams
import datetime
import numpy as np
[docs]
class PromptTuner:
id: str = None
_client: APIClient = None
_training_metadata: dict = None
def __init__(self,
name: str,
task_id: str,
*,
description: str = None,
base_model: str = None,
accumulate_steps: int = None,
batch_size: int = None,
init_method: str = None,
init_text: str = None,
learning_rate: float = None,
max_input_tokens: int = None,
max_output_tokens: int = None,
num_epochs: int = None,
verbalizer: str = None,
tuning_type: str = None,
auto_update_model: bool = True,
group_by_name: bool = None):
self.name = name
self.description = description if description else "Prompt tuning with SDK"
self.auto_update_model = auto_update_model
self.group_by_name = group_by_name
base_model = {'model_id': base_model}
self.prompt_tuning_params = PromptTuningParams(base_model=base_model,
accumulate_steps=accumulate_steps,
batch_size=batch_size,
init_method=init_method,
init_text=init_text,
learning_rate=learning_rate,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
num_epochs=num_epochs,
task_id=task_id,
tuning_type=tuning_type,
verbalizer=verbalizer)
if not isinstance(self.name, str):
raise WMLClientError(f"'name' param expected string, but got {type(self.name)}: {self.name}")
if self.description and (not isinstance(self.description, str)):
raise WMLClientError(f"'description' param expected string, but got {type(self.description)}: "
f"{self.description}")
if self.auto_update_model and (not isinstance(self.auto_update_model, bool)):
raise WMLClientError(f"'auto_update_model' param expected bool, but got {type(self.auto_update_model)}: "
f"{self.auto_update_model}")
if self.group_by_name and (not isinstance(self.group_by_name, bool)):
raise WMLClientError(f"'group_by_name' param expected bool, but got {type(self.group_by_name)}: "
f"{self.group_by_name}")
[docs]
def run(self,
training_data_references: List['DataConnection'],
training_results_reference: DataConnection = None,
background_mode=False
) -> dict:
"""Run a prompt tuning process of foundation model on top of the training data referenced by DataConnection.
:param training_data_references: data storage connection details to inform where training data is stored
:type training_data_references: list[DataConnection]
:param training_results_reference: data storage connection details to store pipeline training results
:type training_results_reference: DataConnection, optional
:param background_mode: indicator if fit() method will run in background (async) or (sync)
:type background_mode: bool, optional
:return: run details
:rtype: dict
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
from ibm_watson_machine_learning.helpers import DataConnection, S3Location
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(
training_data_connection=[DataConnection(
connection_asset_id=connection_id,
location=S3Location(
bucket='prompt_tuning_data',
path='pt_train_data.json')
)
)]
background_mode=False)
"""
WMLResource._validate_type(training_data_references, "training_data_references", list, mandatory=True)
WMLResource._validate_type(training_results_reference, "training_results_reference", object, mandatory=False)
for source_data_connection in [training_data_references]:
if source_data_connection:
self._validate_source_data_connections(source_data_connection)
training_results_reference = self._determine_result_reference(results_reference=training_results_reference,
data_references=training_data_references)
self._initialize_training_metadata(training_data_references, test_data_references=None,
training_results_reference=training_results_reference)
tuning_details = self._client.training.run(meta_props=self._training_metadata, asynchronous=background_mode)
self.id = self._client.training.get_id(tuning_details)
return self._client.training.get_details(self.id) # TODO improve the background_mode = False option
def _initialize_training_metadata(self,
training_data_references: List['DataConnection'],
test_data_references: List['DataConnection'] = None,
training_results_reference: DataConnection = None,
):
self._training_metadata = {
self._client.training.ConfigurationMetaNames.TAGS: self._get_tags(),
self._client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES:
[connection._to_dict() for connection in training_data_references],
self._client.training.ConfigurationMetaNames.NAME: f"{self.name[:100]}",
self._client.training.ConfigurationMetaNames.PROMPT_TUNING: self.prompt_tuning_params.to_dict()
}
if test_data_references:
self._training_metadata[self._client.training.ConfigurationMetaNames.TEST_DATA_REFERENCES] = [
connection._to_dict() for connection in test_data_references]
if training_results_reference:
self._training_metadata[
self._client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE] = \
training_results_reference._to_dict()
if self.description:
self._training_metadata[
self._client.training.ConfigurationMetaNames.DESCRIPTION] = f"{self.description}"
if self.auto_update_model is not None:
self._training_metadata[
self._client.training.ConfigurationMetaNames.AUTO_UPDATE_MODEL] = self.auto_update_model
def _validate_source_data_connections(self, source_data_connections):
for data_connection in source_data_connections:
if isinstance(data_connection.location, ContainerLocation):
if self._client.ICP:
raise ContainerTypeNotSupported() # block Container type on CPD
elif isinstance(data_connection.connection, S3Connection):
# note: remove S3 inline credential from data asset before training
data_connection.connection = None
if hasattr(data_connection.location, 'bucket'):
delattr(data_connection.location, 'bucket')
# --- end note
if isinstance(data_connection.connection, S3Connection) and isinstance(data_connection.location,
AssetLocation):
# note: remove S3 inline credential from data asset before training
data_connection.connection = None
for s3_attr in ['bucket', 'path']:
if hasattr(data_connection.location, s3_attr):
delattr(data_connection.location, s3_attr)
# --- end note
return source_data_connections
def _determine_result_reference(self, results_reference, data_references, result_path="default_tuning_output"):
# note: if user did not provide results storage information, use default ones
if results_reference is None:
if self._client.ICP:
location = FSLocation(path="/{option}/{id}/assets/wx_prompt_tune")
if self._client.default_project_id is None:
location.path = location.path.format(option='spaces',
id=self._client.default_space_id)
else:
location.path = location.path.format(option='projects',
id=self._client.default_project_id)
results_reference = DataConnection(
connection=None,
location=location
)
else:
if isinstance(data_references[0].location, S3Location):
results_reference = DataConnection(
connection=data_references[0].connection,
location=S3Location(bucket=data_references[0].location.bucket,
path=".")
)
elif isinstance(data_references[0].location, AssetLocation):
connection_id = data_references[0].location._get_connection_id(self._client)
if connection_id is not None:
results_reference = DataConnection(
connection_asset_id=connection_id,
location=S3Location(
bucket=data_references[0].location._get_bucket(self._client),
path=result_path)
)
else: # set container output location when default DAta Asset is as a train ref
results_reference = DataConnection(
location=ContainerLocation(path=result_path))
else:
results_reference = DataConnection(location=ContainerLocation(path=result_path))
# -- end note
# note: validate location types:
if self._client.ICP:
if not isinstance(results_reference.location,
FSLocation):
raise TypeError('Unsupported results location type. Results reference can be stored on FSLocation.')
else:
if not isinstance(results_reference.location,
(S3Location, ContainerLocation)):
raise TypeError('Unsupported results location type. Results reference can be stored'
' only on S3Location or ContainerLocation.')
# -- end note
return results_reference
def _get_tags(self):
tags = ['prompt_tuning']
if self.group_by_name is not None and self.group_by_name:
for training in self._client.training.get_details(tag_value='prompt_tuning')['resources']:
if training['metadata'].get('name') == self.name:
# Find recent tags related to 'name'
tags = list(set(tags) | set(training['metadata'].get('tags')))
break
if tags != ['prompt_tuning']:
self._client.generate_ux_tag = False
return tags
@staticmethod
def _get_last_iteration_metrics_for_each_epoch(tuning_details):
last_iteration_metrics_for_each_epoch = []
for ind in range(len(tuning_details['entity']['status']['metrics'])):
if ind == 0:
last_iteration_metrics_for_each_epoch.append(tuning_details['entity']['status']['metrics'][0])
else:
if tuning_details['entity']['status']['metrics'][ind]['ml_metrics']['epoch'] == \
tuning_details['entity']['status']['metrics'][ind - 1]['ml_metrics']['epoch']:
last_iteration_metrics_for_each_epoch.pop()
last_iteration_metrics_for_each_epoch.append(tuning_details['entity']['status']['metrics'][ind])
else:
last_iteration_metrics_for_each_epoch.append(tuning_details['entity']['status']['metrics'][ind])
return last_iteration_metrics_for_each_epoch
@staticmethod
def _get_average_loss_score_for_each_epoch(tuning_details):
scores = []
temp_score = []
epoch = 0
if "data" in tuning_details['entity']['status']['metrics'][0]:
for ind, metric in enumerate(tuning_details['entity']['status']['metrics']):
if int(metric['data']['epoch']) == epoch:
temp_score.append(metric['data']['value'])
else:
epoch += 1
scores.append(np.average(temp_score))
temp_score = [metric['data']['value']]
scores.append(np.average(temp_score))
else:
for ind, metric in enumerate(tuning_details['entity']['status']['metrics']):
if int(metric['ml_metrics']['epoch']) == epoch:
temp_score.append(metric['ml_metrics']['loss'])
else:
epoch += 1
scores.append(np.average(temp_score))
temp_score = [metric['ml_metrics']['loss']]
scores.append(np.average(temp_score))
return scores
@staticmethod
def _get_first_and_last_iteration_metrics_for_each_epoch(tuning_details):
first_and_last_iteration_metrics_for_each_epoch = []
first_iteration = True
tuning_metrics = tuning_details['entity']['status']['metrics']
for ind in range(len(tuning_metrics)):
if ind == 0:
first_and_last_iteration_metrics_for_each_epoch.append(tuning_metrics[ind])
first_and_last_iteration_metrics_for_each_epoch.append(tuning_metrics[ind])
first_iteration = False
elif first_iteration:
first_and_last_iteration_metrics_for_each_epoch.append(tuning_metrics[ind])
first_iteration = False
else:
if tuning_metrics[ind].get("data", tuning_metrics[ind].get("ml_metrics"))['epoch'] == tuning_metrics[ind - 1].get("data", tuning_metrics[ind-1].get("ml_metrics"))['epoch']:
first_and_last_iteration_metrics_for_each_epoch.pop()
first_and_last_iteration_metrics_for_each_epoch.append(tuning_metrics[ind])
else:
first_and_last_iteration_metrics_for_each_epoch.append(tuning_metrics[ind])
first_iteration = True
return first_and_last_iteration_metrics_for_each_epoch
[docs]
def get_params(self) -> dict:
"""Get configuration parameters of PromptTuner.
:return: PromptTuner parameters
:rtype: dict
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.get_params()
# Result:
#
# {'base_model': {'name': 'google/flan-t5-xl'},
# 'task_id': 'summarization',
# 'name': 'Prompt Tuning of Flan T5 model',
# 'auto_update_model': False,
# 'group_by_name': False}
"""
params = self.prompt_tuning_params.to_dict()
params['name'] = self.name
params['description'] = self.description
params['auto_update_model'] = self.auto_update_model
params['group_by_name'] = self.group_by_name
return params
#####################
# Run operations #
#####################
[docs]
def get_run_status(self):
"""Check status/state of initialized Prompt Tuning run if ran in background mode.
:return: Prompt tuning run status
:rtype: str
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.get_run_details()
# Result:
# 'completed'
"""
if self.id is None:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_not_scheduled"))
return self._client.training.get_status(training_uid=self.id).get('state')
[docs]
def get_run_details(self, include_metrics: bool = False) -> dict:
"""Get prompt tuning run details.
:param include_metrics: indicates to include metrics in the training details output
:type include_metrics: bool, optional
:return: Prompt tuning details
:rtype: dict
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.get_run_details()
"""
if self.id is None:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_not_scheduled"))
details = self._client.training.get_details(training_uid=self.id)
if include_metrics:
try:
details["entity"]["status"]["metrics"] = self._get_metrics_data_from_property_or_file(details)
except KeyError:
pass
finally:
return details
if details['entity']['status'].get('metrics', False):
del details['entity']['status']['metrics']
return details
def _get_metrics_data_from_property_or_file(self, details: Dict) -> Dict:
path = details["entity"]["status"]["metrics"][0]["context"]["prompt_tuning"]["metrics_location"]
results_reference = details["entity"]['results_reference']
conn = DataConnection._from_dict(results_reference)
conn._wml_client = self._client
metrics_data = conn._download_json_file(path)
return metrics_data
[docs]
def plot_learning_curve(self):
"""Plot learning curves.
.. note ::
Available only for Jupyter notebooks.
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.plot_learning_curve()
"""
if not is_ipython():
raise WMLClientError("Function `plot_learning_curve` is available only for Jupyter notebooks.")
from ibm_watson_machine_learning.utils.autoai.incremental import plot_learning_curve
import matplotlib.pyplot as plt
tuning_details = self.get_run_details(include_metrics=True)
if 'metrics' in tuning_details['entity']['status']:
# average loss score for each epoch
scores = self._get_average_loss_score_for_each_epoch(tuning_details=tuning_details)
# date_time from the first and last iteration on each epoch
if "data" in tuning_details['entity']['status']['metrics'][0]:
date_times = [datetime.datetime.strptime(m_obj["data"]['timestamp'], '%Y-%m-%dT%H:%M:%S.%f')
for m_obj in
self._get_first_and_last_iteration_metrics_for_each_epoch(tuning_details=tuning_details)]
else:
date_times = [datetime.datetime.strptime(m_obj['timestamp'], '%Y-%m-%dT%H:%M:%S.%f%z')
for m_obj in self._get_first_and_last_iteration_metrics_for_each_epoch(tuning_details=tuning_details)]
elapsed_time = []
for i in range(1, len(date_times), 2):
elapsed_time.append((date_times[i] - date_times[i - 1]).total_seconds())
fig, axes = plt.subplots(1, 3, figsize=(18, 4))
if scores:
plot_learning_curve(fig=fig, axes=axes, scores=scores, fit_times=elapsed_time,
xlabels={'first_xlabel': 'Epochs', 'second_xlabel': 'Epochs'},
titles={'first_plot': 'Loss function'})
else:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_no_metrics"))
[docs]
def summary(self, scoring: str = 'loss') -> 'DataFrame':
"""Print PromptTuner models details (prompt-tuned models).
:param scoring: scoring metric which user wants to use to sort pipelines by,
when not provided use loss one
:type scoring: string, optional
:return: computed models and metrics
:rtype: pandas.DataFrame
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.summary()
# Result:
# Enhancements Base model ... loss
# Model Name
# Prompt_tuned_M_1 [prompt_tuning] google/flan-t5-xl ... 0.449197
"""
if self.id is None:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_not_scheduled"))
from pandas import DataFrame
details = self.get_run_details(include_metrics=True)
metrics = details['entity']['status'].get('metrics', [{}])[0]
is_ml_metrics = 'data' in metrics or 'ml_metrics' in metrics
if not is_ml_metrics:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_no_metrics"))
columns = ['Model Name', 'Enhancements', 'Base model', 'Auto store', 'Epochs', scoring]
values = []
model_name = 'model_' + self.id
base_model_name = None
epochs = None
enhancements = []
if scoring == 'loss':
model_metrics = [self._get_average_loss_score_for_each_epoch(tuning_details=details)[-1]]
else:
if "data" in details['entity']['status']['metrics'][0]:
model_metrics = [details['entity']['status'].get('metrics', [{}])[-1].get('data', {})[scoring]]
else:
model_metrics = [details['entity']['status'].get('metrics', [{}])[-1].get('ml_metrics', {})[scoring]]
if 'prompt_tuning' in details['entity']:
enhancements = [details['entity']['prompt_tuning']['tuning_type']]
base_model_name = details['entity']['prompt_tuning']['base_model']['model_id']
epochs = details['entity']['prompt_tuning']['num_epochs']
values.append((
[model_name] +
[enhancements] +
[base_model_name] +
[details['entity']['auto_update_model']] +
[epochs] +
model_metrics
))
summary = DataFrame(data=values, columns=columns)
summary.set_index('Model Name', inplace=True)
return summary
[docs]
def get_model_id(self):
"""Get model id.
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
prompt_tuner.get_model_id()
"""
run_details = self.get_run_details()
if run_details['entity']['auto_update_model']:
return run_details['entity']['model_id']
else:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_no_model_id"))
[docs]
def cancel_run(self, hard_delete=False) -> None:
"""Cancels or deletes a Prompt Tuning run.
:param hard_delete: When True then the completed or cancelled prompt tuning run is deleted,
if False then the current run is canceled. Default: False
:type hard_delete: bool, optional
"""
if self.id is None:
raise WMLClientError(Messages.get_message(message_id="fm_prompt_tuning_not_scheduled"))
self._client.training.cancel(training_uid=self.id, hard_delete=hard_delete)
[docs]
def get_data_connections(self) -> List['DataConnection']:
"""Create DataConnection objects for further user usage
(eg. to handle data storage connection).
:return: list of DataConnections
:rtype: list['DataConnection']
**Example**
.. code-block:: python
from ibm_watson_machine_learning.experiment import TuneExperiment
experiment = TuneExperiment(credentials, ...)
prompt_tuner = experiment.prompt_tuner(...)
prompt_tuner.run(...)
data_connections = prompt_tuner.get_data_connections()
"""
training_data_references = self.get_run_details()['entity']['training_data_references']
data_connections = [
DataConnection._from_dict(_dict=data_connection) for data_connection in training_data_references]
for data_connection in data_connections:
data_connection.set_client(self._client)
data_connection._run_id = self.id
return data_connections