Fine-tuning Granite with LoRA¶
LoRA is an efficient method of fine tuning transformers and LLMs which reduces the number of trainable parameters that increases the speed and resource usage of training while retaining the same quality of output. Instead of updating all model weights in a neural network during a full finetuning, LoRA freezes the original pre-trained weights and adds small, trainable low-rank matrices that approximate the changes needed for the new task.
One additional benefit of LoRA is that since the pre-trained model weights from the original model are kept frozen, the generated LoRA weight matrix is lightweight and portable and can easily be stored.
In this step-by-step tutorial, you'll use LLaMa Factory. LLaMa Factory is a platform for training and fine-tuning LLMs that allows users to tune LLMs on domain-specific datasets, evaluate model performance for specific tasks, and serve models. It has both a web UI and CLI that is user-friendly and supports over 100 LLMs. The platform supports datasets in Alpaca and ShareGPT formats. LLaMa Factory isn’t the only way to fine-tune LLMs, the PEFT library for Parameter-efficient fine-tuning is another option for updating large models. PEFT provides the ability to perform quantized LoRA (QLoRA) to even further compact the fine-tuned model. In this tutorial, you’ll use a non-quantized version of Granite 3.3.
Although LLaMa Factory can runs without the use of extensive compute resources, it does require a GPU and a great deal of memory. In this tutorial you'll use LLaMa Factory on watsonx to provide GPU resources and storage for generated adapter.
Watson Studio Config¶
a. Log in to watsonx.ai using your IBM Cloud account.
b. Create a watsonx.ai project.
>Take note of your project ID in project > Manage > General > Project ID.
You'll need this ID for this tutorial.
c. Create a watsonx.ai Runtime service instance. For this tutorial you'll need to create a paid instance to access a GPU.
d. Generate a watsonx API Key.
e. Associate the watsonx.ai Runtime service to the project you created in watsonx.ai.
Cloud Object Storage¶
a. In order to create COS for your notebook you'll go to https://cloud.ibm.com/ and then select "Create Instance".
b. That will take you to a Create dialog where you can select a pricing plan. For this tutorial a Standard Plan will be adequate.
c. Then give your COS Instance a name.
d. Once you've created your Instance, go back to the Project and select "New Asset", then select "Connect to a data source".
e. Select "Cloud Object Storage"
f. In the next dialogue, select the Instance that you created in steps a-d by name.
g. Select "Create".
Create a Jupyter Notebook¶
Create a Jupyter Notebook.
a. Select the Assets tab in your project environment.
b. Click New asset.
c. Select the Working with models option in the left panel.
d. Click Working with data and models using Python and R notebooks.
e. Enter a name for your notebook in the Name field. Choose Runtime 23.1 on Python (4 vCPU 16 GB RAM) to define the configuration.
f. Select Create.
Install Dependencies¶
Next, you'll install dependencies onto the runtime. First, Llama-Factory in order to generate the Low Rank Adapters, and then Pandas to format the dataset in Alpaca format.
!pip install -q llamafactory 2>/dev/null
# pandas needed to format the dataset
!pip install -q --upgrade pandas 2>/dev/null
Check GPU environment¶
Next, you'll ensure that your watsonx environment has provides a Torch compatible GPU that will be required to use LLaMa-Factory.
import torch
try:
assert torch.cuda.is_available() is True
except AssertionError:
print("No GPU found, please set up a GPU before using LLaMA Factory.")
If the above code snippet doesn't print "No GPU found" then you're good to go.
Next you'll import libraries to manipulate data and to create the LLaMa Factory configuration file used for training.
# Import libraries
import pandas as pd
import json
import yaml
Download and Process MedReason Dataset¶
In this tutorial you'll use a part of the MedReason dataset. MedReason is an in-depth medical reasoning dataset designed to help enable explainable real-world medical problem-solving in large language models (LLMs). While MedReason focuses on the reasoning of a model and validating the chains of thought that a model uses, in this case it is also helpful to provide a dataset that is too recent to be included in the training data for Granite 3.3.
Granite 3.3 has been designed to learn through fine-tuning, both of which will be run with LLaMa Factory. Granite models can be efficiently fine-tuned even with limited computing resources.
You'll load a selection of the MedReason dataset from Github:
from datasets import load_dataset
train_dataset = pd.read_json("https://raw.githubusercontent.com/UCSC-VLAA/MedReason/refs/heads/main/eval_data/medbullets_op4.jsonl", lines=True)
LLaMa Factory requires the dataset to be preformatted in Alpaca or ShareGPT formats. Thus, we reformat the question and answer fields of the original legal dataset to contain instruction, input, and output fields as per Alpaca format).
Alpaca is a JSON format to represent an instruction, user input, and system output like so:
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
}
Since MedReason isn't formatted in Alpaca, you'll create an Alpaca dataset in the next cell:
!mkdir -p data
# Format Legal Dataset to Alpaca Format
formatted_data = [
{
"instruction": row["question"] + str(row["options"]),
"input": "",
"output": row["answer"]
}
for _, row in train_dataset.iterrows()
]
# output formatted MedReason dataset
with open("data/med.json", "w", encoding="utf-8") as f:
json.dump(formatted_data, f, indent=2, ensure_ascii=False)
Llama Factory utilizes a specific file to understand how to load datasets for training. This file must exist at path data/dataset_info.json
. Thus, we must create a dataset_info.json
file that includes the path to the new formatted medical dataset we created in order for the Llama Factory CLI to access the dataset. For details on the dataset_info.json
file see the documentation. Within the Llama Factory repository there are datasets available to use, however, since we are using our own custom dataset we must add our dataset to the JSON file.
# "med" will be the identifier for the dataset
# which points to the local file that contains the dataset
dataset_info = {
"med": {
"file_name": "med.json",
}
}
# Create dataset_info.json with legal dataset so can reference with llama factory
with open("data/dataset_info.json", "w", encoding="utf-8") as f:
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
Now that the Alpaca formatted JSON object has been saved into the environment, you're ready to start training.
LoRA fine-tune model with LLaMa Factory via Command Line¶
The next step is to set up the training configurations and then write the configs to a YAML file that LLaMa-Factory uses to run training.
Now you'll run Supervised Fine-Tuning (SFT) on the subset of the MedReason dataset. LLaMa Factory supports several different types of training. Some of the most commonly used are:
Pre-Training: where a model undergoes initial training using an extensive dataset to generate responses to fundamental language and ideas. Supervised Fine-Tuning (SFT): where a model receives additional training with annotated data to enhance precision for a particular function or on a specific topic. Reward Modeling: where the model acquires knowledge on how to achieve a specific incentive or reward that will inform its output Proximal Policy Optimization (PPO) Training: is a reinforcement learning (RL) technique where the model is further honed through policy gradient techniques to boost its effectiveness in a specific setting.
There are many settings used in configuring LoRA but a few of the most commonly used are:
Learning Rate (LR): the learning rate determines how significantly model parameters are updated during each iteration of training. A higher LR can speed up convergence by allowing larger updates but risks overshooting the optimal solution or oscillating around it. A lower LR leads to slower but more stable convergence, reducing the risk of instability near the optimal solution.
loraplus_lr_ratio: this sets the the ratio of learning rates. Generally it should be > 1, but the optimal choice of loraplus_lr_ratio is model and task dependent. As a rule of thumb, loraplus_lr_ratio should be larger when the task is more difficult and the model needs to update its features to learn well. In this case, it helps to make the learning rate slightly smaller (e.g., by a factor of 2) than typical vanilla LoRA learning rates.
Effective Batch Size: Correctly configuring your batch size is critical for balancing training stability with the VRAM limitations of the GPU you're using. The Effective Batch Size is set by the product of per_device_train_batch_size
* gradient_accumulation_steps
. A larger Effective Batch Size generally leads to smoother, more stable training, but also may require more VRAM than your GPU contains. A smaller Effective Batch Size may introduce more variance.
# setup training configurations
loraconfig = dict(
stage="sft", # do supervised fine-tuning
do_train=True, # we're actually training
model_name_or_path="ibm-granite/granite-3.3-2b-instruct", # use IBM Granite 3.3 2b instruct model
dataset="med", # use medical datasets we created
template="granite3", # use granite3 prompt template
finetuning_type="lora", # use LoRA adapters to save memory
lora_target="all", # attach LoRA adapters to all linear layers
loraplus_lr_ratio=16.0, # use LoRA+ algorithm with lambda=16.0
output_dir="granite3_lora", # the path to save LoRA adapters
per_device_train_batch_size=4, # the batch size
gradient_accumulation_steps=2, # the gradient accumulation steps
learning_rate=1e-4, # the learning rate
num_train_epochs=3.0, # the epochs of training
max_samples=500, # use 500 examples in each dataset
fp16=True, # use float16 mixed precision training
report_to="none", # disable wandb logging
)
# create training config file to run with llama factory
with open("train_granite3_lora_med.yaml", "w", encoding="utf-8") as file:
yaml.dump(loraconfig, file, indent=2)
The next cell will train the model and may take up to 10 minutes to run:
!llamafactory-cli train train_granite3_lora_med.yaml;
Next you'll create two methods to upload and download data from IBM Cloud Object Storage:
from ibm_botocore.client import Config
import ibm_boto3
def upload_file_cos(credentials, local_file_name, key):
cos = ibm_boto3.client(service_name='s3',
ibm_api_key_id=credentials['IBM_API_KEY_ID'],
ibm_service_instance_id=credentials['IAM_SERVICE_ID'],
ibm_auth_endpoint=credentials['IBM_AUTH_ENDPOINT'],
config=Config(signature_version='oauth'),
endpoint_url=credentials['ENDPOINT'])
try:
res=cos.upload_file(Filename=local_file_name, Bucket=credentials['BUCKET'],Key=key)
except Exception as e:
print(Exception, e)
else:
print(' File Uploaded')
def download_file_cos(credentials,local_file_name,key):
cos = ibm_boto3.client(service_name='s3',
ibm_api_key_id=credentials['IBM_API_KEY_ID'],
ibm_service_instance_id=credentials['IAM_SERVICE_ID'],
ibm_auth_endpoint=credentials['IBM_AUTH_ENDPOINT'],
config=Config(signature_version='oauth'),
endpoint_url=credentials['ENDPOINT'])
try:
res=cos.download_file(Bucket=credentials['BUCKET'],Key=key,Filename=local_file_name)
except Exception as e:
print(Exception, e)
else:
print('File Downloaded')
The next cell contains the credentials of Cloud Object Storage (COS).
In your Notebook, click on the Code Snippets tab in the right corner. This will bring up a menu with several options for generated code snippets. Select "Read Data":
This will bring up a menu to select a data file. If you haven't uploaded anything to your COS Instance you'll need to upload something in order to generate credentials. That can be a classic dataset like wine.csv
.
After clicking "Select" you can now generate the credentials snippet under the "Load as" option:
This will generate a cell like the one below with the correct assets:
# @hidden_cell
# The following code contains metadata for a file in your project storage. You might want to remove secret properties before you share your notebook.#
storage_metadata = {
'IAM_SERVICE_ID': '<>',
'IAM_API_KEY_ID': '<>',
'ENDPOINT': '<>',
'IBM_AUTH_ENDPOINT': '<>',
'BUCKET ': '<>'
}
Now zip folder containing the adapter and the information about the adapter itself:
!zip -r "granite3_lora.zip" "granite3_lora"
adding: granite3_lora/ (stored 0%) adding: granite3_lora/adapter_config.json (deflated 56%) adding: granite3_lora/trainer_log.jsonl (deflated 29%) adding: granite3_lora/training_args.bin (deflated 51%) adding: granite3_lora/added_tokens.json (deflated 58%) adding: granite3_lora/special_tokens_map.json (deflated 78%) adding: granite3_lora/vocab.json (deflated 57%) adding: granite3_lora/trainer_state.json (deflated 59%) adding: granite3_lora/tokenizer_config.json (deflated 81%) adding: granite3_lora/merges.txt (deflated 51%) adding: granite3_lora/adapter_model.safetensors (deflated 8%) adding: granite3_lora/README.md (deflated 47%) adding: granite3_lora/train_results.json (deflated 39%) adding: granite3_lora/tokenizer.json (deflated 81%) adding: granite3_lora/checkpoint-114/ (stored 0%) adding: granite3_lora/checkpoint-114/adapter_config.json (deflated 56%) adding: granite3_lora/checkpoint-114/training_args.bin (deflated 51%) adding: granite3_lora/checkpoint-114/added_tokens.json (deflated 58%) adding: granite3_lora/checkpoint-114/rng_state.pth (deflated 25%) adding: granite3_lora/checkpoint-114/special_tokens_map.json (deflated 78%) adding: granite3_lora/checkpoint-114/scaler.pt (deflated 60%) adding: granite3_lora/checkpoint-114/vocab.json (deflated 57%) adding: granite3_lora/checkpoint-114/trainer_state.json (deflated 56%) adding: granite3_lora/checkpoint-114/tokenizer_config.json (deflated 81%) adding: granite3_lora/checkpoint-114/optimizer.pt (deflated 8%) adding: granite3_lora/checkpoint-114/merges.txt (deflated 51%) adding: granite3_lora/checkpoint-114/adapter_model.safetensors (deflated 8%) adding: granite3_lora/checkpoint-114/README.md (deflated 66%) adding: granite3_lora/checkpoint-114/tokenizer.json (deflated 81%) adding: granite3_lora/checkpoint-114/scheduler.pt (deflated 53%) adding: granite3_lora/all_results.json (deflated 39%)
Check that you've created the zip correctly:
!ls
Now it's time to upload the results of the LoRA SFT so that you can use it elsewhere without needing to re-run LLaMa Factory training.
upload_file_cos(storage_metadata, "granite3_lora.zip", "granite3_lora.zip")
File Uploaded
Infer with the Granite Base Model¶
Now it's time to run inference. The inference will be backed by HuggingFace generation, which provides a model.generate()
method for text generation using PyTorch.
This tutorial shows asking the base model a medical question pulled from the MedReason dataset. It's reasonable that the base model may not be able to answer this question since it is a general-purpose model trained on large, diverse datasets.
Let's ask the chatbot the question: What does assessing unit mean?
Note: The input typed into the prompt box created from the chat service will be hidden if running in Google Colab as Google Colab hides user input in case a secret is being passed. However, the chat service will print out the input provided so you can see what you inputted.
That chat stream is ongoing so you will have to stop the code block in order to continue.
# setup inference configurations
args = dict(
model_name_or_path="ibm-granite/granite-3.3-2b-instruct", # use IBM Granite 3.3 2b instruct model
template="granite3", # set to the same one used in training, template for constructing prompts
infer_backend="huggingface" # choices: [huggingface, vllm]
)
# create inference config file to run with llama factory
with open("inference_config.yaml", "w", encoding="utf-8") as file:
yaml.dump(args, file, indent=2)
from llamafactory.chat import ChatModel
chat_model = ChatModel(args)
messages = []
# run inference chatbot
question = '''
A 1-year-old girl is brought to a neurologist due to increasing seizure frequency over the past 2 months.
She recently underwent a neurology evaluation which revealed hypsarrhythmia on electroencephalography (EEG) with a mix of slow waves, multifocal spikes, and asynchrony.
Her parents have noticed the patient occasionally stiffens and spreads her arms at home. She was born at 38-weeks gestational age without complications.
She has no other medical problems. Her medications consist of lamotrigine and valproic acid. Her temperature is 98.3\u00b0F (36.8\u00b0C), blood pressure is 90/75 mmHg, pulse is 94/min, and respirations are 22/min.
Physical exam reveals innumerable hypopigmented macules on the skin and an irregularly shaped, thickened, and elevated plaque on the lower back.
Which of the following is most strongly associated with this patient's condition?"
"A": "Cardiac rhabdomyoma", "B": "Glaucoma", "C": "Optic glioma", "D": "Polyostotic fibrous dysplasia"
'''
messages.append({"role": "user", "content": question})
response = ""
for new_text in chat_model.stream_chat(messages):
response += new_text
print(response)
messages.append({"role": "assistant", "content": response})
Here's the sample output from the base Granite 3.3 model:
- User:
A 1-year-old girl is brought to a neurologist due to increasing seizure frequency over the past 2 months.
She recently underwent a neurology evaluation which revealed hypsarrhythmia on electroencephalography (EEG) with a mix of slow waves, multifocal spikes, and asynchrony.
Her parents have noticed the patient occasionally stiffens and spreads her arms at home. She was born at 38-weeks gestational age without complications.
She has no other medical problems. Her medications consist of lamotrigine and valproic acid. Her temperature is 98.3\F (36.8\C), blood pressure is 90/75 mmHg, pulse is 94/min, and respirations are 22/min.
Physical exam reveals innumerable hypopigmented macules on the skin and an irregularly shaped, thickened, and elevated plaque on the lower back.
Which of the following is most strongly associated with this patient's condition?"
"A": "Cardiac rhabdomyoma", "B": "Glaucoma", "C": "Optic glioma", "D": "Polyostotic fibrous dysplasia"
- Response:
The most strongly associated condition with this patient's condition is "C": "Optic glioma".
The patient's symptoms of hypsarrhythmia on EEG, seizure frequency increase, and the presence of hypopigmented macules and a thickened plaque on the lower back are indicative of a neurological disorder. Optic glioma is a type of brain tumor that can present with these symptoms, including seizures and visual disturbances.
Option A, "Cardiac rhabdomyoma", typically presents with cardiac involvement and is not associated with the described EEG findings or skin manifestations.
Option B, "Glaucoma", is an eye disease that can lead to vision loss but is not associated with the EEG findings or skin lesions described.
Option D, "Polyostotic fibrous dysplasia", is a bone disorder characterized by multiple bone lesions and is not associated with the neurological symptoms and EEG findings presented.
Therefore, based on the clinical presentation, the most likely diagnosis is an optic glioma.
The correct response from the dataset is:
answer: Cardiac rhabdomyoma
The base model doesn't generate the correct answer.
Infer with the LoRA fine-tuned adapter¶
We will compare the results by running against the base model and the LoRA tuned adapter and asking the same question to see how tuning with the legal dataset allowed the model to better understand and answer medical questions.
The following cell won't be necessary if you've performed LoRA in the same session, however, if you're coming back to the Jupyter Notebook and don't want to retrain, you can download the fine-tuned adapters from your COS Instance.
download_file_cos(credentials, "granite3_lora.zip", "granite3_lora.zip")
!unzip granite3_lora.zip
Now you'll configure the options for the ChatModel so that it will incorporate the adapters.
# setup inference configurations
args = dict(
model_name_or_path="ibm-granite/granite-3.3-2b-instruct", # use IBM Granite 3.3 2b instruct model
adapter_name_or_path="granite3_lora", # load the saved LoRA adapters
template="granite3", # set to the same one used in training, template for constructing prompts
finetuning_type="lora", # which fine-tuning technique used in training
infer_backend="huggingface" # choices: [huggingface, vllm]
)
# create inference config file to run with llama factory
with open("inference_config.yaml", "w", encoding="utf-8") as file:
yaml.dump(args, file, indent=2)
from llamafactory.chat import ChatModel
chat_model = ChatModel(args)
Now we can test the same reasoning challenge to the fine-tuned model:
messages = []
# run inference chatbot
question = '''
A 1-year-old girl is brought to a neurologist due to increasing seizure frequency over the past 2 months.
She recently underwent a neurology evaluation which revealed hypsarrhythmia on electroencephalography (EEG) with a mix of slow waves, multifocal spikes, and asynchrony.
Her parents have noticed the patient occasionally stiffens and spreads her arms at home. She was born at 38-weeks gestational age without complications.
She has no other medical problems. Her medications consist of lamotrigine and valproic acid. Her temperature is 98.3\u00b0F (36.8\u00b0C), blood pressure is 90/75 mmHg, pulse is 94/min, and respirations are 22/min.
Physical exam reveals innumerable hypopigmented macules on the skin and an irregularly shaped, thickened, and elevated plaque on the lower back.
Which of the following is most strongly associated with this patient's condition?"
"A": "Cardiac rhabdomyoma", "B": "Glaucoma", "C": "Optic glioma", "D": "Polyostotic fibrous dysplasia"
'''
messages.append({"role": "user", "content": question})
response = ""
for new_text in chat_model.stream_chat(messages):
response += new_text
print(response)
messages.append({"role": "assistant", "content": response})
Cardiac rhabdomyoma
Sample output from fine-tuned model:¶
- User:
A 1-year-old girl is brought to a neurologist due to increasing seizure frequency over the past 2 months.
She recently underwent a neurology evaluation which revealed hypsarrhythmia on electroencephalography (EEG) with a mix of slow waves, multifocal spikes, and asynchrony.
Her parents have noticed the patient occasionally stiffens and spreads her arms at home. She was born at 38-weeks gestational age without complications.
She has no other medical problems. Her medications consist of lamotrigine and valproic acid. Her temperature is 98.3\u00b0F (36.8\u00b0C), blood pressure is 90/75 mmHg, pulse is 94/min, and respirations are 22/min.
Physical exam reveals innumerable hypopigmented macules on the skin and an irregularly shaped, thickened, and elevated plaque on the lower back.
Which of the following is most strongly associated with this patient's condition?"
"A": "Cardiac rhabdomyoma", "B": "Glaucoma", "C": "Optic glioma", "D": "Polyostotic fibrous dysplasia"
- Response:
Cardiac rhabdomyoma
This time the model did generate the correct answer, thanks to the trained adapter.
One aspect to note: the model no longer responds with its reasoning. This is because the dataset that was used for LoRA has only the correct answer as the desired model output. LoRA fine-tuning can be used both to provide new information but also to instruct the model how to respond.
Summary¶
In this tutorial you LoRA fine-tuned the IBM Granite-3.3-2b-Instruct model with new medical knowledge and a detailed template for how to respond. You saw Granite 3.3's capacity to learn even with a small model and limited samples from the dataset.