Fine tune and deploy custom model#
Use custom training data to tune a model for text generation.
- Note:
This example has been written to enable an end-user to quickly try fine-tuning. In order to obtain better performance, a user would need to experiment with the number of observations and tuning hyperparameters
import time
from pathlib import Path
from dotenv import load_dotenv
from genai.client import Client
from genai.credentials import Credentials
from genai.schema import (
DecodingMethod,
DeploymentStatus,
FilePurpose,
TextGenerationParameters,
TuneParameters,
TuneStatus,
)
load_dotenv()
num_training_samples = 50
num_validation_samples = 20
data_root = Path(__file__).parent.resolve() / ".data"
training_file = data_root / "fpb_train.jsonl"
validation_file = data_root / "fpb_validation.jsonl"
def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"
def create_dataset():
Path(data_root).mkdir(parents=True, exist_ok=True)
if training_file.exists():
print("Dataset is already prepared")
return
try:
import pandas as pd
from datasets import load_dataset
except ImportError:
print("Please install datasets and pandas for downloading the dataset.")
raise
data = load_dataset("locuslab/TOFU")
df = pd.DataFrame(data["train"])
df.rename(columns={"question": "input", "answer": "output"}, inplace=True)
df["output"] = df["output"].astype(str)
train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True)
validation_jsonl = df.iloc[-num_validation_samples:].to_json(orient="records", lines=True, force_ascii=True)
with open(training_file, "w") as fout:
fout.write(train_jsonl)
with open(validation_file, "w") as fout:
fout.write(validation_jsonl)
def upload_files(client: Client, update=True):
files_info = client.file.list(search=training_file.name).results
files_info += client.file.list(search=validation_file.name).results
filenames_to_id = {f.file_name: f.id for f in files_info}
for filepath in [training_file, validation_file]:
filename = filepath.name
if filename in filenames_to_id and update:
print(f"File already present: Overwriting {filename}")
client.file.delete(filenames_to_id[filename])
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
if filename not in filenames_to_id:
print(f"File not present: Uploading {filename}")
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
return filenames_to_id[training_file.name], filenames_to_id[validation_file.name]
# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
client = Client(credentials=Credentials.from_env())
print(heading("Creating dataset"))
create_dataset()
print(heading("Uploading files"))
training_file_id, validation_file_id = upload_files(client, update=True)
hyperparams = TuneParameters(
num_epochs=4,
verbalizer="### Input: {{input}} ### Response: {{output}}",
batch_size=4,
learning_rate=0.4,
# Advanced parameters are not defined in the schema
# but can be passed to the API
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
per_device_train_batch_size=4,
num_train_epochs=4,
)
print(heading("Tuning model"))
tune_result = client.tune.create(
model_id="meta-llama/llama-3-8b-instruct",
name="generation-fine-tune-example",
tuning_type="fine_tuning",
task_id="generation",
parameters=hyperparams,
training_file_ids=[training_file_id],
# validation_file_ids=[validation_file_id], # TODO: Broken at the moment - this causes tune to fail
).result
while tune_result.status not in [TuneStatus.FAILED, TuneStatus.HALTED, TuneStatus.COMPLETED]:
new_tune_result = client.tune.retrieve(tune_result.id).result
print(f"Waiting for tune to finish, current status: {tune_result.status}")
tune_result = new_tune_result
time.sleep(10)
if tune_result.status in [TuneStatus.FAILED, TuneStatus.HALTED]:
print("Model tuning failed or halted")
exit(1)
print(heading("Deploying fine-tuned model"))
deployment = client.deployment.create(tune_id=tune_result.id).result
while deployment.status not in [DeploymentStatus.READY, DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
deployment = client.deployment.retrieve(id=deployment.id).result
print(f"Waiting for deployment to finish, current status: {deployment.status}")
time.sleep(10)
if deployment.status in [DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
print(f"Model deployment failed or expired, status: {deployment.status}")
exit(1)
print(heading("Generate text with fine-tuned model"))
prompt = "What are some books you would reccomend to read?"
print("Prompt: ", prompt)
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE)
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))
print("Answer: ", gen_response.results[0].generated_text)
print(heading("Deleting deployment and tuned model"))
client.deployment.delete(id=deployment.id)
client.tune.delete(id=tune_result.id)