Tune a custom model (Prompt Tuning)

Use custom training data to tune a model for text classification.

Note:

This example has been written to enable an end-user to quickly try prompt-tuning. In order to obtain better performance, a user would need to experiment with the number of observations and tuning hyperparameters

See tune.py on GitHub.
import time
from pathlib import Path
from pprint import pprint

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials
from genai.schema import (
    DecodingMethod,
    FilePurpose,
    TextGenerationParameters,
    TuneAssetType,
    TuneParameters,
    TuneStatus,
)

load_dotenv()
num_training_samples = 100
num_validation_samples = 20
data_root = Path(__file__).parent.resolve() / ".data"
training_file = data_root / "fpb_train.jsonl"
validation_file = data_root / "fpb_validation.jsonl"


def heading(text: str) -> str:
    """Helper function for centering text."""
    return "\n" + f" {text} ".center(80, "=") + "\n"


def create_dataset():
    Path(data_root).mkdir(parents=True, exist_ok=True)
    if training_file.exists() and validation_file.exists():
        print("Dataset is already prepared")
        return

    try:
        import pandas as pd
        from datasets import load_dataset
    except ImportError:
        print("Please install datasets and pandas for downloading the dataset.")
        raise

    data = load_dataset("financial_phrasebank", "sentences_allagree")
    df = pd.DataFrame(data["train"]).sample(n=num_training_samples + num_validation_samples)
    df.rename(columns={"sentence": "input", "label": "output"}, inplace=True)
    df["output"] = df["output"].astype(str)
    train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True)
    validation_jsonl = df.iloc[num_training_samples:].to_json(orient="records", lines=True, force_ascii=True)
    with open(training_file, "w") as fout:
        fout.write(train_jsonl)
    with open(validation_file, "w") as fout:
        fout.write(validation_jsonl)


def upload_files(client: Client, update=True):
    files_info = client.file.list(search=training_file.name).results
    files_info += client.file.list(search=validation_file.name).results

    filenames_to_id = {f.file_name: f.id for f in files_info}
    for filepath in [training_file, validation_file]:
        filename = filepath.name
        if filename in filenames_to_id and update:
            print(f"File already present: Overwriting {filename}")
            client.file.delete(filenames_to_id[filename])
            response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
            filenames_to_id[filename] = response.result.id
        if filename not in filenames_to_id:
            print(f"File not present: Uploading {filename}")
            response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
            filenames_to_id[filename] = response.result.id
    return filenames_to_id[training_file.name], filenames_to_id[validation_file.name]


client = Client(credentials=Credentials.from_env())

print(heading("Creating dataset"))
create_dataset()

print(heading("Uploading files"))
training_file_id, validation_file_id = upload_files(client, update=True)

hyperparams = TuneParameters(num_epochs=2, verbalizer='classify { "0", "1", "2" } Input: {{input}} Output:')

print(heading("Tuning model"))
tune_result = client.tune.create(
    model_id="google/flan-t5-xl",
    name="classification-mpt-tune-api",
    tuning_type="prompt_tuning",
    task_id="classification",  # Another supported task is "summarization"
    parameters=hyperparams,
    training_file_ids=[training_file_id],
    # validation_file_ids=[validation_file_id], # TODO: Broken at the moment - this causes tune to fail
).result

while tune_result.status not in [TuneStatus.FAILED, TuneStatus.HALTED, TuneStatus.COMPLETED]:
    new_tune_result = client.tune.retrieve(tune_result.id).result
    print(f"Waiting for tune to finish, current status: {tune_result.status}")
    tune_result = new_tune_result
    time.sleep(10)

if tune_result.status in [TuneStatus.FAILED, TuneStatus.HALTED]:
    print("Model tuning failed or halted")
    exit(1)

print(heading("Classify with a tuned model"))
prompt = "Return on investment was 5.0 % , compared to a negative 4.1 % in 2009 ."
print("Prompt: ", prompt)
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE, max_new_tokens=1, min_new_tokens=1)
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))
print("Answer: ", gen_response.results[0].generated_text)

print(heading("Get list of tuned models"))
interesting_metadata_fields = ["name", "id", "model_id", "created_at", "status"]
tune_list = client.tune.list(limit=5, offset=0)
for tune in tune_list.results:
    pprint(tune.model_dump(include=interesting_metadata_fields))


print(heading("Retrieve metadata for a single tune"))
tune_detail = client.tune.retrieve(id=tune_result.id).result
pprint(tune_detail.model_dump(include=interesting_metadata_fields + ["parameters"]))

print(heading("Downloading tune model assets"))
logs = client.tune.read(id=tune_result.id, type=TuneAssetType.LOGS).decode("utf-8")
print(logs)

print(heading("Deleting a tuned model"))
client.tune.delete(tune_result.id)
print("OK")