RAD
Reward-Augmented Decoding: Efficient Controlled Text Generation With a Unidirectional Reward Model¶
Authors: Haikang Deng, Colin Raffel
RAD is an output steering method, enabling the users to perform controlled text generation with a unidirectional reward model.
In this demo, we show how RAD can be used to reduce the toxicity of sentences generated by an LLM.
The following authentication steps may be necessary to access any gated models (even after being granted access by Hugging Face). Uncomment the following if you need to log in to the Hugging Face Hub using your token stored in the .env
file:
# !pip install python-dotenv
# from dotenv import load_dotenv
# import os
# load_dotenv()
# token = os.getenv("HUGGINGFACE_TOKEN")
# from huggingface_hub import login
# login(token=token)
Example: Steering for reduced toxicity¶
from aisteer360.algorithms.core.steering_pipeline import SteeringPipeline
from aisteer360.algorithms.output_control.rad.control import RAD
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
MODEL_NAME = "gpt2-large"
We initialize the RAD method with specified parameters.
Below, beta
represents the steering strength with 0
being the original decoding.
RAD requires a trained reward model. In this demo, we will use the toxicity reward model provided by the authors. Please pass the path to the reward model via reward_path
. If you don't pass the path, the reward model will be automatically downloaded to './rad_saved_models/saved_models/gpt2_toxicity'. To train your own reward model, please see https://github.com/r-three/RAD?tab=readme-ov-file for details.
Note: You can adjust beta
as needed.
rad = RAD(
beta=10,
)
If the reward model is already downloaded, please pass the path via reward_path
.
# rad = RAD(
# beta=10,
# reward_path='./rad_saved_models/saved_models/gpt2_toxicity',
# )
We create the SteeringPipeline
object as follows.
rad_pipeline = SteeringPipeline(
model_name_or_path=MODEL_NAME,
controls=[rad],
device="cuda",
hf_model_kwargs={"low_cpu_mem_usage": True},
)
rad_pipeline.steer()
Downloading... From: https://storage.googleapis.com/rad_release/saved_models.zip To: /dccstor/principled_ai/users/erikmiehling/AISteer360/tmp/rad_saved_models.zip
Reward model not found in: None. Downloading from https://github.com/r-three/RAD......
100%|███████████████████████████████████████████████████████| 925M/925M [00:04<00:00, 200MB/s]
Reward model downloaded. Please set reward_path='./tmp/rad_saved_models/saved_models/gpt2_toxicity' in the future. Reward model is loaded.
Controlled text generation via RAD steering¶
Now, let's create a prompt for generation.
tokenizer = rad_pipeline.tokenizer
PROMPT = "I hate you. You are being like a drama queen who seeks attention all the time. I wish you"
enc = tokenizer(PROMPT, return_tensors="pt")
input_ids = enc.input_ids
attention_mask = enc.attention_mask
We define the generation hyperparameters as follows, and generate the sentence continuation with RAD steering.
gen_params = {
"max_new_tokens": 20,
"temperature": 0.6,
"top_p": 0.9,
"do_sample": True,
"repetition_penalty": 1.2,
}
steered_output_ids = rad_pipeline.generate(
input_ids=input_ids,
attention_mask=attention_mask,
runtime_kwargs={},
**gen_params,
)
steered_response = tokenizer.decode(steered_output_ids[0], skip_special_tokens=True)
print(f"Steered response (RAD, beta=10): {steered_response}")
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Steered response (RAD, beta=10): would just go away." "You're right," he said, "but it's not
Comparison (Optional)¶
Users can also readily compare the continuation generation without RAD steering by setting beta = 0
.
rad = RAD(
beta=0,
reward_path='./tmp/rad_saved_models/saved_models/gpt2_toxicity',
)
rad_pipeline = SteeringPipeline(
model_name_or_path=MODEL_NAME,
controls=[rad],
device="cuda",
hf_model_kwargs={"low_cpu_mem_usage": True},
)
rad_pipeline.steer()
original_output_ids = rad_pipeline.generate(
input_ids=input_ids,
attention_mask=attention_mask,
runtime_kwargs={},
**gen_params,
)
original_response = tokenizer.decode(original_output_ids[0], skip_special_tokens=True)
print(f"Steered response (RAD, beta=0): {original_response}")
Reward model found in: ./tmp/rad_saved_models/saved_models/gpt2_toxicity
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Reward model is loaded. Steered response (RAD, beta=0): would just shut up and go away." "You're right," said the woman, "