SASA
(Image from Ko et al., 2025)
Large language models can become strong self-detoxifiers¶
Authors: Ching-Yun Ko, Pin-Yu Chen, Payel Das, Youssef Mroueh, Soham Dan, Georgios Kollias, Subhajit Chaudhury, Tejaswini Pedapati, Luca Daniel
SASA is an output steering method, enabling the users to perform controlled decoding given any desirable value attributes.
SASA leverages the contextual representations from an LLM to learn linear subspaces from labeled data, e.g. characterizing toxic v.s. non-toxic output in analytical forms. When auto-completing a response token-by-token, SASA dynamically tracks the margin of the current output to steer the generation away from the toxic subspace, by adjusting the autoregressive sampling strategy.
In this demo, we show how SASA can be used to reduce the toxicity of sentences generated by an LLM.
Method parameters¶
parameter | type | description |
---|---|---|
beta |
float |
Scaling coefficient for value redistribution. Must be non-negative. |
wv_path |
Optional[str] |
Path to a saved steering-vector tensor. Must end with .pt if provided. |
gen_wv_data_path |
Optional[str] |
Path to the value dataset, e.g. sentences with labeled toxicity. |
gen_wv_length |
Optional[int] |
Maximum number of samples used for preparing SASA steering if wv_path does not exist. |
gen_wv_batch_size |
Optional[int] |
Batch size used for preparing SASA steering if wv_path does not exist. Must be non-negative if wv_path is None . |
The following authentication steps may be necessary to access any gated models (even after being granted access by Hugging Face). Uncomment the following if you need to log in to the Hugging Face Hub:
!pip install python-dotenv
from dotenv import load_dotenv
import os
load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")
from huggingface_hub import login
login(token=token)
Requirement already satisfied: python-dotenv in /dccstor/larimar/irene/poetry-venvs/aisteer360-9MjJc3W9-py3.12/lib/python3.12/site-packages (1.1.0) [notice] A new release of pip is available: 25.0.1 -> 25.2 [notice] To update, run: pip install --upgrade pip
Example: Steering for reduced toxicity¶
from transformers import AutoModelForCausalLM, AutoTokenizer
from aisteer360.algorithms.core.steering_pipeline import SteeringPipeline
from aisteer360.algorithms.output_control.sasa.control import SASA
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
Creating the control¶
We initialize the SASA method with specified parameters.
Below, beta
represents the steering strength with 0
being the original decoding. In the reference paper, authors have tried beta
as big as 500
.
SASA requires contructing the value subspace prior to the steering. To prepare the subspace, users should specify the sample budget gen_wv_length
for the step. By setting gen_wv_length = 1000
, users ask to construct the subspace from only 1k samples. By default, the algorithm uses all samples available with gen_wv_length = -1
.
gen_wv_batch_size
represents the batch size used during this step. Users may also adjust it according to their computational resources.
Note: You can adjust beta
, gen_wv_length
, gen_wv_batch_size
as needed.
sasa = SASA(
beta=10,
gen_wv_length=100,
gen_wv_batch_size=8,
gen_wv_data_path="../../Jigsaw_data"
)
Downloading data¶
By default, the toxicity subspace is constructed using the Jigsaw dataset from Kaggle. To use jigsaw_unintended_bias
you can either download it manually from Kaggle (https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data) or uncomment and run the following cell using the Kaggle API (https://www.kaggle.com/docs/api). Either way, all files should be extracted to one folder, e.g. './tmp/Jigsaw_data/all_data.csv'
.
Automated download instructions (optional)¶
To access your Kaggle token (for downloading data using the API tool), first sign in at kaggle.com. Then:
- Click your profile photo -> "Your Profile" -> "Settings"
- Scroll to API and click "Create New Token"
- Your browser immediately downloads
kaggle.json
Place the json in the kaggle directory in root (typically ~/.config/kaggle/
) and execute the following script.
Note: If you encounter an error 403 (permission error), please ensure that you have clicked "Join the competition" under the "Data" tab on the dataset homepage.
import sys
!{sys.executable} -m ensurepip --upgrade
!{sys.executable} -m pip install --upgrade pip setuptools wheel
!{sys.executable} -m pip install kaggle
import os, glob, zipfile, shutil, pandas as pd
from pathlib import Path
from kaggle.api.kaggle_api_extended import KaggleApi
DATA_DIR = Path("tmp/Jigsaw_data")
DATA_DIR.mkdir(parents=True, exist_ok=True)
api = KaggleApi(); api.authenticate()
api.competition_download_files(
"jigsaw-unintended-bias-in-toxicity-classification",
path=str(DATA_DIR),
force=True,
quiet=False
)
zip_path = glob.glob(str(DATA_DIR / "*.zip"))[0]
with zipfile.ZipFile(zip_path) as z:
z.extractall(DATA_DIR)
train = pd.read_csv(DATA_DIR / "train.csv")
test = pd.read_csv(DATA_DIR / "test.csv")
label_paths = [
p for p in (
DATA_DIR / "test_public_expanded.csv",
DATA_DIR / "test_private_expanded.csv",
DATA_DIR / "test_labels.csv"
) if p.exists()
]
if label_paths:
lbl = pd.concat([pd.read_csv(p) for p in label_paths])
test = test.merge(lbl[["id", "toxicity"]], on="id", how="left")
out_csv = DATA_DIR / "all_data.csv"
pd.concat([train, test]).to_csv(out_csv, index=False)
# cleanup
os.remove(zip_path)
for p in DATA_DIR.iterdir():
if p.resolve() != out_csv.resolve():
(p.unlink() if p.is_file() else shutil.rmtree(p))
If value subspace is available, users can skip these parameters and pass the path to the subspace via wv_path
.
sasa = SASA(
beta=10,
wv_path='./steer_wv.pt',
)
Creating the steering pipeline¶
We create a SteeringPipeline
with the SASA
control.
sasa_pipeline = SteeringPipeline(
model_name_or_path=MODEL_NAME,
controls=[sasa],
device_map="cuda",
hf_model_kwargs={"low_cpu_mem_usage": True},
)
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Next we steer the pipeline (under the single SASA control). Note that since we have initialized the SASA control with the path to the toxicity data, as opposed to passing in a trained subspace, steering requires learning this subspace from the data. This is resource-heavy step (GPU required).
sasa_pipeline.steer()
Loading SASA steer (wv)......
After steering, inference can now be run on the pipeline for a given prompt. We define a prompt that attempts to induce toxic behavior in the model.
PROMPT = "I hate you. You are being like a drama queen who seeks attention all the time. I wish you"
Passing this prompt into the unsteered (base) model demonstrates a toxic continuation.
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
inputs = tokenizer(PROMPT, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
gen_params = {
"max_new_tokens": 30,
"temperature": 0.6,
"top_p": 0.9,
"do_sample": True,
"repetition_penalty": 1.05,
}
baseline_outputs = model.generate(
**inputs.to(model.device),
**gen_params
)
# BASELINE GNEERATE CODE
print(f"\nResponse (baseline):\n")
print(tokenizer.decode(baseline_outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True))
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Response (baseline): would just be normal for once. You may feel angry or upset with your partner, but calling them names and telling them how you feel is not the
Compare this with the response of the base model when steered using SASA (via the steering pipeline).
steered_output_ids = sasa_pipeline.generate(
input_ids=input_ids,
attention_mask=attention_mask,
runtime_kwargs={},
**gen_params,
)
print(f"\nResponse (SASA):\n")
print(tokenizer.decode(steered_output_ids[0], skip_special_tokens=True))
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. `sdpa` attention does not support `output_attentions=True` or `head_mask`. Please set your attention to `eager` if you want any of these features.
Response (SASA): would just calm down and talk to me normally. I don’t want to fight with you. I want us to get along. I want to have
Lastly, note that the beta parameter dictates the strength of the steering, and can thus be adjusted to control the degree of toxicity suppression in the generated response (importantly without having to relearn the subspace).
sasa = SASA(
beta=0,
wv_path='./steer_wv.pt', # we just saved the subspace in the preparation steps above
)
sasa_pipeline = SteeringPipeline(
model_name_or_path=MODEL_NAME,
controls=[sasa],
device_map="cpu",
hf_model_kwargs={"low_cpu_mem_usage": True},
)
sasa_pipeline.steer()
original_output_ids = sasa_pipeline.generate(
input_ids=input_ids,
attention_mask=attention_mask,
runtime_kwargs={},
**gen_params,
)
original_response = tokenizer.decode(original_output_ids[0], skip_special_tokens=True)
print(f"Steered response: {original_response}")