FewShot
Few-shot steering of model behavior¶
Few-shot learning is a simple, yet surprisingly effective, input control method for steering a language model's behavior by including some examples of desirable behavior in the prompt (Brown et al., 2020; Zhao et al., 2021). This notebook illustrates how few-shot learning is implemented in the toolkit (via the FewShot
class). The toolkit contains few-shot steering under the following two modes:
Runtime Examples Mode: Passes specific examples directly at generation time via
runtime_kwargs
.Pool Sampling Mode: Defines (positive and negative) example pools during initialization, along with a sampler, and samples a specified number of examples from the pools at runtime.
In this demo, we'll show how FewShot
can be used to steer a model to respond more concisely.
Method parameters¶
parameter | type | description |
---|---|---|
selector_name |
str \| None |
Name of the example selector to use. If None , uses random selection. Must be "random" if provided. |
template |
str \| None |
Custom template for the system prompt. Use {example_blocks} and {directive} as placeholders. |
directive |
str \| None |
Directive statement at the beginning of the system prompt. |
positive_example_pool |
list[dict] \| None |
Pool of positive examples to sample from at runtime. |
negative_example_pool |
list[dict] \| None |
Pool of negative examples to sample from at runtime. |
k_positive |
int \| None |
Number of positive examples to sample from the pool. Required if positive_example_pool is provided. |
k_negative |
int \| None |
Number of negative examples to sample from the pool. Required if negative_example_pool is provided. |
The following authentication steps may be necessary to access any gated models (even after being granted access by Hugging Face). Uncomment the following if you need to log in to the Hugging Face Hub using your token stored in the .env
file:
# !pip install python-dotenv
# from dotenv import load_dotenv
# import os
# load_dotenv()
# token = os.getenv("HUGGINGFACE_TOKEN")
# from huggingface_hub import login
# login(token=token)
Example: Steering for conciseness¶
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
from aisteer360.algorithms.input_control.few_shot.control import FewShot
from aisteer360.algorithms.core.steering_pipeline import SteeringPipeline
warnings.filterwarnings('ignore', category=UserWarning)
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
/dccstor/principled_ai/users/erikmiehling/AISteer360/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
The following example illustrates how to steer a model's behavior to respond more concisely. We've defined some positive examples (those that represent the desired behavior) and negative examples (those that represent the undesired behavior) below.
# positive examples (concise answers)
positive_examples = [
{"question": "What's the capital of France?", "answer": "Paris"},
{"question": "How many miles is it to the moon?", "answer": "238,855"},
{"question": "What's the boiling point of water?", "answer": "100°C"},
{"question": "How many days in a leap year?", "answer": "366"},
{"question": "What's the speed of light?", "answer": "299,792,458 m/s"},
{"question": "What's 15% of 200?", "answer": "30"},
{"question": "How many continents are there?", "answer": "7"},
{"question": "What's the atomic number of gold?", "answer": "79"}
]
# negative examples (verbose answers)
negative_examples = [
{"question": "What's the capital of France?", "answer": "The capital of France is Paris."},
{"question": "How many miles is it to the moon?", "answer": "The Moon is an average of 238,855 miles (384,400 kilometers) away from Earth."},
{"question": "What's the boiling point of water?", "answer": "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit at sea level."},
{"question": "How many days in a leap year?", "answer": "A leap year contains 366 days, which is one day more than a regular year."},
{"question": "What's the speed of light?", "answer": "The speed of light in vacuum is approximately 299,792,458 meters per second."},
{"question": "What's 15% of 200?", "answer": "Fifteen percent of 200 can be calculated by multiplying 200 by 0.15, which gives 30."},
{"question": "How many continents are there?", "answer": "There are seven continents on Earth: Africa, Antarctica, Asia, Europe, North America, Oceania, and South America."},
{"question": "What's the atomic number of gold?", "answer": "Gold has the atomic number 79 on the periodic table of elements."}
]
To analyze the model's conciseness, we define a prompt (question) that asks a question that can admit a concise answer:
PROMPT = "How many ounces are in a pint?"
Baseline model behavior¶
The baseline model behavior is generated by simply passing the templated and tokenized prompt into the base model's generate.
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
chat = tokenizer.apply_chat_template(
[{"role": "user", "content": PROMPT}],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(chat, return_tensors="pt")
baseline_outputs = model.generate(
**inputs.to(model.device),
do_sample=False,
max_new_tokens=150
)
print("\nResponse (baseline):\n")
print(tokenizer.decode(baseline_outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True))
Loading checkpoint shards: 100%|███████████████████████| 4/4 [00:59<00:00, 14.86s/it] The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details. Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Response (baseline): There are 16 fluid ounces in a pint.
Steering via runtime_kwargs¶
In this mode, examples are passed in at generate()
time. This allows for control over which examples are used for each generation. Note that the instantiation of FewShot
in this mode does not require any arguments.
few_shot_runtime = FewShot()
Given the control, we define the steering pipeline (via SteeringPipeline
) and steer it (performs some lightweight initialization of FewShot
):
few_shot_runtime_pipeline = SteeringPipeline(
model_name_or_path=MODEL_NAME,
controls=[few_shot_runtime],
device_map="auto"
)
few_shot_runtime_pipeline.steer()
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|███████████████████████| 4/4 [00:06<00:00, 1.69s/it]
Inference on the steered model can then be run as usual to generate the steered output. Note the specific examples listed above are passed directly into generate
via the runtime_kwargs
argument.
input_ids = few_shot_runtime_pipeline.tokenizer.encode(PROMPT, return_tensors="pt")
output = few_shot_runtime_pipeline.generate(
input_ids=input_ids,
runtime_kwargs={
"positive_examples": positive_examples,
"negative_examples": negative_examples
},
max_new_tokens=50,
temperature=0.7,
return_full_sequence=False
)
print("\nResponse (FewShot w/ fixed examples):\n")
print(few_shot_runtime_pipeline.tokenizer.decode(output[0], skip_special_tokens=True))
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation. The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Response (FewShot w/ fixed examples): 16
Steering via example pools and a selector¶
In some cases, the requirement to pass in specific examples may not be necessary or even desirable, e.g., if you have a large pool of examples and are not sure which yield the desired behavior. To accommodate this, we allow for the user to specify example pools and a selector for how to sample from the pool.
First, clear the memory from the previous mode:
import gc, torch
gc.collect()
torch.cuda.empty_cache()
positive_example_pool = [
{"question": "What's the capital of France?", "answer": "Paris"},
{"question": "How many miles is it to the moon?", "answer": "238,855"},
{"question": "What's the boiling point of water?", "answer": "100°C"},
{"question": "How many days in a leap year?", "answer": "366"},
{"question": "What's the speed of light?", "answer": "299,792,458 m/s"},
{"question": "What's 15% of 200?", "answer": "30"},
{"question": "How many continents are there?", "answer": "7"},
{"question": "What's the atomic number of gold?", "answer": "79"},
{"question": "What's the capital of Japan?", "answer": "Tokyo"},
{"question": "How many sides does a hexagon have?", "answer": "6"},
{"question": "What's 9 * 7?", "answer": "63"},
{"question": "What's the freezing point of water?", "answer": "0°C"},
{"question": "How many planets are in the Solar System?", "answer": "8"},
{"question": "What's the chemical symbol for sodium?", "answer": "Na"},
{"question": "What's the largest ocean on Earth?", "answer": "Pacific Ocean"},
{"question": "How many degrees are in a right angle?", "answer": "90"},
{"question": "What's the square root of 144?", "answer": "12"},
{"question": "Who's the author of '1984'?", "answer": "George Orwell"},
{"question": "What's the currency of the United Kingdom?", "answer": "Pound sterling"},
{"question": "What gas do plants primarily absorb during photosynthesis?", "answer": "Carbon dioxide"},
{"question": "How many letters are in the English alphabet?", "answer": "26"},
{"question": "What's the largest planet in our solar system?", "answer": "Jupiter"},
{"question": "What's the tallest mountain in the world?", "answer": "Mount Everest"},
{"question": "What's the primary language spoken in Brazil?", "answer": "Portuguese"},
{"question": "What is the Roman numeral for 50?", "answer": "L"},
{"question": "How many hours are in two days?", "answer": "48"},
{"question": "What's 3/4 as a percentage?", "answer": "75%"},
{"question": "What's the chemical formula for table salt?", "answer": "NaCl"},
{"question": "How many bits are in a byte?", "answer": "8"},
{"question": "What's the smallest prime number?", "answer": "2"},
{"question": "What's Pi rounded to two decimal places?", "answer": "3.14"},
{"question": "How many bones are in the adult human body?", "answer": "206"}
]
negative_example_pool = [
{"question": "What's the capital of France?", "answer": "The capital of France is Paris."},
{"question": "How many miles is it to the moon?", "answer": "The Moon is an average of 238,855 miles (384,400 kilometers) away from Earth."},
{"question": "What's the boiling point of water?", "answer": "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit at sea level."},
{"question": "How many days in a leap year?", "answer": "A leap year contains 366 days, which is one day more than a regular year."},
{"question": "What's the speed of light?", "answer": "The speed of light in vacuum is approximately 299,792,458 meters per second."},
{"question": "What's 15% of 200?", "answer": "Fifteen percent of 200 can be calculated by multiplying 200 by 0.15, which gives 30."},
{"question": "How many continents are there?", "answer": "There are seven continents on Earth: Africa, Antarctica, Asia, Europe, North America, Oceania, and South America."},
{"question": "What's the atomic number of gold?", "answer": "Gold has the atomic number 79 on the periodic table of elements."},
{"question": "What's the capital of Japan?", "answer": "The capital of Japan is Tokyo."},
{"question": "How many sides does a hexagon have?", "answer": "A hexagon has six sides."},
{"question": "What's 9 * 7?", "answer": "Nine times seven equals 63."},
{"question": "What's the freezing point of water?", "answer": "Water freezes at 0 degrees Celsius, which is 32 degrees Fahrenheit."},
{"question": "How many planets are in the Solar System?", "answer": "There are eight planets in the Solar System."},
{"question": "What's the chemical symbol for sodium?", "answer": "Sodium is represented by the chemical symbol Na."},
{"question": "What's the largest ocean on Earth?", "answer": "The largest ocean on Earth is the Pacific Ocean."},
{"question": "How many degrees are in a right angle?", "answer": "A right angle measures 90 degrees."},
{"question": "What's the square root of 144?", "answer": "The square root of 144 is 12."},
{"question": "Who's the author of '1984'?", "answer": "The novel '1984' was written by George Orwell."},
{"question": "What's the currency of the United Kingdom?", "answer": "The currency used in the United Kingdom is the pound sterling."},
{"question": "What gas do plants primarily absorb during photosynthesis?", "answer": "Plants primarily absorb carbon dioxide during photosynthesis."},
{"question": "How many letters are in the English alphabet?", "answer": "The English alphabet contains 26 letters."},
{"question": "What's the largest planet in our solar system?", "answer": "The largest planet in our solar system is Jupiter."},
{"question": "What's the tallest mountain in the world?", "answer": "The tallest mountain in the world is Mount Everest."},
{"question": "What's the primary language spoken in Brazil?", "answer": "The primary language spoken in Brazil is Portuguese."},
{"question": "What is the Roman numeral for 50?", "answer": "The Roman numeral for 50 is L."},
{"question": "How many hours are in two days?", "answer": "There are 48 hours in two days."},
{"question": "What's 3/4 as a percentage?", "answer": "Three-quarters expressed as a percentage is 75%."},
{"question": "What's the chemical formula for table salt?", "answer": "The chemical formula for table salt is NaCl."},
{"question": "How many bits are in a byte?", "answer": "There are eight bits in a byte."},
{"question": "What's the smallest prime number?", "answer": "The smallest prime number is 2."},
{"question": "What's Pi rounded to two decimal places?", "answer": "Pi rounded to two decimal places is 3.14."},
{"question": "How many bones are in the adult human body?", "answer": "An adult human has 206 bones."}
]
As before, we define the steering pipeline and steer it, however under this mode example pools, the name of the selector, and the number of positive and negative examples to sample (using the specified selection strategy) are passed in upon initialization of the control.
few_shot_pool = FewShot(
selector_name="random",
positive_example_pool=positive_example_pool,
negative_example_pool=negative_example_pool,
k_positive=12,
k_negative=12
)
few_shot_pool_pipeline = SteeringPipeline(
model_name_or_path=MODEL_NAME,
controls=[few_shot_pool],
device_map="auto"
)
few_shot_pool_pipeline.steer()
Loading checkpoint shards: 100%|███████████████████████| 4/4 [01:07<00:00, 16.78s/it]
Inference on the pipeline proceeds similarly, but now without any runtime_kwargs
(as the specific examples are sampled within the control via the selector).
input_ids = few_shot_pool_pipeline.tokenizer.encode(PROMPT, return_tensors="pt")
output = few_shot_pool_pipeline.generate(
input_ids=input_ids,
runtime_kwargs={},
max_new_tokens=50,
temperature=0.7,
return_full_sequence=False
)
print("\nResponse (FewShot w/ sampled examples):\n")
print(few_shot_pool_pipeline.tokenizer.decode(output[0], skip_special_tokens=True))
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Response (FewShot w/ sampled examples): 16