Commonsense MCQA
Commonsense MCQA¶
This notebook provides a walkthrough of building a benchmark for steering improved performance on the CommonsenseQA problem set. The benchmark will compare three steering pipelines: the unsteered behavior (baseline model), few shot steering, and steering via a LoRA adapter.
For convenience, change the current directory to the notebook if necessary:
import os
os.chdir("./notebooks/benchmarks/commonsense_mcqa/")
from aisteer360.evaluation.use_cases.commonsense_mcqa.use_case import CommonsenseMCQA
from aisteer360.evaluation.metrics.custom.commonsense_mcqa.mcqa_accuracy import MCQAAccuracy
from aisteer360.evaluation.metrics.custom.commonsense_mcqa.mcqa_positional_bias import MCQAPositionalBias
commonsense_mcqa = CommonsenseMCQA(
evaluation_data="./data/evaluation_qa.jsonl",
evaluation_metrics=[
MCQAAccuracy(),
MCQAPositionalBias(),
],
num_shuffling_runs=20,
num_samples=500 # optional
)
/dccstor/principled_ai/users/erikmiehling/AISteer360/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Two custom metrics have been created for the use case: MCQAAccuracy
which measures the accuracy statistics of each question (across trials), and MCQAPositionalBias
which measures the positional bias (via deviation from the uniform distribution across runs). To facilitate computation of these statistics, the use case accepts a keyword argument num_shuffling_runs
dictating how many times each question should be presented to the (steered) model under a randomized ordering of the choices. The num_samples
parameter dictates how many entries from evaluation_data
are used during benchmarking.
Defining the controls¶
The benchmark aims to compare two controls using common steering data.
import json
steering_data_path = "data/steer_qa.jsonl"
with open(steering_data_path, "r") as f:
steering_data = [json.loads(line) for line in f]
steering_data[0]
{'id': '01beaf20-82aa-40b0-8b08-ee08b94e6666', 'question': 'The spirit ascended to the after life, so what was it leaving?', 'answer_chosen': 'human being', 'answer_rejected': 'cemetary'}
The steering data consists of triples (question, answer_chosen, answer_rejected)
extracted from the CommonsenseQA dataset where answer_chosen
is the ground-truth answer and answer_rejected
is a randomly selected incorrect answer. Both controls (FewShot
and LoRA
) are based on the same steering data.
Defining the few shot control¶
The FewShot
control requires specification of example pools. As shown below, each positive example is given by the pair (question
,answer_chosen
) whereas each negative example is given by the pair (question
,answer_rejected
).
positive_pool = []
negative_pool = []
for row in steering_data:
positive_pool.append({
"question": row["question"],
"answer": row["answer_chosen"]
})
negative_pool.append({
"question": row["question"],
"answer": row["answer_rejected"]
})
These pools are then passed in to the FewShot
class upon instantiation, along with the name of the example selector (how examples are drawn from the pools; defaults to random
), and the counts for how many positive and negative examples the selector should draw from the pool.
from aisteer360.algorithms.input_control.few_shot.control import FewShot
few_shot = FewShot(
selector_name="random",
positive_example_pool=positive_pool,
negative_example_pool=negative_pool,
k_positive=25,
k_negative=25
)
Defining the DPO (with LoRA) control¶
from datasets import Dataset
from peft import PeftType
from aisteer360.algorithms.structural_control.wrappers.trl.dpotrainer.control import DPO
train_examples = []
for row in steering_data:
train_examples.append({
"prompt": row['question'],
"chosen": row['answer_chosen'],
"rejected": row['answer_rejected']
})
train_ds = Dataset.from_list(train_examples)
# instantiate dpo control
dpo_lora = DPO(
train_dataset=train_ds,
use_peft=True,
peft_type=PeftType.LORA,
**{
"per_device_train_batch_size": 4,
"num_train_epochs": 2,
"learning_rate": 2e-5,
"output_dir": "trl_models/Qwen2.5-0.5B-DPO-Lora-Steer",
"logging_steps": 100,
"save_strategy": "no",
},
)
Instantiating (and running) the benchmark¶
Given the controls, the benchmark can now be run on any control pipelines, i.e., sequence of controls. In the following benchmark, we compare the unsteered baseline behavior (no control) with few-shot and DPO (with LoRA).
import transformers
from aisteer360.evaluation.benchmark import Benchmark
transformers.logging.set_verbosity_error()
benchmark = Benchmark(
use_case=commonsense_mcqa,
base_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
steering_pipelines={
"baseline": [], # no steering
"few_shot": [few_shot],
"dpo_lora": [dpo_lora],
},
gen_kwargs={
"max_new_tokens": 300,
"do_sample": True,
"temperature": 0.7,
},
device_map="auto"
)
# run and plot/export
profiles = benchmark.run()
Running pipeline: baseline...
done. Running pipeline: few_shot... done. Running pipeline: dpo_lora...
Extracting prompt in train dataset: 100%|██████████| 4871/4871 [00:00<00:00, 17093.58 examples/s] Applying chat template to train dataset: 100%|██████████| 4871/4871 [00:00<00:00, 19933.53 examples/s] Tokenizing train dataset: 100%|██████████| 4871/4871 [00:01<00:00, 3318.00 examples/s]
{'loss': 0.6844, 'grad_norm': 1.3753544092178345, 'learning_rate': 1.9187192118226602e-05, 'rewards/chosen': 0.05939213186502457, 'rewards/rejected': 0.04117845371365547, 'rewards/accuracies': 0.6549999713897705, 'rewards/margins': 0.018213678151369095, 'logps/chosen': -42.31070327758789, 'logps/rejected': -45.085453033447266, 'logits/chosen': 0.8083691596984863, 'logits/rejected': 0.904447078704834, 'epoch': 0.08210180623973727} {'loss': 0.649, 'grad_norm': 1.7281113862991333, 'learning_rate': 1.836617405582923e-05, 'rewards/chosen': 0.3236657977104187, 'rewards/rejected': 0.21876558661460876, 'rewards/accuracies': 0.6924999952316284, 'rewards/margins': 0.10490025579929352, 'logps/chosen': -39.682899475097656, 'logps/rejected': -43.24220657348633, 'logits/chosen': 0.7636573910713196, 'logits/rejected': 0.8838378190994263, 'epoch': 0.16420361247947454} {'loss': 0.5743, 'grad_norm': 2.3698391914367676, 'learning_rate': 1.7545155993431858e-05, 'rewards/chosen': 0.599941074848175, 'rewards/rejected': 0.27951279282569885, 'rewards/accuracies': 0.7475000023841858, 'rewards/margins': 0.3204283118247986, 'logps/chosen': -36.646759033203125, 'logps/rejected': -42.622230529785156, 'logits/chosen': 0.43714645504951477, 'logits/rejected': 0.5537123084068298, 'epoch': 0.24630541871921183} {'loss': 0.5411, 'grad_norm': 3.87504243850708, 'learning_rate': 1.6724137931034485e-05, 'rewards/chosen': 0.5506773591041565, 'rewards/rejected': 0.09627580642700195, 'rewards/accuracies': 0.7574999928474426, 'rewards/margins': 0.4544014632701874, 'logps/chosen': -37.22932434082031, 'logps/rejected': -43.980751037597656, 'logits/chosen': -0.4241505563259125, 'logits/rejected': -0.3645896911621094, 'epoch': 0.3284072249589491} {'loss': 0.4687, 'grad_norm': 3.621622323989868, 'learning_rate': 1.590311986863711e-05, 'rewards/chosen': 0.6624946594238281, 'rewards/rejected': -0.054648738354444504, 'rewards/accuracies': 0.8050000071525574, 'rewards/margins': 0.7171434164047241, 'logps/chosen': -35.80455017089844, 'logps/rejected': -45.6873779296875, 'logits/chosen': -1.308283805847168, 'logits/rejected': -1.2569398880004883, 'epoch': 0.41050903119868637} {'loss': 0.445, 'grad_norm': 6.394109725952148, 'learning_rate': 1.5082101806239739e-05, 'rewards/chosen': 0.7101463079452515, 'rewards/rejected': -0.11587373912334442, 'rewards/accuracies': 0.8224999904632568, 'rewards/margins': 0.8260201215744019, 'logps/chosen': -35.68978500366211, 'logps/rejected': -46.368072509765625, 'logits/chosen': -1.9650485515594482, 'logits/rejected': -1.9396729469299316, 'epoch': 0.49261083743842365} {'loss': 0.4157, 'grad_norm': 4.470280170440674, 'learning_rate': 1.4261083743842366e-05, 'rewards/chosen': 0.8472312688827515, 'rewards/rejected': -0.11590781062841415, 'rewards/accuracies': 0.8349999785423279, 'rewards/margins': 0.9631390571594238, 'logps/chosen': -34.24272537231445, 'logps/rejected': -46.190650939941406, 'logits/chosen': -2.260483503341675, 'logits/rejected': -2.2061336040496826, 'epoch': 0.5747126436781609} {'loss': 0.4027, 'grad_norm': 2.099369525909424, 'learning_rate': 1.3440065681444994e-05, 'rewards/chosen': 1.024877667427063, 'rewards/rejected': -0.10915086418390274, 'rewards/accuracies': 0.8525000214576721, 'rewards/margins': 1.1340285539627075, 'logps/chosen': -32.776031494140625, 'logps/rejected': -46.5506591796875, 'logits/chosen': -2.359621524810791, 'logits/rejected': -2.2885966300964355, 'epoch': 0.6568144499178982} {'loss': 0.4056, 'grad_norm': 6.488311290740967, 'learning_rate': 1.261904761904762e-05, 'rewards/chosen': 1.1096539497375488, 'rewards/rejected': -0.06669649481773376, 'rewards/accuracies': 0.8274999856948853, 'rewards/margins': 1.1763503551483154, 'logps/chosen': -31.663969039916992, 'logps/rejected': -45.84549331665039, 'logits/chosen': -2.4250757694244385, 'logits/rejected': -2.4624521732330322, 'epoch': 0.7389162561576355} {'loss': 0.3683, 'grad_norm': 1.8839678764343262, 'learning_rate': 1.1798029556650248e-05, 'rewards/chosen': 1.2118427753448486, 'rewards/rejected': -0.06144392862915993, 'rewards/accuracies': 0.8525000214576721, 'rewards/margins': 1.2732867002487183, 'logps/chosen': -30.214393615722656, 'logps/rejected': -45.74881362915039, 'logits/chosen': -2.451375722885132, 'logits/rejected': -2.4819979667663574, 'epoch': 0.8210180623973727} {'loss': 0.4011, 'grad_norm': 7.611402988433838, 'learning_rate': 1.0977011494252875e-05, 'rewards/chosen': 1.1366114616394043, 'rewards/rejected': -0.25508174300193787, 'rewards/accuracies': 0.8299999833106995, 'rewards/margins': 1.3916932344436646, 'logps/chosen': -31.509992599487305, 'logps/rejected': -47.90867614746094, 'logits/chosen': -2.359018087387085, 'logits/rejected': -2.4500184059143066, 'epoch': 0.90311986863711} {'loss': 0.3523, 'grad_norm': 3.5366764068603516, 'learning_rate': 1.0155993431855503e-05, 'rewards/chosen': 1.303938627243042, 'rewards/rejected': -0.23387208580970764, 'rewards/accuracies': 0.8475000262260437, 'rewards/margins': 1.5378106832504272, 'logps/chosen': -29.902851104736328, 'logps/rejected': -47.23167037963867, 'logits/chosen': -2.3249595165252686, 'logits/rejected': -2.4316558837890625, 'epoch': 0.9852216748768473} {'loss': 0.3271, 'grad_norm': 5.298281192779541, 'learning_rate': 9.334975369458129e-06, 'rewards/chosen': 1.2465260028839111, 'rewards/rejected': -0.4327118992805481, 'rewards/accuracies': 0.8600000143051147, 'rewards/margins': 1.679237961769104, 'logps/chosen': -30.30620574951172, 'logps/rejected': -49.82505798339844, 'logits/chosen': -2.3140110969543457, 'logits/rejected': -2.432046890258789, 'epoch': 1.0673234811165846} {'loss': 0.3098, 'grad_norm': 4.140628337860107, 'learning_rate': 8.513957307060756e-06, 'rewards/chosen': 1.3219729661941528, 'rewards/rejected': -0.5507988929748535, 'rewards/accuracies': 0.8849999904632568, 'rewards/margins': 1.8727718591690063, 'logps/chosen': -29.239551544189453, 'logps/rejected': -50.61967849731445, 'logits/chosen': -2.33890438079834, 'logits/rejected': -2.476963996887207, 'epoch': 1.1494252873563218} {'loss': 0.2842, 'grad_norm': 4.0794172286987305, 'learning_rate': 7.692939244663384e-06, 'rewards/chosen': 1.4600448608398438, 'rewards/rejected': -0.46452566981315613, 'rewards/accuracies': 0.8799999952316284, 'rewards/margins': 1.9245705604553223, 'logps/chosen': -27.850683212280273, 'logps/rejected': -49.603118896484375, 'logits/chosen': -2.303694009780884, 'logits/rejected': -2.4337828159332275, 'epoch': 1.2315270935960592} {'loss': 0.3264, 'grad_norm': 4.242257595062256, 'learning_rate': 6.87192118226601e-06, 'rewards/chosen': 1.3942707777023315, 'rewards/rejected': -0.4328542649745941, 'rewards/accuracies': 0.8575000166893005, 'rewards/margins': 1.8271249532699585, 'logps/chosen': -29.03736114501953, 'logps/rejected': -49.29841613769531, 'logits/chosen': -2.25759220123291, 'logits/rejected': -2.4029996395111084, 'epoch': 1.3136288998357963} {'loss': 0.3289, 'grad_norm': 8.501762390136719, 'learning_rate': 6.050903119868637e-06, 'rewards/chosen': 1.4935081005096436, 'rewards/rejected': -0.34331658482551575, 'rewards/accuracies': 0.8525000214576721, 'rewards/margins': 1.8368247747421265, 'logps/chosen': -28.000520706176758, 'logps/rejected': -48.63041687011719, 'logits/chosen': -2.275967597961426, 'logits/rejected': -2.4102323055267334, 'epoch': 1.3957307060755337} {'loss': 0.3068, 'grad_norm': 11.254813194274902, 'learning_rate': 5.2298850574712646e-06, 'rewards/chosen': 1.5853734016418457, 'rewards/rejected': -0.3768988847732544, 'rewards/accuracies': 0.8550000190734863, 'rewards/margins': 1.962272047996521, 'logps/chosen': -27.204883575439453, 'logps/rejected': -49.41069412231445, 'logits/chosen': -2.253641366958618, 'logits/rejected': -2.381497621536255, 'epoch': 1.477832512315271} {'loss': 0.3032, 'grad_norm': 2.0862481594085693, 'learning_rate': 4.408866995073892e-06, 'rewards/chosen': 1.6218609809875488, 'rewards/rejected': -0.39557814598083496, 'rewards/accuracies': 0.875, 'rewards/margins': 2.017439126968384, 'logps/chosen': -26.326683044433594, 'logps/rejected': -49.10151290893555, 'logits/chosen': -2.1722238063812256, 'logits/rejected': -2.31107759475708, 'epoch': 1.5599343185550083} {'loss': 0.305, 'grad_norm': 6.3148417472839355, 'learning_rate': 3.587848932676519e-06, 'rewards/chosen': 1.5781724452972412, 'rewards/rejected': -0.5550009608268738, 'rewards/accuracies': 0.8799999952316284, 'rewards/margins': 2.1331734657287598, 'logps/chosen': -27.26508140563965, 'logps/rejected': -51.44060516357422, 'logits/chosen': -2.2129950523376465, 'logits/rejected': -2.3697586059570312, 'epoch': 1.6420361247947455} {'loss': 0.3646, 'grad_norm': 21.627426147460938, 'learning_rate': 2.7668308702791464e-06, 'rewards/chosen': 1.538847804069519, 'rewards/rejected': -0.308122456073761, 'rewards/accuracies': 0.8450000286102295, 'rewards/margins': 1.8469702005386353, 'logps/chosen': -27.825849533081055, 'logps/rejected': -48.34164047241211, 'logits/chosen': -2.143977165222168, 'logits/rejected': -2.2854137420654297, 'epoch': 1.7241379310344827} {'loss': 0.3131, 'grad_norm': 7.732245445251465, 'learning_rate': 1.9458128078817736e-06, 'rewards/chosen': 1.5555739402770996, 'rewards/rejected': -0.4401043653488159, 'rewards/accuracies': 0.8500000238418579, 'rewards/margins': 1.9956783056259155, 'logps/chosen': -27.386014938354492, 'logps/rejected': -49.28156280517578, 'logits/chosen': -2.1641993522644043, 'logits/rejected': -2.2797179222106934, 'epoch': 1.80623973727422} {'loss': 0.258, 'grad_norm': 10.881897926330566, 'learning_rate': 1.1247947454844007e-06, 'rewards/chosen': 1.6380615234375, 'rewards/rejected': -0.5936204791069031, 'rewards/accuracies': 0.9125000238418579, 'rewards/margins': 2.231682062149048, 'logps/chosen': -26.40484619140625, 'logps/rejected': -51.505836486816406, 'logits/chosen': -2.1259610652923584, 'logits/rejected': -2.286334991455078, 'epoch': 1.8883415435139574} {'loss': 0.3171, 'grad_norm': 11.305817604064941, 'learning_rate': 3.0377668308702795e-07, 'rewards/chosen': 1.6640433073043823, 'rewards/rejected': -0.42858800292015076, 'rewards/accuracies': 0.8700000047683716, 'rewards/margins': 2.0926313400268555, 'logps/chosen': -25.379392623901367, 'logps/rejected': -49.155235290527344, 'logits/chosen': -2.116882562637329, 'logits/rejected': -2.2833025455474854, 'epoch': 1.9704433497536946} {'train_runtime': 473.6727, 'train_samples_per_second': 20.567, 'train_steps_per_second': 5.143, 'train_loss': 0.3921107689931083, 'epoch': 2.0} done.
benchmark.export(profiles, save_dir="./profiles/")
Inspecting the profiles¶
Each control pipeline in the benchmark yields an evaluation profile. Each evaluation profile contains metric values as computed by the metrics passed in to the use case, in this case MCQAAccuracy
and MCQAPositionalBias
.
import json
print(json.dumps(profiles['baseline']['evaluations'], indent=2))
{ "MCQAAccuracy": { "trial_mean": 0.5058, "trial_std": 0.4999913590612424, "question_mean": 0.446, "question_std": 0.4975732692947166 }, "MCQAPositionalBias": { "mean": 0.10532592592592592, "std": 0.031184296794208543 } }
print(json.dumps(profiles['few_shot']['evaluations'], indent=2))
{ "MCQAAccuracy": { "trial_mean": 0.7007, "trial_std": 0.4579743268441779, "question_mean": 0.716, "question_std": 0.45138841700470494 }, "MCQAPositionalBias": { "mean": 0.014360000000000001, "std": 0.03518742260266924 } }
print(json.dumps(profiles['dpo_lora']['evaluations'], indent=2))
{ "MCQAAccuracy": { "trial_mean": 0.5352, "trial_std": 0.4987843608051961, "question_mean": 0.482, "question_std": 0.5001763216161011 }, "MCQAPositionalBias": { "mean": 0.07606851211072664, "std": 0.006019957412877974 } }
We can see that FewShot
(using 25 positive/negative examples) yields the best improvement over baseline. The DPO
(with LoRA) control yields a marginal improvement over the baseline, likely because of the small (5k) steering dataset.