Working with TuneExperiment and PromptTuner =========================================== The :ref:`TuneExperiment class` is responsible for creating experiments and scheduling tunings. All experiment results are stored automatically in your chosen Cloud Object Storage (COS) for SaaS or in the cluster's file system for Cloud Pak for Data. Then the TuneExperiment feature can fetch the results and provide them directly to you for further use. Configure PromptTuner ---------------------- For an TuneExperiment object initialization, you need authentication credentials (for examples, see :doc:`setup`) and the ``project_id`` or the ``space_id``. .. hint:: You can copy the project_id from the Project's Manage tab (Project -> Manage -> General -> Details). .. code-block:: python from ibm_watsonx_ai.foundation_models.utils.enums import ModelTypes from ibm_watsonx_ai.experiment import TuneExperiment experiment = TuneExperiment(credentials, project_id="7ac03029-8bdd-4d5f-a561-2c4fd1e40705" ) prompt_tuner = experiment.prompt_tuner( name="prompt tuning name", task_id=experiment.Tasks.CLASSIFICATION, base_model=ModelTypes.FLAN_T5_XL, accumulate_steps=32, batch_size=16, learning_rate=0.2, max_input_tokens=256, max_output_tokens=2, num_epochs=6, tuning_type=experiment.PromptTuningTypes.PT, verbalizer="Extract the satisfaction from the comment. Return simple '1' for satisfied customer or '0' for unsatisfied. Input: {{input}} Output: ", auto_update_model=True ) Get configuration parameters ---------------------------- To see the current configuration parameters, call the ``get_params()`` method. .. code-block:: python config_parameters = prompt_tuner.get_params() print(config_parameters) { 'base_model': {'model_id': 'google/flan-t5-xl'}, 'accumulate_steps': 32, 'batch_size': 16, 'learning_rate': 0.2, 'max_input_tokens': 256, 'max_output_tokens': 2, 'num_epochs': 6, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': "Extract the satisfaction from the comment. Return simple '1' for satisfied customer or '0' for unsatisfied. Input: {{input}} Output: ", 'name': 'prompt tuning name', 'description': 'Prompt tuning with SDK', 'auto_update_model': True } Run prompt tuning ------------------ To schedule a tuning experiment, call the ``run()`` method, which will trigger a training process. The ``run()`` method can be synchronous (``background_mode=False``) or asynchronous (``background_mode=True``). If you don't want to wait for the training to end, invoke the async version. It immediately returns only run details. .. code-block:: python from ibm_watsonx_ai.helpers import DataConnection, ContainerLocation, S3Location tuning_details = prompt_tuner.run( training_data_references=[DataConnection( connection_asset_id=connection_id, location=S3Location( bucket='prompt_tuning_data', path='pt_train_data.json' ) )], background_mode=False) # OR tuning_details = prompt_tuner.run( training_data_references=[DataConnection( data_asset_id='5d99c11a-2060-4ef6-83d5-dc593c6455e2') ], background_mode=True) # OR tuning_details = prompt_tuner.run( training_data_references=[DataConnection( location=ContainerLocation("path_to_file.json")) ], background_mode=True) Get run status, get run details ------------------------------- If you use the ``run()`` method asynchronously, you can monitor the run details and status by using the following two methods: .. code-block:: python status = prompt_tuner.get_run_status() print(status) 'running' # OR 'completed' run_details = prompt_tuner.get_run_details() print(run_details) { 'metadata': {'created_at': '2023-10-12T12:01:40.662Z', 'description': 'Prompt tuning with SDK', 'id': 'b3bc33b3-cb3f-49e7-9fb3-88c6c4d4f8d7', 'modified_at': '2023-10-12T12:09:42.810Z', 'name': 'prompt tuning name', 'project_id': 'efa68764-5ec2-410a-bad9-982c502fbf4e', 'tags': ['prompt_tuning', 'wx_prompt_tune.3c06a0db-3cb9-478c-9421-eaf05276a1b7']}, 'entity': {'auto_update_model': True, 'description': 'Prompt tuning with SDK', 'model_id': 'd854752e-76a7-4c6d-b7db-5f84dd11e827', 'name': 'prompt tuning name', 'project_id': 'efa68764-5ec2-410a-bad9-982c502fbf4e', 'prompt_tuning': {'accumulate_steps': 32, 'base_model': {'model_id': 'google/flan-t5-xl'}, 'batch_size': 16, 'init_method': 'random', 'learning_rate': 0.2, 'max_input_tokens': 256, 'max_output_tokens': 2, 'num_epochs': 6, 'num_virtual_tokens': 100, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': "Extract the satisfaction from the comment. Return simple '1' for satisfied customer or '0' for unsatisfied. Input: {{input}} Output: "}, 'results_reference': {'connection': {}, 'location': {'path': 'default_tuning_output', 'training': 'default_tuning_output/b3bc33b3-cb3f-49e7-9fb3-88c6c4d4f8d7', 'training_status': 'default_tuning_output/b3bc33b3-cb3f-49e7-9fb3-88c6c4d4f8d7/training-status.json', 'model_request_path': 'default_tuning_output/b3bc33b3-cb3f-49e7-9fb3-88c6c4d4f8d7/assets/b3bc33b3-cb3f-49e7-9fb3-88c6c4d4f8d7/resources/wml_model/request.json', 'assets_path': 'default_tuning_output/b3bc33b3-cb3f-49e7-9fb3-88c6c4d4f8d7/assets'}, 'type': 'container'}, 'status': {'completed_at': '2023-10-12T12:09:42.769Z', 'state': 'completed'}, 'tags': ['prompt_tuning'], 'training_data_references': [{'connection': {}, 'location': {'href': '/v2/assets/90258b10-5590-4d4c-be75-5eeeccf09076', 'id': '90258b10-5590-4d4c-be75-5eeeccf09076'}, 'type': 'data_asset'}]} } Get data connections -------------------- The ``data_connections`` list contains all the training connections that you referenced while calling the ``run()`` method. .. code-block:: python data_connections = prompt_tuner.get_data_connections() # Get data in binary format binary_data = data_connections[0].read(binary=True) Summary ------- You can see details of models in the form of a summary table. The output type is a ``pandas.DataFrame`` with model names, enhancements, the base model, an auto update option, the number of epochs used, and the last loss function value. .. code-block:: python results = prompt_tuner.summary() print(results) # Enhancements Base model ... loss # Model Name # Prompt_tuned_M_1 [prompt_tuning] google/flan-t5-xl ... 0.449197 Plot learning curves -------------------- .. note:: Available only for Jupyter notebooks. To see graphically how the tuning was performed, you can view learning curve graphs. .. code-block:: python prompt_tuner.plot_learning_curve() .. image:: _static/learning_curves.png :width: 600 :align: center Get the model identifier ------------------------ .. note:: The model identifier will be available only if the tuning was scheduled first and the ``auto_update_model`` parameter was set as ``True``, which is the default value. To get the ``model_id``, call the get_model_id method. .. code-block:: python model_id = prompt_tuner.get_model_id() print(model_id) 'd854752e-76a7-4c6d-b7db-5f84dd11e827' The ``model_id`` obtained in this way can be used to create deployments and then create ModelInference. For more information, see the next section: :ref:`Tuned Model Inference`.