TRL wrapper
Running TRL methods¶
The toolkit implements some of the TRL methods via a StructuralControl
wrapper. The methods currently available are Supervised Fine-Tuning
(SFT), Direct Preference Optimization
(DPO), and Anchored Preference Optimization
(APO).
The toolkit provides another preference optimization method, Self-Play Preference Optimization
(SPPO), that is not available in the TRL library but follows the design of the TRL methods very closely.
The TRL methods are implemented via a StructuralControl
wrapper. Methods are initiated using a Control
object. This notebook demonstrates how the above methods can be used for training languae models.
from __future__ import annotations
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
SFT with LoRA¶
To run SFT, we need to import SteeringPipeline as well as the SFT control class. The method is implemented as a wrapper around TRL's SFTTrainer class and as such is used similarly.
from aisteer360.algorithms.core.steering_pipeline import SteeringPipeline
from aisteer360.algorithms.structural_control.wrappers.trl.sfttrainer.control import SFT
The example shows supervised fine tuning of a small model with a 500 record sample of a Huggingface preference dataset. We load the tokenizer and preprocess the dataset to convert it to a standard format for SFT.
def preprocess(example):
text = f"Question: {example['prompt']}\n\nAnswer: {example['chosen']}"
tok_data = tokenizer(text, truncation=True, padding='max_length', max_length=1024, return_tensors="pt")
return {'input_ids': tok_data['input_ids'][0], 'attention_mask': tok_data['attention_mask'][0]}
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(
'HuggingFaceH4/ultrafeedback_binarized',
split='train_prefs',
)
subset_size = 500
dataset = dataset.select(list(range(subset_size)))
train_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
Next, the SFT control is instantiated by providing the train_dataset
as well as the output_dir
for saving the steered model. We also set use_peft
to True (default is False) and set peft_type
to enable LoRA. Finally, we override some of the default training arguments. Note that SFT control is based on TRL's SFTConfig
class and uses the default training arguments from there. However, some of these parameters can be ovverriden, as shown below. Please refer to aisteer360.algorithms.structural_control.wrappers.trl.args.py
and aisteer360.algorithms.structural_control.wrappers.trl.sfttrainer.args.py
to see the list of these parameters and their default values. The parameters used for LoRA training are similarly based on the LoraConfig
class, and default values can be overriden as below.
from peft import PeftType
# control
sft = SFT(
train_dataset=train_dataset,
use_peft=True,
peft_type=PeftType.LORA,
**{
"per_device_train_batch_size": 4,
"num_train_epochs": 3,
"learning_rate": 2e-5,
"output_dir": "Qwen2.5-0.5B-SFT-LoRA-Steer",
"logging_steps": 100,
"save_strategy": "no",
"lora_alpha": 16,
},
)
We then create the SteeringPipeline, providing it the model_name_or_path
, set the control to sft
and invoke steer
.
# steering pipeline
sft_pipeline = SteeringPipeline(
model_name_or_path=model_name,
controls=[sft],
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
hf_model_kwargs={"dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
)
sft_pipeline.steer()
data_collator is None
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
Step | Training Loss |
---|---|
100 | 0.765000 |
200 | 0.773000 |
300 | 0.657700 |
dataset = load_dataset(
'HuggingFaceH4/ultrafeedback_binarized',
split='test_prefs',
)
enc = tokenizer(f"Question:{dataset[0]['prompt']} \n Answer:", return_tensors="pt", padding=True).to(sft_pipeline.model.device)
print(f"Question:{dataset[0]['prompt']}")
steered_response = sft_pipeline.generate_text(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
max_new_tokens=20
)
print("output (SFT):")
print(steered_response)
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Question:In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence. Q: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert . A: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively . **** Q: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million . A: No government agency is in charge of calculating the cost . **** Q: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson . A: output (SFT): [' The man who had been found dead in the freezing ice in Charlotte ( Rochester ) was named Silas']
# Releasing memory resources
import gc
del sft_pipeline.model, sft_pipeline
gc.collect()
torch.cuda.empty_cache()
We load the LoRA adapter, merge it into the base model, and save the combined model.
from transformers import AutoModelForCausalLM
from peft import PeftModel, PeftConfig
lora_adapter_path = "Qwen2.5-0.5B-SFT-LoRA-Steer"
print('# Load PEFT config')
config = PeftConfig.from_pretrained(lora_adapter_path)
print('# Load base model')
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(lora_adapter_path)
print('# Get PeftModel')
peft_model = PeftModel.from_pretrained(base_model, lora_adapter_path, 'abcd')
breakpoint()
peft_model.set_adapter('abcd') # set adapter as active
print("# Merge adapter into model")
merged_model = peft_model.merge_and_unload()
breakpoint()
# merged_model.save_pretrained("Qwen2.5-0.5B-SFT-LoRA-Steer-Merged")
merged_model.save_pretrained("Qwen2.5-0.5B-SFT-LoRA-Steer-Merged")
tokenizer.save_pretrained("Qwen2.5-0.5B-SFT-LoRA-Steer-Merged")
# Load PEFT config # Load base model # Get PeftModel # Merge adapter into model
('Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/tokenizer_config.json', 'Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/special_tokens_map.json', 'Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/chat_template.jinja', 'Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/vocab.json', 'Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/merges.txt', 'Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/added_tokens.json', 'Qwen2.5-0.5B-SFT-LoRA-Steer-Merged/tokenizer.json')
del base_model, tokenizer, peft_model, merged_model
gc.collect()
torch.cuda.empty_cache()
DPO¶
We next further steer the above SFT LoRA model using DPO.
In the example below, we use a preference dataset that is already in a conversational format needed by DPO so no preprocessing is neeed.
model_name = "Qwen2.5-0.5B-SFT-LoRA-Steer-Merged"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
subset_size = 500
dataset = dataset.select(list(range(subset_size)))
example = dataset[0]
example
{'chosen': [{'content': 'Use the pygame library to write a version of the classic game Snake, with a unique twist', 'role': 'user'}, {'content': "Sure, I'd be happy to help you write a version of the classic game Snake using the pygame library! Here's a basic outline of how we can approach this:\n\n1. First, we'll need to set up the game display and create a game object that we can use to handle the game's state.\n2. Next, we'll create the game's grid, which will be used to represent the game board. We'll need to define the size of the grid and the spaces within it.\n3. After that, we'll create the snake object, which will be used to represent the player's movement. We'll need to define the size of the snake and the speed at which it moves.\n4. We'll also need to create a food object, which will be used to represent the food that the player must collect to score points. We'll need to define the location of the food and the speed at which it moves.\n5. Once we have these objects set up, we can start handling the game's logic. We'll need to define the rules for the player's movements, such as how the player can move the snake and how the snake will grow as the player collects more food.\n6. We'll also need to add collisions detection to the game, so that the snake and the food can collide with each other.\n7. Finally, we'll need to add a user interface to the game, such as a menu and a scoreboard.\n\nNow, as for the unique twist, we could add a few features to make the game more interesting. For example, we could add power-ups that give the player special abilities, such as the ability to grow the snake faster or to change its direction. We could also add obstacles, such as walls or pits, that the player must avoid.\n\nHere's some sample code to get us started:\n```\nimport pygame\n\n# Initialize pygame\npygame.init()\n\n# Set up the game display\nwidth = 800\nheight = 600\nscreen = pygame.display.set_mode((width, height))\n\n# Define the colors\nWHITE = (255, 255, 255)\nBLACK = (0, 0, 0)\nGREEN = (0, 255, 0)\n\n# Define the game objects\nsnake = pygame.Rect(50, 50, 10, 10)\nfood = pygame.Rect(100, 100, 10, 10)\n\n# Define the game logic\ndef update_snake():\n # Get the mouse position\n mouse_pos = pygame.mouse.get_pos()\n\n # Move the snake based on the mouse position\n if mouse_pos[0] > snake.x:\n snake.x += 10\n elif mouse_pos[0] < snake.x:\n snake.x -= 10\n elif mouse_pos[1] > snake.y:\n snake.y += 10\n elif mouse_pos[1] < snake.y:\n snake.y -= 10\n\n # Update the snake's size\n if snake.x == food.x and snake.y == food.y:\n snake.width += 10\n snake.height += 10\n\n# Define the game loop\ndef game_loop():\n # Clear the screen\n screen.fill(BLACK)\n\n # Update the game objects\n update_snake()\n\n # Draw the game objects\n screen.fill(WHITE)\n screen.draw.rect(snake, GREEN)\n screen.draw.rect(food, BLACK)\n\n # Update the display\n pygame.display.flip()\n\n# Start the game loop\ngame_loop()\n\n# Run the game\nwhile True:\n for event in pygame.event.get():\n if event.type == pygame.QUIT:\n pygame.quit()\n break\n\n pygame.time.Clock().tick(60)\n```\nThis code sets up a basic game display, defines the snake and food objects, and updates the game state based on the player's movements. We've also added a simple user interface and some basic collision detection.\n\nAs for the unique twist, we could add a few features to make the game", 'role': 'assistant'}], 'rejected': [{'content': 'Use the pygame library to write a version of the classic game Snake, with a unique twist', 'role': 'user'}, {'content': 'Sure, here\'s an example of how to write a version of Snake game with a unique twist using the Pygame library:\n```python\nimport pygame\n\nclass SnakeGame:\n def __init__(self, game_width, game_height):\n pygame.init()\n screen = pygame.display.set_mode((game_width, game_height))\n pygame.display.set_caption("Snake Game")\n self.speed = 5 # Speed of the snake\n self.food_speed = 1 # Speed of the food\n self.direction = 0 # Initial direction of the snake\n self.snakelen = 0 # Length of the snake\n self.food = pygame.image.load("snake_food.png")\n self.head = pygame.image.load("snake_head.png")\n self.tail = pygame.image.load("snake_tail.png")\n self.game Quint()\n def Quint(self):\n for i in range(50):\n pygame.draw.line(screen, (180, 100, 220), (0, 0, 200), 2)\n pygame.draw.line(screen, (180, 100, 220), (50, 0, 200), 2)\n pygame.draw.line(screen, (180, 100, 220), (50, 50, 200), 2)\n pygame.draw.line(screen, (180, 100, 220), (50, 50, 300), 2)\n pygame.draw.line(screen, (180, 100, 220), (50, 80, 200), 2)\n pygame.draw.line(screen, (180, 100, 220), (50, 80, 300), 2)\n pygame.draw.line(screen, (180, 100, 220), (0, 50, 200), 2)\n pygame.draw.line(screen, (180, 100, 220), (0, 50, 300), 2)\n pygame.draw.line(screen, (180, 100, 220), (0, 80, 200), 2)\n pygame.draw.line(screen, (180, 100, 220), (0, 80, 300), 2)\n pygame.display.flip()\n self.game.run()\n def run(self):\n while True:\n for event in pygame.event. pygame.KEYDOWN:\n if event.key == pygame.K_LEFT:\n self.direction = -1\n if event.key == pygame.K_RIGHT:\n self.direction = 1\n self.snakelen += 1\n if self.snakelen == 0:\n self.snakelen = 10\n if self.snakelen > 20:\n self.snakelen = 20\n self.gameQuint()\n self.foodCrossing()\n self.headRun()\n pygame.display.update()\ngame = SnakeGame(800, 600)\ngame.run()\n```\nIn this game, the snake moves with a constant speed, but the direction of the snake can be controlled by the user using the left and right arrow keys. The snake grows in length every 10 segments, and when it reaches a certain length, it resets to 10 segments. The food moves fast and randomly crosses the screen, and the snake can eat it by colliding with it. The snake\'s head and tail move independently of each other. The game ends when the snake dies or reaches the end of the screen.', 'role': 'assistant'}], 'score_chosen': 6.0, 'score_rejected': 4.0}
To use DPO, we import the corresponding DPO control
from aisteer360.algorithms.structural_control.wrappers.trl.dpotrainer.control import DPO
DPO steering is run the same way as SFT above. The DPO control is created and steering pipeline in invoked after providing the model name and control set to dpo
# control
dpo = DPO(
train_dataset=dataset,
**{
"per_device_train_batch_size": 4,
"num_train_epochs": 3,
"learning_rate": 2e-5,
"output_dir": "Qwen2.5-0.5B-DPO-Steer",
"logging_steps": 100,
"save_strategy": "no",
},
)
# steering pipeline
dpo_pipeline = SteeringPipeline(
model_name_or_path=model_name,
controls=[dpo],
device_map="auto" if torch.cuda.is_available() else "cpu",
hf_model_kwargs={"dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
)
dpo_pipeline.steer()
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
Step | Training Loss |
---|---|
100 | 0.854200 |
200 | 0.248300 |
300 | 0.036500 |
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test")
question = 'QUESION'+dataset[0]['chosen'][-2]['content'].rsplit('QUESTION',1)[-1]
print(question)
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": question}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
enc = tokenizer(text, return_tensors="pt", padding=True, padding_side="left").to(dpo_pipeline.model.device)
steered_response = dpo_pipeline.generate_text(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
max_new_tokens=100,
do_sample=True
)
print("output (DPO):")
print(steered_response)
QUESIONAs an HR manager, you want to test a potential employee's ability to solve puzzles to determine their suitability for a job. Write a Python script that generates a list of questions that require logical reasoning to answer. Your list should include questions related to mathematical puzzles, language puzzles, logic puzzles, lateral thinking puzzles, and pattern recognition puzzles. Use the following code as a starting point: questions = { "Mathematical puzzles": ["If the value of x+y = 20 and x-y = 10, what is the value of x and y?", "If a pizza has a radius of 8 inches and is cut into 6 equal slices, what is the area of each slice?"], "Language puzzles": ["What word starts with 'e' and ends with 'e' but only contains one letter?", "I am taken from a mine, and shut up in a wooden case, from which I am never released, and yet I am used by almost every person. What am I?"], "Logic puzzles": ["You have 3 boxes. One contains only apples, one contains only oranges, and one contains both apples and oranges. The boxes have been incorrectly labeled such that no label identifies the actual contents of the box it labels. Opening just one box, and without looking in the box, you take out one piece of fruit. From that one piece of fruit, how can you immediately label all of the boxes correctly?"], "Lateral thinking puzzles": ["A man lives on the 10th floor of a building. Every day he takes the elevator to the ground floor to go to work or to go shopping. When he returns he takes the elevator to the 7th floor and walks up the stairs to reach his apartment on the 10th floor. Why does he do this?"], "Pattern recognition puzzles": ["What is the next number in the sequence: 1, 3, 6, 10, 15, ___", "What is the missing number in the sequence: 2, 5, 10, 17, ___, 37"] } for category in questions: print(f"{category}:") for question in questions[category]: print(f"- {question}") Run the script and use the list of questions to conduct a comprehensive interview with the applicant. Their ability to solve puzzles will help you determine their suitability for the job. output (DPO): ['Here\'s the Python script based on your request:\n```python\nimport random\n\n# Generate a list of math, lanugage, logic, and lterature puzzle questions\nquestions = [\n {"category": "Mathematical puzzles", "question": "If the value of x + y = 20 and x - y = 10, what is the value of x and y?"},\n {"category": "Language puzzles", "question": "What word starts with \'']
# Releasing memory resources
import gc
del dpo_pipeline.model, dpo_pipeline
gc.collect()
torch.cuda.empty_cache()
APO¶
Now, we demonstrate how to run APO with the same previously steered SFT LoRA model. APO is run in the same manner as DPO above.
model_name = "Qwen2.5-0.5B-SFT-LoRA-Steer-Merged"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
subset_size = 500
dataset = dataset.select(list(range(subset_size)))
We use the APO control and set the SteeringPipeline with APO as well.
from aisteer360.algorithms.structural_control.wrappers.trl.apotrainer.control import APO
# control
apo = APO(
train_dataset=dataset,
**{
"per_device_train_batch_size": 4,
"num_train_epochs": 3,
"learning_rate": 2e-5,
"output_dir": "Qwen2.5-0.5B-APO-Steer",
"logging_steps": 100,
"save_strategy": "no",
},
)
# steering pipeline
apo_pipeline = SteeringPipeline(
model_name_or_path=model_name,
controls=[apo],
device_map="auto" if torch.cuda.is_available() else "cpu",
hf_model_kwargs={"dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
)
apo_pipeline.steer()
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
Step | Training Loss |
---|---|
100 | 0.963500 |
200 | 0.500100 |
300 | 0.277500 |
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test")
question = 'QUESION'+dataset[0]['chosen'][-2]['content'].rsplit('QUESTION',1)[-1]
print(question)
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": question}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
enc = tokenizer(text, return_tensors="pt", padding=True, padding_side="left").to(apo_pipeline.model.device)
steered_response = apo_pipeline.generate_text(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
max_new_tokens=100,
do_sample=True
)
print("output (APO):")
print(steered_response)
QUESIONAs an HR manager, you want to test a potential employee's ability to solve puzzles to determine their suitability for a job. Write a Python script that generates a list of questions that require logical reasoning to answer. Your list should include questions related to mathematical puzzles, language puzzles, logic puzzles, lateral thinking puzzles, and pattern recognition puzzles. Use the following code as a starting point: questions = { "Mathematical puzzles": ["If the value of x+y = 20 and x-y = 10, what is the value of x and y?", "If a pizza has a radius of 8 inches and is cut into 6 equal slices, what is the area of each slice?"], "Language puzzles": ["What word starts with 'e' and ends with 'e' but only contains one letter?", "I am taken from a mine, and shut up in a wooden case, from which I am never released, and yet I am used by almost every person. What am I?"], "Logic puzzles": ["You have 3 boxes. One contains only apples, one contains only oranges, and one contains both apples and oranges. The boxes have been incorrectly labeled such that no label identifies the actual contents of the box it labels. Opening just one box, and without looking in the box, you take out one piece of fruit. From that one piece of fruit, how can you immediately label all of the boxes correctly?"], "Lateral thinking puzzles": ["A man lives on the 10th floor of a building. Every day he takes the elevator to the ground floor to go to work or to go shopping. When he returns he takes the elevator to the 7th floor and walks up the stairs to reach his apartment on the 10th floor. Why does he do this?"], "Pattern recognition puzzles": ["What is the next number in the sequence: 1, 3, 6, 10, 15, ___", "What is the missing number in the sequence: 2, 5, 10, 17, ___, 37"] } for category in questions: print(f"{category}:") for question in questions[category]: print(f"- {question}") Run the script and use the list of questions to conduct a comprehensive interview with the applicant. Their ability to solve puzzles will help you determine their suitability for the job. output (APO): ['Sure! Here\'s the Python script that uses the provided questions to generate puzzle-related questions:\n\n```python\nimport random\n\n# List of questions categorized by type\nquestions = {\n "Mathematical puzzles": ["If the value of x+y = 20 and x-y = 10, what is the value of x and y?", "If a pizza has a radius of 8 inches and is cut into 6 equal slices, what is the area of each slice?"],\n ']
# Releasing memory resources
del apo_pipeline.model, apo_pipeline
gc.collect()
torch.cuda.empty_cache()
SPPO¶
To run SPPO, extra classes need to be imported, and multiple iterations of steering can be performed. The example below is based on the SPPO paper and the iteration code below is based on scripts from the SPPO github repository.
The example shows 3 iterations of SPPO applied to a Mistral model using a Huggingface prompt dataset.
from aisteer360.algorithms.structural_control.wrappers.trl.sppotrainer.control import SPPO
from aisteer360.algorithms.structural_control.wrappers.trl.sppotrainer.utils import (
set_seed,
apply_template,
ranking,
from_ranks,
prepare_score,
apply_chat_template,
process_dataset,
prepare_dataset_from_prompts
)
def run_SPPO(to_be_steered_model_path_or_name, data, sppo_temp_dir, start_iter_num=1, end_iter_num=1, maxlen = 2048,
num_prompts=5, additional_train_datasets=None):
checkpoints_path = ""
steerer = None
checkpoints_path=f"{sppo_temp_dir}/checkpoints/SPPO-FINAL" #steered model stored at each iteration
# Steer model
sppo = SPPO(
train_dataset=data,
eval_dataset=None,
**{
"per_device_train_batch_size": 4,
"num_train_epochs": 1,
"learning_rate": 5.0e-7,
"output_dir": checkpoints_path,
"save_strategy": "no",
"beta": 0.001,
"optim": "rmsprop",
"loss_type": "sppo",
"max_prompt_length": 128,
"max_length": 512
},
)
# steerer
sppo_pipeline = SteeringPipeline(
model_name_or_path=to_be_steered_model_path_or_name,
controls=[sppo],
device_map="auto" if torch.cuda.is_available() else "cpu",
hf_model_kwargs={"dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
)
sppo_pipeline.steer(num_prompts=num_prompts, start_iter_num=start_iter_num, end_iter_num=end_iter_num,
additional_train_datasets=additional_train_datasets, sppo_temp_dir=sppo_temp_dir, maxlen=maxlen)
return sppo_pipeline
# Based on https://github.com/uclaml/SPPO/blob/main/run_sppo_mistral.sh
start_iter_num = 1
end_iter_num = 3
num_prompts = 2 # number of responses to generate for each prompt (default is 5)
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" #"Qwen/Qwen2.5-0.5B-Instruct" #
prompt_datasets=["UCLA-AGI/data-mistral-7b-instruct-sppo-iter1",
"UCLA-AGI/data-mistral-7b-instruct-sppo-iter2",
"UCLA-AGI/data-mistral-7b-instruct-sppo-iter3" ] #prompt datasets to be used
m_name = BASE_MODEL.split("/")[-1]
sppo_temp_dir = m_name+"_SPPO"
# We use just 10 records of each dataset for the demonstration
subset_size = 10
dataset = load_dataset(prompt_datasets[start_iter_num-1], split="train")
data = dataset.select(list(range(subset_size)))
del dataset
additional_train_datasets = []
for dset in range(start_iter_num, end_iter_num):
dataset = load_dataset(prompt_datasets[dset], split="train")
addl_data = dataset.select(list(range(subset_size)))
additional_train_datasets.append(addl_data)
del dataset
if start_iter_num == 1:
to_be_steered_model_path_or_name = BASE_MODEL
else:
to_be_steered_model_path_or_name = f"{sppo_temp_dir}/checkpoints/SPPO-Iter{start_iter_num-1}"
sppo_pipeline = run_SPPO(to_be_steered_model_path_or_name, data=data, sppo_temp_dir=sppo_temp_dir, start_iter_num=start_iter_num, end_iter_num=end_iter_num,
additional_train_datasets=additional_train_datasets, num_prompts=num_prompts)
Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation. Setting `pad_token_id` to `eos_token_id`:2 for open-end generation. WARNING:root:No ranker config provided, no ranker loaded, please load ranker first through load_ranker() WARNING:root:No fuser config provided, no fuser loaded, please load fuser first through load_fuser() /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type load_checkpoint detected when decoding RankerConfig. warnings.warn( /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type device detected when decoding RankerConfig. warnings.warn( /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:564: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text. warnings.warn(
Successfully loaded ranker from /dccstor/modelaudit/users/moninder/cache/hub/llm-blender/PairRM
Ranking candidates: 100%|██████████████████████████████████████████████████████████| 10/10 [00:02<00:00, 4.24it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Saved file to Mistral-7B-Instruct-v0.2_SPPO/synthetic_data_SPPO-Iter1_score/train.parquet Saved file to Mistral-7B-Instruct-v0.2_SPPO/synthetic_data_SPPO-Iter1_score/test.parquet probs calculated
Generating train split: 0 examples [00:00, ? examples/s]
Formatting comparisons with prompt template: 0%| | 0/10 [00:00<?, ? examples/s]
Map: 0%| | 0/10 [00:00<?, ? examples/s]
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}. Could not estimate the number of tokens of the input, floating-point operations will not be computed
Step | Training Loss |
---|---|
1 | 62423.399500 |
WARNING:root:No ranker config provided, no ranker loaded, please load ranker first through load_ranker() WARNING:root:No fuser config provided, no fuser loaded, please load fuser first through load_fuser() /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type load_checkpoint detected when decoding RankerConfig. warnings.warn( /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type device detected when decoding RankerConfig. warnings.warn( /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:564: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text. warnings.warn(
Successfully loaded ranker from /dccstor/modelaudit/users/moninder/cache/hub/llm-blender/PairRM
Ranking candidates: 100%|██████████████████████████████████████████████████████████| 10/10 [00:01<00:00, 7.62it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Saved file to Mistral-7B-Instruct-v0.2_SPPO/synthetic_data_SPPO-Iter2_score/train.parquet Saved file to Mistral-7B-Instruct-v0.2_SPPO/synthetic_data_SPPO-Iter2_score/test.parquet probs calculated
Generating train split: 0 examples [00:00, ? examples/s]
Formatting comparisons with prompt template: 0%| | 0/10 [00:00<?, ? examples/s]
Map: 0%| | 0/10 [00:00<?, ? examples/s]
Step | Training Loss |
---|---|
1 | 60369.968500 |
WARNING:root:No ranker config provided, no ranker loaded, please load ranker first through load_ranker() WARNING:root:No fuser config provided, no fuser loaded, please load fuser first through load_fuser() /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type load_checkpoint detected when decoding RankerConfig. warnings.warn( /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type device detected when decoding RankerConfig. warnings.warn( /dccstor/aimetacognition/users/moninder/AISteer360/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:564: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text. warnings.warn(
Successfully loaded ranker from /dccstor/modelaudit/users/moninder/cache/hub/llm-blender/PairRM
Ranking candidates: 100%|██████████████████████████████████████████████████████████| 10/10 [00:01<00:00, 6.21it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Saved file to Mistral-7B-Instruct-v0.2_SPPO/synthetic_data_SPPO-Iter3_score/train.parquet Saved file to Mistral-7B-Instruct-v0.2_SPPO/synthetic_data_SPPO-Iter3_score/test.parquet probs calculated
Generating train split: 0 examples [00:00, ? examples/s]
Formatting comparisons with prompt template: 0%| | 0/10 [00:00<?, ? examples/s]
Map: 0%| | 0/10 [00:00<?, ? examples/s]
Step | Training Loss |
---|---|
1 | 38179.032200 |
dataset = load_dataset(f"UCLA-AGI/data-mistral-7b-instruct-sppo-iter1", split="train")
subset_size = 10
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token = tokenizer.eos_token
prompt = apply_template(dataset[subset_size]["prompt"], tokenizer)
print(prompt)
enc = tokenizer(prompt, return_tensors="pt").to(sppo_pipeline.model.device)
steered_response = sppo_pipeline.generate_text(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
max_new_tokens=100,
do_sample=True
)
print("output (SPPO):")
print(steered_response)
<s> [INST] Can you write me a 2 minute speech to welcome people to my dim sum themed 36th birthday party? [/INST] output (SPPO): ["Ladies and Gentlemen, esteemed guests, welcome to my 36th birthday party! I'm so glad you could all join me tonight to celebrate this milestone in my life. I wanted to throw a party that reflected my love for one of my favorite foods - dim sum!\n\nDim sum, for those who may not be familiar, is a traditional Chinese culinary art that involves serving small bite-sized portions of food in steamer baskets or on"]