Skip to content

API reference

aisteer360

AI Steerability 360 toolkit.

The AI Steerability 360 toolkit (AISteer360) enables systematic control over language model behavior through four model control surfaces: input, structural, state, and output. Methods can be composed into composite model operations (via steering pipelines). Benchmarks enable comparison of steering pipelines on common use cases.

algorithms

Contains all steering logic and control implementations across input, structural, state, and output control methods.

core

Core functionality for steering pipelines, steering utilities, and argument parsing.

base_args

Base argument validation for steering method configuration.

T = TypeVar('T', bound='BaseArgs') module-attribute
steering_pipeline

Core steering pipeline for composing and applying multiple LLM control methods.

SteeringPipeline dataclass

Main steering pipeline for applying various control methods to Hugging Face causal language models.

Enables application of structural, state, input, and output controls in a coordinated manner. Controls are applied in a fixed bottom-up order during steering, then used together during generation.

Workflow:

  1. Instantiate with a base model checkpoint and/or control objects
  2. Call steer() once to apply all controls in order (structural → state → input → output)
  3. Use generate() or generate_text() for inference with steering applied

Parameters:

Name Type Description Default
model_name_or_path str or Path

HuggingFace model hub name or local directory. Required when lazy_init=False. Ignored when lazy_init=True and the structural control returns a model.

None
controls Sequence[StructuralControl | StateControl | InputControl | OutputControl]

Controls for the steering pipeline, max one control per category. Omitted categories fall back to no-op controls (see control base classes).

()
tokenizer_name_or_path str

Tokenizer location. Defaults to model_name_or_path.

None
device_map str or dict[str, int]

Device map (passed to transformers.AutoModelForCausalLM.from_pretrained). Defaults to "auto". Cannot be used together with device parameter.

'auto'
device (device, str)

Device (passed to model's .to() method). When specified, device_map must remain at its default value of "auto".

None
hf_model_kwargs dict

Extra keyword arguments passed to transformers.AutoModelForCausalLM.from_pretrained.

dict()
lazy_init bool

If True, defers loading the base model until steer() time. Useful when a StructuralControl will itself load or create the final weights (e.g., MergeKit). When False, the model is loaded during SteeringPipeline construction. Defaults to False.

False

Raises:

Type Description
RuntimeError

If generate() is called before steer()

ValueError

If multiple controls provided for same category or required arguments missing

Note:

  • Maximum one control per category; omitted categories use no-op defaults
  • Controls with a tokenizer attribute will have it auto-injected if not already set
Source code in aisteer360/algorithms/core/steering_pipeline.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
@dataclass(slots=True)
class SteeringPipeline:
    """Main steering pipeline for applying various control methods to Hugging Face causal language models.

    Enables application of structural, state, input, and output controls in a coordinated manner.
    Controls are applied in a fixed bottom-up order during steering, then used together during generation.

    Workflow:

    1. Instantiate with a base model checkpoint and/or control objects
    2. Call `steer()` once to apply all controls in order (structural → state → input → output)
    3. Use `generate()` or `generate_text()` for inference with steering applied

    Args:
        model_name_or_path (str or pathlib.Path, optional): HuggingFace model hub name or local directory.
            Required when `lazy_init=False`. Ignored when `lazy_init=True` and the structural
            control returns a model.
        controls (Sequence[StructuralControl | StateControl | InputControl | OutputControl], optional):
            Controls for the steering pipeline, max one control per category. Omitted categories
            fall back to no-op controls (see control base classes).
        tokenizer_name_or_path (str, optional): Tokenizer location. Defaults to `model_name_or_path`.
        device_map (str or dict[str, int], optional): Device map (passed to
            `transformers.AutoModelForCausalLM.from_pretrained`). Defaults to `"auto"`.
            Cannot be used together with `device` parameter.
        device (torch.device, str, optional): Device (passed to model's `.to()` method).
            When specified, `device_map` must remain at its default value of `"auto"`.
        hf_model_kwargs (dict, optional): Extra keyword arguments passed to
            `transformers.AutoModelForCausalLM.from_pretrained`.
        lazy_init (bool, optional): If `True`, defers loading the base model until `steer()` time.
            Useful when a `StructuralControl` will itself load or create the final weights
            (e.g., MergeKit). When `False`, the model is loaded during `SteeringPipeline`
            construction. Defaults to `False`.

    Raises:
        RuntimeError: If `generate()` is called before `steer()`
        ValueError: If multiple controls provided for same category or required arguments missing

    Note:

    - Maximum one control per category; omitted categories use no-op defaults
    - Controls with a `tokenizer` attribute will have it auto-injected if not already set
    """

    # construction args
    model_name_or_path: str | Path | None = None
    controls: Sequence[StructuralControl | StateControl | InputControl | OutputControl] = ()
    tokenizer_name_or_path: str | None = None
    device_map: str | dict[str, int] | int | torch.device | None = "auto"
    device: torch.device | str | None = None
    hf_model_kwargs: dict = field(default_factory=dict)
    lazy_init: bool = False

    # lazy‑filled fields
    model: PreTrainedModel | None = field(init=False, default=None)
    tokenizer: AutoTokenizer | None = field(init=False, default=None)

    structural_control: StructuralControl = field(init=False)
    state_control: StateControl = field(init=False)
    input_control: InputControl = field(init=False)
    output_control: OutputControl = field(init=False)

    _is_steered: bool = field(default=False, init=False, repr=False)

    def __post_init__(self) -> None:

        # sort/validate the supplied steering methods
        controls_merged = merge_controls(self.controls)
        self.structural_control = controls_merged["structural_control"]
        self.state_control = controls_merged["state_control"]
        self.input_control = controls_merged["input_control"]
        self.output_control = controls_merged["output_control"]

        # load HF artifacts
        if not self.lazy_init:
            if self.model_name_or_path is None:
                raise ValueError("`model_name_or_path` must be provided when lazy_init=False")

            if self.device is not None and self.device_map != "auto":
                raise ValueError("Cannot specify both `device` and `device_map`.")

            if self.device is not None:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name_or_path,
                    **self.hf_model_kwargs,
                )
                self.model = self.model.to(self.device)
                self.device = self.model.device
            else:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name_or_path,
                    device_map=self.device_map,
                    **self.hf_model_kwargs,
                )
                self.device = self.model.device

            self.tokenizer = AutoTokenizer.from_pretrained(
                self.tokenizer_name_or_path or self.model_name_or_path,
                trust_remote_code=True,
            )
            self.tokenizer = ensure_pad_token(self.tokenizer)
        else:
            if isinstance(self.tokenizer_name_or_path, str | Path):
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.tokenizer_name_or_path,
                    trust_remote_code=True
                )
                self.tokenizer = ensure_pad_token(self.tokenizer)

        # late‑inject tokenizer into controls that accept it
        controls_iter = (self.structural_control, self.state_control, self.input_control, self.output_control)
        for control in controls_iter:
            if hasattr(control, "tokenizer") and getattr(control, "tokenizer") is None:
                setattr(control, "tokenizer", self.tokenizer)

    def steer(self, **steer_kwargs) -> None:
        """Apply all steering controls to the model in place.

        Executes each control's steer() method in a fixed bottom-up order: structural -> state -> input -> output.
        This ensures that higher-level controls always see the final configured model from lower levels.

        If any control's steer() method returns a PreTrainedModel instance, it replaces the current model for subsequent
        controls.

        Args:
            **steer_kwargs: Keyword arguments passed to all control steer() methods

        Raises:
            RuntimeError: If called more than once or no model available after steering
        """
        if self._is_steered:
            return

        # steer each control (bottom-up order)
        for control in (self.structural_control, self.state_control, self.input_control, self.output_control):
            steer_fn = getattr(control, "steer", None)
            if callable(steer_fn):
                maybe_new_model = steer_fn(self.model, tokenizer=self.tokenizer, **steer_kwargs)
                if isinstance(maybe_new_model, PreTrainedModel):
                    self.model = maybe_new_model

        # safety checks
        if self.model is None:
            raise RuntimeError(
                "No model is available after steering. Either provide a base model (lazy_init=False) or ensure a "
                "`StructuralControl` returns one."
            )

        if self.tokenizer is None:
            repo = getattr(self.model, "name_or_path", None)
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    repo or Path(getattr(self.structural_control.args, "out_path", "")),
                    trust_remote_code=True,
                )
                self.tokenizer = ensure_pad_token(self.tokenizer)

            except Exception as exception:
                raise RuntimeError("Failed to resolve tokenizer post‑steer.") from exception

        # for control in (self.input_control, self.structural_control, self.state_control, self.output_control):
        #     if hasattr(control, "tokenizer") and getattr(control, "tokenizer") is None:
        #         setattr(control, "tokenizer", self.tokenizer)

        # return steered steerer
        self._is_steered = True

    def generate(
            self,
            input_ids: list[int] | torch.LongTensor,
            attention_mask: torch.Tensor | None = None,
            runtime_kwargs: dict | None = None,
            **gen_kwargs
    ) -> torch.Tensor:
        """Generate text with all steering controls applied.

        Applies controls in sequence during generation:

        1. Input control adapts the prompt
        2. State control registers hooks for state control (e.g., activation steering)
        3. Output control handles the actual generation

        Args:
            input_ids: Token IDs as list or tensor (shape: [seq_len] or [batch, seq_len])
            attention_mask: Optional attention mask matching input_ids shape
            runtime_kwargs: Per-generation parameters for controls (e.g., {"substrings": [...]})
            **gen_kwargs: Generation parameters passed to `model.generate()`

        Returns:
            Generated token IDs (shape: [batch, generated_len])

        Raises:
            RuntimeError: If steer() has not yet been called
        """
        if not self._is_steered:
            raise RuntimeError("Must call `.steer()` before `.generate()`.")

        runtime_kwargs = runtime_kwargs or {}

        return_full_sequence = bool(gen_kwargs.pop("return_full_sequence", False))

        # input control
        steered_input_ids = self.input_control.get_prompt_adapter()(input_ids, runtime_kwargs)
        if isinstance(steered_input_ids, list):
            steered_input_ids = torch.tensor(steered_input_ids, dtype=torch.long)
        if steered_input_ids.ndim == 1:
            steered_input_ids = steered_input_ids.unsqueeze(0)
        steered_input_ids = steered_input_ids.to(self.model.device)

        # attention_mask (reshape and move to device)
        if attention_mask is not None:
            if isinstance(attention_mask, list):
                attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
            if attention_mask.ndim == 1:
                attention_mask = attention_mask.unsqueeze(0)
            attention_mask = attention_mask.to(
                dtype=steered_input_ids.dtype, device=steered_input_ids.device
            )

        # state control
        hooks = self.state_control.get_hooks(steered_input_ids, runtime_kwargs, **gen_kwargs)
        self.state_control.set_hooks(hooks)
        self.state_control._model_ref = self.model

        # output control
        self.state_control.reset()
        with self.state_control:  # hooks live only for duration of decoding
            output_ids = self.output_control.generate(
                input_ids=steered_input_ids,
                attention_mask=attention_mask,
                runtime_kwargs=runtime_kwargs,
                model=self.model,
                **gen_kwargs
            )

        if not return_full_sequence:
            output_ids = output_ids[:, steered_input_ids.size(1):]

        return output_ids

    def generate_text(self, *args, **kwargs) -> str | list[str]:
        """Generate text and decode to string(s).

        Convenience wrapper that calls generate() and decodes the output tokens.

        Args:
            *args: Arguments passed to generate()
            **kwargs: Keyword arguments passed to generate()

        Returns:
            Decoded text string (single prompt) or list of strings (batch)
        """
        ids = self.generate(*args, **kwargs)
        if ids.ndim == 1:
            return self.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
controls = () class-attribute instance-attribute
device = None class-attribute instance-attribute
device_map = 'auto' class-attribute instance-attribute
hf_model_kwargs = field(default_factory=dict) class-attribute instance-attribute
input_control = field(init=False) class-attribute instance-attribute
lazy_init = False class-attribute instance-attribute
model = field(init=False, default=None) class-attribute instance-attribute
model_name_or_path = None class-attribute instance-attribute
output_control = field(init=False) class-attribute instance-attribute
state_control = field(init=False) class-attribute instance-attribute
structural_control = field(init=False) class-attribute instance-attribute
tokenizer = field(init=False, default=None) class-attribute instance-attribute
tokenizer_name_or_path = None class-attribute instance-attribute
generate(input_ids, attention_mask=None, runtime_kwargs=None, **gen_kwargs)

Generate text with all steering controls applied.

Applies controls in sequence during generation:

  1. Input control adapts the prompt
  2. State control registers hooks for state control (e.g., activation steering)
  3. Output control handles the actual generation

Parameters:

Name Type Description Default
input_ids list[int] | LongTensor

Token IDs as list or tensor (shape: [seq_len] or [batch, seq_len])

required
attention_mask Tensor | None

Optional attention mask matching input_ids shape

None
runtime_kwargs dict | None

Per-generation parameters for controls (e.g., {"substrings": [...]})

None
**gen_kwargs

Generation parameters passed to model.generate()

{}

Returns:

Type Description
Tensor

Generated token IDs (shape: [batch, generated_len])

Raises:

Type Description
RuntimeError

If steer() has not yet been called

Source code in aisteer360/algorithms/core/steering_pipeline.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def generate(
        self,
        input_ids: list[int] | torch.LongTensor,
        attention_mask: torch.Tensor | None = None,
        runtime_kwargs: dict | None = None,
        **gen_kwargs
) -> torch.Tensor:
    """Generate text with all steering controls applied.

    Applies controls in sequence during generation:

    1. Input control adapts the prompt
    2. State control registers hooks for state control (e.g., activation steering)
    3. Output control handles the actual generation

    Args:
        input_ids: Token IDs as list or tensor (shape: [seq_len] or [batch, seq_len])
        attention_mask: Optional attention mask matching input_ids shape
        runtime_kwargs: Per-generation parameters for controls (e.g., {"substrings": [...]})
        **gen_kwargs: Generation parameters passed to `model.generate()`

    Returns:
        Generated token IDs (shape: [batch, generated_len])

    Raises:
        RuntimeError: If steer() has not yet been called
    """
    if not self._is_steered:
        raise RuntimeError("Must call `.steer()` before `.generate()`.")

    runtime_kwargs = runtime_kwargs or {}

    return_full_sequence = bool(gen_kwargs.pop("return_full_sequence", False))

    # input control
    steered_input_ids = self.input_control.get_prompt_adapter()(input_ids, runtime_kwargs)
    if isinstance(steered_input_ids, list):
        steered_input_ids = torch.tensor(steered_input_ids, dtype=torch.long)
    if steered_input_ids.ndim == 1:
        steered_input_ids = steered_input_ids.unsqueeze(0)
    steered_input_ids = steered_input_ids.to(self.model.device)

    # attention_mask (reshape and move to device)
    if attention_mask is not None:
        if isinstance(attention_mask, list):
            attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
        if attention_mask.ndim == 1:
            attention_mask = attention_mask.unsqueeze(0)
        attention_mask = attention_mask.to(
            dtype=steered_input_ids.dtype, device=steered_input_ids.device
        )

    # state control
    hooks = self.state_control.get_hooks(steered_input_ids, runtime_kwargs, **gen_kwargs)
    self.state_control.set_hooks(hooks)
    self.state_control._model_ref = self.model

    # output control
    self.state_control.reset()
    with self.state_control:  # hooks live only for duration of decoding
        output_ids = self.output_control.generate(
            input_ids=steered_input_ids,
            attention_mask=attention_mask,
            runtime_kwargs=runtime_kwargs,
            model=self.model,
            **gen_kwargs
        )

    if not return_full_sequence:
        output_ids = output_ids[:, steered_input_ids.size(1):]

    return output_ids
generate_text(*args, **kwargs)

Generate text and decode to string(s).

Convenience wrapper that calls generate() and decodes the output tokens.

Parameters:

Name Type Description Default
*args

Arguments passed to generate()

()
**kwargs

Keyword arguments passed to generate()

{}

Returns:

Type Description
str | list[str]

Decoded text string (single prompt) or list of strings (batch)

Source code in aisteer360/algorithms/core/steering_pipeline.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def generate_text(self, *args, **kwargs) -> str | list[str]:
    """Generate text and decode to string(s).

    Convenience wrapper that calls generate() and decodes the output tokens.

    Args:
        *args: Arguments passed to generate()
        **kwargs: Keyword arguments passed to generate()

    Returns:
        Decoded text string (single prompt) or list of strings (batch)
    """
    ids = self.generate(*args, **kwargs)
    if ids.ndim == 1:
        return self.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
steer(**steer_kwargs)

Apply all steering controls to the model in place.

Executes each control's steer() method in a fixed bottom-up order: structural -> state -> input -> output. This ensures that higher-level controls always see the final configured model from lower levels.

If any control's steer() method returns a PreTrainedModel instance, it replaces the current model for subsequent controls.

Parameters:

Name Type Description Default
**steer_kwargs

Keyword arguments passed to all control steer() methods

{}

Raises:

Type Description
RuntimeError

If called more than once or no model available after steering

Source code in aisteer360/algorithms/core/steering_pipeline.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def steer(self, **steer_kwargs) -> None:
    """Apply all steering controls to the model in place.

    Executes each control's steer() method in a fixed bottom-up order: structural -> state -> input -> output.
    This ensures that higher-level controls always see the final configured model from lower levels.

    If any control's steer() method returns a PreTrainedModel instance, it replaces the current model for subsequent
    controls.

    Args:
        **steer_kwargs: Keyword arguments passed to all control steer() methods

    Raises:
        RuntimeError: If called more than once or no model available after steering
    """
    if self._is_steered:
        return

    # steer each control (bottom-up order)
    for control in (self.structural_control, self.state_control, self.input_control, self.output_control):
        steer_fn = getattr(control, "steer", None)
        if callable(steer_fn):
            maybe_new_model = steer_fn(self.model, tokenizer=self.tokenizer, **steer_kwargs)
            if isinstance(maybe_new_model, PreTrainedModel):
                self.model = maybe_new_model

    # safety checks
    if self.model is None:
        raise RuntimeError(
            "No model is available after steering. Either provide a base model (lazy_init=False) or ensure a "
            "`StructuralControl` returns one."
        )

    if self.tokenizer is None:
        repo = getattr(self.model, "name_or_path", None)
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                repo or Path(getattr(self.structural_control.args, "out_path", "")),
                trust_remote_code=True,
            )
            self.tokenizer = ensure_pad_token(self.tokenizer)

        except Exception as exception:
            raise RuntimeError("Failed to resolve tokenizer post‑steer.") from exception

    # for control in (self.input_control, self.structural_control, self.state_control, self.output_control):
    #     if hasattr(control, "tokenizer") and getattr(control, "tokenizer") is None:
    #         setattr(control, "tokenizer", self.tokenizer)

    # return steered steerer
    self._is_steered = True
steering_utils

Helper functions for steering.

ensure_pad_token(tokenizer)

Set pad token to eos token if not already defined.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizerBase

HuggingFace tokenizer instance

required

Returns:

Type Description
PreTrainedTokenizerBase

The same tokenizer with pad_token configured

Source code in aisteer360/algorithms/core/steering_utils.py
72
73
74
75
76
77
78
79
80
81
82
83
84
def ensure_pad_token(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
    """Set pad token to eos token if not already defined.

    Args:
       tokenizer: HuggingFace tokenizer instance

    Returns:
       The same tokenizer with pad_token configured
    """
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer
merge_controls(supplied)

Sort supplied controls by category and ensure at most one per category.

Parameters:

Name Type Description Default
supplied Iterable[StructuralControl | StateControl | InputControl | OutputControl]

List of control instances to organize

required

Returns:

Type Description
dict[str, object]

Dict mapping field names to control instances (with default no-ops for unspecified categories)

Raises:

Type Description
ValueError

If multiple controls of the same category are supplied

TypeError

If an unrecognized control type is supplied

Source code in aisteer360/algorithms/core/steering_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def merge_controls(
        supplied: Iterable[StructuralControl | StateControl | InputControl | OutputControl]
) -> dict[str, object]:
    """Sort supplied controls by category and ensure at most one per category.

    Args:
       supplied: List of control instances to organize

    Returns:
       Dict mapping field names to control instances (with default no-ops for unspecified categories)

    Raises:
       ValueError: If multiple controls of the same category are supplied
       TypeError: If an unrecognized control type is supplied
    """
    bucket: dict[type, list] = defaultdict(list)
    for control in supplied:
        for category in _CATEGORY_TO_DEFAULT:
            if isinstance(control, category):
                bucket[category].append(control)
                break
        else:
            raise TypeError(f"Unknown control type: {type(control)}")

    # todo (future): allow for user to compose multiple methods of the same category in a specified order;
    #  will require necessary validation logic to ensure no conflicts.
    for category, controls in bucket.items():
        if len(controls) > 1:
            names = [type(control).__name__ for control in controls]
            raise ValueError(f"Multiple {category.__name__}s supplied: {names}")

    out: dict[str, object] = {}
    for category, default_instance in _CATEGORY_TO_DEFAULT.items():
        instance = bucket.get(category, [default_instance])[0]
        out_key = (
            "input_control"
            if category is InputControl
            else "structural_control"
            if category is StructuralControl
            else "state_control"
            if category is StateControl
            else "output_control"
        )
        out[out_key] = instance
    return out

input_control

base

Input control base classes.

This module provides the abstract base class for methods that modify prompts before they reach the model.

Two base classes are provided:

  • InputControl: Base class for all input control methods.
  • NoInputControl: Identity (null) control; used when no input control is defined in steering pipeline.

Input controls implement steering through prompt transformation σ(x), enabling behavior modification without altering model parameters or architecture. These methods transform inputs before they reach the model, resulting in generations following y ~ p_θ(σ(x)).

Examples of input controls:

  • Few-shot learning (prepending examples)
  • Prompt templates and formatting
  • Soft prompts and prompt tuning
  • Chain-of-thought prompting
  • Iterative prompt refinement

See Also:

  • aisteer360.algorithms.input_control: Implementations of input control methods
  • aisteer360.core.steering_pipeline: Integration with steering pipeline
InputControl

Bases: ABC

Abstract base class for input control steering methods.

Transforms prompts before model processing through a prompt adapter function that modifies input token sequences.

Methods:

Name Description
get_prompt_adapter

Return transformation function (required)

steer

One-time preparation (optional)

Source code in aisteer360/algorithms/input_control/base.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class InputControl(ABC):
    """Abstract base class for input control steering methods.

    Transforms prompts before model processing through a prompt adapter function that modifies input token sequences.

    Methods:
        get_prompt_adapter(runtime_kwargs) -> Callable: Return transformation function (required)
        steer(model, tokenizer, **kwargs) -> None: One-time preparation (optional)
    """

    Args: Type[BaseArgs] | None = None

    enabled: bool = True

    def __init__(self, *args, **kwargs) -> None:
        if self.Args is None:  # null control
            if args or kwargs:
                raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
            return

        self.args: BaseArgs = self.Args.validate(*args, **kwargs)

        # move fields to attributes
        for field in fields(self.args):
            setattr(self, field.name, getattr(self.args, field.name))

    @abstractmethod
    def get_prompt_adapter(
        self,
        runtime_kwargs: dict | None = None
    ) -> Callable[[list[int] | torch.Tensor, dict[str, Any]], list[int] | torch.Tensor]:
        """Receives (input_ids, runtime_kwargs) and returns modified input_ids.."""
        pass

    def steer(self,
              model=None,
              tokenizer=None,
              **kwargs) -> None:
        """Optional steering/preparation."""
        pass
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
get_prompt_adapter(runtime_kwargs=None) abstractmethod

Receives (input_ids, runtime_kwargs) and returns modified input_ids..

Source code in aisteer360/algorithms/input_control/base.py
63
64
65
66
67
68
69
@abstractmethod
def get_prompt_adapter(
    self,
    runtime_kwargs: dict | None = None
) -> Callable[[list[int] | torch.Tensor, dict[str, Any]], list[int] | torch.Tensor]:
    """Receives (input_ids, runtime_kwargs) and returns modified input_ids.."""
    pass
steer(model=None, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/input_control/base.py
71
72
73
74
75
76
def steer(self,
          model=None,
          tokenizer=None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass
NoInputControl

Bases: InputControl

Identity input control.

Used as the default when no input control is needed. Returns input_ids.

Source code in aisteer360/algorithms/input_control/base.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class NoInputControl(InputControl):
    """Identity input control.

    Used as the default when no input control is needed. Returns input_ids.
    """
    enabled: bool = False
    tokenizer: PreTrainedTokenizerBase | None = None

    def get_prompt_adapter(
            self,
            runtime_kwargs: dict | None = None
    ):
        """Null adapter operation; returns identity map."""
        if self.tokenizer is None:
            return lambda ids, _: ids

        def adapter(input_ids: list[int] | torch.Tensor, runtime_kwargs) -> list[int] | torch.Tensor:
            return input_ids

        return adapter

    def steer(
            self,
            model=None,
            tokenizer: PreTrainedTokenizerBase | None = None,
            **kwargs
    ) -> None:
        """Null steer operation; attaches tokenizer."""
        self.tokenizer = tokenizer
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = False class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
get_prompt_adapter(runtime_kwargs=None)

Null adapter operation; returns identity map.

Source code in aisteer360/algorithms/input_control/base.py
87
88
89
90
91
92
93
94
95
96
97
98
def get_prompt_adapter(
        self,
        runtime_kwargs: dict | None = None
):
    """Null adapter operation; returns identity map."""
    if self.tokenizer is None:
        return lambda ids, _: ids

    def adapter(input_ids: list[int] | torch.Tensor, runtime_kwargs) -> list[int] | torch.Tensor:
        return input_ids

    return adapter
steer(model=None, tokenizer=None, **kwargs)

Null steer operation; attaches tokenizer.

Source code in aisteer360/algorithms/input_control/base.py
100
101
102
103
104
105
106
107
def steer(
        self,
        model=None,
        tokenizer: PreTrainedTokenizerBase | None = None,
        **kwargs
) -> None:
    """Null steer operation; attaches tokenizer."""
    self.tokenizer = tokenizer
few_shot
args
control

Few-shot learning control for prompt adaptation.

FewShot

Bases: InputControl

Implementation of few-shot learning control for prompt adaptation.

FewShot enables selective behavioral steering by prepending specific examples to user prompts, guiding model responses through demonstration.

The method operates in two modes:

  1. Pool-based sampling: Maintains pools of positive and negative examples from which k examples are dynamically selected using configurable sampling strategies (random, semantic similarity, etc.).

  2. Runtime injection: Accepts examples directly at inference time through runtime_kwargs, enabling context-specific demonstrations without predefined pools. Useful for dynamic or user-provided examples.

The selected examples are formatted into a system prompt with clear positive/negative labels and prepended to the user query using the model's chat template, allowing the model to learn the desired behavior pattern from the demonstrations.

Parameters:

Name Type Description Default
directive str

Instruction text that precedes the examples, explaining the task or desired behavior. Defaults to None.

required
positive_example_pool Sequence[dict]

Pool of positive examples demonstrating desired behavior. Each dict can contain multiple key-value pairs. Defaults to None.

required
negative_example_pool Sequence[dict]

Pool of negative examples showing undesired behavior to avoid. Each dict can contain multiple key-value pairs. Defaults to None.

required
k_positive int

Number of positive examples to sample from the pool per query. Defaults to None.

required
k_negative int

Number of negative examples to sample from the pool per query. Defaults to None.

required
selector_name str

Name of the selection strategy ('random', 'semantic', etc.). Determines how examples are chosen from pools. Defaults to 'random'.

required
template str

Custom template for formatting the system prompt. Should contain {directive} and {example_blocks} placeholders. Defaults to built-in template.

required

Runtime keyword arguments:

  • positive_examples (list[dict], optional): Positive examples to use for this specific query (overrides pool-based selection).
  • negative_examples (list[dict], optional): Negative examples to use for this specific query (overrides pool-based selection).

Notes:

  • Requires a tokenizer with chat_template support for optimal formatting
  • Examples are automatically labeled as "### Positive example" or "### Negative example"
  • When both pools and runtime examples are available, runtime examples take precedence
  • If no examples are provided, the original input is returned unchanged
Source code in aisteer360/algorithms/input_control/few_shot/control.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class FewShot(InputControl):
    """
    Implementation of few-shot learning control for prompt adaptation.

    FewShot enables selective behavioral steering by prepending specific examples to user prompts, guiding model
    responses through demonstration.

    The method operates in two modes:

    1. **Pool-based sampling**: Maintains pools of positive and negative examples from which k examples are dynamically
        selected using configurable sampling strategies (random, semantic similarity, etc.).

    2. **Runtime injection**: Accepts examples directly at inference time through runtime_kwargs, enabling
        context-specific demonstrations without predefined pools. Useful for dynamic or user-provided examples.

    The selected examples are formatted into a system prompt with clear positive/negative labels and prepended to the
    user query using the model's chat template, allowing the model to learn the desired behavior pattern from the
    demonstrations.

    Args:
        directive (str, optional): Instruction text that precedes the examples, explaining the task or desired behavior.
            Defaults to None.
        positive_example_pool (Sequence[dict], optional): Pool of positive examples demonstrating desired behavior.
            Each dict can contain multiple key-value pairs. Defaults to None.
        negative_example_pool (Sequence[dict], optional): Pool of negative examples showing undesired behavior to avoid.
            Each dict can contain multiple key-value pairs. Defaults to None.
        k_positive (int, optional): Number of positive examples to sample from the pool per query.
            Defaults to None.
        k_negative (int, optional): Number of negative examples to sample from the pool per query.
            Defaults to None.
        selector_name (str, optional): Name of the selection strategy ('random', 'semantic', etc.).
            Determines how examples are chosen from pools. Defaults to 'random'.
        template (str, optional): Custom template for formatting the system prompt. Should contain
            {directive} and {example_blocks} placeholders. Defaults to built-in template.

    Runtime keyword arguments:

    - `positive_examples` (`list[dict]`, `optional`): Positive examples to use for this specific query (overrides pool-based
    selection).
    - `negative_examples` (`list[dict]`, `optional`): Negative examples to use for this specific query (overrides pool-based
    selection).

    Notes:

    - Requires a tokenizer with chat_template support for optimal formatting
    - Examples are automatically labeled as "### Positive example" or "### Negative example"
    - When both pools and runtime examples are available, runtime examples take precedence
    - If no examples are provided, the original input is returned unchanged
    """

    Args = FewShotArgs

    # default templates
    _SYSTEM_PROMPT_TEMPLATE = "{directive}: \n{example_blocks}\n\n"
    _POSITIVE_EXAMPLE_TEMPLATE = "### Positive example (behavior to follow)\n{content}\n"
    _NEGATIVE_EXAMPLE_TEMPLATE = "### Negative example (behavior to avoid)\n{content}\n"

    # placeholders
    tokenizer: PreTrainedTokenizer | None = None
    selector_name: str | None = None
    directive: str | None = None
    positive_example_pool: Sequence[dict] | None = None
    negative_example_pool: Sequence[dict] | None = None
    k_positive: int | None = None
    k_negative: int | None = None
    selector: Selector | None = None

    def steer(
            self,
            model=None,
            tokenizer: PreTrainedTokenizer | None = None,
            **kwargs
    ) -> None:
        self.tokenizer = tokenizer

        # initialize selector if using pool mode
        if self.positive_example_pool is not None or self.negative_example_pool is not None:
            if self.selector_name:
                selector_cls = SELECTOR_REGISTRY.get(self.selector_name, RandomSelector)
                self.selector = selector_cls()
            else:
                self.selector = RandomSelector()

    def get_prompt_adapter(self) -> Callable[[list[int] | torch.Tensor, dict[str, Any]], list[int] | torch.Tensor]:
        """Return a prompt adapter function that adds few-shot examples to the model's system prompt. Creates and
        returns a closure that modifies input token sequences by prepending few-shot examples.

        The returned adapter function performs the following steps:

        1. Determines operational mode (runtime examples take precedence over pools)
        2. Decodes input tokens to retrieve the original user message
        3. Selects or retrieves appropriate examples based on mode
        4. Formats examples with positive/negative labels
        5. Constructs a system prompt containing the examples
        6. Applies the model's chat template (if available) to combine system prompt and user message
        7. Re-encodes the adapted text to tokens

        Returns:
            A prompt adapter function.

        Raises:
            RuntimeError: If tokenizer is not set (requires calling `steer()` first)

        Warnings:
            UserWarning: Issued when:

                - No examples available from either pools or runtime_kwargs
                - No examples remain after selection/sampling
                - Tokenizer lacks chat_template support (falls back to direct prepending)
        """

        if self.tokenizer is None:
            raise RuntimeError("FewShot needs a tokenizer; call .steer() first.")

        def adapter(input_ids: list[int] | torch.Tensor, runtime_kwargs: dict[str, Any]) -> list[int] | torch.Tensor:

            # infer mode from arguments
            using_runtime_examples = (runtime_kwargs and ("positive_examples" in runtime_kwargs or
                                                          "negative_examples" in runtime_kwargs))
            using_pool_mode = self.positive_example_pool is not None or self.negative_example_pool is not None

            if not using_runtime_examples and not using_pool_mode:
                warnings.warn(
                    "FewShot: No examples provided via runtime_kwargs or example pools. "
                    "Returning original input unchanged.",
                    UserWarning
                )
                return input_ids

            # decode to retrieve user message
            if isinstance(input_ids, torch.Tensor):
                input_ids_list = input_ids.tolist()[0]
            else:
                input_ids_list = input_ids

            original_text = self.tokenizer.decode(input_ids_list, skip_special_tokens=True)

            # get examples based on mode
            if using_runtime_examples:
                examples = self._gather_runtime_examples(runtime_kwargs)
            else:
                examples = self._sample_from_pools()

            if not examples:
                warnings.warn(
                    "FewShot: No examples available after selection. Returning original input unchanged.",
                    UserWarning
                )
                return input_ids

            examples_text = self._format_examples(examples)

            # apply chat template
            if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
                messages = [
                    {"role": "system", "content": examples_text},
                    {"role": "user", "content": original_text}
                ]
                adapted_text = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
            else:
                warnings.warn(
                    "No chat template found for tokenizer. Prepending few-shot examples directly to user query.",
                    UserWarning
                )
                adapted_text = examples_text + original_text

            # encode the adapted text
            adapted_tokens = self.tokenizer.encode(
                adapted_text,
                add_special_tokens=False,
                return_tensors="pt" if isinstance(input_ids, torch.Tensor) else None
            )

            if isinstance(input_ids, torch.Tensor):
                return adapted_tokens.squeeze(0) if adapted_tokens.dim() > 1 else adapted_tokens
            else:
                return adapted_tokens

        return adapter

    def _sample_from_pools(self) -> list[dict[str, Any]]:
        """Sample examples from the pools."""
        all_examples = []

        if self.positive_example_pool and self.k_positive and self.k_positive > 0:
            positive_samples = self.selector.sample(
                self.positive_example_pool,
                self.k_positive
            )
            for example in positive_samples:
                all_examples.append({**example, "_label": "positive"})

        if self.negative_example_pool and self.k_negative and self.k_negative > 0:
            negative_samples = self.selector.sample(
                self.negative_example_pool,
                self.k_negative
            )
            for example in negative_samples:
                all_examples.append({**example, "_label": "negative"})

        return all_examples

    def _format_examples(self, examples: list[dict[str, Any]]) -> str:
        """Format examples for system prompt."""
        if not examples:
            return ""

        example_blocks = []
        for example in examples:
            is_positive = example.get("_label", "positive") == "positive"
            content = self._format_example_content(example)

            if is_positive:
                example_blocks.append(self._POSITIVE_EXAMPLE_TEMPLATE.format(content=content))
            else:
                example_blocks.append(self._NEGATIVE_EXAMPLE_TEMPLATE.format(content=content))

        template = getattr(self, 'template', None) or self._SYSTEM_PROMPT_TEMPLATE
        formatted_blocks = "\n".join(example_blocks)

        return template.format(directive=self.directive or "", example_blocks=formatted_blocks)

    @staticmethod
    def _gather_runtime_examples(runtime_kwargs: dict[str, Any]) -> list[dict[str, Any]]:
        """Gather examples from runtime_kwargs."""
        examples = []
        if "positive_examples" in runtime_kwargs:
            for example in runtime_kwargs["positive_examples"]:
                examples.append({**example, "_label": "positive"})
        if "negative_examples" in runtime_kwargs:
            for example in runtime_kwargs["negative_examples"]:
                examples.append({**example, "_label": "negative"})
        return examples

    @staticmethod
    def _format_example_content(example: dict[str, Any]) -> str:
        segments = []
        for key, value in example.items():
            if key == "_label":
                continue
            formatted_key = key.replace("_", " ").title()
            segments.append(f"{formatted_key}: {value}")

        return "\n".join(segments)
args = self.Args.validate(*args, **kwargs) instance-attribute
directive = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
k_negative = None class-attribute instance-attribute
k_positive = None class-attribute instance-attribute
negative_example_pool = None class-attribute instance-attribute
positive_example_pool = None class-attribute instance-attribute
selector = None class-attribute instance-attribute
selector_name = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
get_prompt_adapter()

Return a prompt adapter function that adds few-shot examples to the model's system prompt. Creates and returns a closure that modifies input token sequences by prepending few-shot examples.

The returned adapter function performs the following steps:

  1. Determines operational mode (runtime examples take precedence over pools)
  2. Decodes input tokens to retrieve the original user message
  3. Selects or retrieves appropriate examples based on mode
  4. Formats examples with positive/negative labels
  5. Constructs a system prompt containing the examples
  6. Applies the model's chat template (if available) to combine system prompt and user message
  7. Re-encodes the adapted text to tokens

Returns:

Type Description
Callable[[list[int] | Tensor, dict[str, Any]], list[int] | Tensor]

A prompt adapter function.

Raises:

Type Description
RuntimeError

If tokenizer is not set (requires calling steer() first)

Warns:

Type Description
UserWarning

Issued when:

  • No examples available from either pools or runtime_kwargs
  • No examples remain after selection/sampling
  • Tokenizer lacks chat_template support (falls back to direct prepending)
Source code in aisteer360/algorithms/input_control/few_shot/control.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def get_prompt_adapter(self) -> Callable[[list[int] | torch.Tensor, dict[str, Any]], list[int] | torch.Tensor]:
    """Return a prompt adapter function that adds few-shot examples to the model's system prompt. Creates and
    returns a closure that modifies input token sequences by prepending few-shot examples.

    The returned adapter function performs the following steps:

    1. Determines operational mode (runtime examples take precedence over pools)
    2. Decodes input tokens to retrieve the original user message
    3. Selects or retrieves appropriate examples based on mode
    4. Formats examples with positive/negative labels
    5. Constructs a system prompt containing the examples
    6. Applies the model's chat template (if available) to combine system prompt and user message
    7. Re-encodes the adapted text to tokens

    Returns:
        A prompt adapter function.

    Raises:
        RuntimeError: If tokenizer is not set (requires calling `steer()` first)

    Warnings:
        UserWarning: Issued when:

            - No examples available from either pools or runtime_kwargs
            - No examples remain after selection/sampling
            - Tokenizer lacks chat_template support (falls back to direct prepending)
    """

    if self.tokenizer is None:
        raise RuntimeError("FewShot needs a tokenizer; call .steer() first.")

    def adapter(input_ids: list[int] | torch.Tensor, runtime_kwargs: dict[str, Any]) -> list[int] | torch.Tensor:

        # infer mode from arguments
        using_runtime_examples = (runtime_kwargs and ("positive_examples" in runtime_kwargs or
                                                      "negative_examples" in runtime_kwargs))
        using_pool_mode = self.positive_example_pool is not None or self.negative_example_pool is not None

        if not using_runtime_examples and not using_pool_mode:
            warnings.warn(
                "FewShot: No examples provided via runtime_kwargs or example pools. "
                "Returning original input unchanged.",
                UserWarning
            )
            return input_ids

        # decode to retrieve user message
        if isinstance(input_ids, torch.Tensor):
            input_ids_list = input_ids.tolist()[0]
        else:
            input_ids_list = input_ids

        original_text = self.tokenizer.decode(input_ids_list, skip_special_tokens=True)

        # get examples based on mode
        if using_runtime_examples:
            examples = self._gather_runtime_examples(runtime_kwargs)
        else:
            examples = self._sample_from_pools()

        if not examples:
            warnings.warn(
                "FewShot: No examples available after selection. Returning original input unchanged.",
                UserWarning
            )
            return input_ids

        examples_text = self._format_examples(examples)

        # apply chat template
        if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
            messages = [
                {"role": "system", "content": examples_text},
                {"role": "user", "content": original_text}
            ]
            adapted_text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        else:
            warnings.warn(
                "No chat template found for tokenizer. Prepending few-shot examples directly to user query.",
                UserWarning
            )
            adapted_text = examples_text + original_text

        # encode the adapted text
        adapted_tokens = self.tokenizer.encode(
            adapted_text,
            add_special_tokens=False,
            return_tensors="pt" if isinstance(input_ids, torch.Tensor) else None
        )

        if isinstance(input_ids, torch.Tensor):
            return adapted_tokens.squeeze(0) if adapted_tokens.dim() > 1 else adapted_tokens
        else:
            return adapted_tokens

    return adapter
steer(model=None, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/input_control/few_shot/control.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def steer(
        self,
        model=None,
        tokenizer: PreTrainedTokenizer | None = None,
        **kwargs
) -> None:
    self.tokenizer = tokenizer

    # initialize selector if using pool mode
    if self.positive_example_pool is not None or self.negative_example_pool is not None:
        if self.selector_name:
            selector_cls = SELECTOR_REGISTRY.get(self.selector_name, RandomSelector)
            self.selector = selector_cls()
        else:
            self.selector = RandomSelector()
selectors

Example selectors for few-shot learning prompt adaptation.

This module provides different strategies for selecting examples from pools during few-shot prompting. Selectors determine which examples are passed as demonstrations to the model.

Available selectors:

  • RandomSelector: Randomly samples examples from the pool
SELECTOR_REGISTRY = {'random': RandomSelector} module-attribute
base

Base interface for few-shot example selection strategies.

Selector

Bases: ABC

Base class for example selector.

Source code in aisteer360/algorithms/input_control/few_shot/selectors/base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Selector(ABC):
    """
    Base class for example selector.
    """

    @abstractmethod
    def sample(
        self,
        pool: Sequence[dict],
        k: int,
        **kwargs: Any
    ) -> list[dict]:
        """Return k items chosen from pool."""
        raise NotImplementedError
sample(pool, k, **kwargs) abstractmethod

Return k items chosen from pool.

Source code in aisteer360/algorithms/input_control/few_shot/selectors/base.py
13
14
15
16
17
18
19
20
21
@abstractmethod
def sample(
    self,
    pool: Sequence[dict],
    k: int,
    **kwargs: Any
) -> list[dict]:
    """Return k items chosen from pool."""
    raise NotImplementedError
random_selector
RandomSelector

Bases: Selector

Selects examples uniformly at random from a pool for few-shot prompting.

Source code in aisteer360/algorithms/input_control/few_shot/selectors/random_selector.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class RandomSelector(Selector):
    """Selects examples uniformly at random from a pool for few-shot prompting."""

    def sample(self, pool: Sequence[dict], k: int, **_) -> list[dict]:
        """Select k examples uniformly at random from the pool.

        Args:
            pool: Available examples to select from
            k: Number of examples to select
            **_: Ignored (for compatibility with other selectors)

        Returns:
            List of randomly selected examples (up to min(k, len(pool)))
        """
        return random.sample(pool, min(k, len(pool)))
sample(pool, k, **_)

Select k examples uniformly at random from the pool.

Parameters:

Name Type Description Default
pool Sequence[dict]

Available examples to select from

required
k int

Number of examples to select

required
**_

Ignored (for compatibility with other selectors)

{}

Returns:

Type Description
list[dict]

List of randomly selected examples (up to min(k, len(pool)))

Source code in aisteer360/algorithms/input_control/few_shot/selectors/random_selector.py
10
11
12
13
14
15
16
17
18
19
20
21
def sample(self, pool: Sequence[dict], k: int, **_) -> list[dict]:
    """Select k examples uniformly at random from the pool.

    Args:
        pool: Available examples to select from
        k: Number of examples to select
        **_: Ignored (for compatibility with other selectors)

    Returns:
        List of randomly selected examples (up to min(k, len(pool)))
    """
    return random.sample(pool, min(k, len(pool)))

output_control

base

Output control base classes.

This module provides the abstract base classes for methods that intervene during text generation (e.g., via modifying logits, constraining the output space, or implementing alternative decoding strategies).

Two base classes are provided:

  • OutputControl: Base class for all output control methods.
  • NoOutputControl: Identity (null) control; used when no output control is defined in steering pipeline.

Output controls implement steering through decoding algorithms and constraints, modifying the sampling process to produce generations y ~ᵈ p_θ(x), where ~ᵈ indicates the modified generation process.

Examples of output controls:

  • Constrained beam search
  • Reward-augmented decoding
  • Grammar-constrained generation
  • Token filtering and masking
  • Classifier-guided generation
  • Best-of-N sampling

See Also:

  • aisteer360.algorithms.output_control: Implementations of output control methods
  • aisteer360.core.steering_pipeline: Integration with steering pipeline
NoOutputControl

Bases: OutputControl

Identity output control.

Used as the default when no output control is needed. Calls (unsteered) model's generate.

Source code in aisteer360/algorithms/output_control/base.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class NoOutputControl(OutputControl):
    """Identity output control.

    Used as the default when no output control is needed. Calls (unsteered) model's generate.
    """
    enabled: bool = False

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,  # only for API compliance as runtime_kwargs are not used in HF models.
        model: PreTrainedModel,
        **gen_kwargs,
    ) -> torch.Tensor:
        """Null generate operation; applies model's generate."""
        return model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = False class-attribute instance-attribute
generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs)

Null generate operation; applies model's generate.

Source code in aisteer360/algorithms/output_control/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def generate(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    runtime_kwargs: dict | None,  # only for API compliance as runtime_kwargs are not used in HF models.
    model: PreTrainedModel,
    **gen_kwargs,
) -> torch.Tensor:
    """Null generate operation; applies model's generate."""
    return model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
steer(model, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/output_control/base.py
76
77
78
79
80
81
def steer(self,
          model: PreTrainedModel,
          tokenizer=None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass
OutputControl

Bases: ABC

Abstract base class for output control steering methods.

Overrides the generation process with custom logic.

Methods:

Name Description
generate

Custom generation (required)

steer

One-time preparation (optional)

Source code in aisteer360/algorithms/output_control/base.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class OutputControl(ABC):
    """Abstract base class for output control steering methods.

    Overrides the generation process with custom logic.

    Methods:
        generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs) -> Tensor: Custom generation (required)
        steer(model, tokenizer, **kwargs) -> None: One-time preparation (optional)
    """

    Args: Type[BaseArgs] | None = None

    enabled: bool = True

    def __init__(self, *args, **kwargs) -> None:
        if self.Args is None:  # null control
            if args or kwargs:
                raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
            return

        self.args: BaseArgs = self.Args.validate(*args, **kwargs)

        # move fields to attributes
        for field in fields(self.args):
            setattr(self, field.name, getattr(self.args, field.name))

    @abstractmethod
    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,
        model: PreTrainedModel,
        **gen_kwargs,
    ) -> torch.Tensor:
        """Custom generation logic."""
        pass

    def steer(self,
              model: PreTrainedModel,
              tokenizer=None,
              **kwargs) -> None:
        """Optional steering/preparation."""
        pass
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs) abstractmethod

Custom generation logic.

Source code in aisteer360/algorithms/output_control/base.py
64
65
66
67
68
69
70
71
72
73
74
@abstractmethod
def generate(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    runtime_kwargs: dict | None,
    model: PreTrainedModel,
    **gen_kwargs,
) -> torch.Tensor:
    """Custom generation logic."""
    pass
steer(model, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/output_control/base.py
76
77
78
79
80
81
def steer(self,
          model: PreTrainedModel,
          tokenizer=None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass
deal
args
control
DeAL

Bases: OutputControl

Implementation of DeAL (Decoding-time Alignment) from Deng et al., 2024.

DeAL performs controlled text generation through iterative lookahead search and reward-guided beam selection. Unlike training-time alignment methods, DeAL operates purely at inference time to steer language model outputs toward desired behaviors.

The algorithm works in three phases:

  1. Lookahead Generation: Generate multiple candidate continuations using beam search from the current context.

  2. Reward-based Scoring: Evaluate each candidate continuation using a provided reward function that measures alignment with the desired objective (e.g., helpfulness, safety).

  3. Iterative Refinement: Select the top-k highest-scoring beams and repeat the process until termination conditions are met (EOS token, max length, or max iterations reached).

This approach allows for flexible alignment with various objectives without requiring model retraining or fine-tuning.

Parameters:

Name Type Description Default
reward_func Callable

Function that scores generated continuations. Should accept (prompt: str, continuations: list[str], reward_params: dict) and return list[float].

required
lookahead int

Number of tokens to generate in each lookahead step. Defaults to 4.

required
init_beams int

Number of initial beams to generate at each iteration. Defaults to 8.

required
topk int

Number of top-scoring beams to retain for the next iteration. Defaults to 4.

required
max_iterations int

Maximum number of search iterations before termination. Defaults to 10.

required

Reference:

  • "DeAL: Decoding-time Alignment for Large Language Models" James Y. Huang, Sailik Sengupta, Daniele Bonadiman, Yi-an Lai, Arshit Gupta, Nikolaos Pappas, Saab Mansour, Katrin Kirchhoff, Dan Roth https://arxiv.org/abs/2402.06147
Source code in aisteer360/algorithms/output_control/deal/control.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class DeAL(OutputControl):
    """
    Implementation of DeAL (Decoding-time Alignment) from Deng et al., 2024.

    DeAL performs controlled text generation through iterative lookahead search and reward-guided beam selection. Unlike
    training-time alignment methods, DeAL operates purely at inference time to steer language model outputs toward
    desired behaviors.

    The algorithm works in three phases:

    1. **Lookahead Generation**: Generate multiple candidate continuations using beam search from the current context.

    2. **Reward-based Scoring**: Evaluate each candidate continuation using a provided reward function that measures
    alignment with the desired objective (e.g., helpfulness, safety).

    3. **Iterative Refinement**: Select the top-k highest-scoring beams and repeat the process until termination
    conditions are met (EOS token, max length, or max iterations reached).

    This approach allows for flexible alignment with various objectives without requiring model retraining or
    fine-tuning.

    Args:
        reward_func (Callable): Function that scores generated continuations. Should accept
            (prompt: str, continuations: list[str], reward_params: dict) and return list[float].
        lookahead (int): Number of tokens to generate in each lookahead step. Defaults to 4.
        init_beams (int): Number of initial beams to generate at each iteration. Defaults to 8.
        topk (int): Number of top-scoring beams to retain for the next iteration. Defaults to 4.
        max_iterations (int): Maximum number of search iterations before termination. Defaults to 10.

    Reference:

    - "DeAL: Decoding-time Alignment for Large Language Models"
    James Y. Huang, Sailik Sengupta, Daniele Bonadiman, Yi-an Lai, Arshit Gupta, Nikolaos Pappas, Saab Mansour,
    Katrin Kirchhoff, Dan Roth
    https://arxiv.org/abs/2402.06147
    """

    Args = DeALArgs

    # placeholders
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    base_generate: Callable | None = None

    def steer(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer | None = None,
            **_
    ) -> PreTrainedModel:
        """Lightweight preparation; attaches model, tokenizer, and generate to instance."""
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        self.base_generate = model.generate
        return model

    def _lookahead_generation(
        self,
        input_ids: torch.Tensor,
        reward_func: Callable[[str, list[str], dict], list[float]],
        reward_params: dict,
        base_generate: Callable,
        input_length: int,
        **gen_kwargs,
    ) -> tuple[list[float], torch.Tensor]:
        """Generate and score candidate continuations for one lookahead iteration.

        Generates multiple beam candidates using the base model's generation method, then evaluates each continuation
        with the reward function to guide selection.

        Args:
            input_ids (torch.Tensor): Current context tokens to continue from.
                Shape can vary based on number of active beams.
            reward_func (Callable[[str, list[str], dict], list[float]]): Function to score continuations.
                Receives (original_prompt, continuation_texts, params).
            reward_params (dict): Parameters passed to reward function, including algorithm
                settings (lookahead, init_beams, topk, max_iterations).
            base_generate (Callable): Generation function used to produce candidate continuations.
            input_length (int): Length of original input prompt, used to extract only the newly generated portion for
                scoring.
            **gen_kwargs: Generation parameters forwarded to base_generate (including num_beams, max_new_tokens, etc.)

        Returns:
            tuple[list[float], torch.Tensor]: Tuple containing:
                - Reward scores for each generated beam (list of floats)
                - Full token sequences including input and continuations (tensor)

        Raises:
            RuntimeError: If reward function returns wrong number of scores (must match number of generated beams).

        Note:

        - Continuations are decoded to text for reward evaluation
        - Special tokens are skipped when extracting continuation text
        - Stores original prompt in self.prompt for reward function access
        """
        lookaheads = base_generate(input_ids=input_ids, **gen_kwargs)
        continuations: list[str] = self.tokenizer.batch_decode(
            lookaheads[:, input_length:], skip_special_tokens=True
        )
        scores = reward_func(self.prompt, continuations, reward_params)
        if len(scores) != lookaheads.size(0):
            raise RuntimeError(f"Reward function returned {len(scores)} scores for {lookaheads.size(0)} beams.")
        return scores, lookaheads

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,
        model: PreTrainedModel,
        **gen_kwargs,
    ) -> torch.Tensor:
        """Execute guided generation with iterative lookahead search and reward-based selection. Returns the
        highest-scoring generation.

        The generation process is as follows:

        1. Generate `init_beams` candidate continuations of `lookahead` tokens each
        2. Score all candidates using the provided reward function
        3. Select top-k highest scoring beams
        4. Check termination conditions (EOS, max length, max iterations)
        5. If not terminated, continue from the selected beams
        6. Return the highest-scoring complete generation

        Args:
            input_ids (torch.Tensor): Input token IDs of shape [1, seq_len].
                Currently only supports single prompts (batch size must be 1).
            attention_mask (torch.Tensor): Attention mask matching input_ids shape.
                Automatically recomputed during iteration based on padding tokens.
            runtime_kwargs (dict | None): Runtime parameters including:

                - "base_generate" (`Callable`, optional): Override the model's generate function
                - "reward_params" (`dict`, optional): Additional parameters passed to reward_func
            model (PreTrainedModel): The language model used for generation.
                Must match the model provided during steer().
            **gen_kwargs: Generation parameters passed to the underlying model.generate().
                Note: `max_new_tokens` is extracted and used as global limit; `num_beams` and `num_return_sequences` are
                overridden by DeAL parameters.

        Returns:
            torch.Tensor: Generated token IDs of shape [1, output_len] or [output_len].
                Contains the highest-scoring complete generation found during search.

        Raises:
            ValueError: If base_generate is not callable
            NotImplementedError: If input has batch size > 1 (multiple prompts not supported)
            RuntimeError: If reward function returns incorrect number of scores
        """
        runtime_kwargs = runtime_kwargs or {}

        reward_func = self.reward_func
        base_generate = runtime_kwargs.get("base_generate", self.base_generate)

        if not callable(base_generate):
            raise ValueError("'base_generate' must be callable; supplied or cached from steer().")

        # assert (
        #     self.model is not None and self.tokenizer is not None
        # ), "DeAL.steer() must run before generate()."

        if input_ids.dim() != 2 or input_ids.size(0) != 1:
            raise NotImplementedError("Current DeAL implementation handles one prompt at a time.")

        # record call‑specific objects
        self.prompt: str = self.tokenizer.decode(
            input_ids[0], skip_special_tokens=True
        )
        input_length = input_ids.size(1)

        reward_params = {
            **runtime_kwargs.get("reward_params", {}),
            "lookahead": self.lookahead,
            "init_beams": self.init_beams,
            "topk": self.topk,
            "max_iterations": self.max_iterations,
        }

        original_max_tokens: Optional[int] = gen_kwargs.pop("max_new_tokens", None)

        # search loop
        best_beam: torch.Tensor | None = None
        best_score = float("-inf")
        current_input_ids = input_ids
        iteration = 0

        while iteration < self.max_iterations:
            iteration += 1

            attention_mask = (current_input_ids != self.tokenizer.pad_token_id).long()
            gen_args = copy.deepcopy(gen_kwargs)
            gen_args.update(
                {
                    "max_new_tokens": self.lookahead,
                    "num_beams": self.init_beams,
                    "num_return_sequences": self.init_beams,
                    "attention_mask": attention_mask,
                }
            )

            # rollout + scoring
            scores, beams = self._lookahead_generation(
                current_input_ids,
                reward_func=reward_func,
                reward_params=reward_params,
                base_generate=base_generate,
                input_length=input_length,
                **gen_args,
            )

            # select top-k
            score_tensor = torch.tensor(scores, device=beams.device)
            topk = min(self.topk, score_tensor.numel())
            top_idx = torch.topk(score_tensor, topk).indices
            beams = beams[top_idx]
            scores = score_tensor[top_idx].tolist()

            # termination mask
            finished_flags = []
            for beam in beams:
                eos_hit = beam[...,-1] == self.tokenizer.eos_token_id
                len_hit = (
                        original_max_tokens is not None
                        and beam.size(0) - input_length >= original_max_tokens
                )
                finished_flags.append(bool(eos_hit or len_hit))

            # update best-so-far
            best_local = int(torch.argmax(torch.tensor(scores)))
            if scores[best_local] > best_score:
                best_score = scores[best_local]
                best_beam = beams[best_local]

            if all(finished_flags):
                break

            # prune unfinished beams for next round
            current_input_ids = beams[
                [i for i, f in enumerate(finished_flags) if not f]
            ]

        final_ids = best_beam if best_beam is not None else beams[0]
        return final_ids.unsqueeze(0) if final_ids.dim() == 1 else final_ids
args = self.Args.validate(*args, **kwargs) instance-attribute
base_generate = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs)

Execute guided generation with iterative lookahead search and reward-based selection. Returns the highest-scoring generation.

The generation process is as follows:

  1. Generate init_beams candidate continuations of lookahead tokens each
  2. Score all candidates using the provided reward function
  3. Select top-k highest scoring beams
  4. Check termination conditions (EOS, max length, max iterations)
  5. If not terminated, continue from the selected beams
  6. Return the highest-scoring complete generation

Parameters:

Name Type Description Default
input_ids Tensor

Input token IDs of shape [1, seq_len]. Currently only supports single prompts (batch size must be 1).

required
attention_mask Tensor

Attention mask matching input_ids shape. Automatically recomputed during iteration based on padding tokens.

required
runtime_kwargs dict | None

Runtime parameters including:

  • "base_generate" (Callable, optional): Override the model's generate function
  • "reward_params" (dict, optional): Additional parameters passed to reward_func
required
model PreTrainedModel

The language model used for generation. Must match the model provided during steer().

required
**gen_kwargs

Generation parameters passed to the underlying model.generate(). Note: max_new_tokens is extracted and used as global limit; num_beams and num_return_sequences are overridden by DeAL parameters.

{}

Returns:

Type Description
Tensor

torch.Tensor: Generated token IDs of shape [1, output_len] or [output_len]. Contains the highest-scoring complete generation found during search.

Raises:

Type Description
ValueError

If base_generate is not callable

NotImplementedError

If input has batch size > 1 (multiple prompts not supported)

RuntimeError

If reward function returns incorrect number of scores

Source code in aisteer360/algorithms/output_control/deal/control.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def generate(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    runtime_kwargs: dict | None,
    model: PreTrainedModel,
    **gen_kwargs,
) -> torch.Tensor:
    """Execute guided generation with iterative lookahead search and reward-based selection. Returns the
    highest-scoring generation.

    The generation process is as follows:

    1. Generate `init_beams` candidate continuations of `lookahead` tokens each
    2. Score all candidates using the provided reward function
    3. Select top-k highest scoring beams
    4. Check termination conditions (EOS, max length, max iterations)
    5. If not terminated, continue from the selected beams
    6. Return the highest-scoring complete generation

    Args:
        input_ids (torch.Tensor): Input token IDs of shape [1, seq_len].
            Currently only supports single prompts (batch size must be 1).
        attention_mask (torch.Tensor): Attention mask matching input_ids shape.
            Automatically recomputed during iteration based on padding tokens.
        runtime_kwargs (dict | None): Runtime parameters including:

            - "base_generate" (`Callable`, optional): Override the model's generate function
            - "reward_params" (`dict`, optional): Additional parameters passed to reward_func
        model (PreTrainedModel): The language model used for generation.
            Must match the model provided during steer().
        **gen_kwargs: Generation parameters passed to the underlying model.generate().
            Note: `max_new_tokens` is extracted and used as global limit; `num_beams` and `num_return_sequences` are
            overridden by DeAL parameters.

    Returns:
        torch.Tensor: Generated token IDs of shape [1, output_len] or [output_len].
            Contains the highest-scoring complete generation found during search.

    Raises:
        ValueError: If base_generate is not callable
        NotImplementedError: If input has batch size > 1 (multiple prompts not supported)
        RuntimeError: If reward function returns incorrect number of scores
    """
    runtime_kwargs = runtime_kwargs or {}

    reward_func = self.reward_func
    base_generate = runtime_kwargs.get("base_generate", self.base_generate)

    if not callable(base_generate):
        raise ValueError("'base_generate' must be callable; supplied or cached from steer().")

    # assert (
    #     self.model is not None and self.tokenizer is not None
    # ), "DeAL.steer() must run before generate()."

    if input_ids.dim() != 2 or input_ids.size(0) != 1:
        raise NotImplementedError("Current DeAL implementation handles one prompt at a time.")

    # record call‑specific objects
    self.prompt: str = self.tokenizer.decode(
        input_ids[0], skip_special_tokens=True
    )
    input_length = input_ids.size(1)

    reward_params = {
        **runtime_kwargs.get("reward_params", {}),
        "lookahead": self.lookahead,
        "init_beams": self.init_beams,
        "topk": self.topk,
        "max_iterations": self.max_iterations,
    }

    original_max_tokens: Optional[int] = gen_kwargs.pop("max_new_tokens", None)

    # search loop
    best_beam: torch.Tensor | None = None
    best_score = float("-inf")
    current_input_ids = input_ids
    iteration = 0

    while iteration < self.max_iterations:
        iteration += 1

        attention_mask = (current_input_ids != self.tokenizer.pad_token_id).long()
        gen_args = copy.deepcopy(gen_kwargs)
        gen_args.update(
            {
                "max_new_tokens": self.lookahead,
                "num_beams": self.init_beams,
                "num_return_sequences": self.init_beams,
                "attention_mask": attention_mask,
            }
        )

        # rollout + scoring
        scores, beams = self._lookahead_generation(
            current_input_ids,
            reward_func=reward_func,
            reward_params=reward_params,
            base_generate=base_generate,
            input_length=input_length,
            **gen_args,
        )

        # select top-k
        score_tensor = torch.tensor(scores, device=beams.device)
        topk = min(self.topk, score_tensor.numel())
        top_idx = torch.topk(score_tensor, topk).indices
        beams = beams[top_idx]
        scores = score_tensor[top_idx].tolist()

        # termination mask
        finished_flags = []
        for beam in beams:
            eos_hit = beam[...,-1] == self.tokenizer.eos_token_id
            len_hit = (
                    original_max_tokens is not None
                    and beam.size(0) - input_length >= original_max_tokens
            )
            finished_flags.append(bool(eos_hit or len_hit))

        # update best-so-far
        best_local = int(torch.argmax(torch.tensor(scores)))
        if scores[best_local] > best_score:
            best_score = scores[best_local]
            best_beam = beams[best_local]

        if all(finished_flags):
            break

        # prune unfinished beams for next round
        current_input_ids = beams[
            [i for i, f in enumerate(finished_flags) if not f]
        ]

    final_ids = best_beam if best_beam is not None else beams[0]
    return final_ids.unsqueeze(0) if final_ids.dim() == 1 else final_ids
steer(model, tokenizer=None, **_)

Lightweight preparation; attaches model, tokenizer, and generate to instance.

Source code in aisteer360/algorithms/output_control/deal/control.py
57
58
59
60
61
62
63
64
65
66
67
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer | None = None,
        **_
) -> PreTrainedModel:
    """Lightweight preparation; attaches model, tokenizer, and generate to instance."""
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    self.base_generate = model.generate
    return model
rad
args
control
GPT2RewardModel

Bases: Module

GPT-2 based reward model for scoring text toxicity or other attributes.

Modified GPT-2 architecture where the language modeling head is replaced with a classification head. Used to score text sequences for desired attributes during RAD-guided generation.

Parameters:

Name Type Description Default
reward_model_name str

Base GPT-2 model variant to use. Defaults to "gpt2".

'gpt2'
out_features int

Number of output classes/attributes. Defaults to 1.

1
Source code in aisteer360/algorithms/output_control/rad/control.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
class GPT2RewardModel(nn.Module):
    """GPT-2 based reward model for scoring text toxicity or other attributes.

    Modified GPT-2 architecture where the language modeling head is replaced with a classification head. Used to score
    text sequences for desired attributes during RAD-guided generation.

    Args:
        reward_model_name (str): Base GPT-2 model variant to use. Defaults to "gpt2".
        out_features (int): Number of output classes/attributes. Defaults to 1.
    """
    def __init__(self, reward_model_name="gpt2", out_features=1, cache_dir='./'):
        super(GPT2RewardModel, self).__init__()
        model = GPT2LMHeadModel.from_pretrained(reward_model_name, cache_dir=cache_dir)
        model.lm_head = nn.Linear(in_features=model.lm_head.in_features, out_features=out_features, bias=True)
        self.model = model
        self.pad_token_id = model.config.eos_token_id
        self.out_features = out_features

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
    ):
        """Forward pass through reward model.

        Processes input through GPT-2 backbone and returns scores from the classification head at the last valid token
        position for each sequence.

        Args:
            input_ids: Token IDs of shape [batch_size, seq_len].
            past_key_values: Cached key-value pairs for efficient generation.
            attention_mask: Attention mask for padding.
            token_type_ids: Token type IDs (unused for GPT-2).
            position_ids: Position embeddings.
            head_mask: Attention head mask.

        Returns:
            torch.Tensor: Classification scores of shape [batch_size, out_features].
                Extracted from the last non-padding position of each sequence.
        """
        outputs = self.model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
        )
        logits = outputs['logits']
        # find the last valid token's ids
        sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(logits.device)
        # use the last valid token's representation: (batch, max_length, out_features) => (batch, out_features)
        scores = logits[torch.arange(input_ids.shape[0], device=logits.device), sequence_lengths]

        return scores
model = model instance-attribute
out_features = out_features instance-attribute
pad_token_id = model.config.eos_token_id instance-attribute
forward(input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)

Forward pass through reward model.

Processes input through GPT-2 backbone and returns scores from the classification head at the last valid token position for each sequence.

Parameters:

Name Type Description Default
input_ids Optional[Tensor]

Token IDs of shape [batch_size, seq_len].

None
past_key_values Optional[Tuple[FloatTensor]]

Cached key-value pairs for efficient generation.

None
attention_mask Optional[Tensor]

Attention mask for padding.

None
token_type_ids Optional[Tensor]

Token type IDs (unused for GPT-2).

None
position_ids Optional[Tensor]

Position embeddings.

None
head_mask Optional[Tensor]

Attention head mask.

None

Returns:

Type Description

torch.Tensor: Classification scores of shape [batch_size, out_features]. Extracted from the last non-padding position of each sequence.

Source code in aisteer360/algorithms/output_control/rad/control.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
):
    """Forward pass through reward model.

    Processes input through GPT-2 backbone and returns scores from the classification head at the last valid token
    position for each sequence.

    Args:
        input_ids: Token IDs of shape [batch_size, seq_len].
        past_key_values: Cached key-value pairs for efficient generation.
        attention_mask: Attention mask for padding.
        token_type_ids: Token type IDs (unused for GPT-2).
        position_ids: Position embeddings.
        head_mask: Attention head mask.

    Returns:
        torch.Tensor: Classification scores of shape [batch_size, out_features].
            Extracted from the last non-padding position of each sequence.
    """
    outputs = self.model(
        input_ids=input_ids,
        past_key_values=past_key_values,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
    )
    logits = outputs['logits']
    # find the last valid token's ids
    sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(logits.device)
    # use the last valid token's representation: (batch, max_length, out_features) => (batch, out_features)
    scores = logits[torch.arange(input_ids.shape[0], device=logits.device), sequence_lengths]

    return scores
RAD

Bases: OutputControl

Implementation of RAD (Reward-Augmented Decoding) from Deng and Raffel, 2023. Integrated from the official implementation of RAD (https://github.com/r-three/RAD?tab=readme-ov-file).

RAD works in two phases:

  1. Reward model training: Train a reward model with a lebeled dataset containing texts and labels. For detials about this step, please see https://github.com/r-three/RAD?tab=readme-ov-file. We skip this step in this implementation and re-use the open-source toxicity reward model trained by the authors via gdown https://storage.googleapis.com/rad_release/saved_models.zip

  2. Controlled decoding: At every decoding step the candidate-token logits are shifted by beta * reward, where the reward is given by a trained reward model.

Parameters:

Name Type Description Default
beta float

Steering intensity. Defaults to 0.0.

required
reward_path str

Path to the trained reward model. See https://github.com/r-three/RAD for details. Defaults to None.

required

Reference:

  • "Reward-Augmented Decoding: Efficient Controlled Text Generation With a Unidirectional Reward Model" Haikang Deng, Colin Raffel https://arxiv.org/abs/2310.09520
Source code in aisteer360/algorithms/output_control/rad/control.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class RAD(OutputControl):
    """
    Implementation of RAD (Reward-Augmented Decoding) from Deng and Raffel, 2023.
    Integrated from the official implementation of RAD ([https://github.com/r-three/RAD?tab=readme-ov-file](https://github.com/r-three/RAD?tab=readme-ov-file)).

    RAD works in two phases:

    1. **Reward model training**: Train a reward model with a lebeled dataset containing texts and labels.
    For detials about this step, please see [https://github.com/r-three/RAD?tab=readme-ov-file](https://github.com/r-three/RAD?tab=readme-ov-file). We skip this
    step in this implementation and re-use the open-source toxicity reward model trained by the authors via
    gdown [https://storage.googleapis.com/rad_release/saved_models.zip](https://storage.googleapis.com/rad_release/saved_models.zip)

    2. **Controlled decoding**: At every decoding step the candidate-token logits are shifted by **beta * reward**,
    where the *reward* is given by a trained reward model.

    Args:
        beta (float): Steering intensity. Defaults to 0.0.
        reward_path (str, optional): Path to the trained reward model. See [https://github.com/r-three/RAD](https://github.com/r-three/RAD) for details. Defaults to None.

    Reference:

    - "Reward-Augmented Decoding: Efficient Controlled Text Generation With a Unidirectional Reward Model" Haikang Deng,
     Colin Raffel
     [https://arxiv.org/abs/2310.09520](https://arxiv.org/abs/2310.09520)
    """
    Args = RADArgs

    # placeholders (filled by steer)
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    base_generate: Callable | None = None

    beta: float

    def steer(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer | None = None,
            **__,
    ) -> PreTrainedModel:
        """Initialize RAD by loading and configuring the reward model.

        Sets up the toxicity reward model used for steering during generation. Automatically downloads the model
        from the RAD repository if not found locally.

        Args:
            model (PreTrainedModel): The base language model to be steered.
            tokenizer (PreTrainedTokenizer | None): Tokenizer for the base model.
                If None, attempts to retrieve from model attributes.
            **__: Additional arguments (unused).

        Returns:
            PreTrainedModel: The input model, unchanged.

        Note:

        - Downloads ~500MB reward model on first use if not cached
        - Reward model is GPT2-based with 7 toxicity classification heads
        - Model weights are loaded onto the same device as the base model
        """
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        self.base_generate = model.generate
        self.device = next(model.parameters()).device

        # load reward model from rad
        self.rm_tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=self.reward_path)
        self.rm_tokenizer.pad_token = self.rm_tokenizer.eos_token
        self.rm_tokenizer.padding_side = 'right'
        self.rm_tokenizer.max_length = 1024
        import os
        if (self.reward_path is None) or not os.path.exists(os.path.join(self.reward_path, "pytorch_model.bin")):
            print(f"Reward model not found in: {self.reward_path}. Downloading from https://github.com/r-three/RAD......")
            import zipfile
            try:
                import gdown
            except ImportError:
                import subprocess
                import sys
                subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
                import gdown
            gdown.download('https://storage.googleapis.com/rad_release/saved_models.zip', output='./tmp/rad_saved_models.zip', quiet=False)
            with zipfile.ZipFile("./tmp/rad_saved_models.zip","r") as f:
                f.extractall('./tmp/rad_saved_models')
            print("Reward model downloaded. Please set reward_path='./tmp/rad_saved_models/saved_models/gpt2_toxicity' in the future.")
        else:
            print(f"Reward model found in: {self.reward_path}")
        if self.reward_path is None:
            self.reward_path = './tmp/rad_saved_models/saved_models/gpt2_toxicity'
        state_dict = torch.load(os.path.join(self.reward_path, "pytorch_model.bin"), map_location="cpu")
        self.rm = GPT2RewardModel(reward_model_name="gpt2", out_features=7, cache_dir=self.reward_path)
        self.rm.load_state_dict(state_dict, strict=False)
        self.rm = self.rm.to(self.device)
        print("Reward model is loaded.")

        return model

    @torch.no_grad()
    def generate(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
            runtime_kwargs: dict | None,
            model: PreTrainedModel,
            **gen_kwargs,
    ) -> torch.Tensor:
        """Execute RAD-guided generation with reward-augmented logits processing.

        Performs controlled generation by shifting token logits at each decoding step based on reward model scores.
        Returns generated text steered toward desired behavior.

        At each decoding step:

        1. Generate top-k candidate next tokens
        2. Score each candidate continuation with the reward model
        3. Adjust logits by beta * reward_score
        4. Sample from adjusted distribution

        Args:
            input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len].
            attention_mask (torch.Tensor): Attention mask matching input_ids shape.
            runtime_kwargs (dict | None): Runtime parameters (currently unused).
            model (PreTrainedModel): The language model used for generation.
                Must match the model provided during steer().
            **gen_kwargs: Generation parameters passed to model.generate():

                - "temperature" (`float`, optional): Sampling temperature. Defaults to 1.0.
                - "top_k" (`int`, optional): Top-k filtering. Defaults to 0 (disabled).
                - "top_p" (`float`, optional): Nucleus sampling threshold. Defaults to 1.0.
                - "repetition_penalty" (`float`, optional): Penalty for repeated tokens. Defaults to 1.0.
                - Other standard generation arguments (max_length, pad_token_id, etc.)

        Returns:
            torch.Tensor: Generated token IDs with same batch dimension as input.

        Note:

        - Requires reward model to be loaded during steer() phase
        - When both top_k and top_p are specified, top_k takes precedence for RAD processing
        - Reward scores are clamped to [0, 1] and inverted (1 - score) for toxicity reduction
        - Non-top-k tokens are set to -inf to ensure selection from reward-adjusted candidates
        """

        runtime_kwargs = runtime_kwargs or {}
        beta = self.beta

        processors = LogitsProcessorList()
        temperature = gen_kwargs.get("temperature", 1.0)
        if temperature and temperature != 1.0:
            processors.append(TemperatureLogitsWarper(temperature))

        top_k = gen_kwargs.get("top_k", 0)
        if top_k and top_k > 0:
            processors.append(TopKLogitsWarper(top_k))
            rad_topk = top_k
            rad_topp = 1

        top_p = gen_kwargs.get("top_p", 1.0)
        if top_p and top_p < 1.0:
            processors.append(TopPLogitsWarper(top_p))
            rad_topp = top_p
            rad_topk = None

        repetition_penalty = gen_kwargs.get("repetition_penalty", 1.0)
        if repetition_penalty and repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))

        processors.append(
            RewardAugmentedLogitsProcessorNoPkv(
                        self.tokenizer,
                        self.rm_tokenizer,
                        self.rm,
                        topk=rad_topk,
                        topp=rad_topp,
                        method="linear",
                        beta=beta,
                        inverse=True,
                    )
        )

        # generate candidates
        output = self.base_generate(input_ids=input_ids, attention_mask=attention_mask, logits_processor=processors, **gen_kwargs)
        return output
args = self.Args.validate(*args, **kwargs) instance-attribute
base_generate = None class-attribute instance-attribute
beta instance-attribute
enabled = True class-attribute instance-attribute
model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs)

Execute RAD-guided generation with reward-augmented logits processing.

Performs controlled generation by shifting token logits at each decoding step based on reward model scores. Returns generated text steered toward desired behavior.

At each decoding step:

  1. Generate top-k candidate next tokens
  2. Score each candidate continuation with the reward model
  3. Adjust logits by beta * reward_score
  4. Sample from adjusted distribution

Parameters:

Name Type Description Default
input_ids Tensor

Input token IDs of shape [batch_size, seq_len].

required
attention_mask Tensor

Attention mask matching input_ids shape.

required
runtime_kwargs dict | None

Runtime parameters (currently unused).

required
model PreTrainedModel

The language model used for generation. Must match the model provided during steer().

required
**gen_kwargs

Generation parameters passed to model.generate():

  • "temperature" (float, optional): Sampling temperature. Defaults to 1.0.
  • "top_k" (int, optional): Top-k filtering. Defaults to 0 (disabled).
  • "top_p" (float, optional): Nucleus sampling threshold. Defaults to 1.0.
  • "repetition_penalty" (float, optional): Penalty for repeated tokens. Defaults to 1.0.
  • Other standard generation arguments (max_length, pad_token_id, etc.)
{}

Returns:

Type Description
Tensor

torch.Tensor: Generated token IDs with same batch dimension as input.

Note:

  • Requires reward model to be loaded during steer() phase
  • When both top_k and top_p are specified, top_k takes precedence for RAD processing
  • Reward scores are clamped to [0, 1] and inverted (1 - score) for toxicity reduction
  • Non-top-k tokens are set to -inf to ensure selection from reward-adjusted candidates
Source code in aisteer360/algorithms/output_control/rad/control.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
@torch.no_grad()
def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,
        model: PreTrainedModel,
        **gen_kwargs,
) -> torch.Tensor:
    """Execute RAD-guided generation with reward-augmented logits processing.

    Performs controlled generation by shifting token logits at each decoding step based on reward model scores.
    Returns generated text steered toward desired behavior.

    At each decoding step:

    1. Generate top-k candidate next tokens
    2. Score each candidate continuation with the reward model
    3. Adjust logits by beta * reward_score
    4. Sample from adjusted distribution

    Args:
        input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len].
        attention_mask (torch.Tensor): Attention mask matching input_ids shape.
        runtime_kwargs (dict | None): Runtime parameters (currently unused).
        model (PreTrainedModel): The language model used for generation.
            Must match the model provided during steer().
        **gen_kwargs: Generation parameters passed to model.generate():

            - "temperature" (`float`, optional): Sampling temperature. Defaults to 1.0.
            - "top_k" (`int`, optional): Top-k filtering. Defaults to 0 (disabled).
            - "top_p" (`float`, optional): Nucleus sampling threshold. Defaults to 1.0.
            - "repetition_penalty" (`float`, optional): Penalty for repeated tokens. Defaults to 1.0.
            - Other standard generation arguments (max_length, pad_token_id, etc.)

    Returns:
        torch.Tensor: Generated token IDs with same batch dimension as input.

    Note:

    - Requires reward model to be loaded during steer() phase
    - When both top_k and top_p are specified, top_k takes precedence for RAD processing
    - Reward scores are clamped to [0, 1] and inverted (1 - score) for toxicity reduction
    - Non-top-k tokens are set to -inf to ensure selection from reward-adjusted candidates
    """

    runtime_kwargs = runtime_kwargs or {}
    beta = self.beta

    processors = LogitsProcessorList()
    temperature = gen_kwargs.get("temperature", 1.0)
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature))

    top_k = gen_kwargs.get("top_k", 0)
    if top_k and top_k > 0:
        processors.append(TopKLogitsWarper(top_k))
        rad_topk = top_k
        rad_topp = 1

    top_p = gen_kwargs.get("top_p", 1.0)
    if top_p and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p))
        rad_topp = top_p
        rad_topk = None

    repetition_penalty = gen_kwargs.get("repetition_penalty", 1.0)
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))

    processors.append(
        RewardAugmentedLogitsProcessorNoPkv(
                    self.tokenizer,
                    self.rm_tokenizer,
                    self.rm,
                    topk=rad_topk,
                    topp=rad_topp,
                    method="linear",
                    beta=beta,
                    inverse=True,
                )
    )

    # generate candidates
    output = self.base_generate(input_ids=input_ids, attention_mask=attention_mask, logits_processor=processors, **gen_kwargs)
    return output
steer(model, tokenizer=None, **__)

Initialize RAD by loading and configuring the reward model.

Sets up the toxicity reward model used for steering during generation. Automatically downloads the model from the RAD repository if not found locally.

Parameters:

Name Type Description Default
model PreTrainedModel

The base language model to be steered.

required
tokenizer PreTrainedTokenizer | None

Tokenizer for the base model. If None, attempts to retrieve from model attributes.

None
**__

Additional arguments (unused).

{}

Returns:

Name Type Description
PreTrainedModel PreTrainedModel

The input model, unchanged.

Note:

  • Downloads ~500MB reward model on first use if not cached
  • Reward model is GPT2-based with 7 toxicity classification heads
  • Model weights are loaded onto the same device as the base model
Source code in aisteer360/algorithms/output_control/rad/control.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer | None = None,
        **__,
) -> PreTrainedModel:
    """Initialize RAD by loading and configuring the reward model.

    Sets up the toxicity reward model used for steering during generation. Automatically downloads the model
    from the RAD repository if not found locally.

    Args:
        model (PreTrainedModel): The base language model to be steered.
        tokenizer (PreTrainedTokenizer | None): Tokenizer for the base model.
            If None, attempts to retrieve from model attributes.
        **__: Additional arguments (unused).

    Returns:
        PreTrainedModel: The input model, unchanged.

    Note:

    - Downloads ~500MB reward model on first use if not cached
    - Reward model is GPT2-based with 7 toxicity classification heads
    - Model weights are loaded onto the same device as the base model
    """
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    self.base_generate = model.generate
    self.device = next(model.parameters()).device

    # load reward model from rad
    self.rm_tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=self.reward_path)
    self.rm_tokenizer.pad_token = self.rm_tokenizer.eos_token
    self.rm_tokenizer.padding_side = 'right'
    self.rm_tokenizer.max_length = 1024
    import os
    if (self.reward_path is None) or not os.path.exists(os.path.join(self.reward_path, "pytorch_model.bin")):
        print(f"Reward model not found in: {self.reward_path}. Downloading from https://github.com/r-three/RAD......")
        import zipfile
        try:
            import gdown
        except ImportError:
            import subprocess
            import sys
            subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
            import gdown
        gdown.download('https://storage.googleapis.com/rad_release/saved_models.zip', output='./tmp/rad_saved_models.zip', quiet=False)
        with zipfile.ZipFile("./tmp/rad_saved_models.zip","r") as f:
            f.extractall('./tmp/rad_saved_models')
        print("Reward model downloaded. Please set reward_path='./tmp/rad_saved_models/saved_models/gpt2_toxicity' in the future.")
    else:
        print(f"Reward model found in: {self.reward_path}")
    if self.reward_path is None:
        self.reward_path = './tmp/rad_saved_models/saved_models/gpt2_toxicity'
    state_dict = torch.load(os.path.join(self.reward_path, "pytorch_model.bin"), map_location="cpu")
    self.rm = GPT2RewardModel(reward_model_name="gpt2", out_features=7, cache_dir=self.reward_path)
    self.rm.load_state_dict(state_dict, strict=False)
    self.rm = self.rm.to(self.device)
    print("Reward model is loaded.")

    return model
RewardAugmentedLogitsProcessorNoPkv

Bases: LogitsProcessor

Logits processor that adjusts token probabilities based on reward model scores.

Implements the core RAD algorithm by evaluating candidate tokens with a reward model and shifting their logits proportionally to the reward scores. Designed to work with transformers' generate() method as part of a LogitsProcessorList.

Parameters:

Name Type Description Default
lm_tokenizer

Tokenizer for the language model being steered.

required
rm_tokenizer

Tokenizer for the reward model (typically GPT-2).

required
reward_model

Trained reward model that scores text for desired attributes.

required
topk int

Number of candidate tokens to evaluate. Defaults to 20.

20
topp float

Nucleus sampling threshold if using top-p instead of top-k. Defaults to 1.

1
method str

Reward application method. Currently only "linear" supported. Defaults to "linear".

'linear'
beta float

Scaling factor for reward scores. Higher values = stronger steering. Defaults to 30.

30
inverse bool

Whether to invert reward scores (1 - score). Used for toxicity reduction. Defaults to False.

False
Source code in aisteer360/algorithms/output_control/rad/control.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class RewardAugmentedLogitsProcessorNoPkv(LogitsProcessor):
    """Logits processor that adjusts token probabilities based on reward model scores.

    Implements the core RAD algorithm by evaluating candidate tokens with a reward model and shifting their logits
    proportionally to the reward scores. Designed to work with transformers' generate() method as part of a
    `LogitsProcessorList`.

    Args:
        lm_tokenizer: Tokenizer for the language model being steered.
        rm_tokenizer: Tokenizer for the reward model (typically GPT-2).
        reward_model: Trained reward model that scores text for desired attributes.
        topk (int): Number of candidate tokens to evaluate. Defaults to 20.
        topp (float): Nucleus sampling threshold if using top-p instead of top-k. Defaults to 1.
        method (str): Reward application method. Currently only "linear" supported. Defaults to "linear".
        beta (float): Scaling factor for reward scores. Higher values = stronger steering. Defaults to 30.
        inverse (bool): Whether to invert reward scores (1 - score). Used for toxicity reduction. Defaults to False.
    """
    def __init__(self, lm_tokenizer, rm_tokenizer, reward_model, topk=20, topp=1,
                 method="linear", beta=30, inverse=False):
        self._lm_tokenizer = lm_tokenizer
        self._rm_tokenizer = rm_tokenizer
        self._reward_model = reward_model
        self._device = next(self._reward_model.parameters()).device
        self._reward_model.eval()
        self._topk = topk
        self._topp = topp
        self._method = method
        self._beta = beta
        self._inverse = inverse

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        """Apply reward-based adjustments to token logits.

        For each position in the batch, evaluates top-k candidate tokens by constructing full text sequences, scoring
        them with the reward model, and adjusting logits.

        Args:
            input_ids (torch.LongTensor): Current token sequence of shape [batch_size, seq_len].
            scores (torch.FloatTensor): Raw logits for next token of shape [batch_size, vocab_size].

        Returns:
            torch.FloatTensor: Adjusted logits with reward-based modifications.
                Non-candidate tokens are set to -inf to ensure sampling from evaluated tokens only.

        Note:
            - Dynamically switches between top-k and top-p candidate selection
            - Constructs full text for each candidate to enable proper reward model evaluation
            - Memory usage scales with batch_size * topk for candidate evaluation
        """
        if self._topp < 1:
            ## top p modification, batch=1
            sorted_logits, sorted_indices = torch.sort(scores, descending=False)
            cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
            sorted_indices_to_keep = cumulative_probs > (1 - self._topp)
            indices_to_keep = sorted_indices_to_keep.scatter(1, sorted_indices, sorted_indices_to_keep)
            topk_ids = torch.nonzero(indices_to_keep)[:,1].unsqueeze(0)
            self._topk = topk_ids.shape[1]
            del sorted_logits, sorted_indices, cumulative_probs, sorted_indices_to_keep, indices_to_keep
            torch.cuda.empty_cache()  # Ensure immediate deallocation
        else:
            _, topk_ids = torch.topk(scores, self._topk, dim=-1)                                    # (batch, topk,)
        input_ids_enflated = input_ids.unsqueeze(1).expand((-1, self._topk, -1))                # (batch, topk, seq_len)
        candidate_input_ids = torch.cat((input_ids_enflated, topk_ids.unsqueeze(-1)), dim=-1)   # (batch, topk, seq_len+1)
        candidate_input_ids_unroll = candidate_input_ids.reshape((
            candidate_input_ids.shape[0]*candidate_input_ids.shape[1], -1))         # (batch*topk, seq_len+1)
        candidate_input_texts = self._lm_tokenizer.batch_decode(candidate_input_ids_unroll, skip_special_tokens=True)

        # return reward scores
        reward_scores = self.get_reward(candidate_input_texts).reshape((input_ids.shape[0], -1))

        # apply function (topk_scores, logits)
        for score, id, rs in zip(scores, topk_ids, reward_scores):

            score[id] = self.apply_function(score[id], rs)
            inverse_id = torch.tensor(np.setdiff1d(range(len(score.cpu().numpy())), id.cpu().numpy()), device=self._device)
            score[inverse_id] = -float("Inf")  # set all other scores to -inf
        return scores

    def get_reward(self, candidate_texts):
        """Score candidate text sequences with the reward model.

        Args:
            candidate_texts: List of text strings to evaluate.

        Returns:
            torch.Tensor: Reward scores for each candidate, extracted from first output head.
        """
        with torch.inference_mode():
            # tokenizer should be configured in RAD
            input_ids = self._rm_tokenizer.batch_encode_plus(
                candidate_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self._rm_tokenizer.max_length,
            ).to(self._device)

            reward = self._reward_model(**input_ids)
            return reward[:,0]

    def apply_function(self, original_score, reward_score):
        """Apply reward adjustment to original logits.

        Args:
            original_score: Original logit values for candidate tokens.
            reward_score: Reward model scores for candidates.

        Returns:
            torch.Tensor: Adjusted logits computed as original + (beta * reward).

        Raises:
            ValueError: If method is not "linear".

        Note:

        - Reward scores are clamped to [0, 1] before application.
        """
        reward_score = torch.clamp(reward_score, min=0, max=1)
        if self._inverse:
            reward_score = 1-reward_score
        if self._method == "linear":
            return original_score + (reward_score*self._beta).to(original_score.dtype)
        else:
            raise ValueError(f"method {self._method} not supported")
apply_function(original_score, reward_score)

Apply reward adjustment to original logits.

Parameters:

Name Type Description Default
original_score

Original logit values for candidate tokens.

required
reward_score

Reward model scores for candidates.

required

Returns:

Type Description

torch.Tensor: Adjusted logits computed as original + (beta * reward).

Raises:

Type Description
ValueError

If method is not "linear".

Note:

  • Reward scores are clamped to [0, 1] before application.
Source code in aisteer360/algorithms/output_control/rad/control.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def apply_function(self, original_score, reward_score):
    """Apply reward adjustment to original logits.

    Args:
        original_score: Original logit values for candidate tokens.
        reward_score: Reward model scores for candidates.

    Returns:
        torch.Tensor: Adjusted logits computed as original + (beta * reward).

    Raises:
        ValueError: If method is not "linear".

    Note:

    - Reward scores are clamped to [0, 1] before application.
    """
    reward_score = torch.clamp(reward_score, min=0, max=1)
    if self._inverse:
        reward_score = 1-reward_score
    if self._method == "linear":
        return original_score + (reward_score*self._beta).to(original_score.dtype)
    else:
        raise ValueError(f"method {self._method} not supported")
get_reward(candidate_texts)

Score candidate text sequences with the reward model.

Parameters:

Name Type Description Default
candidate_texts

List of text strings to evaluate.

required

Returns:

Type Description

torch.Tensor: Reward scores for each candidate, extracted from first output head.

Source code in aisteer360/algorithms/output_control/rad/control.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def get_reward(self, candidate_texts):
    """Score candidate text sequences with the reward model.

    Args:
        candidate_texts: List of text strings to evaluate.

    Returns:
        torch.Tensor: Reward scores for each candidate, extracted from first output head.
    """
    with torch.inference_mode():
        # tokenizer should be configured in RAD
        input_ids = self._rm_tokenizer.batch_encode_plus(
            candidate_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self._rm_tokenizer.max_length,
        ).to(self._device)

        reward = self._reward_model(**input_ids)
        return reward[:,0]
sasa
args
control
SASA

Bases: OutputControl

Implementation of SASA (Self-disciplined autoregressive sampling) from Ko et al., 2024.

SASA works in two phases:

  1. Subspace learning: From a labelled toxic / non-toxic corpus, it fits a linear classifier in the model’s own sentence-embedding space; the weight vector defines a toxicity subspace.

  2. Controlled decoding: At every decoding step the candidate-token logits are shifted by beta * margin, where margin is the classifier distance of the updated context from the toxic side of the subspace. Sampling from the soft-max of these adjusted logits (optionally with nucleus sampling) nudges generation away from toxic regions while staying close to the original distribution.

Parameters:

Name Type Description Default
beta float

Scaling coefficient for value redistribution. Defaults to 0.0.

required
wv_path str

Path to a saved steering-vector tensor. Defaults to None.

required
gen_wv_data_path str

Path to the value dataset, e.g. sentences with labeled toxicity. Defaults to "Jigsaw_data/".

required
gen_wv_length int

The maximum number of samples used for preparing SASA steering if wv_path does not exist. Defaults to -1 (use all).

required
gen_wv_batch_size int

The batch size used for preparing SASA steering if wv_path does not exist. Defaults to 4.

required

Reference:

  • "Large Language Models can Become Strong Self-Detoxifiers" Ching-Yun Ko, Pin-Yu Chen, Payel Das, Youssef Mroueh, Soham Dan, Georgios Kollias, Subhajit Chaudhury, Tejaswini Pedapati, Luca Daniel https://arxiv.org/abs/2410.03818
Source code in aisteer360/algorithms/output_control/sasa/control.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
class SASA(OutputControl):
    """
    Implementation of SASA (Self-disciplined autoregressive sampling) from Ko et al., 2024.

    SASA works in two phases:

    1. **Subspace learning**: From a labelled toxic / non-toxic corpus, it fits a linear classifier in the model’s
    own sentence-embedding space; the weight vector defines a toxicity subspace.

    2. **Controlled decoding**: At every decoding step the candidate-token logits are shifted by **beta * margin**,
    where *margin* is the classifier distance of the updated context from the toxic side of the subspace.  Sampling
    from the soft-max of these adjusted logits (optionally with nucleus sampling) nudges generation away from
    toxic regions while staying close to the original distribution.

    Args:
        beta (float): Scaling coefficient for value redistribution. Defaults to 0.0.
        wv_path (str, optional): Path to a saved steering-vector tensor. Defaults to None.
        gen_wv_data_path (str, optional): Path to the value dataset, e.g. sentences with labeled toxicity. Defaults to "Jigsaw_data/".
        gen_wv_length (int, optional): The maximum number of samples used for preparing SASA steering if wv_path does not exist. Defaults to -1 (use all).
        gen_wv_batch_size (int, optional): The batch size used for preparing SASA steering if wv_path does not exist. Defaults to 4.

    Reference:

    - "Large Language Models can Become Strong Self-Detoxifiers"
      Ching-Yun Ko, Pin-Yu Chen, Payel Das, Youssef Mroueh, Soham Dan, Georgios Kollias, Subhajit Chaudhury,
      Tejaswini Pedapati, Luca Daniel
      [https://arxiv.org/abs/2410.03818](https://arxiv.org/abs/2410.03818)
    """
    Args = SASAArgs

    # placeholders (filled by steer)
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    base_generate: Callable | None = None

    beta: float
    wv: torch.Tensor | None

    def steer(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer | None = None,
            **__,
    ) -> PreTrainedModel:
        """Initialize SASA by loading or generating the toxicity steering vector.

        Sets up the linear classifier in the model's embedding space that defines the toxicity subspace. Either loads a
        pre-computed steering vector or generates one from labeled data.

        Args:
            model (PreTrainedModel): The base language model to be steered.
            tokenizer (PreTrainedTokenizer | None): Tokenizer for the base model.
                If None, attempts to retrieve from model attributes.
            **__: Additional arguments (unused).

        Returns:
            PreTrainedModel: The input model (unchanged).

        Raises:
            FileNotFoundError: If gen_wv_data_path doesn't contain required Jigsaw dataset

        Note:

        - If wv_path is provided, loads pre-computed steering vector
        - Otherwise generates steering vector from Jigsaw toxicity dataset
        - Steering vector generation uses closed-form Bayes optimal classifier
        - Saves generated steering vector to 'steer_wv.pt' for future use
        """
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        if self.tokenizer.pad_token_id is None:
            print("pad_token is absent. Setting it to eos_token or '<pad>'.")
            if self.tokenizer.eos_token_id is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            else:  # edge case
                self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
        if self.model.generation_config.pad_token_id is None:
            self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
            self.model.config.pad_token_id = tokenizer.eos_token_id
        self.base_generate = model.generate
        self.device = next(model.parameters()).device
        if getattr(self, "wv_path", None):
            print("Loading SASA steer (wv)......")
            self.wv = torch.load(self.wv_path, map_location="cpu")
        else:
            print("Creating SASA steer (wv)......")
            self._setup_wv()
            # self.wv =  {k: v.cpu() for k, v in self.wv.item().items()}
            torch.save(self.wv, 'steer_wv.pt')
        self.wv = {key: value.to(self.device) for key, value in self.wv.items()}
        return model

    def _setup_wv(self):
        """Generate steering vector from labeled toxicity data.

        Loads the Jigsaw toxicity dataset and learns a linear classifier in the model's embedding space using a
        closed-form Bayes optimal solution. The resulting weight vector defines the toxicity subspace used during
        generation.

        Process:
        1. Load toxic and non-toxic sentences from Jigsaw dataset
        2. Generate sentence embeddings using the model's last hidden states
        3. Compute mean vectors and covariance matrix for both classes
        4. Apply SVD for dimensionality reduction and numerical stability
        5. Compute Bayes optimal linear classifier in reduced space
        6. Project back to original space and normalize

        Raises:
            FileNotFoundError: If Jigsaw dataset not found at gen_wv_data_path

        Note:

        - Uses pooled representation from last non-padding token
        - Handles NaN embeddings by filtering them out
        - Saves computed steering vector to 'steer_wv.pt'
        - Batch processing to manage memory usage
        """

        def batcher(sentences):
            """Generate sentence embeddings using the model's hidden states.

            Args:
                sentences: List of text strings to embed.

            Returns:
                torch.Tensor: Pooled embeddings from last hidden layer, shape [batch_size, hidden_dim].
                    Uses representation from last non-padding token position.
            """
            batch = self.tokenizer.batch_encode_plus(
                sentences,
                return_tensors='pt', truncation=True, max_length=1024, padding=True,
            )
            for k in batch:
                batch[k] = batch[k].to(self.device)
            batch.pop('token_type_ids', None)

            with torch.no_grad():
                outputs = self.model(**batch, output_hidden_states=True, return_dict=True)
                last_hidden = outputs.hidden_states[-1]

            pooled_result = last_hidden[range(len(last_hidden)), batch['attention_mask'].sum(-1) - 1]
            return pooled_result.cpu()

        # Load dataset
        import os

        os.makedirs(self.gen_wv_data_path, exist_ok=True)
        if self.gen_wv_data is not None:
            print(f"Data found in: {self.gen_wv_data}")
            pos = self.gen_wv_data['pos']
            neg = self.gen_wv_data['neg']
        elif os.path.exists(os.path.join(self.gen_wv_data_path, "all_data.csv")):
            print(f"Dataset found in: {self.gen_wv_data_path}")
            dataset = pd.read_csv(os.path.join(self.gen_wv_data_path, "all_data.csv"))
            pos = [row for i, row in dataset['comment_text'].items() if isinstance(row, str) and dataset['toxicity'][i] == 0]
            neg = [row for i, row in dataset['comment_text'].items() if isinstance(row, str) and dataset['toxicity'][i] > 0]
        else:
            raise FileNotFoundError(
                f"""
                    Jigsaw dataset not found at: {self.gen_wv_data_path}
                    To use jigsaw_unintended_bias you have to download it manually from Kaggle: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data
                    You can manually download the data from it's homepage or use the Kaggle CLI tool (follow the instructions here: https://www.kaggle.com/docs/api)
                    Please extract all files in one folder and then load the dataset with:
                    dataset = pd.read_csv('/Jigsaw_data/all_data.csv')
                    """
            )

        num = len(pos) + len(neg)
        print(f"There are overall {len(pos)} positive sentences and {len(neg)} negative sentences.")
        if self.gen_wv_length > 0 and self.gen_wv_length < num:
            num_pos = int(self.gen_wv_length / num * len(pos))
            num_neg = self.gen_wv_length - num_pos
            pos = pos[:num_pos]
            neg = neg[:num_neg]
        print(f"Generating wv via {len(pos)} positive sentences and {len(neg)} negative sentences.")

        sorted_pos = sorted(pos, key=lambda z: -len(z))
        sorted_neg = sorted(neg, key=lambda z: -len(z))

        # Gather embeddings
        embeddings_pos = []
        embeddings_neg = []
        for ii in tqdm(range(0, len(sorted_pos), self.gen_wv_batch_size), desc="Embedding POS"):
            batch = sorted_pos[ii:ii + self.gen_wv_batch_size]
            embeddings_pos.append(torch.tensor(batcher(batch)))
        for ii in tqdm(range(0, len(sorted_neg), self.gen_wv_batch_size), desc="Embedding NEG"):
            batch = sorted_neg[ii:ii + self.gen_wv_batch_size]
            embeddings_neg.append(torch.tensor(batcher(batch)))

        X1_train = torch.vstack(embeddings_pos)
        X2_train = torch.vstack(embeddings_neg)
        X1_train = X1_train[~torch.isnan(X1_train).any(dim=1)]
        X2_train = X2_train[~torch.isnan(X2_train).any(dim=1)]

        # Obtain closed-form Bayes optimal classifier
        mu_1 = torch.mean(X1_train, axis=0)
        cov = torch.cov(X1_train.T) * (X1_train.shape[0] - 1)
        mu_2 = torch.mean(X2_train, axis=0)
        cov += torch.cov(X2_train.T) * (X2_train.shape[0] - 1)
        cov = cov / (X1_train.shape[0] + X2_train.shape[0] - 2)

        torch.cuda.empty_cache()

        F, D, _ = torch.svd(cov, some=True)
        F = F[:, D > 1e-6].float()
        D = D[D > 1e-6].float()
        D_inv = torch.diag(D ** (-1))

        mu = torch.matmul(F.t(), (mu_1 - mu_2) / 2)
        mu_mu = (mu_1 + mu_2) / 2
        w_0 = torch.matmul(D_inv, mu)
        wv = torch.matmul(F, w_0)
        wv = wv / torch.norm(wv)

        self.wv = {'wv': wv, 'mu_mu': mu_mu}

    @staticmethod
    def repeat_kv_cache(cache, repeats: int):
        """Repeat KV cache entries for parallel candidate evaluation.

        Duplicates cache entries to enable efficient parallel processing of multiple candidate tokens without
        recomputing shared context.

        Args:
            cache: KV cache in various formats (DynamicCache, tuple, or custom).
            repeats (int): Number of times to repeat each cache entry.

        Returns:
            Repeated cache in same format as input.

        Raises:
            TypeError: If cache type is not supported.
        """
        if hasattr(cache, "batch_repeat_interleave"):
            cache.batch_repeat_interleave(repeats)
            return cache

        elif hasattr(cache, "to_legacy_cache"):
            raw = cache.to_legacy_cache()
            repeated = tuple(
                tuple(t.repeat(repeats, 1, 1, 1) for t in layer)
                for layer in raw
            )
            return DynamicCache.from_legacy_cache(repeated)

        elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
            for i in range(len(cache.key_cache)):
                cache.key_cache[i] = cache.key_cache[i].repeat_interleave(repeats, dim=0)
                cache.value_cache[i] = cache.value_cache[i].repeat_interleave(repeats, dim=0)
            return cache

        elif isinstance(cache, tuple):
            return tuple(
                tuple(t.repeat_interleave(repeats, dim=0) for t in layer)
                for layer in cache
            )

        else:
            raise TypeError(f"Unsupported cache type: {type(cache).__name__}")

    @staticmethod
    def select_kv_cache(cache, select_idx: torch.Tensor):
        """Select specific entries from KV cache based on indices.

        Extracts cache entries corresponding to selected beam paths, used after evaluating multiple candidates to
        continue with the chosen token.

        Args:
            cache: KV cache in various formats.
            select_idx (torch.Tensor): 1D tensor of indices to select.

        Returns:
            Selected cache entries in same format as input.

        Raises:
            ValueError: If select_idx is not 1D.
            TypeError: If cache type is not supported.
        """
        if not torch.is_tensor(select_idx):
            select_idx = torch.as_tensor(select_idx)
        if select_idx.dtype != torch.long:
            select_idx = select_idx.long()
        if select_idx.dim() != 1:
            raise ValueError(f"select_idx must be 1D, got shape {tuple(select_idx.shape)}")

        if hasattr(cache, "batch_select"):
            cache.batch_select(select_idx)
            return cache

        elif hasattr(cache, "batch_gather"):
            cache.batch_gather(select_idx)
            return cache

        elif hasattr(cache, "to_legacy_cache"):
            raw = cache.to_legacy_cache()
            selected = tuple(
                tuple(t[select_idx, :, :, :] for t in layer)
                for layer in raw
            )
            return DynamicCache.from_legacy_cache(selected)

        elif hasattr(cache, 'key_cache') and hasattr(cache, 'value_cache'):
            for i in range(len(cache.key_cache)):
                if cache.key_cache[i] is not None:
                    select_idx_device = select_idx.to(cache.key_cache[i].device)
                    cache.key_cache[i] = cache.key_cache[i].index_select(dim=0, index=select_idx_device)
                if cache.value_cache[i] is not None:
                    select_idx_device = select_idx.to(cache.value_cache[i].device)
                    cache.value_cache[i] = cache.value_cache[i].index_select(dim=0, index=select_idx_device)
            return cache

        elif isinstance(cache, tuple):
            return tuple(
                tuple(t.index_select(dim=0, index=select_idx.to(t.device)) for t in layer)
                for layer in cache
            )

        else:
            raise TypeError(f"Unsupported cache type: {type(cache).__name__}")

    @torch.no_grad()
    def generate(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
            runtime_kwargs: dict | None,
            model: PreTrainedModel,
            **gen_kwargs,
    ) -> torch.Tensor:
        """Execute SASA-guided generation with margin-based logit adjustment.

        Performs controlled generation by computing the distance from toxic subspace at each decoding step and adjusting
        token logits based on this margin. Returns text steered away from toxic regions while maintaining coherence.

        At each decoding step:

        1. Generate embeddings for all valid candidate tokens
        2. Compute margin (distance from toxic subspace) for each candidate
        3. Adjust logits by beta * softmax(margins)
        4. Sample from adjusted distribution

        Args:
            input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len].
            attention_mask (torch.Tensor): Attention mask matching input_ids shape.
            runtime_kwargs (dict | None): Runtime parameters (unused).
            model (PreTrainedModel): The language model used for generation.
                Must match the model provided during steer().
            **gen_kwargs: Generation parameters passed to model internals:

                - "generation_config" (`GenerationConfig`, optional): Generation configuration object
                - "logits_processor" (`LogitsProcessorList`, optional): Custom logit processors
                - "stopping_criteria" (`StoppingCriteriaList`, optional): Custom stopping criteria
                - "max_new_tokens" (`int`, optional): Maximum tokens to generate
                - Standard generation arguments (temperature, top_p, etc.)

        Returns:
            torch.Tensor: Generated token IDs including the input prompt.

        Note:

        - Computes full forward passes for all valid candidate tokens at each step
        - Uses custom KV cache manipulation for efficient candidate evaluation
        - Margins computed relative to learned toxic/non-toxic boundary
        - SASA is memory intensive; scales with vocabulary size at each generation step
        """

        runtime_kwargs = runtime_kwargs or {}
        beta = self.beta
        wv = self.wv

        # # If vanilla decoding, allow opt-out
        # if not runtime_kwargs.get("sasa_enabled", True):
        #     return self.base_generate(input_ids=input_ids, **gen_kwargs)

        inputs: torch.Tensor = input_ids

        generation_config: Optional[GenerationConfig] = gen_kwargs.pop("generation_config", None)
        logits_processor: Optional[LogitsProcessorList] = gen_kwargs.pop("logits_processor", None)
        stopping_criteria: Optional[StoppingCriteriaList] = gen_kwargs.pop("stopping_criteria", None)
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = gen_kwargs.pop(
            "prefix_allowed_tokens_fn", None)

        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
        if generation_config is None:
            generation_config = self.model.generation_config if hasattr(self.model,
                                                                        "generation_config") else GenerationConfig()
        else:
            generation_config = copy.deepcopy(generation_config)

        generation_config, model_kwargs = self.model._prepare_generation_config(
            generation_config,
            use_model_defaults=True,
            **gen_kwargs
        )
        generation_config.validate()

        # Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

        # Define model inputs
        # input_ids has to be defined
        # all model-specific keyword inputs are removed from `model_kwargs`
        input_ids, _, model_kwargs = self.model._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = input_ids.shape[0]  # todo: unused?
        device = input_ids.device
        self.model._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

        # Prepare `max_length` depending on other stopping criteria.
        input_ids_seq_length = input_ids.shape[-1]
        if generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

        # Prepare logits processor, stopping criteria
        logits_processor = self.model._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=input_ids,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
            model_kwargs=model_kwargs,
        )
        stopping_criteria = self.model._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria, **gen_kwargs
        )

        # Expand input_ids with `num_return_sequences` additional sequences per batch
        input_ids, model_kwargs = self.model._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_return_sequences,
            is_encoder_decoder=False,
            **model_kwargs,
        )

        # Run sample
        # init values
        scores = ()
        mv = None

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
        this_peer_finished = False  # used by synced_gpus only

        model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
        model_kwargs["attention_mask"] = attention_mask
        # model_kwargs = self.model._get_initial_cache_position(input_ids, model_kwargs)

        # auto-regressive generation
        while True:
            if mv is None:  # when generating the first token
                # prepare model inputs
                model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)

                # forward pass to get next token
                outputs = self.model(
                    **model_inputs,
                    return_dict=True,
                    output_attentions=True,
                    output_hidden_states=True,
                )
            else:
                selected_index = (indices[:, -1] == next_tokens).nonzero(as_tuple=True)
                assert len(selected_index) == 1 and len(selected_index[0]) == 1
                outputs.logits = outputs.logits[selected_index, :, :]
                outputs.hidden_states = tuple(
                    [outputs.hidden_states[i][selected_index, :, :] for i in range(len(outputs.hidden_states))]
                )
                outputs.past_key_values = self.select_kv_cache(outputs.past_key_values, selected_index)

            next_token_logits = outputs.logits[:, -1, :]
            next_token_scores = logits_processor(input_ids, next_token_logits)

            model_kwargs = self.model._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=False
            )

            # prepare the value margins
            with torch.no_grad():
                prev_hidden_states = outputs['hidden_states'][-1][:, -1, :].clone()
                indices = torch.nonzero(next_token_scores > -torch.inf)
                num = indices.shape[0]
                input_ids_temp = torch.cat([input_ids.repeat(num, 1), indices[:, -1].unsqueeze(1)], dim=-1)
                model_kwargs_temp = model_kwargs.copy()

                is_gemma = hasattr(self.model, 'config') and 'gemma' in str(type(self.model)).lower()
                if is_gemma:
                    if hasattr(model_kwargs['past_key_values'], 'get_seq_length'):
                        cache_length = model_kwargs['past_key_values'].get_seq_length()
                    else:
                        # fallback: use cache_position to infer length
                        cache_length = model_kwargs_temp['cache_position'][0].item()
                    # Trim attention_mask to match cache length for gemma
                    model_kwargs_temp['attention_mask'] = model_kwargs_temp['attention_mask'][:, :cache_length]

                    original_cache_pos = model_kwargs_temp['cache_position']
                    new_token_position = original_cache_pos[-1] + 1
                    model_kwargs_temp['cache_position'] = torch.tensor([new_token_position],
                                                                    dtype=original_cache_pos.dtype,
                                                                    device=original_cache_pos.device
                                                                    )
                model_kwargs_temp['attention_mask'] = model_kwargs_temp['attention_mask'].repeat(num, 1)
                model_kwargs_temp['past_key_values'] = self.repeat_kv_cache(model_kwargs['past_key_values'], num)

                model_inputs = self.model.prepare_inputs_for_generation(input_ids_temp, **model_kwargs_temp)
                outputs = self.model(**model_inputs, return_dict=True, output_attentions=True,
                                     output_hidden_states=True, )

                if wv is not None:
                    if isinstance(wv, dict) and len(wv) == 2:
                        mv = (wv['wv'] * (outputs['hidden_states'][-1][:, -1, :] - wv['mu_mu'])).sum(axis=1)
                    else:
                        mv = (wv * (outputs['hidden_states'][-1][:, -1, :] - prev_hidden_states)).sum(axis=1)

            # re-distribute weights
            if wv is not None and mv is not None:
                redistribute = next_token_scores[next_token_scores > -torch.inf] + (beta * mv.softmax(dim=-1)).to(
                    dtype=next_token_scores.dtype)
                next_token_scores[next_token_scores > -torch.inf] = redistribute

            probs = nn.functional.softmax(next_token_scores, dim=-1)
            assert probs.sum() > 0
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

            # store scores
            scores += (next_token_scores,)

            # finished sentences should have their next token be a padding token
            if generation_config.eos_token_id is not None:
                if generation_config.pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (
                        1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
            this_peer_finished = unfinished_sequences.max() == 0

            if this_peer_finished:
                break

        return input_ids
args = self.Args.validate(*args, **kwargs) instance-attribute
base_generate = None class-attribute instance-attribute
beta instance-attribute
enabled = True class-attribute instance-attribute
model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
wv instance-attribute
generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs)

Execute SASA-guided generation with margin-based logit adjustment.

Performs controlled generation by computing the distance from toxic subspace at each decoding step and adjusting token logits based on this margin. Returns text steered away from toxic regions while maintaining coherence.

At each decoding step:

  1. Generate embeddings for all valid candidate tokens
  2. Compute margin (distance from toxic subspace) for each candidate
  3. Adjust logits by beta * softmax(margins)
  4. Sample from adjusted distribution

Parameters:

Name Type Description Default
input_ids Tensor

Input token IDs of shape [batch_size, seq_len].

required
attention_mask Tensor

Attention mask matching input_ids shape.

required
runtime_kwargs dict | None

Runtime parameters (unused).

required
model PreTrainedModel

The language model used for generation. Must match the model provided during steer().

required
**gen_kwargs

Generation parameters passed to model internals:

  • "generation_config" (GenerationConfig, optional): Generation configuration object
  • "logits_processor" (LogitsProcessorList, optional): Custom logit processors
  • "stopping_criteria" (StoppingCriteriaList, optional): Custom stopping criteria
  • "max_new_tokens" (int, optional): Maximum tokens to generate
  • Standard generation arguments (temperature, top_p, etc.)
{}

Returns:

Type Description
Tensor

torch.Tensor: Generated token IDs including the input prompt.

Note:

  • Computes full forward passes for all valid candidate tokens at each step
  • Uses custom KV cache manipulation for efficient candidate evaluation
  • Margins computed relative to learned toxic/non-toxic boundary
  • SASA is memory intensive; scales with vocabulary size at each generation step
Source code in aisteer360/algorithms/output_control/sasa/control.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
@torch.no_grad()
def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,
        model: PreTrainedModel,
        **gen_kwargs,
) -> torch.Tensor:
    """Execute SASA-guided generation with margin-based logit adjustment.

    Performs controlled generation by computing the distance from toxic subspace at each decoding step and adjusting
    token logits based on this margin. Returns text steered away from toxic regions while maintaining coherence.

    At each decoding step:

    1. Generate embeddings for all valid candidate tokens
    2. Compute margin (distance from toxic subspace) for each candidate
    3. Adjust logits by beta * softmax(margins)
    4. Sample from adjusted distribution

    Args:
        input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len].
        attention_mask (torch.Tensor): Attention mask matching input_ids shape.
        runtime_kwargs (dict | None): Runtime parameters (unused).
        model (PreTrainedModel): The language model used for generation.
            Must match the model provided during steer().
        **gen_kwargs: Generation parameters passed to model internals:

            - "generation_config" (`GenerationConfig`, optional): Generation configuration object
            - "logits_processor" (`LogitsProcessorList`, optional): Custom logit processors
            - "stopping_criteria" (`StoppingCriteriaList`, optional): Custom stopping criteria
            - "max_new_tokens" (`int`, optional): Maximum tokens to generate
            - Standard generation arguments (temperature, top_p, etc.)

    Returns:
        torch.Tensor: Generated token IDs including the input prompt.

    Note:

    - Computes full forward passes for all valid candidate tokens at each step
    - Uses custom KV cache manipulation for efficient candidate evaluation
    - Margins computed relative to learned toxic/non-toxic boundary
    - SASA is memory intensive; scales with vocabulary size at each generation step
    """

    runtime_kwargs = runtime_kwargs or {}
    beta = self.beta
    wv = self.wv

    # # If vanilla decoding, allow opt-out
    # if not runtime_kwargs.get("sasa_enabled", True):
    #     return self.base_generate(input_ids=input_ids, **gen_kwargs)

    inputs: torch.Tensor = input_ids

    generation_config: Optional[GenerationConfig] = gen_kwargs.pop("generation_config", None)
    logits_processor: Optional[LogitsProcessorList] = gen_kwargs.pop("logits_processor", None)
    stopping_criteria: Optional[StoppingCriteriaList] = gen_kwargs.pop("stopping_criteria", None)
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = gen_kwargs.pop(
        "prefix_allowed_tokens_fn", None)

    # priority: `generation_config` argument > `model.generation_config` (the default generation config)
    if generation_config is None:
        generation_config = self.model.generation_config if hasattr(self.model,
                                                                    "generation_config") else GenerationConfig()
    else:
        generation_config = copy.deepcopy(generation_config)

    generation_config, model_kwargs = self.model._prepare_generation_config(
        generation_config,
        use_model_defaults=True,
        **gen_kwargs
    )
    generation_config.validate()

    # Set generation parameters if not already defined
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

    # Define model inputs
    # input_ids has to be defined
    # all model-specific keyword inputs are removed from `model_kwargs`
    input_ids, _, model_kwargs = self.model._prepare_model_inputs(
        inputs, generation_config.bos_token_id, model_kwargs
    )
    batch_size = input_ids.shape[0]  # todo: unused?
    device = input_ids.device
    self.model._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

    # Prepare `max_length` depending on other stopping criteria.
    input_ids_seq_length = input_ids.shape[-1]
    if generation_config.max_new_tokens is not None:
        generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

    # Prepare logits processor, stopping criteria
    logits_processor = self.model._get_logits_processor(
        generation_config=generation_config,
        input_ids_seq_length=input_ids_seq_length,
        encoder_input_ids=input_ids,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        logits_processor=logits_processor,
        model_kwargs=model_kwargs,
    )
    stopping_criteria = self.model._get_stopping_criteria(
        generation_config=generation_config, stopping_criteria=stopping_criteria, **gen_kwargs
    )

    # Expand input_ids with `num_return_sequences` additional sequences per batch
    input_ids, model_kwargs = self.model._expand_inputs_for_generation(
        input_ids=input_ids,
        expand_size=generation_config.num_return_sequences,
        is_encoder_decoder=False,
        **model_kwargs,
    )

    # Run sample
    # init values
    scores = ()
    mv = None

    # keep track of which sequences are already finished
    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
    this_peer_finished = False  # used by synced_gpus only

    model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
    model_kwargs["attention_mask"] = attention_mask
    # model_kwargs = self.model._get_initial_cache_position(input_ids, model_kwargs)

    # auto-regressive generation
    while True:
        if mv is None:  # when generating the first token
            # prepare model inputs
            model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self.model(
                **model_inputs,
                return_dict=True,
                output_attentions=True,
                output_hidden_states=True,
            )
        else:
            selected_index = (indices[:, -1] == next_tokens).nonzero(as_tuple=True)
            assert len(selected_index) == 1 and len(selected_index[0]) == 1
            outputs.logits = outputs.logits[selected_index, :, :]
            outputs.hidden_states = tuple(
                [outputs.hidden_states[i][selected_index, :, :] for i in range(len(outputs.hidden_states))]
            )
            outputs.past_key_values = self.select_kv_cache(outputs.past_key_values, selected_index)

        next_token_logits = outputs.logits[:, -1, :]
        next_token_scores = logits_processor(input_ids, next_token_logits)

        model_kwargs = self.model._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=False
        )

        # prepare the value margins
        with torch.no_grad():
            prev_hidden_states = outputs['hidden_states'][-1][:, -1, :].clone()
            indices = torch.nonzero(next_token_scores > -torch.inf)
            num = indices.shape[0]
            input_ids_temp = torch.cat([input_ids.repeat(num, 1), indices[:, -1].unsqueeze(1)], dim=-1)
            model_kwargs_temp = model_kwargs.copy()

            is_gemma = hasattr(self.model, 'config') and 'gemma' in str(type(self.model)).lower()
            if is_gemma:
                if hasattr(model_kwargs['past_key_values'], 'get_seq_length'):
                    cache_length = model_kwargs['past_key_values'].get_seq_length()
                else:
                    # fallback: use cache_position to infer length
                    cache_length = model_kwargs_temp['cache_position'][0].item()
                # Trim attention_mask to match cache length for gemma
                model_kwargs_temp['attention_mask'] = model_kwargs_temp['attention_mask'][:, :cache_length]

                original_cache_pos = model_kwargs_temp['cache_position']
                new_token_position = original_cache_pos[-1] + 1
                model_kwargs_temp['cache_position'] = torch.tensor([new_token_position],
                                                                dtype=original_cache_pos.dtype,
                                                                device=original_cache_pos.device
                                                                )
            model_kwargs_temp['attention_mask'] = model_kwargs_temp['attention_mask'].repeat(num, 1)
            model_kwargs_temp['past_key_values'] = self.repeat_kv_cache(model_kwargs['past_key_values'], num)

            model_inputs = self.model.prepare_inputs_for_generation(input_ids_temp, **model_kwargs_temp)
            outputs = self.model(**model_inputs, return_dict=True, output_attentions=True,
                                 output_hidden_states=True, )

            if wv is not None:
                if isinstance(wv, dict) and len(wv) == 2:
                    mv = (wv['wv'] * (outputs['hidden_states'][-1][:, -1, :] - wv['mu_mu'])).sum(axis=1)
                else:
                    mv = (wv * (outputs['hidden_states'][-1][:, -1, :] - prev_hidden_states)).sum(axis=1)

        # re-distribute weights
        if wv is not None and mv is not None:
            redistribute = next_token_scores[next_token_scores > -torch.inf] + (beta * mv.softmax(dim=-1)).to(
                dtype=next_token_scores.dtype)
            next_token_scores[next_token_scores > -torch.inf] = redistribute

        probs = nn.functional.softmax(next_token_scores, dim=-1)
        assert probs.sum() > 0
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # store scores
        scores += (next_token_scores,)

        # finished sentences should have their next token be a padding token
        if generation_config.eos_token_id is not None:
            if generation_config.pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (
                    1 - unfinished_sequences)

        # update generated ids, model inputs, and length for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0

        if this_peer_finished:
            break

    return input_ids
repeat_kv_cache(cache, repeats) staticmethod

Repeat KV cache entries for parallel candidate evaluation.

Duplicates cache entries to enable efficient parallel processing of multiple candidate tokens without recomputing shared context.

Parameters:

Name Type Description Default
cache

KV cache in various formats (DynamicCache, tuple, or custom).

required
repeats int

Number of times to repeat each cache entry.

required

Returns:

Type Description

Repeated cache in same format as input.

Raises:

Type Description
TypeError

If cache type is not supported.

Source code in aisteer360/algorithms/output_control/sasa/control.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
@staticmethod
def repeat_kv_cache(cache, repeats: int):
    """Repeat KV cache entries for parallel candidate evaluation.

    Duplicates cache entries to enable efficient parallel processing of multiple candidate tokens without
    recomputing shared context.

    Args:
        cache: KV cache in various formats (DynamicCache, tuple, or custom).
        repeats (int): Number of times to repeat each cache entry.

    Returns:
        Repeated cache in same format as input.

    Raises:
        TypeError: If cache type is not supported.
    """
    if hasattr(cache, "batch_repeat_interleave"):
        cache.batch_repeat_interleave(repeats)
        return cache

    elif hasattr(cache, "to_legacy_cache"):
        raw = cache.to_legacy_cache()
        repeated = tuple(
            tuple(t.repeat(repeats, 1, 1, 1) for t in layer)
            for layer in raw
        )
        return DynamicCache.from_legacy_cache(repeated)

    elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
        for i in range(len(cache.key_cache)):
            cache.key_cache[i] = cache.key_cache[i].repeat_interleave(repeats, dim=0)
            cache.value_cache[i] = cache.value_cache[i].repeat_interleave(repeats, dim=0)
        return cache

    elif isinstance(cache, tuple):
        return tuple(
            tuple(t.repeat_interleave(repeats, dim=0) for t in layer)
            for layer in cache
        )

    else:
        raise TypeError(f"Unsupported cache type: {type(cache).__name__}")
select_kv_cache(cache, select_idx) staticmethod

Select specific entries from KV cache based on indices.

Extracts cache entries corresponding to selected beam paths, used after evaluating multiple candidates to continue with the chosen token.

Parameters:

Name Type Description Default
cache

KV cache in various formats.

required
select_idx Tensor

1D tensor of indices to select.

required

Returns:

Type Description

Selected cache entries in same format as input.

Raises:

Type Description
ValueError

If select_idx is not 1D.

TypeError

If cache type is not supported.

Source code in aisteer360/algorithms/output_control/sasa/control.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
@staticmethod
def select_kv_cache(cache, select_idx: torch.Tensor):
    """Select specific entries from KV cache based on indices.

    Extracts cache entries corresponding to selected beam paths, used after evaluating multiple candidates to
    continue with the chosen token.

    Args:
        cache: KV cache in various formats.
        select_idx (torch.Tensor): 1D tensor of indices to select.

    Returns:
        Selected cache entries in same format as input.

    Raises:
        ValueError: If select_idx is not 1D.
        TypeError: If cache type is not supported.
    """
    if not torch.is_tensor(select_idx):
        select_idx = torch.as_tensor(select_idx)
    if select_idx.dtype != torch.long:
        select_idx = select_idx.long()
    if select_idx.dim() != 1:
        raise ValueError(f"select_idx must be 1D, got shape {tuple(select_idx.shape)}")

    if hasattr(cache, "batch_select"):
        cache.batch_select(select_idx)
        return cache

    elif hasattr(cache, "batch_gather"):
        cache.batch_gather(select_idx)
        return cache

    elif hasattr(cache, "to_legacy_cache"):
        raw = cache.to_legacy_cache()
        selected = tuple(
            tuple(t[select_idx, :, :, :] for t in layer)
            for layer in raw
        )
        return DynamicCache.from_legacy_cache(selected)

    elif hasattr(cache, 'key_cache') and hasattr(cache, 'value_cache'):
        for i in range(len(cache.key_cache)):
            if cache.key_cache[i] is not None:
                select_idx_device = select_idx.to(cache.key_cache[i].device)
                cache.key_cache[i] = cache.key_cache[i].index_select(dim=0, index=select_idx_device)
            if cache.value_cache[i] is not None:
                select_idx_device = select_idx.to(cache.value_cache[i].device)
                cache.value_cache[i] = cache.value_cache[i].index_select(dim=0, index=select_idx_device)
        return cache

    elif isinstance(cache, tuple):
        return tuple(
            tuple(t.index_select(dim=0, index=select_idx.to(t.device)) for t in layer)
            for layer in cache
        )

    else:
        raise TypeError(f"Unsupported cache type: {type(cache).__name__}")
steer(model, tokenizer=None, **__)

Initialize SASA by loading or generating the toxicity steering vector.

Sets up the linear classifier in the model's embedding space that defines the toxicity subspace. Either loads a pre-computed steering vector or generates one from labeled data.

Parameters:

Name Type Description Default
model PreTrainedModel

The base language model to be steered.

required
tokenizer PreTrainedTokenizer | None

Tokenizer for the base model. If None, attempts to retrieve from model attributes.

None
**__

Additional arguments (unused).

{}

Returns:

Name Type Description
PreTrainedModel PreTrainedModel

The input model (unchanged).

Raises:

Type Description
FileNotFoundError

If gen_wv_data_path doesn't contain required Jigsaw dataset

Note:

  • If wv_path is provided, loads pre-computed steering vector
  • Otherwise generates steering vector from Jigsaw toxicity dataset
  • Steering vector generation uses closed-form Bayes optimal classifier
  • Saves generated steering vector to 'steer_wv.pt' for future use
Source code in aisteer360/algorithms/output_control/sasa/control.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer | None = None,
        **__,
) -> PreTrainedModel:
    """Initialize SASA by loading or generating the toxicity steering vector.

    Sets up the linear classifier in the model's embedding space that defines the toxicity subspace. Either loads a
    pre-computed steering vector or generates one from labeled data.

    Args:
        model (PreTrainedModel): The base language model to be steered.
        tokenizer (PreTrainedTokenizer | None): Tokenizer for the base model.
            If None, attempts to retrieve from model attributes.
        **__: Additional arguments (unused).

    Returns:
        PreTrainedModel: The input model (unchanged).

    Raises:
        FileNotFoundError: If gen_wv_data_path doesn't contain required Jigsaw dataset

    Note:

    - If wv_path is provided, loads pre-computed steering vector
    - Otherwise generates steering vector from Jigsaw toxicity dataset
    - Steering vector generation uses closed-form Bayes optimal classifier
    - Saves generated steering vector to 'steer_wv.pt' for future use
    """
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    if self.tokenizer.pad_token_id is None:
        print("pad_token is absent. Setting it to eos_token or '<pad>'.")
        if self.tokenizer.eos_token_id is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        else:  # edge case
            self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
    if self.model.generation_config.pad_token_id is None:
        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        self.model.config.pad_token_id = tokenizer.eos_token_id
    self.base_generate = model.generate
    self.device = next(model.parameters()).device
    if getattr(self, "wv_path", None):
        print("Loading SASA steer (wv)......")
        self.wv = torch.load(self.wv_path, map_location="cpu")
    else:
        print("Creating SASA steer (wv)......")
        self._setup_wv()
        # self.wv =  {k: v.cpu() for k, v in self.wv.item().items()}
        torch.save(self.wv, 'steer_wv.pt')
    self.wv = {key: value.to(self.device) for key, value in self.wv.items()}
    return model
thinking_intervention
args
control
ThinkingIntervention

Bases: OutputControl

Implementation of Thinking Intervention from Wu et al., 2025.

ThinkingIntervention enables controlled text generation by injecting structured thinking processes into the model's reasoning chain. The method modifies the input prompt to include explicit thinking steps enclosed in special tags, allowing the model to engage in guided reasoning before producing the final output.

The algorithm works in three phases:

  1. Prompt Modification: Transform the original prompt by applying an intervention function that injects thinking instructions, reasoning templates, or structured prompts to guide the model's internal reasoning process.

  2. Guided Generation: Generate text using the modified prompt, where the model first produces thinking content within special tags (e.g., ...) before generating the actual response.

  3. Output Extraction: Parse the generated text to extract only the content after the thinking tags.

Parameters:

Name Type Description Default
intervention Callable[[str, dict], str]

Function that modifies the input prompt to include thinking instructions. Takes the original prompt string and parameter dict, returns the modified prompt string.

required
Reference

"Effectively Controlling Reasoning Models through Thinking Intervention" Tong Wu, Chong Xiang, Jiachen T. Wang, G. Edward Suh, Prateek Mittal https://arxiv.org/abs/2503.24370

Source code in aisteer360/algorithms/output_control/thinking_intervention/control.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class ThinkingIntervention(OutputControl):
    """
    Implementation of Thinking Intervention from Wu et al., 2025.

    `ThinkingIntervention` enables controlled text generation by injecting structured thinking processes into the model's
    reasoning chain. The method modifies the input prompt to include explicit thinking steps enclosed in special tags,
    allowing the model to engage in guided reasoning before producing the final output.

    The algorithm works in three phases:

    1. **Prompt Modification**: Transform the original prompt by applying an intervention function that injects thinking
    instructions, reasoning templates, or structured prompts to guide the model's internal reasoning process.

    2. **Guided Generation**: Generate text using the modified prompt, where the model first produces thinking content
    within special tags (e.g., <think>...</think>) before generating the actual response.

    3. **Output Extraction**: Parse the generated text to extract only the content after the thinking tags.

    Args:
        intervention (Callable[[str, dict], str]): Function that modifies the input prompt to include thinking
            instructions. Takes the original prompt string and parameter dict, returns the modified prompt string.

    Reference:
        "Effectively Controlling Reasoning Models through Thinking Intervention"
        Tong Wu, Chong Xiang, Jiachen T. Wang, G. Edward Suh, Prateek Mittal
        https://arxiv.org/abs/2503.24370
    """

    Args = ThinkingInterventionArgs

    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    base_generate: Callable | None = None

    def steer(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer | None = None,
            **_
    ) -> PreTrainedModel:
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        self.base_generate = model.generate
        return model

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,
        model: PreTrainedModel,
        **gen_kwargs,
    ) -> torch.Tensor:
        runtime_kwargs = runtime_kwargs or {}
        self.tag_ids = self.tokenizer("</think>", add_special_tokens=False).input_ids
        # Paper says interventions are best at the beginning
        intervention = self.intervention
        input_params = {**runtime_kwargs.get('params', {})}

        base_generate = runtime_kwargs.get("base_generate", self.base_generate)

        original_prompt_ids = input_ids[0]
        original_input_len = original_prompt_ids.size(0)

        prompt_str = self.tokenizer.decode(
            original_prompt_ids, skip_special_tokens=True
        )
        modified_prompt_str = intervention(prompt_str, input_params)

        new_input = self.tokenizer(modified_prompt_str, return_tensors="pt").to(self.model.device)

        gen_kwargs["return_dict_in_generate"] = False
        output_ids = base_generate(**new_input, **gen_kwargs)[0]
        keep_prefix = output_ids[: original_input_len]

        decoded   = self.tokenizer.decode(output_ids, skip_special_tokens=False)
        remainder_txt = decoded.rsplit("</think>", 1)[-1].lstrip()

        remainder = (
            self.tokenizer(
                remainder_txt,
                add_special_tokens=False,
                return_tensors="pt"
            )["input_ids"]
            .to(output_ids.device)
            .squeeze(0)
        )

        final_ids = torch.cat([keep_prefix, remainder], dim=0)
        return final_ids.unsqueeze(0) if final_ids.dim() == 1 else final_ids
args = self.Args.validate(*args, **kwargs) instance-attribute
base_generate = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs)

Custom generation logic.

Source code in aisteer360/algorithms/output_control/thinking_intervention/control.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def generate(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    runtime_kwargs: dict | None,
    model: PreTrainedModel,
    **gen_kwargs,
) -> torch.Tensor:
    runtime_kwargs = runtime_kwargs or {}
    self.tag_ids = self.tokenizer("</think>", add_special_tokens=False).input_ids
    # Paper says interventions are best at the beginning
    intervention = self.intervention
    input_params = {**runtime_kwargs.get('params', {})}

    base_generate = runtime_kwargs.get("base_generate", self.base_generate)

    original_prompt_ids = input_ids[0]
    original_input_len = original_prompt_ids.size(0)

    prompt_str = self.tokenizer.decode(
        original_prompt_ids, skip_special_tokens=True
    )
    modified_prompt_str = intervention(prompt_str, input_params)

    new_input = self.tokenizer(modified_prompt_str, return_tensors="pt").to(self.model.device)

    gen_kwargs["return_dict_in_generate"] = False
    output_ids = base_generate(**new_input, **gen_kwargs)[0]
    keep_prefix = output_ids[: original_input_len]

    decoded   = self.tokenizer.decode(output_ids, skip_special_tokens=False)
    remainder_txt = decoded.rsplit("</think>", 1)[-1].lstrip()

    remainder = (
        self.tokenizer(
            remainder_txt,
            add_special_tokens=False,
            return_tensors="pt"
        )["input_ids"]
        .to(output_ids.device)
        .squeeze(0)
    )

    final_ids = torch.cat([keep_prefix, remainder], dim=0)
    return final_ids.unsqueeze(0) if final_ids.dim() == 1 else final_ids
steer(model, tokenizer=None, **_)

Optional steering/preparation.

Source code in aisteer360/algorithms/output_control/thinking_intervention/control.py
48
49
50
51
52
53
54
55
56
57
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer | None = None,
        **_
) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    self.base_generate = model.generate
    return model

state_control

base

State control base classes.

This module provides the abstract base class for methods that register hooks into the model (e.g., to modify intermediate representations during inference); does not change model weights.

Two base classes are provided:

  • StateControl: Base class for all state control methods.
  • NoStateControl: Identity (null) control; used when no state control is defined in steering pipeline.

State controls implement steering through runtime intervention in the model's forward pass, modifying internal states (activations, attention patterns) to produce generations following y ~ p_θᵃ(x), where "p_θᵃ" is the model with state controls.

Examples of state controls:

  • Activation steering (e.g., adding direction vectors)
  • Attention head manipulation and pruning
  • Layer-wise activation editing
  • Dynamic routing between components
  • Representation engineering techniques

The base class provides automatic hook management through context managers (ensures cleanup and avoids memory leaks).

See Also:

  • aisteer360.algorithms.state_control: Implementations of state control methods
  • aisteer360.core.steering_pipeline: Integration with steering pipeline
BackwardHook = Callable[[nn.Module, tuple, tuple], tuple] module-attribute
ForwardHook = Callable[[nn.Module, tuple, torch.Tensor], torch.Tensor] module-attribute
HookSpec = dict[str, str | PreHook | ForwardHook | BackwardHook] module-attribute
PreHook = Callable[[nn.Module, tuple], tuple | torch.Tensor] module-attribute
NoStateControl

Bases: StateControl

Identity state control.

Used as the default when no state control is needed. Returns empty hook dictionaries and skips registration.

Source code in aisteer360/algorithms/state_control/base.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class NoStateControl(StateControl):
    """Identity state control.

    Used as the default when no state control is needed. Returns empty hook dictionaries and skips registration.
    """
    enabled: bool = False

    def get_hooks(self, *_, **__) -> dict[str, list[HookSpec]]:
        """Return empty hooks."""
        return {"pre": [], "forward": [], "backward": []}

    def steer(self,
              model: PreTrainedModel,
              tokenizer=None,
              **kwargs) -> None:
        """Null steering operation."""
        pass

    def register_hooks(self, *_):
        """Null registration operation."""
        pass

    def remove_hooks(self, *_):
        """Null removal operation."""
        pass

    def set_hooks(self, hooks: dict[str, list[HookSpec]]):
        """Null set operation."""
        pass

    def reset(self):
        """Null reset operation."""
        pass
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = False class-attribute instance-attribute
hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute
registered = [] instance-attribute
get_hooks(*_, **__)

Return empty hooks.

Source code in aisteer360/algorithms/state_control/base.py
152
153
154
def get_hooks(self, *_, **__) -> dict[str, list[HookSpec]]:
    """Return empty hooks."""
    return {"pre": [], "forward": [], "backward": []}
register_hooks(*_)

Null registration operation.

Source code in aisteer360/algorithms/state_control/base.py
163
164
165
def register_hooks(self, *_):
    """Null registration operation."""
    pass
remove_hooks(*_)

Null removal operation.

Source code in aisteer360/algorithms/state_control/base.py
167
168
169
def remove_hooks(self, *_):
    """Null removal operation."""
    pass
reset()

Null reset operation.

Source code in aisteer360/algorithms/state_control/base.py
175
176
177
def reset(self):
    """Null reset operation."""
    pass
set_hooks(hooks)

Null set operation.

Source code in aisteer360/algorithms/state_control/base.py
171
172
173
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Null set operation."""
    pass
steer(model, tokenizer=None, **kwargs)

Null steering operation.

Source code in aisteer360/algorithms/state_control/base.py
156
157
158
159
160
161
def steer(self,
          model: PreTrainedModel,
          tokenizer=None,
          **kwargs) -> None:
    """Null steering operation."""
    pass
StateControl

Bases: ABC

Abstract base class for state control steering methods.

Modifies internal model states during forward passes via hooks.

Methods:

Name Description
get_hooks

Create hook specs (required)

steer

One-time preparation (optional)

reset

Reset logic (optional)

register_hooks

Attach hooks to model (provided)

remove_hooks

Remove all registered hooks (provided)

Source code in aisteer360/algorithms/state_control/base.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class StateControl(ABC):
    """Abstract base class for state control steering methods.

    Modifies internal model states during forward passes via hooks.

    Methods:
        get_hooks(input_ids, runtime_kwargs, **kwargs) -> dict: Create hook specs (required)
        steer(model, tokenizer, **kwargs) -> None: One-time preparation (optional)
        reset() -> None: Reset logic (optional)
        register_hooks(model) -> None: Attach hooks to model (provided)
        remove_hooks() -> None: Remove all registered hooks (provided)
    """

    Args: Type[BaseArgs] | None = None

    enabled: bool = True
    _model_ref: PreTrainedModel | None = None

    def __init__(self, *args, **kwargs) -> None:
        if self.Args is None:  # null control
            if args or kwargs:
                raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
            return

        self.args: BaseArgs = self.Args.validate(*args, **kwargs)

        # move fields to attributes
        for field in fields(self.args):
            setattr(self, field.name, getattr(self.args, field.name))

        self.hooks: dict[str, list[HookSpec]] = {"pre": [], "forward": [], "backward": []}
        self.registered: list[torch.utils.hooks.RemovableHandle] = []

    @abstractmethod
    def get_hooks(
        self,
        input_ids: torch.Tensor,
        runtime_kwargs: dict | None,
        **kwargs,
    ) -> dict[str, list[HookSpec]]:
        """Create hook specifications for the current generation."""
        pass

    def steer(self,
              model: PreTrainedModel,
              tokenizer: PreTrainedTokenizerBase = None,
              **kwargs) -> None:
        """Optional steering/preparation."""
        pass

    def register_hooks(self, model: PreTrainedModel) -> None:
        """Attach hooks to model."""
        for phase in ("pre", "forward", "backward"):
            for spec in self.hooks[phase]:
                module = model.get_submodule(spec["module"])
                if phase == "pre":
                    handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
                elif phase == "forward":
                    handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
                else:
                    handle = module.register_full_backward_hook(spec["hook_func"])
                self.registered.append(handle)

    def remove_hooks(self) -> None:
        """Remove all registered hooks from the model."""
        for handle in self.registered:
            handle.remove()
        self.registered.clear()

    def set_hooks(self, hooks: dict[str, list[HookSpec]]):
        """Update the hook specifications to be registered."""
        self.hooks = hooks

    def __enter__(self):
        """Context manager entry: register hooks to model.

        Raises:
            RuntimeError: If model reference not set by pipeline
        """
        if self._model_ref is None:
            raise RuntimeError("Model reference not set before entering context.")
        self.register_hooks(self._model_ref)

        return self

    def __exit__(self, exc_type, exc, tb):
        """Context manager exit: clean up all hooks."""
        self.remove_hooks()

    def reset(self):
        """Optional reset call for state control."""
        pass


    def reset(self):
        """Optional reset call for state control"""
        pass
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute
registered = [] instance-attribute
get_hooks(input_ids, runtime_kwargs, **kwargs) abstractmethod

Create hook specifications for the current generation.

Source code in aisteer360/algorithms/state_control/base.py
79
80
81
82
83
84
85
86
87
@abstractmethod
def get_hooks(
    self,
    input_ids: torch.Tensor,
    runtime_kwargs: dict | None,
    **kwargs,
) -> dict[str, list[HookSpec]]:
    """Create hook specifications for the current generation."""
    pass
register_hooks(model)

Attach hooks to model.

Source code in aisteer360/algorithms/state_control/base.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def register_hooks(self, model: PreTrainedModel) -> None:
    """Attach hooks to model."""
    for phase in ("pre", "forward", "backward"):
        for spec in self.hooks[phase]:
            module = model.get_submodule(spec["module"])
            if phase == "pre":
                handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
            elif phase == "forward":
                handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
            else:
                handle = module.register_full_backward_hook(spec["hook_func"])
            self.registered.append(handle)
remove_hooks()

Remove all registered hooks from the model.

Source code in aisteer360/algorithms/state_control/base.py
109
110
111
112
113
def remove_hooks(self) -> None:
    """Remove all registered hooks from the model."""
    for handle in self.registered:
        handle.remove()
    self.registered.clear()
reset()

Optional reset call for state control

Source code in aisteer360/algorithms/state_control/base.py
140
141
142
def reset(self):
    """Optional reset call for state control"""
    pass
set_hooks(hooks)

Update the hook specifications to be registered.

Source code in aisteer360/algorithms/state_control/base.py
115
116
117
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Update the hook specifications to be registered."""
    self.hooks = hooks
steer(model, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/state_control/base.py
89
90
91
92
93
94
def steer(self,
          model: PreTrainedModel,
          tokenizer: PreTrainedTokenizerBase = None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass
cast
args
control
CAST

Bases: StateControl

Implementation of CAST (Conditional Activation Steering) from Lee et al., 2024.

CAST enables selective control of LLM behavior by conditionally applying activation steering based on input context, allowing fine-grained control without affecting responses to non-targeted content.

The method operates in two phases:

  1. Condition Detection: Analyzes hidden state activation patterns at specified layers during inference to detect if the input matches target conditions. This is done by projecting hidden states onto a condition subspace and computing similarity scores against a threshold.

  2. Conditional Behavior Modification: When conditions are met, applies steering vectors to hidden states at designated behavior layers. This selectively modifies the model's internal representations to produce desired behavioral changes while preserving normal functionality for non-matching inputs.

Parameters:

Name Type Description Default
condition_vector SteeringVector

Steering vector defining the condition subspace for detecting target input patterns. Defaults to None.

required
behavior_vector SteeringVector

Steering vector applied to modify behavior when conditions are met. Defaults to None.

required
condition_layer_ids list[int]

Layer indices where condition detection occurs. Defaults to None.

required
behavior_layer_ids list[int]

Layer indices where behavior modification is applied. Defaults to None.

required
condition_vector_threshold float

Similarity threshold for condition detection. Higher values require stronger pattern matches. Defaults to 0.5.

required
behavior_vector_strength float

Scaling factor for the behavior steering vector. Controls the intensity of behavioral modification. Defaults to 1.0.

required
condition_comparator_threshold_is str

Comparison mode for threshold ('larger' or 'smaller'). Determines if condition is met when similarity is above or below threshold. Defaults to 'larger'.

required
condition_threshold_comparison_mode str

How to aggregate hidden states for comparison ('mean' or 'last'). Defaults to 'mean'.

required
apply_behavior_on_first_call bool

Whether to apply behavior steering on the first forward pass. Defaults to True.

required
use_ooi_preventive_normalization bool

Apply out-of-distribution preventive normalization to maintain hidden state magnitudes. Defaults to False.

required
use_explained_variance bool

Scale steering vectors by their explained variance for adaptive layer-wise control. Defaults to False.

required

Reference:

  • "Programming Refusal with Conditional Activation Steering" Bruce W. Lee, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Erik Miehling, Pierre Dognin, Manish Nagireddy, Amit Dhurandhar https://arxiv.org/abs/2409.05907
Source code in aisteer360/algorithms/state_control/cast/control.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
class CAST(StateControl):
    """
    Implementation of CAST (Conditional Activation Steering) from Lee et al., 2024.

    CAST enables selective control of LLM behavior by conditionally applying activation steering based on input context,
    allowing fine-grained control without affecting responses to non-targeted content.

    The method operates in two phases:

    1. **Condition Detection**: Analyzes hidden state activation patterns at specified layers during inference to detect
        if the input matches target conditions. This is done by projecting hidden states onto a condition subspace and
        computing similarity scores against a threshold.

    2. **Conditional Behavior Modification**: When conditions are met, applies steering vectors to hidden states at
        designated behavior layers. This selectively modifies the model's internal representations to produce desired
        behavioral changes while preserving normal functionality for non-matching inputs.

    Args:
        condition_vector (SteeringVector, optional): Steering vector defining the condition subspace for detecting
            target input patterns. Defaults to None.
        behavior_vector (SteeringVector, optional): Steering vector applied to modify behavior when conditions are met.
            Defaults to None.
        condition_layer_ids (list[int], optional): Layer indices where condition detection occurs. Defaults to None.
        behavior_layer_ids (list[int], optional): Layer indices where behavior modification is applied. Defaults to None.
        condition_vector_threshold (float, optional): Similarity threshold for condition detection. Higher values
            require stronger pattern matches. Defaults to 0.5.
        behavior_vector_strength (float, optional): Scaling factor for the behavior steering vector. Controls the
            intensity of behavioral modification. Defaults to 1.0.
        condition_comparator_threshold_is (str, optional): Comparison mode for threshold ('larger' or 'smaller').
            Determines if condition is met when similarity is above or below threshold. Defaults to 'larger'.
        condition_threshold_comparison_mode (str, optional): How to aggregate hidden states for comparison ('mean'
            or 'last'). Defaults to 'mean'.
        apply_behavior_on_first_call (bool, optional): Whether to apply behavior steering on the first forward pass.
            Defaults to True.
        use_ooi_preventive_normalization (bool, optional): Apply out-of-distribution preventive normalization to
            maintain hidden state magnitudes. Defaults to False.
        use_explained_variance (bool, optional): Scale steering vectors by their explained variance for adaptive
            layer-wise control. Defaults to False.

    Reference:

    - "Programming Refusal with Conditional Activation Steering"
    Bruce W. Lee, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Erik Miehling, Pierre Dognin, Manish Nagireddy, Amit Dhurandhar
    [https://arxiv.org/abs/2409.05907](https://arxiv.org/abs/2409.05907)
    """

    Args = CASTArgs

    # placeholders
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    device: torch.device | str | None = None

    # layers list reference
    _layers: list | None = None
    _layers_names: list | None = None
    _layers_states: dict[int, LayerArgs] | None = None

    # Boolean lists for condition and behavior layers
    _condition_layers: dict[int, bool] | None = None
    _behavior_layers: dict[int, bool] | None = None

    # Logic flags
    _condition_met: dict[int, bool] = defaultdict(bool)
    _forward_calls: dict[int, int] = defaultdict(int)

    # condition similarity record
    _condition_similarities: dict = defaultdict(lambda: defaultdict(float))

    def reset(self):
        """Reset internal state tracking between generation calls.

        Clears condition detection flags, forward call counters, and similarity scores.
        """
        self._condition_met = defaultdict(bool)
        self._forward_calls = defaultdict(int)
        self._condition_similarities = defaultdict(lambda: defaultdict(float))

    def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer | None = None,
        **__
    ) -> PreTrainedModel:
        """Initialization by configuring condition detection and behavior modification layers.

        Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation
        steering. Pre-computes projection matrices and behavior vectors.

        Args:
            model (PreTrainedModel): The base language model to be steered.
            tokenizer (PreTrainedTokenizer | None): Tokenizer (currently unused but maintained
                for API consistency). If None, attempts to retrieve from model attributes.
            **__: Additional arguments (unused).

        Returns:
            PreTrainedModel: The input model, unchanged.
        """
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        self.device = next(model.parameters()).device

        self._setup(self.model)

        return model

    def get_hooks(
            self,
            input_ids: torch.Tensor,
            runtime_kwargs: dict | None,
            **__,
    ) -> dict[str, list]:
        """Create pre-forward hooks for conditional activation steering.

        Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior
        modifications during the forward pass.

        Args:
            input_ids (torch.Tensor): Input token IDs (unused but required by interface).
            runtime_kwargs (dict | None): Runtime parameters (currently unused).
            **__: Additional arguments (unused).

        Returns:
            dict[str, list]: Hook specifications with "pre", "forward", "backward" keys.
                Only "pre" hooks are populated with CAST steering logic.
        """

        hooks: dict[str, list] = {"pre": [], "forward": [], "backward": []}
        for layer_id, layer_name in enumerate(self._layers_names):
            hooks["pre"].append(
                {
                    "module": layer_name,  # "model.layers.0"
                    "hook_func": partial(
                        self._cast_pre_hook,
                        layer_id=layer_id,
                    ),
                }
            )

        return hooks

    def get_model_layer_list(self, model: PreTrainedModel) -> list:
        """Extract the list of transformer layers from the model.

        Args:
            model (PreTrainedModel): Model to extract layers from.

        Returns:

            List of layers for given model
            List of layers module name prefix for given model
        """
        layers = []
        layers_names = []

        model_layers = None
        model_layers_prefix = ''

        if hasattr(model, "model"):  # mistral-, llama-, gemma-like models
            model_layers = model.model.layers
            model_layers_prefix = "model.layers"
        elif hasattr(model, "transformer"):  # gpt2-like models
            model_layers = model.transformer.h
            model_layers_prefix = "transformer.h"
        else:
            raise ValueError(f"Don't know how to get layer list from model for {type(model)=}")

        for idx, layer in enumerate(model_layers):
            layers.append(layer)
            layers_names.append(f"{model_layers_prefix}.{idx}")

        return layers, layers_names

    def _setup(self, model: PreTrainedModel):
        """Configure all CAST internals for the given model.

        Pre-computes steering vectors, condition projectors, and layer configurations to minimize runtime overhead during generation.

        Process:

        1. Identifies condition and behavior layers from configuration
        2. Computes condition projection matrices for detection layers
        3. Prepares scaled behavior vectors for modification layers
        4. Stores layer-specific parameters in _layer_states

        Args:
            model (PreTrainedModel): Model to configure CAST for.
        """
        self._layers, self._layers_names = self.get_model_layer_list(model)
        num_layers = len(self._layers)

        # Creating dicts for condition and behavior layers
        condition_layers = [False] * num_layers
        behavior_layers = [False] * num_layers

        if self.condition_vector is not None and self.condition_layer_ids is not None:
            for layer_id in self.condition_layer_ids:
                condition_layers[layer_id] = True

        if self.behavior_vector is not None:
            for layer_id in self.behavior_layer_ids:
                behavior_layers[layer_id] = True

        self._condition_layers = {i: v for i, v in enumerate(condition_layers)}
        self._behavior_layers = {i: v for i, v in enumerate(behavior_layers)}

        # Precompute behavior vectors and condition projectors
        condition_layer_ids_set = set(self.condition_layer_ids) if self.condition_layer_ids is not None else set()
        behavior_layer_ids_set = set(self.behavior_layer_ids)

        self._layer_states = {}

        for layer_id in range(num_layers):
            # layer = self._layers[layer_id]
            behavior_tensor = None
            if self.behavior_vector is not None:
                if layer_id in behavior_layer_ids_set:
                    if self.use_explained_variance:
                        behavior_direction = self._use_explained_variance_func(self.behavior_vector)
                    else:
                        behavior_direction = self.behavior_vector.directions[layer_id]

                    behavior_tensor = torch.tensor(self.behavior_vector_strength * behavior_direction, dtype=self.model.dtype).to(self.model.device)

            condition_projector = None
            if self.condition_vector is not None and layer_id in condition_layer_ids_set:
                condition_direction = self.condition_vector.directions[layer_id]
                if self.use_explained_variance:
                    condition_direction = self._use_explained_variance_func(self.condition_vector)
                else:
                    condition_direction = self.condition_vector.directions[layer_id]

                condition_tensor = torch.tensor(condition_direction, dtype=self.model.dtype).to(self.model.device)
                condition_projector = torch.ger(condition_tensor, condition_tensor) / torch.dot(condition_tensor, condition_tensor)

            layer_control_params = LayerControlParams()

            layer_args = LayerArgs(
                behavior_vector=behavior_tensor,
                condition_projector=condition_projector,
                threshold=self.condition_vector_threshold,
                use_ooi_preventive_normalization=self.use_ooi_preventive_normalization,
                apply_behavior_on_first_call=self.apply_behavior_on_first_call,
                condition_comparator_threshold_is=self.condition_comparator_threshold_is,
                condition_threshold_comparison_mode=self.condition_threshold_comparison_mode,
                params=layer_control_params,
            )

            self._layer_states[layer_id] = layer_args

    def _use_explained_variance_func(self, vector: SteeringVector, layer_id: int) -> np.ndarray:
        """Scale steering vector by its explained variance for adaptive control.

        This method scales the steering vector based on its explained variance,
        potentially adjusting its impact on different layers of the model.

        Args:
            vector (SteeringVector): Steering vector containing directions and variances.
            layer_id (int): Layer index to retrieve variance scaling for.

        Returns:
            np.ndarray: Direction vector scaled by explained variance.
        """

        if hasattr(vector, 'explained_variances'):
            variance_scale = vector.explained_variances.get(layer_id, 1)
            direction = vector.directions.get(layer_id, 1)
            direction = direction * variance_scale

        return direction

    def _cast_pre_hook(
        self,
        module,
        input_args: Tuple,
        input_kwargs: dict,
        layer_id: int,
    ):
        """Apply conditional activation steering as a pre-forward hook.

        Detect conditions and applies behavior modifications during the model's forward pass. Processes each layer
        independently based on its configuration.

        Process:

        1. Extract hidden states from arguments
        2. If condition layer: detect if input matches target pattern
        3. If behavior layer and conditions met: apply steering vector
        4. Optionally apply OOI normalization to prevent distribution shift

        Args:
            module: The layer module being hooked.
            input_args: Positional arguments to the forward pass.
            input_kwargs: Keyword arguments to the forward pass.
            layer_id (int): Index of the current layer.

        Returns:
            Tuple of potentially modified (input_args, input_kwargs).

        Raises:
            RuntimeError: If hidden states cannot be located.
        """
        hidden_states = input_args[0] if input_args else input_kwargs.get("hidden_states")
        if hidden_states is None:
            raise RuntimeError("CAST: could not locate hidden states")

        self._forward_calls[layer_id] += 1
        batch_size, seq_length, hidden_dim = hidden_states.shape

        if self._condition_layers is None:
            # CASE 1 -> no steering
            is_condition_layer = False
            is_behavior_layer = False
        else:
            # CASE 2 -> steering
            is_condition_layer = self._condition_layers[layer_id]
            is_behavior_layer = self._behavior_layers[layer_id]

        original_norm = hidden_states.norm(dim=-1, keepdim=True)

        if is_condition_layer:
            self._process_single_condition(hidden_states[0], layer_id)

        if is_behavior_layer:
            self._apply_single_behavior(hidden_states, layer_id)

        if self.use_ooi_preventive_normalization and is_behavior_layer:
            hidden_states = self._apply_ooi_normalization(hidden_states, original_norm)

        if input_args:
            input_list = list(input_args)
            input_list[0] = hidden_states
            my_input_args = tuple(input_list)
        else:
            my_input_args = input_args
            input_kwargs["hidden_states"] = hidden_states

        return my_input_args, input_kwargs

    def _process_single_condition(self, hidden_state, layer_id: int):
        """Detect if input matches target condition pattern.

        Projects hidden states onto condition subspace and compares similarity against threshold to determine if
        steering should be activated.

        Process:

        1. Aggregate hidden states (mean or last token based on config)
        2. Project onto condition subspace using precomputed projector
        3. Compute cosine similarity between original and projected
        4. Compare against threshold with specified comparator

        Args:
            hidden_state: Hidden state tensor to analyze [seq_len, hidden_dim].
            layer_id (int): Current layer index.
        """
        layer_args = self._layer_states[layer_id]

        if not self._condition_met[0] and self._forward_calls[layer_id] == 1:
            if layer_args.condition_threshold_comparison_mode == "mean":
                hidden_state = hidden_state.mean(dim=0)
            elif layer_args.condition_threshold_comparison_mode == "last":
                hidden_state = hidden_state[-1, :]

            projected_hidden_state = torch.tanh(torch.matmul(layer_args.condition_projector, hidden_state))
            condition_similarity = self._compute_similarity(hidden_state, projected_hidden_state)
            self._condition_similarities[0][layer_id] = condition_similarity

            if layer_args.condition_comparator_threshold_is == "smaller":
                condition_met = (condition_similarity >= layer_args.threshold)
            elif layer_args.condition_comparator_threshold_is == "larger":
                condition_met = (condition_similarity < layer_args.threshold)
            else:
                raise ValueError(f"invalid {layer_args.condition_comparator_threshold_is}")

            self._condition_met[0] = condition_met

            print(f"layer {layer_id}:  similarity: {condition_similarity} "
                  f"threshold: {layer_args.threshold} "
                  f"condition comparator threshold '{layer_args.condition_comparator_threshold_is}' -- "
                  f"Condition Met: {condition_met}")

    def _apply_single_behavior(self, hidden_states, layer_id: int):
        """Apply behavior steering vector when conditions are met.

        Modifies hidden states by adding scaled steering vectors to shift model behavior toward desired outputs.

        Args:
            hidden_states: Hidden states to modify [batch, seq_len, hidden_dim].
            layer_id (int): Current layer index.
        """
        layer_args = self._layer_states[layer_id]

        should_apply = not any(self._condition_layers.values()) or self._condition_met[0]

        # print(f"Should Apply Behavior: {should_apply}")

        if should_apply:
            control = layer_args.behavior_vector.to(dtype=hidden_states.dtype)
            if self._forward_calls[layer_id] == 1:
                if layer_args.apply_behavior_on_first_call:
                    hidden_states[0] = layer_args.params.operator(hidden_states[0], control)
                else:
                    print("apply_behavior_on_first_call is False, skipping behavior vector application")
            else:
                hidden_states[0] = layer_args.params.operator(hidden_states[0], control)
                # print(f"{layer_id=}: Applying behavior vector to all tokens")

    def _compute_similarity(self, x: torch.Tensor, y: torch.Tensor) -> float:
        """
        Compute the cosine similarity between two tensors.

        Args:
            x: First tensor.
            y: Second tensor.

        Returns:
            The cosine similarity as a float.
        """
        cossim = torch.dot(x.flatten(), y.flatten()) / (torch.norm(x) * torch.norm(y))
        return float(cossim.item())

    def _apply_ooi_normalization(self, hidden_states, original_norm):
        """Apply out-of-distribution preventive normalization.

        Prevents hidden states from drifting too far from original distribution by rescaling to maintain norm magnitudes after steering.

        Args:
            hidden_states: Modified hidden states to normalize.
            original_norm: Original norm before modifications.

        Returns:
            torch.Tensor: Normalized hidden states.

        Raises:
            ValueError: If NaN or Inf detected in hidden states.
        """
        new_norm = hidden_states.norm(dim=-1, keepdim=True)
        max_ratio = (new_norm / original_norm).max().item()
        has_nan_inf = torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any()

        if has_nan_inf:
            # NaN propagates, decided to raise instead of just applying norm as in original code.
            raise ValueError(f"NaN: {torch.isnan(hidden_states).any()} or Inf: {torch.isinf(hidden_states).any()} dectected in hidden_states")

        if max_ratio > 1 or has_nan_inf:
            print(f"Applying OOI preventive normalization. Max_ratio was {max_ratio}")
            hidden_states = hidden_states * (original_norm / new_norm)
        else:
            print(f"No OOI preventive normalization. Max_ratio was {max_ratio}")

        return hidden_states
args = self.Args.validate(*args, **kwargs) instance-attribute
device = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute
model = None class-attribute instance-attribute
registered = [] instance-attribute
tokenizer = None class-attribute instance-attribute
get_hooks(input_ids, runtime_kwargs, **__)

Create pre-forward hooks for conditional activation steering.

Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior modifications during the forward pass.

Parameters:

Name Type Description Default
input_ids Tensor

Input token IDs (unused but required by interface).

required
runtime_kwargs dict | None

Runtime parameters (currently unused).

required
**__

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, list]

dict[str, list]: Hook specifications with "pre", "forward", "backward" keys. Only "pre" hooks are populated with CAST steering logic.

Source code in aisteer360/algorithms/state_control/cast/control.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def get_hooks(
        self,
        input_ids: torch.Tensor,
        runtime_kwargs: dict | None,
        **__,
) -> dict[str, list]:
    """Create pre-forward hooks for conditional activation steering.

    Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior
    modifications during the forward pass.

    Args:
        input_ids (torch.Tensor): Input token IDs (unused but required by interface).
        runtime_kwargs (dict | None): Runtime parameters (currently unused).
        **__: Additional arguments (unused).

    Returns:
        dict[str, list]: Hook specifications with "pre", "forward", "backward" keys.
            Only "pre" hooks are populated with CAST steering logic.
    """

    hooks: dict[str, list] = {"pre": [], "forward": [], "backward": []}
    for layer_id, layer_name in enumerate(self._layers_names):
        hooks["pre"].append(
            {
                "module": layer_name,  # "model.layers.0"
                "hook_func": partial(
                    self._cast_pre_hook,
                    layer_id=layer_id,
                ),
            }
        )

    return hooks
get_model_layer_list(model)

Extract the list of transformer layers from the model.

Parameters:

Name Type Description Default
model PreTrainedModel

Model to extract layers from.

required

Returns:

List of layers for given model
List of layers module name prefix for given model
Source code in aisteer360/algorithms/state_control/cast/control.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def get_model_layer_list(self, model: PreTrainedModel) -> list:
    """Extract the list of transformer layers from the model.

    Args:
        model (PreTrainedModel): Model to extract layers from.

    Returns:

        List of layers for given model
        List of layers module name prefix for given model
    """
    layers = []
    layers_names = []

    model_layers = None
    model_layers_prefix = ''

    if hasattr(model, "model"):  # mistral-, llama-, gemma-like models
        model_layers = model.model.layers
        model_layers_prefix = "model.layers"
    elif hasattr(model, "transformer"):  # gpt2-like models
        model_layers = model.transformer.h
        model_layers_prefix = "transformer.h"
    else:
        raise ValueError(f"Don't know how to get layer list from model for {type(model)=}")

    for idx, layer in enumerate(model_layers):
        layers.append(layer)
        layers_names.append(f"{model_layers_prefix}.{idx}")

    return layers, layers_names
register_hooks(model)

Attach hooks to model.

Source code in aisteer360/algorithms/state_control/base.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def register_hooks(self, model: PreTrainedModel) -> None:
    """Attach hooks to model."""
    for phase in ("pre", "forward", "backward"):
        for spec in self.hooks[phase]:
            module = model.get_submodule(spec["module"])
            if phase == "pre":
                handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
            elif phase == "forward":
                handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
            else:
                handle = module.register_full_backward_hook(spec["hook_func"])
            self.registered.append(handle)
remove_hooks()

Remove all registered hooks from the model.

Source code in aisteer360/algorithms/state_control/base.py
109
110
111
112
113
def remove_hooks(self) -> None:
    """Remove all registered hooks from the model."""
    for handle in self.registered:
        handle.remove()
    self.registered.clear()
reset()

Reset internal state tracking between generation calls.

Clears condition detection flags, forward call counters, and similarity scores.

Source code in aisteer360/algorithms/state_control/cast/control.py
91
92
93
94
95
96
97
98
def reset(self):
    """Reset internal state tracking between generation calls.

    Clears condition detection flags, forward call counters, and similarity scores.
    """
    self._condition_met = defaultdict(bool)
    self._forward_calls = defaultdict(int)
    self._condition_similarities = defaultdict(lambda: defaultdict(float))
set_hooks(hooks)

Update the hook specifications to be registered.

Source code in aisteer360/algorithms/state_control/base.py
115
116
117
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Update the hook specifications to be registered."""
    self.hooks = hooks
steer(model, tokenizer=None, **__)

Initialization by configuring condition detection and behavior modification layers.

Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation steering. Pre-computes projection matrices and behavior vectors.

Parameters:

Name Type Description Default
model PreTrainedModel

The base language model to be steered.

required
tokenizer PreTrainedTokenizer | None

Tokenizer (currently unused but maintained for API consistency). If None, attempts to retrieve from model attributes.

None
**__

Additional arguments (unused).

{}

Returns:

Name Type Description
PreTrainedModel PreTrainedModel

The input model, unchanged.

Source code in aisteer360/algorithms/state_control/cast/control.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def steer(
    self,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer | None = None,
    **__
) -> PreTrainedModel:
    """Initialization by configuring condition detection and behavior modification layers.

    Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation
    steering. Pre-computes projection matrices and behavior vectors.

    Args:
        model (PreTrainedModel): The base language model to be steered.
        tokenizer (PreTrainedTokenizer | None): Tokenizer (currently unused but maintained
            for API consistency). If None, attempts to retrieve from model attributes.
        **__: Additional arguments (unused).

    Returns:
        PreTrainedModel: The input model, unchanged.
    """
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    self.device = next(model.parameters()).device

    self._setup(self.model)

    return model
pasta
args
control
PASTA

Bases: StateControl

Implementation of PASTA (Post-hoc Attention STeering Approach) from Zhang et al., 2023.

PASTA performs controlled text generation by dynamically modifying attention patterns during inference to amplify or suppress the influence of specific text spans. This allows for fine-grained steering of model behavior without requiring model retraining or parameter updates.

The algorithm works by:

  1. Substring Identification: Locate target substrings within the input prompt using tokenizer offset mapping to determine precise token ranges.

  2. Attention Modification: Inject scaling factors into the attention mask of specified layers and heads to increase or decrease attention weights for the identified token ranges.

  3. Dynamic Steering: Apply different scaling strategies (include, exclude, or generation-focused) to control how the model attends to relevant spans during text generation.

This approach enables real-time control over model focus and can be used for tasks like concept amplification, bias mitigation, or content filtering without architectural changes.

Parameters:

Name Type Description Default
alpha float

Scaling factor for attention modification. Positive values increase attention, negative values decrease attention. Defaults to 1.0.

required
head_config dict | list

Configuration specifying which layers/heads to modify. If dict, maps layer indices to lists of head indices. If list, applies to all heads in specified layers.

required
scale_position str

Strategy for applying attention scaling. Options:

  • "include": Scale attention TO the target substrings
  • "exclude": Scale attention AWAY FROM the target substrings
  • "generation": Scale attention during generation phase

Defaults to "include".

required

Reference: - "PASTA: Tell Your Model Where to Attend: Post-hoc Attention Steering for LLMs" Qingru Zhang, Chandan Singh, Liyuan Liu, Xiaodong Liu, Bin Yu, Jianfeng Gao, Tuo Zhao https://arxiv.org/abs/2311.02262

Source code in aisteer360/algorithms/state_control/pasta/control.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class PASTA(StateControl):
    """
    Implementation of PASTA (Post-hoc Attention STeering Approach) from Zhang et al., 2023.

    PASTA performs controlled text generation by dynamically modifying attention patterns during inference to amplify or
    suppress the influence of specific text spans. This allows for fine-grained steering of model behavior without
    requiring model retraining or parameter updates.

    The algorithm works by:

    1. **Substring Identification**: Locate target substrings within the input prompt using tokenizer offset mapping to
    determine precise token ranges.

    2. **Attention Modification**: Inject scaling factors into the attention mask of specified layers and heads to
    increase or decrease attention weights for the identified token ranges.

    3. **Dynamic Steering**: Apply different scaling strategies (include, exclude, or generation-focused) to control how
    the model attends to relevant spans during text generation.

    This approach enables real-time control over model focus and can be used for tasks like concept amplification, bias
    mitigation, or content filtering without architectural changes.

    Args:
        alpha (float): Scaling factor for attention modification. Positive values increase attention, negative values
            decrease attention. Defaults to 1.0.
        head_config (dict | list): Configuration specifying which layers/heads to modify. If dict, maps layer indices
            to lists of head indices. If list, applies to all heads in specified layers.
        scale_position (str): Strategy for applying attention scaling. Options:

            - "include": Scale attention TO the target substrings
            - "exclude": Scale attention AWAY FROM the target substrings
            - "generation": Scale attention during generation phase

            Defaults to "include".

    Reference:
    - "PASTA: Tell Your Model Where to Attend: Post-hoc Attention Steering for LLMs"
    Qingru Zhang, Chandan Singh, Liyuan Liu, Xiaodong Liu, Bin Yu, Jianfeng Gao, Tuo Zhao
    [https://arxiv.org/abs/2311.02262](https://arxiv.org/abs/2311.02262)
    """

    Args = PASTAArgs

    # placeholders
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    device: torch.device | str | None = None

    _head_map: dict[int, list[int]] | None = None
    _layers: list[int] | None = None
    _scale_constant: torch.Tensor | None = None

    def steer(
        self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer | None = None, **__
    ) -> PreTrainedModel:
        """Initialize PASTA by configuring attention head mappings and model references.

        Sets up the layer and head configurations that will be modified during generation.
        Validates head configurations against model architecture.

        Args:
            model (PreTrainedModel): The base language model to be steered.
            tokenizer (PreTrainedTokenizer | None): Tokenizer for substring identification.
                If None, attempts to retrieve from model attributes.
            **__: Additional arguments (unused).

        Returns:
            PreTrainedModel: The input model (unchanged).
        """
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        self.device = next(model.parameters()).device
        self._setup_head_config(self.head_config)
        return model

    def get_hooks(
        self,
        input_ids: torch.Tensor,
        runtime_kwargs: dict | None,
        **__,
    ) -> dict[str, list]:
        """Create attention modification hooks for specified substrings.

        Identifies token ranges corresponding to target substrings and prepares hooks that will modify attention weights
        during the forward pass.

        Args:
            input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len].
            runtime_kwargs (dict | None): Must contain "substrings" key with target text spans:

                - str: Single substring applied to all batch items
                - list[str]: List of substrings applied to all batch items
                - list[list[str]]: Per-batch substring groups
            **__: Additional arguments (unused).

        Returns:
            dict[str, list]: Hook specifications with "pre", "forward", "backward" keys. Only "pre" hooks are populated for attention modification.

        Raises:
            ValueError: If "substrings" not in runtime_kwargs or batch size mismatch.
        """
        if not runtime_kwargs or "substrings" not in runtime_kwargs:
            raise ValueError("PASTA requires 'substrings' inside runtime_kwargs")

        substrings = runtime_kwargs["substrings"]
        batch_size = input_ids.size(0)

        # normalize substrings to shape (batch, group, str)
        if isinstance(substrings, str):
            substrings = [[substrings]] * batch_size
        elif substrings and isinstance(substrings[0], str):
            substrings = [substrings] * batch_size
        elif len(substrings) != batch_size:
            raise ValueError(
                f"Need {batch_size} substring groups (one per prompt); got {len(substrings)}"
            )

        # decode and get offsets
        prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # Have to encode & decode substrings along with prompts, since we observed prompts getting changed due to
        # tokenization (e.g. spaces removed); and we need to replicate the same effect in the substrings to ensure they
        # actually match
        for idx, substring in enumerate(substrings):
            try:
                substrings[idx] = self.tokenizer.batch_decode(
                    self.tokenizer(substring, return_tensors="pt", padding=True)['input_ids'],
                    skip_special_tokens=True
                )
            except:
                breakpoint()

        if self.tokenizer.padding_side != "left":
            self.tokenizer.padding_side = "left"

        tokenized: BatchEncoding = self.tokenizer(
            prompts,
            return_tensors="pt",
            return_offsets_mapping=True,
            add_special_tokens=False,
            padding=True,
        ).to(self.device)

        offset_mapping = tokenized.pop("offset_mapping")
        input_len = tokenized["input_ids"].size(-1)

        token_ranges = self._token_ranges_from_batch(
            prompts, substrings, offset_mapping
        )

        if self._scale_constant is None:
            self._scale_constant = torch.tensor(
                [self.alpha],
                device=self.device,
                dtype=tokenized.input_ids.dtype,
            ).log()

        hooks: dict[str, list] = {"pre": [], "forward": [], "backward": []}
        for layer in self._layers:
            hooks["pre"].append(
                {
                    "module": f"model.layers.{layer}.self_attn",
                    "hook_func": partial(
                        self._attention_pre_hook,
                        head_idx=self._head_map[layer],
                        token_ranges=token_ranges,
                        input_len=input_len,
                    ),
                }
            )

        return hooks

    def _setup_head_config(self, head_config):
        """Parse and validate attention head configuration.

        Converts various configuration formats into internal layer-head mappings and validates against model architecture.

        Args:
            head_config: Configuration specifying which layers/heads to modify:

                - dict: Maps layer indices to lists of head indices
                - list: Layer indices (applies to all heads in those layers)

        Raises:
            ValueError: If configuration format invalid or heads out of range.
        """
        if isinstance(head_config, dict):
            self._head_map = {int(l): list(h) for l, h in head_config.items()}
            self._layers = sorted(self._head_map.keys())
        elif isinstance(head_config, list):
            self._layers = [int(l) for l in head_config]
            self._head_map = {
                l: list(range(self.model.config.num_attention_heads))
                for l in self._layers
            }
        else:
            raise ValueError(f"Invalid head configuration: {head_config!r}")

        num_heads = self.model.config.num_attention_heads
        for layer, heads in self._head_map.items():
            for head in heads:
                if not 0 <= head < num_heads:
                    raise ValueError(
                        f"Head {head} out of range for layer {layer} (0–{num_heads-1})"
                    )

    @staticmethod
    def _find_token_range(
        string: str,
        substring: str,
        offset_mapping: Sequence[tuple[int, int]],
        occurrence: int = 0,
    ) -> tuple[int, int]:
        """Map a substring to its token index range using offset mapping.

        Locates the character positions of a substring and converts them to token indices using the tokenizer's offset mapping.

        Args:
            string: Full text to search within.
            substring: Target substring to locate.
            offset_mapping: List of (start_char, end_char) tuples for each token.
            occurrence: Which occurrence to find if substring appears multiple times.
                Defaults to 0 (first occurrence).

        Returns:
            tuple[int, int]: Start (inclusive) and end (exclusive) token indices.

        Raises:
            ValueError: If substring cannot be mapped to token range.
        """
        if substring not in string:
            print(f"'{substring}' not found in input {string}")
            return 0, 0

        char_index = -1
        for _ in range(occurrence + 1):
            char_index = string.index(substring, char_index + 1)
        char_start = char_index
        char_end = char_start + len(substring)

        token_start = token_end = None
        for token_idx, (start_char, end_char) in enumerate(offset_mapping):
            if token_start is None and start_char <= char_start < end_char:
                token_start = token_idx
            if token_end is None and start_char < char_end <= end_char:
                token_end = token_idx

        if token_start is None or token_end is None:
            raise ValueError("Could not map substring to token range")

        return token_start, token_end + 1

    def _token_ranges_from_batch(
        self,
        texts: Sequence[str],
        groups: Sequence[Sequence[str]],
        offsets_mapping: Sequence[Sequence[tuple[int, int]]],
        occurrence: int = 0,
    ) -> list[torch.Tensor]:
        """Convert batch of substring groups to token ranges.

        Maps multiple substrings across batch items to their corresponding token index ranges for attention modification.

        Args:
            texts: Decoded text for each batch item.
            groups: Groups of substrings for each batch item.
            offsets_mapping: Token offset mappings for each batch item.
            occurrence: Which occurrence to find for repeated substrings.

        Returns:
            list[torch.Tensor]: Token range tensors for each batch item.
                Each tensor has shape [num_substrings, 2] with [start, end] pairs.
        """
        token_ranges: list[torch.Tensor] = []

        for text, substrings, offsets in zip(texts, groups, offsets_mapping):
            substring_ranges = [
                torch.tensor(
                    self._find_token_range(text, substring, offsets, occurrence)
                )
                for substring in substrings
            ]
            token_ranges.append(torch.stack(substring_ranges))

        return token_ranges

    def _attention_pre_hook(
        self,
        module,
        input_args: tuple,
        input_kwargs: dict,
        head_idx: list[int],
        token_ranges: list[torch.Tensor],
        input_len: int,
    ):
        """Modify attention mask to steer focus toward/away from target tokens.

        Pre-forward hook that adjusts attention weights by adding scaling factors to the attention mask for specified token ranges and attention heads.

        Args:
            module: The attention module being hooked.
            input_args: Positional arguments to the forward pass.
            input_kwargs: Keyword arguments to the forward pass.
            head_idx: List of attention head indices to modify.
            token_ranges: Token index ranges to apply scaling to.
            input_len: Length of input sequence (for generation positioning).

        Returns:
            Tuple of potentially modified (input_args, input_kwargs).

        Raises:
            RuntimeError: If hidden states cannot be located.
            ValueError: If scale_position is invalid.
        """
        hidden_states = (
            input_args[0] if input_args else input_kwargs.get("hidden_states")
        )
        if hidden_states is None:
            raise RuntimeError("PASTA: could not locate hidden states")

        attention_mask = input_kwargs.get("attention_mask")
        if attention_mask is None:  # build it
            batch_size, sequence_len, _ = hidden_states.size()
            num_heads = self.model.config.num_attention_heads
            causal = torch.triu(
                hidden_states.new_full((sequence_len, sequence_len), float("-inf")),
                diagonal=1,
            )
            attention_mask = causal[None, None]  # (1,1,q,k)
            attention_mask = attention_mask.expand(
                batch_size, num_heads, -1, -1
            ).contiguous()
            input_kwargs["attention_mask"] = attention_mask

        attention_mask = attention_mask.to(hidden_states.dtype).contiguous()
        if attention_mask.size(1) == 1:
            attention_mask = attention_mask.expand(
                -1,
                self.model.config.num_attention_heads,
                -1,
                -1,
            ).contiguous()

        batch_size = attention_mask.size(0)
        for batch_index in range(batch_size):
            for start_idx, end_idx in token_ranges[batch_index].tolist():
                if start_idx == end_idx:
                    continue
                if self.scale_position == "include":
                    attention_mask[
                        batch_index, head_idx, :, start_idx:end_idx
                    ] += self._scale_constant
                elif self.scale_position == "exclude":
                    attention_mask[
                        batch_index, head_idx, :, :start_idx
                    ] += self._scale_constant
                    attention_mask[
                        batch_index, head_idx, :, end_idx:input_len
                    ] += self._scale_constant
                elif self.scale_position == "generation":
                    attention_mask[
                        batch_index, head_idx, :, :input_len
                    ] += self._scale_constant

                else:
                    raise ValueError(f"Unknown scale_position '{self.scale_position}'")

        if self.scale_position == "include":
            attention_mask[:, head_idx, :, :input_len] -= self._scale_constant

        input_kwargs["attention_mask"] = attention_mask
        return input_args, input_kwargs
args = self.Args.validate(*args, **kwargs) instance-attribute
device = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute
model = None class-attribute instance-attribute
registered = [] instance-attribute
tokenizer = None class-attribute instance-attribute
get_hooks(input_ids, runtime_kwargs, **__)

Create attention modification hooks for specified substrings.

Identifies token ranges corresponding to target substrings and prepares hooks that will modify attention weights during the forward pass.

Parameters:

Name Type Description Default
input_ids Tensor

Input token IDs of shape [batch_size, seq_len].

required
runtime_kwargs dict | None

Must contain "substrings" key with target text spans:

  • str: Single substring applied to all batch items
  • list[str]: List of substrings applied to all batch items
  • list[list[str]]: Per-batch substring groups
required
**__

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, list]

dict[str, list]: Hook specifications with "pre", "forward", "backward" keys. Only "pre" hooks are populated for attention modification.

Raises:

Type Description
ValueError

If "substrings" not in runtime_kwargs or batch size mismatch.

Source code in aisteer360/algorithms/state_control/pasta/control.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_hooks(
    self,
    input_ids: torch.Tensor,
    runtime_kwargs: dict | None,
    **__,
) -> dict[str, list]:
    """Create attention modification hooks for specified substrings.

    Identifies token ranges corresponding to target substrings and prepares hooks that will modify attention weights
    during the forward pass.

    Args:
        input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len].
        runtime_kwargs (dict | None): Must contain "substrings" key with target text spans:

            - str: Single substring applied to all batch items
            - list[str]: List of substrings applied to all batch items
            - list[list[str]]: Per-batch substring groups
        **__: Additional arguments (unused).

    Returns:
        dict[str, list]: Hook specifications with "pre", "forward", "backward" keys. Only "pre" hooks are populated for attention modification.

    Raises:
        ValueError: If "substrings" not in runtime_kwargs or batch size mismatch.
    """
    if not runtime_kwargs or "substrings" not in runtime_kwargs:
        raise ValueError("PASTA requires 'substrings' inside runtime_kwargs")

    substrings = runtime_kwargs["substrings"]
    batch_size = input_ids.size(0)

    # normalize substrings to shape (batch, group, str)
    if isinstance(substrings, str):
        substrings = [[substrings]] * batch_size
    elif substrings and isinstance(substrings[0], str):
        substrings = [substrings] * batch_size
    elif len(substrings) != batch_size:
        raise ValueError(
            f"Need {batch_size} substring groups (one per prompt); got {len(substrings)}"
        )

    # decode and get offsets
    prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

    # Have to encode & decode substrings along with prompts, since we observed prompts getting changed due to
    # tokenization (e.g. spaces removed); and we need to replicate the same effect in the substrings to ensure they
    # actually match
    for idx, substring in enumerate(substrings):
        try:
            substrings[idx] = self.tokenizer.batch_decode(
                self.tokenizer(substring, return_tensors="pt", padding=True)['input_ids'],
                skip_special_tokens=True
            )
        except:
            breakpoint()

    if self.tokenizer.padding_side != "left":
        self.tokenizer.padding_side = "left"

    tokenized: BatchEncoding = self.tokenizer(
        prompts,
        return_tensors="pt",
        return_offsets_mapping=True,
        add_special_tokens=False,
        padding=True,
    ).to(self.device)

    offset_mapping = tokenized.pop("offset_mapping")
    input_len = tokenized["input_ids"].size(-1)

    token_ranges = self._token_ranges_from_batch(
        prompts, substrings, offset_mapping
    )

    if self._scale_constant is None:
        self._scale_constant = torch.tensor(
            [self.alpha],
            device=self.device,
            dtype=tokenized.input_ids.dtype,
        ).log()

    hooks: dict[str, list] = {"pre": [], "forward": [], "backward": []}
    for layer in self._layers:
        hooks["pre"].append(
            {
                "module": f"model.layers.{layer}.self_attn",
                "hook_func": partial(
                    self._attention_pre_hook,
                    head_idx=self._head_map[layer],
                    token_ranges=token_ranges,
                    input_len=input_len,
                ),
            }
        )

    return hooks
register_hooks(model)

Attach hooks to model.

Source code in aisteer360/algorithms/state_control/base.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def register_hooks(self, model: PreTrainedModel) -> None:
    """Attach hooks to model."""
    for phase in ("pre", "forward", "backward"):
        for spec in self.hooks[phase]:
            module = model.get_submodule(spec["module"])
            if phase == "pre":
                handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
            elif phase == "forward":
                handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
            else:
                handle = module.register_full_backward_hook(spec["hook_func"])
            self.registered.append(handle)
remove_hooks()

Remove all registered hooks from the model.

Source code in aisteer360/algorithms/state_control/base.py
109
110
111
112
113
def remove_hooks(self) -> None:
    """Remove all registered hooks from the model."""
    for handle in self.registered:
        handle.remove()
    self.registered.clear()
reset()

Optional reset call for state control

Source code in aisteer360/algorithms/state_control/base.py
140
141
142
def reset(self):
    """Optional reset call for state control"""
    pass
set_hooks(hooks)

Update the hook specifications to be registered.

Source code in aisteer360/algorithms/state_control/base.py
115
116
117
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Update the hook specifications to be registered."""
    self.hooks = hooks
steer(model, tokenizer=None, **__)

Initialize PASTA by configuring attention head mappings and model references.

Sets up the layer and head configurations that will be modified during generation. Validates head configurations against model architecture.

Parameters:

Name Type Description Default
model PreTrainedModel

The base language model to be steered.

required
tokenizer PreTrainedTokenizer | None

Tokenizer for substring identification. If None, attempts to retrieve from model attributes.

None
**__

Additional arguments (unused).

{}

Returns:

Name Type Description
PreTrainedModel PreTrainedModel

The input model (unchanged).

Source code in aisteer360/algorithms/state_control/pasta/control.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def steer(
    self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer | None = None, **__
) -> PreTrainedModel:
    """Initialize PASTA by configuring attention head mappings and model references.

    Sets up the layer and head configurations that will be modified during generation.
    Validates head configurations against model architecture.

    Args:
        model (PreTrainedModel): The base language model to be steered.
        tokenizer (PreTrainedTokenizer | None): Tokenizer for substring identification.
            If None, attempts to retrieve from model attributes.
        **__: Additional arguments (unused).

    Returns:
        PreTrainedModel: The input model (unchanged).
    """
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    self.device = next(model.parameters()).device
    self._setup_head_config(self.head_config)
    return model

structural_control

base

Structural control base classes.

This module provides the abstract base class for methods that create persistent changes to the model, either through weight updates or architectural changes.

Two base classes are provided:

  • StructuralControl: Base class for all structural control methods.
  • NoStructuralControl: Identity (null) control; used when no structural control is defined in steering pipeline.

Structural controls implement steering through model weight or architecture modifications, transforming base parameters θ to θ', resulting in generations following y ~ p_θ'(x).

Examples of structural controls:

  • Fine-tuning (full or parameter-efficient like LoRA)
  • Model merging (e.g., via MergeKit)
  • Direct Preference Optimization (DPO)
  • Adapter layers and modules
  • Weight interpolation and averaging

See Also:

  • aisteer360.algorithms.structural_control: Implementations of structural control methods
  • aisteer360.core.steering_pipeline: Integration with steering pipeline
NoStructuralControl

Bases: StructuralControl

Identity structural control.

Used as the default when no structural control is needed. Passes the model through unchanged.

Source code in aisteer360/algorithms/structural_control/base.py
72
73
74
75
76
77
78
79
80
81
class NoStructuralControl(StructuralControl):
    """Identity structural control.

    Used as the default when no structural control is needed. Passes the model through unchanged.
    """
    enabled: bool = False

    def steer(self, model: PreTrainedModel, **__) -> PreTrainedModel:
        """Null steer operation; returns model."""
        return model
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = False class-attribute instance-attribute
steer(model, **__)

Null steer operation; returns model.

Source code in aisteer360/algorithms/structural_control/base.py
79
80
81
def steer(self, model: PreTrainedModel, **__) -> PreTrainedModel:
    """Null steer operation; returns model."""
    return model
StructuralControl

Bases: ABC

Abstract base class for structural control steering methods.

Modifies model parameters or architecture persistently, returning a new model instance with transformed weights.

Methods:

Name Description
steer

Training logic (required)

Source code in aisteer360/algorithms/structural_control/base.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class StructuralControl(ABC):
    """Abstract base class for structural control steering methods.

    Modifies model parameters or architecture persistently, returning a new model instance with transformed weights.

    Methods:
        steer(model, tokenizer, **kwargs) -> PreTrainedModel: Training logic (required)
    """

    Args: Type[BaseArgs] | None = None

    enabled: bool = True

    def __init__(self, *args, **kwargs) -> None:
        if self.Args is None:  # null control
            if args or kwargs:
                raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
            return

        self.args: BaseArgs = self.Args.validate(*args, **kwargs)

        # move fields to attributes
        for field in fields(self.args):
            setattr(self, field.name, getattr(self.args, field.name))

    @abstractmethod
    def steer(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer = None,
            **kwargs
    ) -> PreTrainedModel:
        """Required steering/preparation."""
        pass
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
steer(model, tokenizer=None, **kwargs) abstractmethod

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/base.py
61
62
63
64
65
66
67
68
69
@abstractmethod
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer = None,
        **kwargs
) -> PreTrainedModel:
    """Required steering/preparation."""
    pass
wrappers
mergekit
args
control
MergeKit

Bases: StructuralControl

Wrapper for merging models via MergeKit https://github.com/arcee-ai/mergekit.

MergeKit combines multiple language models using various merge strategies like linear interpolation, SLERP, and TIES. This wrapper integrates MergeKit's functionality to enable structural control through model composition.

The process involves loading a merge configuration (from YAML or dict), executing the merge operation, and optionally loading the resulting merged model. Supports caching to avoid redundant operations.

Parameters:

Name Type Description Default
config_path str

Path to YAML merge configuration file. Defaults to None.

required
config_dict dict

Dictionary merge configuration. Defaults to None.

required
out_path str

Output directory for merged model.

required
load_merged bool

Whether to load merged model after merging. Defaults to True.

required
force_remerge bool

Force remerge even if output exists. Defaults to False.

required
allow_cuda bool

Use CUDA acceleration if available. Defaults to True.

required
device_map str | dict

Device mapping for model loading. Defaults to None.

required
trust_remote_code bool

Trust remote code when loading. Defaults to False.

required
dtype str

PyTorch dtype for loading. Defaults to "float16".

required

Reference:

  • "Arcee's MergeKit: A Toolkit for Merging Large Language Models" Charles Goddard, Shamane Siriwardhana, Malikeh Ehghaghi, Luke Meyers, Vladimir Karpukhin, Brian Benedict, Mark McQuade, Jacob Solawetz https://aclanthology.org/2024.emnlp-industry.36
Source code in aisteer360/algorithms/structural_control/wrappers/mergekit/control.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class MergeKit(StructuralControl):
    """
    Wrapper for merging models via MergeKit [https://github.com/arcee-ai/mergekit](https://github.com/arcee-ai/mergekit).

    MergeKit combines multiple language models using various merge strategies like linear interpolation, SLERP, and
    TIES. This wrapper integrates MergeKit's functionality to enable structural control through model composition.

    The process involves loading a merge configuration (from YAML or dict), executing the merge operation, and
    optionally loading the resulting merged model. Supports caching to avoid redundant operations.

    Args:
        config_path (str, optional): Path to YAML merge configuration file. Defaults to None.
        config_dict (dict, optional): Dictionary merge configuration. Defaults to None.
        out_path (str): Output directory for merged model.
        load_merged (bool): Whether to load merged model after merging. Defaults to True.
        force_remerge (bool): Force remerge even if output exists. Defaults to False.
        allow_cuda (bool): Use CUDA acceleration if available. Defaults to True.
        device_map (str | dict, optional): Device mapping for model loading. Defaults to None.
        trust_remote_code (bool): Trust remote code when loading. Defaults to False.
        dtype (str): PyTorch dtype for loading. Defaults to "float16".

    Reference:

    - "Arcee's MergeKit: A Toolkit for Merging Large Language Models"
      Charles Goddard, Shamane Siriwardhana, Malikeh Ehghaghi, Luke Meyers, Vladimir Karpukhin, Brian Benedict,
      Mark McQuade, Jacob Solawetz
      [https://aclanthology.org/2024.emnlp-industry.36](https://aclanthology.org/2024.emnlp-industry.36)
    """

    Args = MergeKitArgs

    def steer(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer = None,
            **_
    ):
        """Execute model merging via MergeKit and optionally return the merged model.

        Performs structural steering by merging multiple models according to a configuration file or dictionary.
        Supports caching to avoid redundant merge operations and can either return the merged model or the original
        model based on configuration.

        The method follows this logic:

        1. Load merge configuration from YAML file or dictionary
        2. Check if merged model already exists (skip if `force_remerge=False`)
        3. Execute merge if needed using MergeKit
        4. Optionally load and return the merged model

        Args:
            model (PreTrainedModel): The base model (potentially unused depending on the method).
            tokenizer (PreTrainedTokenizer, optional): Base tokenizer (currently unused).
            **_: Additional arguments (ignored).

        Returns:
            PreTrainedModel: Either the merged model (if `load_merged=True`) or the original model. When returning
            merged model, attempts to attach a new tokenizer if one was created during merging.

        Note:

        - If out_path exists and `force_remerge=False`, skips merging and loads cached result
        - Merged model saved to `out_path` directory with full weights and config
        - If `load_merged=False`, performs merge but returns original model
        """
        args: MergeKitArgs = self.args

        if args.config_path:
            config = mk_config.MergeConfiguration.from_yaml(args.config_path)
        else:
            config = mk_config.MergeConfiguration(**args.config_dict)

        # find merged weights
        out_path = Path(args.out_path)
        if out_path.exists() and not args.force_remerge:
            if args.load_merged:
                merged = AutoModelForCausalLM.from_pretrained(
                    pretrained_model_name_or_path=str(out_path),
                    device_map=args.device_map,
                    trust_remote_code=args.trust_remote_code,
                    torch_dtype=getattr(torch, args.dtype)
                )
                return merged
            return model

        # merge
        # with FileLock(str(out_path) + ".lock"):
        mk_merge.run_merge(
            merge_config=config,
            out_path=str(out_path),
            options=mk_merge.MergeOptions(
                use_cuda=args.allow_cuda,
                trust_remote_code=args.trust_remote_code,
            )
        )

        # load merged checkpoint (and check if merge returned new tokenizer)
        if args.load_merged:
            merged = AutoModelForCausalLM.from_pretrained(
                out_path,
                torch_dtype=getattr(torch, args.dtype),
                device_map=args.device_map,
                trust_remote_code=args.trust_remote_code,
            )
            try:
                merged.tokenizer = AutoTokenizer.from_pretrained(
                    out_path,
                    trust_remote_code=args.trust_remote_code
                )
            except Exception:
                pass
            return merged

        return model
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
steer(model, tokenizer=None, **_)

Execute model merging via MergeKit and optionally return the merged model.

Performs structural steering by merging multiple models according to a configuration file or dictionary. Supports caching to avoid redundant merge operations and can either return the merged model or the original model based on configuration.

The method follows this logic:

  1. Load merge configuration from YAML file or dictionary
  2. Check if merged model already exists (skip if force_remerge=False)
  3. Execute merge if needed using MergeKit
  4. Optionally load and return the merged model

Parameters:

Name Type Description Default
model PreTrainedModel

The base model (potentially unused depending on the method).

required
tokenizer PreTrainedTokenizer

Base tokenizer (currently unused).

None
**_

Additional arguments (ignored).

{}

Returns:

Name Type Description
PreTrainedModel

Either the merged model (if load_merged=True) or the original model. When returning

merged model, attempts to attach a new tokenizer if one was created during merging.

Note:

  • If out_path exists and force_remerge=False, skips merging and loads cached result
  • Merged model saved to out_path directory with full weights and config
  • If load_merged=False, performs merge but returns original model
Source code in aisteer360/algorithms/structural_control/wrappers/mergekit/control.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer = None,
        **_
):
    """Execute model merging via MergeKit and optionally return the merged model.

    Performs structural steering by merging multiple models according to a configuration file or dictionary.
    Supports caching to avoid redundant merge operations and can either return the merged model or the original
    model based on configuration.

    The method follows this logic:

    1. Load merge configuration from YAML file or dictionary
    2. Check if merged model already exists (skip if `force_remerge=False`)
    3. Execute merge if needed using MergeKit
    4. Optionally load and return the merged model

    Args:
        model (PreTrainedModel): The base model (potentially unused depending on the method).
        tokenizer (PreTrainedTokenizer, optional): Base tokenizer (currently unused).
        **_: Additional arguments (ignored).

    Returns:
        PreTrainedModel: Either the merged model (if `load_merged=True`) or the original model. When returning
        merged model, attempts to attach a new tokenizer if one was created during merging.

    Note:

    - If out_path exists and `force_remerge=False`, skips merging and loads cached result
    - Merged model saved to `out_path` directory with full weights and config
    - If `load_merged=False`, performs merge but returns original model
    """
    args: MergeKitArgs = self.args

    if args.config_path:
        config = mk_config.MergeConfiguration.from_yaml(args.config_path)
    else:
        config = mk_config.MergeConfiguration(**args.config_dict)

    # find merged weights
    out_path = Path(args.out_path)
    if out_path.exists() and not args.force_remerge:
        if args.load_merged:
            merged = AutoModelForCausalLM.from_pretrained(
                pretrained_model_name_or_path=str(out_path),
                device_map=args.device_map,
                trust_remote_code=args.trust_remote_code,
                torch_dtype=getattr(torch, args.dtype)
            )
            return merged
        return model

    # merge
    # with FileLock(str(out_path) + ".lock"):
    mk_merge.run_merge(
        merge_config=config,
        out_path=str(out_path),
        options=mk_merge.MergeOptions(
            use_cuda=args.allow_cuda,
            trust_remote_code=args.trust_remote_code,
        )
    )

    # load merged checkpoint (and check if merge returned new tokenizer)
    if args.load_merged:
        merged = AutoModelForCausalLM.from_pretrained(
            out_path,
            torch_dtype=getattr(torch, args.dtype),
            device_map=args.device_map,
            trust_remote_code=args.trust_remote_code,
        )
        try:
            merged.tokenizer = AutoTokenizer.from_pretrained(
                out_path,
                trust_remote_code=args.trust_remote_code
            )
        except Exception:
            pass
        return merged

    return model
trl

The TRL wrapper implements a variety of methods from Hugging Face's TRL library.

The current functionality spans the following methods:

  • SFT (Supervised Fine-Tuning): Standard supervised learning to fine-tune language models on demonstration data
  • DPO (Direct Preference Optimization): Trains models directly on preference data without requiring a separate reward model
  • APO (Anchored Preference Optimization): A variant of DPO that uses an anchor model to improve training stability and performance
  • SPPO (Self-Play Preference Optimization): Iterative preference optimization using self-generated synthetic data to reduce dependency on external preference datasets

For documentation information, please refer to the TRL page and the SPPO repository.

apotrainer
args
control
APO

Bases: DPOTrainerMixin

Source code in aisteer360/algorithms/structural_control/wrappers/trl/apotrainer/control.py
 9
10
11
12
13
class APO(DPOTrainerMixin):
    """

    """
    Args = APOArgs
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
ref_model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, ref_model=None, **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/dpotrainer/base_mixin.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def steer(self, model: PreTrainedModel, tokenizer=None, ref_model=None, **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    self.ref_model = ref_model
    if self.train_dataset is not None:
        training_args = DPOConfig(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )
            self.ref_model = None

        trainer = DPOTrainer(
            model=self.model,
            ref_model=self.ref_model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            processing_class=self.tokenizer,
            peft_config=peft_config
        )
        trainer.train()

        self.model = trainer.model

        if training_args.output_dir:
            trainer.save_model(training_args.output_dir)

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
args
base_mixin
TRLMixin

Bases: StructuralControl

Source code in aisteer360/algorithms/structural_control/wrappers/trl/base_mixin.py
4
5
6
7
8
9
class TRLMixin(StructuralControl):
    """

    """

    pass
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
steer(model, tokenizer=None, **kwargs) abstractmethod

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/base.py
61
62
63
64
65
66
67
68
69
@abstractmethod
def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer = None,
        **kwargs
) -> PreTrainedModel:
    """Required steering/preparation."""
    pass
dpotrainer
args
base_mixin
DPOTrainerMixin

Bases: StructuralControl

Source code in aisteer360/algorithms/structural_control/wrappers/trl/dpotrainer/base_mixin.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class DPOTrainerMixin(StructuralControl):
    """

    """
    Config: Type  # set by subclass
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    ref_model: PreTrainedModel | None = None

    train_dataset = None
    eval_dataset = None
    training_args: dict = None
    use_peft: bool = False
    peft_type = None
    lora_kwargs: dict = None

    def steer(self, model: PreTrainedModel, tokenizer=None, ref_model=None, **_) -> PreTrainedModel:
        self.model = model
        self.tokenizer = tokenizer
        self.ref_model = ref_model
        if self.train_dataset is not None:
            training_args = DPOConfig(
                **{
                    **self.training_args
                }
            )

            peft_config = None
            if self.use_peft and self.peft_type == PeftType.LORA:
                peft_config = LoraConfig(
                    **{
                        **self.lora_kwargs
                    }
                )
                self.ref_model = None

            trainer = DPOTrainer(
                model=self.model,
                ref_model=self.ref_model,
                args=training_args,
                train_dataset=self.train_dataset,
                eval_dataset=self.eval_dataset,
                processing_class=self.tokenizer,
                peft_config=peft_config
            )
            trainer.train()

            self.model = trainer.model

            if training_args.output_dir:
                trainer.save_model(training_args.output_dir)

        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
        return self.model
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
ref_model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, ref_model=None, **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/dpotrainer/base_mixin.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def steer(self, model: PreTrainedModel, tokenizer=None, ref_model=None, **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    self.ref_model = ref_model
    if self.train_dataset is not None:
        training_args = DPOConfig(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )
            self.ref_model = None

        trainer = DPOTrainer(
            model=self.model,
            ref_model=self.ref_model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            processing_class=self.tokenizer,
            peft_config=peft_config
        )
        trainer.train()

        self.model = trainer.model

        if training_args.output_dir:
            trainer.save_model(training_args.output_dir)

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
control
DPO

Bases: DPOTrainerMixin

Source code in aisteer360/algorithms/structural_control/wrappers/trl/dpotrainer/control.py
 9
10
11
12
13
class DPO(DPOTrainerMixin):
    """

    """
    Args = DPOArgs
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
ref_model = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, ref_model=None, **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/dpotrainer/base_mixin.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def steer(self, model: PreTrainedModel, tokenizer=None, ref_model=None, **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    self.ref_model = ref_model
    if self.train_dataset is not None:
        training_args = DPOConfig(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )
            self.ref_model = None

        trainer = DPOTrainer(
            model=self.model,
            ref_model=self.ref_model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            processing_class=self.tokenizer,
            peft_config=peft_config
        )
        trainer.train()

        self.model = trainer.model

        if training_args.output_dir:
            trainer.save_model(training_args.output_dir)

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
sfttrainer
args
base_mixin
SFTTrainerMixin

Bases: StructuralControl

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sfttrainer/base_mixin.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class SFTTrainerMixin(StructuralControl):
    """

    """
    Config: Type  # set by subclass
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    data_collator:  Any | None = None

    train_dataset = None
    eval_dataset = None
    training_args: dict = None
    use_peft: bool = False
    peft_type = None
    lora_kwargs: dict = None

    def steer(self, model: PreTrainedModel, tokenizer=None, **_) -> PreTrainedModel:
        self.model = model
        self.tokenizer = tokenizer
        if self.train_dataset is not None:
            print("data_collator is ", self.data_collator)
            training_args = SFTConfig(
                **{
                    **self.training_args
                }
            )

            peft_config = None
            if self.use_peft and self.peft_type == PeftType.LORA:
                peft_config = LoraConfig(
                    **{
                        **self.lora_kwargs
                    }
                )

            trainer = SFTTrainer(
                model=self.model,
                args=training_args,
                train_dataset=self.train_dataset,
                eval_dataset=self.eval_dataset,
                data_collator=self.data_collator,
                processing_class=self.tokenizer,
                peft_config=peft_config
            )
            trainer.train()

            self.model = trainer.model

            if training_args.output_dir:
                trainer.save_model(training_args.output_dir)

        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
        return self.model
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
data_collator = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sfttrainer/base_mixin.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def steer(self, model: PreTrainedModel, tokenizer=None, **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    if self.train_dataset is not None:
        print("data_collator is ", self.data_collator)
        training_args = SFTConfig(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )

        trainer = SFTTrainer(
            model=self.model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            processing_class=self.tokenizer,
            peft_config=peft_config
        )
        trainer.train()

        self.model = trainer.model

        if training_args.output_dir:
            trainer.save_model(training_args.output_dir)

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
control
SFT

Bases: SFTTrainerMixin

Structural control that applies a LoRA adapter.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sfttrainer/control.py
 9
10
11
12
13
class SFT(SFTTrainerMixin):
    """
    Structural control that applies a LoRA adapter.
    """
    Args = SFTArgs
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
data_collator = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sfttrainer/base_mixin.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def steer(self, model: PreTrainedModel, tokenizer=None, **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    if self.train_dataset is not None:
        print("data_collator is ", self.data_collator)
        training_args = SFTConfig(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )

        trainer = SFTTrainer(
            model=self.model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=self.data_collator,
            processing_class=self.tokenizer,
            peft_config=peft_config
        )
        trainer.train()

        self.model = trainer.model

        if training_args.output_dir:
            trainer.save_model(training_args.output_dir)

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
sppotrainer
args
base_mixin
SPPOTrainerMixin

Bases: StructuralControl

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/base_mixin.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class SPPOTrainerMixin(StructuralControl):
    """

    """
    Config: Type  # set by subclass
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    refModel: PreTrainedModel | None = None

    train_dataset = None
    eval_dataset = None
    training_args: dict = None
    use_peft: bool = False
    peft_type = None
    lora_kwargs: dict = None

    def steer(self, model: PreTrainedModel, tokenizer=None, refModel=None,
              maxlen=2048, num_prompts=5, start_iter_num=1, end_iter_num=1, additional_train_datasets=None,
              sppo_temp_dir="sppo_temp_dir", **_) -> PreTrainedModel:
        self.model = model
        self.tokenizer = tokenizer
        self.refModel = refModel
        if self.train_dataset is not None:
            training_args = DPOConfig(      #TrainingArguments(
                **{
                    **self.training_args
                }
            )

            peft_config = None
            if self.use_peft and self.peft_type == PeftType.LORA:
                peft_config = LoraConfig(
                    **{
                        **self.lora_kwargs
                    }
                )
                self.refModel = None

            checkpoints_path = ""
            steerer = None
            for i in range(start_iter_num, end_iter_num+1):

                checkpoints_path=f"{sppo_temp_dir}/checkpoints/SPPO-Iter{i}"  # steered model stored at each iteration

                if i == start_iter_num or additional_train_datasets is None:
                    dataset = self.train_dataset
                else:
                    dataset = additional_train_datasets[i-start_iter_num-1]
                processed_train = prepare_dataset_from_prompts(
                    self.model,
                    self.tokenizer,
                    dataset,
                    sppo_temp_dir=sppo_temp_dir,
                    iter_num=i,
                    maxlen = maxlen,
                    num_prompts=num_prompts
                )
                trainer = SPPOTrainer(
                    model=self.model,
                    ref_model=self.refModel,
                    args=training_args,
                    train_dataset=processed_train,
                    eval_dataset=self.eval_dataset,
                    processing_class=self.tokenizer,
                    peft_config = peft_config,
                    beta=training_args.beta,
                    max_length=training_args.max_length,
                    max_prompt_length=training_args.max_prompt_length,
                    loss_type=training_args.loss_type,
                )
                trainer.train()
                self.model = trainer.model

                trainer.save_model(checkpoints_path)
                if training_args.output_dir:
                    if i == end_iter_num:
                        trainer.save_model(training_args.output_dir)  ### this also needs to be changed

        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
        return self.model
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
refModel = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, refModel=None, maxlen=2048, num_prompts=5, start_iter_num=1, end_iter_num=1, additional_train_datasets=None, sppo_temp_dir='sppo_temp_dir', **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/base_mixin.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def steer(self, model: PreTrainedModel, tokenizer=None, refModel=None,
          maxlen=2048, num_prompts=5, start_iter_num=1, end_iter_num=1, additional_train_datasets=None,
          sppo_temp_dir="sppo_temp_dir", **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    self.refModel = refModel
    if self.train_dataset is not None:
        training_args = DPOConfig(      #TrainingArguments(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )
            self.refModel = None

        checkpoints_path = ""
        steerer = None
        for i in range(start_iter_num, end_iter_num+1):

            checkpoints_path=f"{sppo_temp_dir}/checkpoints/SPPO-Iter{i}"  # steered model stored at each iteration

            if i == start_iter_num or additional_train_datasets is None:
                dataset = self.train_dataset
            else:
                dataset = additional_train_datasets[i-start_iter_num-1]
            processed_train = prepare_dataset_from_prompts(
                self.model,
                self.tokenizer,
                dataset,
                sppo_temp_dir=sppo_temp_dir,
                iter_num=i,
                maxlen = maxlen,
                num_prompts=num_prompts
            )
            trainer = SPPOTrainer(
                model=self.model,
                ref_model=self.refModel,
                args=training_args,
                train_dataset=processed_train,
                eval_dataset=self.eval_dataset,
                processing_class=self.tokenizer,
                peft_config = peft_config,
                beta=training_args.beta,
                max_length=training_args.max_length,
                max_prompt_length=training_args.max_prompt_length,
                loss_type=training_args.loss_type,
            )
            trainer.train()
            self.model = trainer.model

            trainer.save_model(checkpoints_path)
            if training_args.output_dir:
                if i == end_iter_num:
                    trainer.save_model(training_args.output_dir)  ### this also needs to be changed

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
control
SPPO

Bases: SPPOTrainerMixin

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/control.py
 9
10
11
12
13
class SPPO(SPPOTrainerMixin):
    """

    """
    Args = SPPOArgs
Config instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
enabled = True class-attribute instance-attribute
eval_dataset = None class-attribute instance-attribute
lora_kwargs = None class-attribute instance-attribute
model = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
refModel = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
train_dataset = None class-attribute instance-attribute
training_args = None class-attribute instance-attribute
use_peft = False class-attribute instance-attribute
steer(model, tokenizer=None, refModel=None, maxlen=2048, num_prompts=5, start_iter_num=1, end_iter_num=1, additional_train_datasets=None, sppo_temp_dir='sppo_temp_dir', **_)

Required steering/preparation.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/base_mixin.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def steer(self, model: PreTrainedModel, tokenizer=None, refModel=None,
          maxlen=2048, num_prompts=5, start_iter_num=1, end_iter_num=1, additional_train_datasets=None,
          sppo_temp_dir="sppo_temp_dir", **_) -> PreTrainedModel:
    self.model = model
    self.tokenizer = tokenizer
    self.refModel = refModel
    if self.train_dataset is not None:
        training_args = DPOConfig(      #TrainingArguments(
            **{
                **self.training_args
            }
        )

        peft_config = None
        if self.use_peft and self.peft_type == PeftType.LORA:
            peft_config = LoraConfig(
                **{
                    **self.lora_kwargs
                }
            )
            self.refModel = None

        checkpoints_path = ""
        steerer = None
        for i in range(start_iter_num, end_iter_num+1):

            checkpoints_path=f"{sppo_temp_dir}/checkpoints/SPPO-Iter{i}"  # steered model stored at each iteration

            if i == start_iter_num or additional_train_datasets is None:
                dataset = self.train_dataset
            else:
                dataset = additional_train_datasets[i-start_iter_num-1]
            processed_train = prepare_dataset_from_prompts(
                self.model,
                self.tokenizer,
                dataset,
                sppo_temp_dir=sppo_temp_dir,
                iter_num=i,
                maxlen = maxlen,
                num_prompts=num_prompts
            )
            trainer = SPPOTrainer(
                model=self.model,
                ref_model=self.refModel,
                args=training_args,
                train_dataset=processed_train,
                eval_dataset=self.eval_dataset,
                processing_class=self.tokenizer,
                peft_config = peft_config,
                beta=training_args.beta,
                max_length=training_args.max_length,
                max_prompt_length=training_args.max_prompt_length,
                loss_type=training_args.loss_type,
            )
            trainer.train()
            self.model = trainer.model

            trainer.save_model(checkpoints_path)
            if training_args.output_dir:
                if i == end_iter_num:
                    trainer.save_model(training_args.output_dir)  ### this also needs to be changed

    self.model.eval()
    for p in self.model.parameters():
        p.requires_grad_(False)
    return self.model
trainer
SPPOTrainer

Bases: Trainer

Initialize SPPOTrainer.

Parameters:

Name Type Description Default
model `transformers.PreTrainedModel`

The model to train, preferably an AutoModelForSequenceClassification.

None
ref_model `PreTrainedModelWrapper`

Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.

None
beta `float`, defaults to 0.1

The beta factor in DPO loss. In SPPO, eta=1/beta. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.

0.1
label_smoothing `float`, defaults to 0

The robust DPO label smoothing parameter from the cDPO report that should be between 0 and 0.5.

0
loss_type `str`, defaults to `"sigmoid"`

The type of loss to use. 'sppo' reproduces the SPPO algorithms. Other choices are explained as follows: "sigmoid" represents the default DPO loss,"hinge" loss from SLiC paper, "ipo" from IPO paper, or "kto" from the HALOs report.

'sigmoid'
args `transformers.TrainingArguments`

The arguments to use for training.

None
data_collator `transformers.DataCollator`

The data collator to use for training. If None is specified, the default data collator (DPODataCollatorWithPadding) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.

None
label_pad_token_id `int`, defaults to `-100`

The label pad token id. This argument is required if you want to use the default data collator.

-100
padding_value `int`, defaults to `0`

The padding value if it is different to the tokenizer's pad_token_id.

0
truncation_mode `str`, defaults to `keep_end`

The truncation mode to use, either keep_end or keep_start. This argument is required if you want to use the default data collator.

'keep_end'
train_dataset `datasets.Dataset`

The dataset to use for training.

None
eval_dataset `datasets.Dataset`

The dataset to use for evaluation.

None
processing_class `transformers.PreTrainedTokenizerBase`

The tokenizer to use for training. This argument is required if you want to use the default data collator.

None
model_init `Callable[[], transformers.PreTrainedModel]`

The model initializer to use for training. If None is specified, the default model initializer will be used.

None
callbacks `List[transformers.TrainerCallback]`

The callbacks to use for training.

None
optimizers `Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`

The optimizer and scheduler to use for training.

(None, None)
preprocess_logits_for_metrics `Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`

The function to use to preprocess the logits before computing the metrics.

None
max_length `int`, defaults to `None`

The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.

None
max_prompt_length `int`, defaults to `None`

The maximum length of the prompt. This argument is required if you want to use the default data collator.

None
max_target_length `int`, defaults to `None`

The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.

None
peft_config `Dict`, defaults to `None`

The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.

None
is_encoder_decoder `Optional[bool]`, `optional`, defaults to `None`

If no model is provided, we need to know if the model_init returns an encoder-decoder.

None
disable_dropout `bool`, defaults to `True`

Whether or not to disable dropouts in model and ref_model.

True
generate_during_eval `bool`, defaults to `False`

Whether to sample and log generations during evaluation step.

False
compute_metrics `Callable[[EvalPrediction], Dict]`, *optional*

The function to use to compute the metrics. Must take a EvalPrediction and return a dictionary string to metric values.

None
precompute_ref_log_probs `bool`, defaults to `False`

Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train without the reference model and reduce the total GPU memory needed.

False
model_init_kwargs Optional[Dict]

(Optional[Dict], optional): Dict of Optional kwargs to pass when instantiating the model from a string

None
ref_model_init_kwargs Optional[Dict]

(Optional[Dict], optional): Dict of Optional kwargs to pass when instantiating the ref model from a string

None
model_adapter_name `str`, defaults to `None`

Name of the train target PEFT adapter, when using LoRA with multiple adapters.

None
ref_adapter_name `str`, defaults to `None`

Name of the reference PEFT adapter, when using LoRA with multiple adapters.

None
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
class SPPOTrainer(Trainer):
    r"""
    Initialize SPPOTrainer.

    Args:
        model (`transformers.PreTrainedModel`):
            The model to train, preferably an `AutoModelForSequenceClassification`.
        ref_model (`PreTrainedModelWrapper`):
            Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
            reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
        beta (`float`, defaults to 0.1):
            The beta factor in DPO loss. In SPPO, eta=1/beta. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
        label_smoothing (`float`, defaults to 0):
            The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
        loss_type (`str`, defaults to `"sigmoid"`):
            The type of loss to use. 'sppo' reproduces the SPPO algorithms. Other choices are explained as follows: `"sigmoid"` represents the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
        args (`transformers.TrainingArguments`):
            The arguments to use for training.
        data_collator (`transformers.DataCollator`):
            The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
            which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
        label_pad_token_id (`int`, defaults to `-100`):
            The label pad token id. This argument is required if you want to use the default data collator.
        padding_value (`int`, defaults to `0`):
            The padding value if it is different to the tokenizer's pad_token_id.
        truncation_mode (`str`, defaults to `keep_end`):
            The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
        train_dataset (`datasets.Dataset`):
            The dataset to use for training.
        eval_dataset (`datasets.Dataset`):
            The dataset to use for evaluation.
        processing_class (`transformers.PreTrainedTokenizerBase`):
            The tokenizer to use for training. This argument is required if you want to use the default data collator.
        model_init (`Callable[[], transformers.PreTrainedModel]`):
            The model initializer to use for training. If None is specified, the default model initializer will be used.
        callbacks (`List[transformers.TrainerCallback]`):
            The callbacks to use for training.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
            The optimizer and scheduler to use for training.
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
            The function to use to preprocess the logits before computing the metrics.
        max_length (`int`, defaults to `None`):
            The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
        max_prompt_length (`int`, defaults to `None`):
            The maximum length of the prompt. This argument is required if you want to use the default data collator.
        max_target_length (`int`, defaults to `None`):
            The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
        peft_config (`Dict`, defaults to `None`):
            The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
        is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
            If no model is provided, we need to know if the model_init returns an encoder-decoder.
        disable_dropout (`bool`, defaults to `True`):
            Whether or not to disable dropouts in `model` and `ref_model`.
        generate_during_eval (`bool`, defaults to `False`):
            Whether to sample and log generations during evaluation step.
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
            The function to use to compute the metrics. Must take a `EvalPrediction` and return
            a dictionary string to metric values.
        precompute_ref_log_probs (`bool`, defaults to `False`):
            Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train
            without the reference model and reduce the total GPU memory needed.
        model_init_kwargs: (`Optional[Dict]`, *optional*):
            Dict of Optional kwargs to pass when instantiating the model from a string
        ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
            Dict of Optional kwargs to pass when instantiating the ref model from a string
        model_adapter_name (`str`, defaults to `None`):
            Name of the train target PEFT adapter, when using LoRA with multiple adapters.
        ref_adapter_name (`str`, defaults to `None`):
            Name of the reference PEFT adapter, when using LoRA with multiple adapters.
    """

    _tag_names = ["trl", "sppo"]

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module, str] = None,
        ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        truncation_mode: str = "keep_end",
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        max_length: Optional[int] = None,
        max_prompt_length: Optional[int] = None,
        max_target_length: Optional[int] = None,
        peft_config: Optional[Dict] = None,
        is_encoder_decoder: Optional[bool] = None,
        disable_dropout: bool = True,
        generate_during_eval: bool = False,
        compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
        precompute_ref_log_probs: bool = False,
        model_init_kwargs: Optional[Dict] = None,
        ref_model_init_kwargs: Optional[Dict] = None,
        model_adapter_name: str = None,
        ref_adapter_name: str = None,
    ):
        if model_init_kwargs is None:
            model_init_kwargs = {}
        elif not isinstance(model, str):
            raise ValueError("You passed model_kwargs to the SPPOTrainer. But your model is already instantiated.")

        if ref_model_init_kwargs is None:
            ref_model_init_kwargs = {}
        elif not isinstance(ref_model, str):
            raise ValueError(
                "You passed ref_model_kwargs to the SPPOTrainer. But your ref_model is already instantiated."
            )

        if isinstance(model, str):
            warnings.warn(
                "You passed a model_id to the SPPOTrainer. This will automatically create an "
                "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
            )
            model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

        if isinstance(ref_model, str):
            warnings.warn(
                "You passed a ref model_id to the SPPOTrainer. This will automatically create an "
                "`AutoModelForCausalLM`"
            )
            ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)

        # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
        # has been called in order to properly call autocast if needed.
        self._peft_has_been_casted_to_bf16 = False

        if not is_peft_available() and peft_config is not None:
            raise ValueError(
                "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
            )
        elif is_peft_available() and peft_config is not None:
            raise NotImplementedError
            # # if model is a peft model and we have a peft_config, we merge and unload it first
            # if isinstance(model, PeftModel):
            #     model = model.merge_and_unload()

            # if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
            #     _support_gc_kwargs = hasattr(
            #         args, "gradient_checkpointing_kwargs"
            #     ) and "gradient_checkpointing_kwargs" in list(
            #         inspect.signature(prepare_model_for_kbit_training).parameters
            #     )

            #     preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

            #     if _support_gc_kwargs:
            #         preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

            #     model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
            # elif getattr(args, "gradient_checkpointing", False):
            #     # For backward compatibility with older versions of transformers
            #     if hasattr(model, "enable_input_require_grads"):
            #         model.enable_input_require_grads()
            #     else:

            #         def make_inputs_require_grad(module, input, output):
            #             output.requires_grad_(True)

            #         model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

            # # get peft model with the given config
            # model = get_peft_model(model, peft_config)
            # if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
            #     peft_module_casting_to_bf16(model)
            #     # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
            #     self._peft_has_been_casted_to_bf16 = True

        # For models that use gradient_checkpoiting, we need to attach a hook that enables input
        # to explicitly have `requires_grad=True`, otherwise training will either silently
        # fail or completely fail.
        elif getattr(args, "gradient_checkpointing", False):
            # For backward compatibility with older versions of transformers
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        if generate_during_eval and not is_wandb_available():
            raise ValueError(
                "`generate_during_eval=True` requires Weights and Biases to be installed."
                " Please install `wandb` to resolve."
            )

        if model is not None:
            self.is_encoder_decoder = model.config.is_encoder_decoder
        elif is_encoder_decoder is None:
            raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
        else:
            self.is_encoder_decoder = is_encoder_decoder

        self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
        self.model_adapter_name = model_adapter_name
        self.ref_adapter_name = ref_adapter_name

        if ref_model:
            self.ref_model = ref_model
        elif self.is_peft_model or precompute_ref_log_probs:
            # The `model` with adapters turned off will be used as the reference model
            self.ref_model = None
        else:
            self.ref_model = create_reference_model(model)

        if processing_class is None:
            raise ValueError("processing_class must be specified to tokenize a SPPO dataset.")
        if max_length is None:
            warnings.warn(
                "`max_length` is not set in the SPPOTrainer's init"
                " it will default to `512` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_length = 512
        if max_prompt_length is None:
            warnings.warn(
                "`max_prompt_length` is not set in the SPPOTrainer's init"
                " it will default to `128` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_prompt_length = 128

        if max_target_length is None and self.is_encoder_decoder:
            warnings.warn(
                "When using an encoder decoder architecture, you should set `max_target_length` in the SPPOTrainer's init"
                " it will default to `128` by default, but you should do it yourself in the future.",
                UserWarning,
            )
            max_target_length = 128

        if data_collator is None:
            data_collator = DPODataCollatorWithPadding(
                pad_token_id=processing_class.pad_token_id,
                label_pad_token_id=label_pad_token_id,
                is_encoder_decoder=self.is_encoder_decoder,
            )

            if args.remove_unused_columns:
                args.remove_unused_columns = False
                # warn users
                warnings.warn(
                    "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
                    " we have set it for you, but you should do it yourself in the future.",
                    UserWarning,
                )

            self.use_dpo_data_collator = True
        else:
            self.use_dpo_data_collator = False

        if disable_dropout:
            disable_dropout_in_model(model)
            if self.ref_model is not None:
                disable_dropout_in_model(self.ref_model)

        self.max_length = max_length
        self.generate_during_eval = generate_during_eval
        self.label_pad_token_id = label_pad_token_id
        self.padding_value = padding_value if padding_value is not None else processing_class.pad_token_id
        self.max_prompt_length = max_prompt_length
        self.truncation_mode = truncation_mode
        self.max_target_length = max_target_length
        self.processing_class = processing_class
        self.precompute_ref_log_probs = precompute_ref_log_probs

        # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
        # keep track of first called to avoid computation of future calls
        self._precomputed_train_ref_log_probs = False
        self._precomputed_eval_ref_log_probs = False

        if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
            warnings.warn(
                "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
            )

        self.beta = beta
        self.label_smoothing = label_smoothing
        self.loss_type = loss_type

        self._stored_metrics = defaultdict(lambda: defaultdict(list))

        # tokenize the dataset
        # print('=== before map', train_dataset.features)
        # chosen_probs = train_dataset['chosen_probs']
        # chosen_probs_win = train_dataset['chosen_probs_win']
        # chosen_probs_lose = train_dataset['chosen_probs_lose']
        # old_train_dataset = train_dataset
        train_dataset = train_dataset.map(self.tokenize_row)
        # print('=== before add', train_dataset.features)
        # import pandas as pd
        # mid_dataset = pd.DataFrame(train_dataset)
        # mid_dataset['chosen_probs'] = chosen_probs
        # mid_dataset['chosen_probs_win'] = chosen_probs_win
        # mid_dataset['chosen_probs_lose'] = chosen_probs_lose
        # train_dataset = Dataset.from_pandas(mid_dataset)
        # print('=== after add', train_dataset.features)
        if eval_dataset is not None:
            eval_dataset = eval_dataset.map(self.tokenize_row)
        #print('=========')
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        if not hasattr(self, "accelerator"):
            raise AttributeError(
                "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
            )

        # Deepspeed Zero-3 does not support precompute_ref_log_probs
        if self.is_deepspeed_enabled:
            if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
                raise ValueError(
                    "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
                )

        if self.ref_model is None:
            if not (self.is_peft_model or self.precompute_ref_log_probs):
                raise ValueError(
                    "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
                )
        else:
            if self.is_deepspeed_enabled:
                self.ref_model = self._prepare_deepspeed(self.ref_model)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

    def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
                    # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        return model

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
        """

        if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
            dataloader_params = {
                "batch_size": self.args.per_device_train_batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }

            # prepare dataloader
            data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))

            reference_chosen_logps = []
            reference_rejected_logps = []
            for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
                reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
                reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
                    (reference_chosen_logp, reference_rejected_logp)
                )
                reference_chosen_logps.append(reference_chosen_logp.cpu())
                reference_rejected_logps.append(reference_rejected_logp.cpu())

            all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
            all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()

            self.train_dataset = self.train_dataset.add_column(
                name="reference_chosen_logps", column=all_reference_chosen_logps
            )
            self.train_dataset = self.train_dataset.add_column(
                name="reference_rejected_logps", column=all_reference_rejected_logps
            )

            self._precomputed_train_ref_log_probs = True

        return super().get_train_dataloader()

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
            dataloader_params = {
                "batch_size": self.args.per_device_eval_batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }

            # prepare dataloader
            data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

            reference_chosen_logps = []
            reference_rejected_logps = []
            for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
                reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
                reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
                    (reference_chosen_logp, reference_rejected_logp)
                )
                reference_chosen_logps.append(reference_chosen_logp.cpu())
                reference_rejected_logps.append(reference_rejected_logp.cpu())

            all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
            all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()

            eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
            eval_dataset = eval_dataset.add_column(
                name="reference_rejected_logps", column=all_reference_rejected_logps
            )

            # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
            if self.eval_dataset is not None:
                self.eval_dataset = eval_dataset
            self._precomputed_eval_ref_log_probs = True

        return super().get_eval_dataloader(eval_dataset=eval_dataset)

    def build_tokenized_answer(self, prompt, answer):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """

        full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
        prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]

        answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
        answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

        # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
        full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

        # Prepare input tokens for token by token comparison
        full_input_ids = np.array(full_tokenized["input_ids"])

        if len(full_input_ids) != len(full_concat_input_ids):
            raise ValueError("Prompt input ids and answer input ids should have the same length.")

        # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
        # can be merged together when tokenizing prompt+answer. This could result
        # on the last token from the prompt being different when tokenized on its own
        # vs when done as prompt+answer.
        response_token_ids_start_idx = len(prompt_input_ids)

        # If tokenized prompt is different than both prompt+answer, then it means the
        # last token has changed due to merging.
        if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
            response_token_ids_start_idx -= 1

        prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
        prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

        if len(prompt_input_ids) != len(prompt_attention_mask):
            raise ValueError("Prompt input ids and attention mask should have the same length.")

        answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
        answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

        return dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=prompt_attention_mask,
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
        )

    def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict:
        """Tokenize a single row from a SPPO specific dataset.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
        in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]
        if not self.is_encoder_decoder:
            # Check issues below for more details
            #  1. https://github.com/huggingface/trl/issues/907
            #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
            #  3. https://github.com/LianjiaTech/BELLE/issues/337

            if not isinstance(prompt, str):
                raise ValueError(f"prompt should be an str but got {type(prompt)}")
            prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
            prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

            if not isinstance(chosen, str):
                raise ValueError(f"chosen should be an str but got {type(chosen)}")
            chosen_tokens = self.build_tokenized_answer(prompt, chosen)

            if not isinstance(rejected, str):
                raise ValueError(f"rejected should be an str but got {type(rejected)}")
            rejected_tokens = self.build_tokenized_answer(prompt, rejected)

            # Last prompt token might get merged by tokenizer and
            # it should not be included for generation if that happens
            prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

            chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
            rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
            prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)

            for k, v in prompt_tokens.items():
                prompt_tokens[k] = v[:prompt_len_input_ids]

            # Make sure prompts only have one different token at most an
            # and length only differs by 1 at most
            num_diff_tokens = sum(
                [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
            )
            num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
            if num_diff_tokens > 1 or num_diff_len > 1:
                raise ValueError(
                    "Chosen and rejected prompt_input_ids might only differ on the "
                    "last token due to tokenizer merge ops."
                )

            # add BOS token to head of prompt
            prompt_tokens["prompt_input_ids"] = [self.processing_class.bos_token_id] + prompt_tokens["prompt_input_ids"]
            chosen_tokens["prompt_input_ids"] = [self.processing_class.bos_token_id] + chosen_tokens["prompt_input_ids"]
            rejected_tokens["prompt_input_ids"] = [self.processing_class.bos_token_id] + rejected_tokens["prompt_input_ids"]

            prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
            chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
            rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

            # add EOS token to end of answer
            chosen_tokens["input_ids"].append(self.processing_class.eos_token_id)
            chosen_tokens["attention_mask"].append(1)

            rejected_tokens["input_ids"].append(self.processing_class.eos_token_id)
            rejected_tokens["attention_mask"].append(1)

            longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

            # if combined sequence is too long, truncate the prompt
            for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    if self.truncation_mode == "keep_start":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
                    elif self.truncation_mode == "keep_end":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
                    else:
                        raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

            # if that's still too long, truncate the response
            for answer_tokens in [chosen_tokens, rejected_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    for k in ["input_ids", "attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]

            # Create labels
            chosen_sequence_tokens = {
                k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            rejected_sequence_tokens = {
                k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
            chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(chosen_tokens["prompt_input_ids"])
            rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
            rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(rejected_tokens["prompt_input_ids"])

            for k, toks in {
                "chosen_": chosen_sequence_tokens,
                "rejected_": rejected_sequence_tokens,
                "": prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}{type_key}"] = tokens


        else:
            chosen_tokens = self.processing_class(
                chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            rejected_tokens = self.processing_class(
                rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            prompt_tokens = self.processing_class(
                prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
            )

            batch["chosen_labels"] = chosen_tokens["input_ids"]
            batch["rejected_labels"] = rejected_tokens["input_ids"]
            batch["prompt_input_ids"] = prompt_tokens["input_ids"]
            batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

            if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
                batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=batch["rejected_labels"]
                )
                batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=batch["chosen_labels"]
                )
        #print('batch=======', batch.keys())
        return batch

    @contextmanager
    def null_ref_context(self):
        """Context manager for handling null reference model (that is, peft adapter manipulation)."""
        with self.accelerator.unwrap_model(
            self.model
        ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
            if self.ref_adapter_name:
                self.model.set_adapter(self.ref_adapter_name)
            yield
            if self.ref_adapter_name:
                self.model.set_adapter(self.model_adapter_name or "default")

    def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
        """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
        compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

        # compute reference logps
        with torch.no_grad(), compte_ref_context_manager():
            if self.ref_model is None:
                with self.null_ref_context():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, padded_batch)
            else:
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                ) = self.concatenated_forward(self.ref_model, padded_batch)

        return reference_chosen_logps, reference_rejected_logps

    @staticmethod
    def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        """Concatenate the chosen and rejected inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
            is_encoder_decoder: Whether the model is an encoder-decoder model.
            label_pad_token_id: The label pad token id.
            padding_value: The padding value to use for the concatenated inputs_ids.
            device: The device for the concatenated inputs.

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """
        concatenated_batch = {}

        if is_encoder_decoder:
            max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
        else:
            max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])

        for k in batch:
            if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("chosen", "concatenated")
                concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
        for k in batch:
            if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("rejected", "concatenated")
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        pad_to_length(batch[k], max_length, pad_value=pad_value),
                    ),
                    dim=0,
                ).to(device=device)

        if is_encoder_decoder:
            concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
            concatenated_batch["concatenated_attention_mask"] = (
                batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
            )

        return concatenated_batch

    def sppo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        chosen_probs: Union[torch.FloatTensor, None] = None,
        chosen_probs_win: Union[torch.FloatTensor, None] = None,
        chosen_probs_lose: Union[torch.FloatTensor, None] = None,
        reference_free: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the SPPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the SPPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        if reference_free:
            ref_logratios = 0
        else:
            ref_logratios = reference_chosen_logps - reference_rejected_logps

        pi_logratios = pi_logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = pi_logratios - ref_logratios

        # For sppo
        logits_w = policy_chosen_logps - reference_chosen_logps
        logits_l = policy_rejected_logps - reference_rejected_logps

        # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. In SPPO, beta=1/eta has a different meaning, and is usually chosen around 1e-3.
        # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
        # calculates a conservative SPPO loss.
        if self.loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        elif self.loss_type == "hinge":
            losses = torch.relu(1 - self.beta * logits)
        elif self.loss_type == "ipo":
            # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
            losses = (logits - 1 / (2 * self.beta)) ** 2
        elif self.loss_type == "sppo":
            loss_w = (logits_w - (1 / self.beta)*(chosen_probs_win - 0.5)) ** 2
            loss_l = (logits_l - (1 / self.beta)*(chosen_probs_lose - 0.5)) ** 2
            losses = (loss_w + loss_l)/2
        elif self.loss_type == "sppo_single":
            loss_w = (logits_w - (1 / self.beta)*(chosen_probs - 0.5)) ** 2
            loss_l = (logits_l + (1 / self.beta)*(chosen_probs - 0.5)) ** 2
            losses = (loss_w + loss_l)/2
        elif self.loss_type == "kto_pair":
            # eqn (7) of the HALOs paper
            chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
            rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

            chosen_logratios = policy_chosen_logps - reference_chosen_logps
            rejected_logratios = policy_rejected_logps - reference_rejected_logps
            # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
            losses = torch.cat(
                (
                    1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
                    1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
                ),
                0,
            )
        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
            )

        chosen_rewards = (
            self.beta
            * (
                policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
            ).detach()
        )
        rejected_rewards = (
            self.beta
            * (
                policy_rejected_logps.to(self.accelerator.device)
                - reference_rejected_logps.to(self.accelerator.device)
            ).detach()
        )

        return losses, chosen_rewards, rejected_rewards

    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        average_log_prob: bool = False,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> torch.FloatTensor:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
            label_pad_token_id: The label pad token id.
            is_encoder_decoder: Whether the model is an encoder-decoder model.

        Returns:
            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
        """
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == label_pad_token_id] = 0

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        if average_log_prob:
            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
        else:
            return (per_token_logps * loss_mask).sum(-1)

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = (
            {
                "labels": concatenated_batch["concatenated_labels"],
                "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
            }
            if self.is_encoder_decoder
            else {}
        )
        all_logits = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            **model_kwargs,
        ).logits

        all_logps = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            average_log_prob=False,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the SPPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)

        chosen_probs = torch.tensor(batch["chosen_probs"], dtype=float, device=policy_chosen_logps.device)
        chosen_probs_win = torch.tensor(batch["chosen_probs_win"], dtype=float, device=policy_chosen_logps.device)
        chosen_probs_lose = torch.tensor(batch["chosen_probs_lose"], dtype=float, device=policy_chosen_logps.device)
        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.sppo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
            chosen_probs,
            chosen_probs_win,
            chosen_probs_lose,
            # rejected_probs,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

        return losses.mean(), metrics

    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
        num_items_in_batch=None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        if not self.use_dpo_data_collator:
            warnings.warn(
                "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
                "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
            )

        compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

        with compute_loss_context_manager():
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

        # force log the metrics
        self.store_metrics(metrics, train_eval="train")

        if return_outputs:
            return (loss, metrics)
        return loss

    def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
        """Generate samples from the model and reference model for the given batch of inputs."""

        # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
        # the torch cuda amp context manager as some hidden states are silently casted to full precision.
        generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast

        with generate_context_manager():
            policy_output = model.generate(
                input_ids=batch["prompt_input_ids"],
                attention_mask=batch["prompt_attention_mask"],
                max_length=self.max_length,
                do_sample=True,
                pad_token_id=self.processing_class.pad_token_id,
            )

            # if reference_output in batch use that otherwise use the reference model
            if "reference_output" in batch:
                reference_output = batch["reference_output"]
            else:
                if self.ref_model is None:
                    with self.null_ref_context():
                        reference_output = self.model.generate(
                            input_ids=batch["prompt_input_ids"],
                            attention_mask=batch["prompt_attention_mask"],
                            max_length=self.max_length,
                            do_sample=True,
                            pad_token_id=self.processing_class.pad_token_id,
                        )
                else:
                    reference_output = self.ref_model.generate(
                        input_ids=batch["prompt_input_ids"],
                        attention_mask=batch["prompt_attention_mask"],
                        max_length=self.max_length,
                        do_sample=True,
                        pad_token_id=self.processing_class.pad_token_id,
                    )

        policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
        policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)

        reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
        reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)

        return policy_output_decoded, reference_output_decoded

    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ):
        if not self.use_dpo_data_collator:
            warnings.warn(
                "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
                "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
            )
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

        with torch.no_grad(), prediction_context_manager():
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

        # force log the metrics
        self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return (loss.detach(), None, None)

        # logits for the chosen and rejected samples from model
        logits_dict = {
            "eval_logits/chosen": metrics["eval_logits/chosen"],
            "eval_logits/rejected": metrics["eval_logits/rejected"],
        }
        logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
        logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
        labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

        return (loss.detach(), logits, labels)

    def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Overriding built-in evaluation loop to store metrics for each batch.
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

        Works both with or without labels.
        """

        # Sample and save to game log if requested (for one batch to save time)
        if self.generate_during_eval:
            # Generate random indices within the range of the total number of samples
            num_samples = len(dataloader.dataset)
            random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)

            # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
            random_batch_dataset = dataloader.dataset.select(random_indices)
            random_batch = self.data_collator(random_batch_dataset)
            random_batch = self._prepare_inputs(random_batch)

            policy_output_decoded, ref_output_decoded = self.generate_from_model(self.model, random_batch)

            self.log(
                {
                    "game_log": wandb.Table(
                        columns=["Prompt", "Policy", "Ref Model"],
                        rows=[
                            [prompt, pol[len(prompt) :], ref[len(prompt) :]]
                            for prompt, pol, ref in zip(
                                random_batch["prompt"], policy_output_decoded, ref_output_decoded
                            )
                        ],
                    )
                }
            )
            self.state.log_history.pop()

        # Base evaluation
        initial_output = super().evaluation_loop(
            dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
        )

        return initial_output

    def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
        """
        Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        return super().log(logs, start_time)
beta = beta instance-attribute
generate_during_eval = generate_during_eval instance-attribute
is_encoder_decoder = model.config.is_encoder_decoder instance-attribute
is_peft_model = is_peft_available() and isinstance(model, PeftModel) instance-attribute
label_pad_token_id = label_pad_token_id instance-attribute
label_smoothing = label_smoothing instance-attribute
loss_type = loss_type instance-attribute
max_length = max_length instance-attribute
max_prompt_length = max_prompt_length instance-attribute
max_target_length = max_target_length instance-attribute
model_adapter_name = model_adapter_name instance-attribute
padding_value = padding_value if padding_value is not None else processing_class.pad_token_id instance-attribute
precompute_ref_log_probs = precompute_ref_log_probs instance-attribute
processing_class = processing_class instance-attribute
ref_adapter_name = ref_adapter_name instance-attribute
ref_model = ref_model instance-attribute
truncation_mode = truncation_mode instance-attribute
use_dpo_data_collator = True instance-attribute
build_tokenized_answer(prompt, answer)

Llama tokenizer does satisfy enc(a + b) = enc(a) + enc(b). It does ensure enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]. Reference: https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
def build_tokenized_answer(self, prompt, answer):
    """
    Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
    It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
    Reference:
        https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
    """

    full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
    prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]

    answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
    answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

    # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
    full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

    # Prepare input tokens for token by token comparison
    full_input_ids = np.array(full_tokenized["input_ids"])

    if len(full_input_ids) != len(full_concat_input_ids):
        raise ValueError("Prompt input ids and answer input ids should have the same length.")

    # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
    # can be merged together when tokenizing prompt+answer. This could result
    # on the last token from the prompt being different when tokenized on its own
    # vs when done as prompt+answer.
    response_token_ids_start_idx = len(prompt_input_ids)

    # If tokenized prompt is different than both prompt+answer, then it means the
    # last token has changed due to merging.
    if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
        response_token_ids_start_idx -= 1

    prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
    prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

    if len(prompt_input_ids) != len(prompt_attention_mask):
        raise ValueError("Prompt input ids and attention mask should have the same length.")

    answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
    answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

    return dict(
        prompt_input_ids=prompt_input_ids,
        prompt_attention_mask=prompt_attention_mask,
        input_ids=answer_input_ids,
        attention_mask=answer_attention_mask,
    )
compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
def compute_loss(
    self,
    model: Union[PreTrainedModel, nn.Module],
    inputs: Dict[str, Union[torch.Tensor, Any]],
    return_outputs=False,
    num_items_in_batch=None
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
    if not self.use_dpo_data_collator:
        warnings.warn(
            "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
            "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
        )

    compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

    with compute_loss_context_manager():
        loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

    # force log the metrics
    self.store_metrics(metrics, train_eval="train")

    if return_outputs:
        return (loss, metrics)
    return loss
compute_reference_log_probs(padded_batch)

Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
    """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
    compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

    # compute reference logps
    with torch.no_grad(), compte_ref_context_manager():
        if self.ref_model is None:
            with self.null_ref_context():
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                ) = self.concatenated_forward(self.model, padded_batch)
        else:
            (
                reference_chosen_logps,
                reference_rejected_logps,
                _,
                _,
            ) = self.concatenated_forward(self.ref_model, padded_batch)

    return reference_chosen_logps, reference_rejected_logps
concatenated_forward(model, batch)

Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

We do this to avoid doing two forward passes, because it's faster for FSDP.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
def concatenated_forward(
    self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

    We do this to avoid doing two forward passes, because it's faster for FSDP.
    """
    concatenated_batch = self.concatenated_inputs(
        batch,
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
        padding_value=self.padding_value,
        device=self.accelerator.device,
    )
    len_chosen = batch["chosen_labels"].shape[0]

    model_kwargs = (
        {
            "labels": concatenated_batch["concatenated_labels"],
            "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
        }
        if self.is_encoder_decoder
        else {}
    )
    all_logits = model(
        concatenated_batch["concatenated_input_ids"],
        attention_mask=concatenated_batch["concatenated_attention_mask"],
        **model_kwargs,
    ).logits

    all_logps = self.get_batch_logps(
        all_logits,
        concatenated_batch["concatenated_labels"],
        average_log_prob=False,
        is_encoder_decoder=self.is_encoder_decoder,
        label_pad_token_id=self.label_pad_token_id,
    )

    chosen_logps = all_logps[:len_chosen]
    rejected_logps = all_logps[len_chosen:]

    chosen_logits = all_logits[:len_chosen]
    rejected_logits = all_logits[len_chosen:]

    return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
concatenated_inputs(batch, is_encoder_decoder=False, label_pad_token_id=-100, padding_value=0, device=None) staticmethod

Concatenate the chosen and rejected inputs into a single tensor.

Parameters:

Name Type Description Default
batch Dict[str, Union[List, LongTensor]]

A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).

required
is_encoder_decoder bool

Whether the model is an encoder-decoder model.

False
label_pad_token_id int

The label pad token id.

-100
padding_value int

The padding value to use for the concatenated inputs_ids.

0
device Optional[device]

The device for the concatenated inputs.

None

Returns:

Type Description
Dict[str, LongTensor]

A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
@staticmethod
def concatenated_inputs(
    batch: Dict[str, Union[List, torch.LongTensor]],
    is_encoder_decoder: bool = False,
    label_pad_token_id: int = -100,
    padding_value: int = 0,
    device: Optional[torch.device] = None,
) -> Dict[str, torch.LongTensor]:
    """Concatenate the chosen and rejected inputs into a single tensor.

    Args:
        batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
        is_encoder_decoder: Whether the model is an encoder-decoder model.
        label_pad_token_id: The label pad token id.
        padding_value: The padding value to use for the concatenated inputs_ids.
        device: The device for the concatenated inputs.

    Returns:
        A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
    """
    concatenated_batch = {}

    if is_encoder_decoder:
        max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
    else:
        max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])

    for k in batch:
        if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
            if "labels" in k or is_encoder_decoder:
                pad_value = label_pad_token_id
            elif k.endswith("_input_ids"):
                pad_value = padding_value
            elif k.endswith("_attention_mask"):
                pad_value = 0
            concatenated_key = k.replace("chosen", "concatenated")
            concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
    for k in batch:
        if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
            if "labels" in k or is_encoder_decoder:
                pad_value = label_pad_token_id
            elif k.endswith("_input_ids"):
                pad_value = padding_value
            elif k.endswith("_attention_mask"):
                pad_value = 0
            concatenated_key = k.replace("rejected", "concatenated")
            concatenated_batch[concatenated_key] = torch.cat(
                (
                    concatenated_batch[concatenated_key],
                    pad_to_length(batch[k], max_length, pad_value=pad_value),
                ),
                dim=0,
            ).to(device=device)

    if is_encoder_decoder:
        concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
        concatenated_batch["concatenated_attention_mask"] = (
            batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
        )

    return concatenated_batch
evaluation_loop(dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix='eval')

Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by Trainer.evaluate() and Trainer.predict().

Works both with or without labels.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
def evaluation_loop(
    self,
    dataloader: DataLoader,
    description: str,
    prediction_loss_only: Optional[bool] = None,
    ignore_keys: Optional[List[str]] = None,
    metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
    """
    Overriding built-in evaluation loop to store metrics for each batch.
    Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

    Works both with or without labels.
    """

    # Sample and save to game log if requested (for one batch to save time)
    if self.generate_during_eval:
        # Generate random indices within the range of the total number of samples
        num_samples = len(dataloader.dataset)
        random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)

        # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
        random_batch_dataset = dataloader.dataset.select(random_indices)
        random_batch = self.data_collator(random_batch_dataset)
        random_batch = self._prepare_inputs(random_batch)

        policy_output_decoded, ref_output_decoded = self.generate_from_model(self.model, random_batch)

        self.log(
            {
                "game_log": wandb.Table(
                    columns=["Prompt", "Policy", "Ref Model"],
                    rows=[
                        [prompt, pol[len(prompt) :], ref[len(prompt) :]]
                        for prompt, pol, ref in zip(
                            random_batch["prompt"], policy_output_decoded, ref_output_decoded
                        )
                    ],
                )
            }
        )
        self.state.log_history.pop()

    # Base evaluation
    initial_output = super().evaluation_loop(
        dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
    )

    return initial_output
generate_from_model(model, batch)

Generate samples from the model and reference model for the given batch of inputs.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
    """Generate samples from the model and reference model for the given batch of inputs."""

    # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
    # the torch cuda amp context manager as some hidden states are silently casted to full precision.
    generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast

    with generate_context_manager():
        policy_output = model.generate(
            input_ids=batch["prompt_input_ids"],
            attention_mask=batch["prompt_attention_mask"],
            max_length=self.max_length,
            do_sample=True,
            pad_token_id=self.processing_class.pad_token_id,
        )

        # if reference_output in batch use that otherwise use the reference model
        if "reference_output" in batch:
            reference_output = batch["reference_output"]
        else:
            if self.ref_model is None:
                with self.null_ref_context():
                    reference_output = self.model.generate(
                        input_ids=batch["prompt_input_ids"],
                        attention_mask=batch["prompt_attention_mask"],
                        max_length=self.max_length,
                        do_sample=True,
                        pad_token_id=self.processing_class.pad_token_id,
                    )
            else:
                reference_output = self.ref_model.generate(
                    input_ids=batch["prompt_input_ids"],
                    attention_mask=batch["prompt_attention_mask"],
                    max_length=self.max_length,
                    do_sample=True,
                    pad_token_id=self.processing_class.pad_token_id,
                )

    policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
    policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)

    reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
    reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)

    return policy_output_decoded, reference_output_decoded
get_batch_logps(logits, labels, average_log_prob=False, label_pad_token_id=-100, is_encoder_decoder=False) staticmethod

Compute the log probabilities of the given labels under the given logits.

Parameters:

Name Type Description Default
logits FloatTensor

Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)

required
labels LongTensor

Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)

required
average_log_prob bool

If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

False
label_pad_token_id int

The label pad token id.

-100
is_encoder_decoder bool

Whether the model is an encoder-decoder model.

False

Returns:

Type Description
FloatTensor

A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
@staticmethod
def get_batch_logps(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    average_log_prob: bool = False,
    label_pad_token_id: int = -100,
    is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
        label_pad_token_id: The label pad token id.
        is_encoder_decoder: Whether the model is an encoder-decoder model.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    if logits.shape[:-1] != labels.shape:
        raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

    if not is_encoder_decoder:
        labels = labels[:, 1:].clone()
        logits = logits[:, :-1, :]
    loss_mask = labels != label_pad_token_id

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == label_pad_token_id] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)
get_batch_loss_metrics(model, batch, train_eval='train')

Compute the SPPO loss and other metrics for the given batch of inputs for train or test.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
def get_batch_loss_metrics(
    self,
    model,
    batch: Dict[str, Union[List, torch.LongTensor]],
    train_eval: Literal["train", "eval"] = "train",
):
    """Compute the SPPO loss and other metrics for the given batch of inputs for train or test."""
    metrics = {}

    (
        policy_chosen_logps,
        policy_rejected_logps,
        policy_chosen_logits,
        policy_rejected_logits,
    ) = self.concatenated_forward(model, batch)

    chosen_probs = torch.tensor(batch["chosen_probs"], dtype=float, device=policy_chosen_logps.device)
    chosen_probs_win = torch.tensor(batch["chosen_probs_win"], dtype=float, device=policy_chosen_logps.device)
    chosen_probs_lose = torch.tensor(batch["chosen_probs_lose"], dtype=float, device=policy_chosen_logps.device)
    # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
    if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
        reference_chosen_logps = batch["reference_chosen_logps"]
        reference_rejected_logps = batch["reference_rejected_logps"]
    else:
        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, batch)
            else:
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                ) = self.concatenated_forward(self.ref_model, batch)

    losses, chosen_rewards, rejected_rewards = self.sppo_loss(
        policy_chosen_logps,
        policy_rejected_logps,
        reference_chosen_logps,
        reference_rejected_logps,
        chosen_probs,
        chosen_probs_win,
        chosen_probs_lose,
        # rejected_probs,
    )
    reward_accuracies = (chosen_rewards > rejected_rewards).float()

    prefix = "eval_" if train_eval == "eval" else ""
    metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
    metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
    metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
    metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
    metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
    metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
    metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
    metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

    return losses.mean(), metrics
get_eval_dataloader(eval_dataset=None)

Returns the evaluation [~torch.utils.data.DataLoader].

Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute ref_log_probs.

Parameters:

Name Type Description Default
eval_dataset `torch.utils.data.Dataset`, *optional*

If provided, will override self.eval_dataset. If it is a [~datasets.Dataset], columns not accepted by the model.forward() method are automatically removed. It must implement __len__.

None
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
    """
    Returns the evaluation [`~torch.utils.data.DataLoader`].

    Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.

    Args:
        eval_dataset (`torch.utils.data.Dataset`, *optional*):
            If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
            by the `model.forward()` method are automatically removed. It must implement `__len__`.
    """
    if eval_dataset is None and self.eval_dataset is None:
        raise ValueError("Trainer: evaluation requires an eval_dataset.")
    eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

    if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
        dataloader_params = {
            "batch_size": self.args.per_device_eval_batch_size,
            "collate_fn": self.data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "shuffle": False,
        }

        # prepare dataloader
        data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

        reference_chosen_logps = []
        reference_rejected_logps = []
        for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
            reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
            reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
                (reference_chosen_logp, reference_rejected_logp)
            )
            reference_chosen_logps.append(reference_chosen_logp.cpu())
            reference_rejected_logps.append(reference_rejected_logp.cpu())

        all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
        all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()

        eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
        eval_dataset = eval_dataset.add_column(
            name="reference_rejected_logps", column=all_reference_rejected_logps
        )

        # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
        if self.eval_dataset is not None:
            self.eval_dataset = eval_dataset
        self._precomputed_eval_ref_log_probs = True

    return super().get_eval_dataloader(eval_dataset=eval_dataset)
get_train_dataloader()

Returns the training [~torch.utils.data.DataLoader].

Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute ref_log_probs.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def get_train_dataloader(self) -> DataLoader:
    """
    Returns the training [`~torch.utils.data.DataLoader`].

    Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
    """

    if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
        dataloader_params = {
            "batch_size": self.args.per_device_train_batch_size,
            "collate_fn": self.data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "shuffle": False,
        }

        # prepare dataloader
        data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))

        reference_chosen_logps = []
        reference_rejected_logps = []
        for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
            reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
            reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
                (reference_chosen_logp, reference_rejected_logp)
            )
            reference_chosen_logps.append(reference_chosen_logp.cpu())
            reference_rejected_logps.append(reference_rejected_logp.cpu())

        all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
        all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()

        self.train_dataset = self.train_dataset.add_column(
            name="reference_chosen_logps", column=all_reference_chosen_logps
        )
        self.train_dataset = self.train_dataset.add_column(
            name="reference_rejected_logps", column=all_reference_rejected_logps
        )

        self._precomputed_train_ref_log_probs = True

    return super().get_train_dataloader()
log(logs, start_time=None)

Log logs on the various objects watching training, including stored metrics.

Parameters:

Name Type Description Default
logs `Dict[str, float]`

The values to log.

required
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
    """
    Log `logs` on the various objects watching training, including stored metrics.

    Args:
        logs (`Dict[str, float]`):
            The values to log.
    """
    # logs either has 'loss' or 'eval_loss'
    train_eval = "train" if "loss" in logs else "eval"
    # Add averaged stored metrics to logs
    for key, metrics in self._stored_metrics[train_eval].items():
        logs[key] = torch.tensor(metrics).mean().item()
    del self._stored_metrics[train_eval]
    return super().log(logs, start_time)
null_ref_context()

Context manager for handling null reference model (that is, peft adapter manipulation).

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
727
728
729
730
731
732
733
734
735
736
737
@contextmanager
def null_ref_context(self):
    """Context manager for handling null reference model (that is, peft adapter manipulation)."""
    with self.accelerator.unwrap_model(
        self.model
    ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
        if self.ref_adapter_name:
            self.model.set_adapter(self.ref_adapter_name)
        yield
        if self.ref_adapter_name:
            self.model.set_adapter(self.model_adapter_name or "default")
prediction_step(model, inputs, prediction_loss_only, ignore_keys=None)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
def prediction_step(
    self,
    model: Union[PreTrainedModel, nn.Module],
    inputs: Dict[str, Union[torch.Tensor, Any]],
    prediction_loss_only: bool,
    ignore_keys: Optional[List[str]] = None,
):
    if not self.use_dpo_data_collator:
        warnings.warn(
            "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
            "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
        )
    if ignore_keys is None:
        if hasattr(model, "config"):
            ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
        else:
            ignore_keys = []

    prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

    with torch.no_grad(), prediction_context_manager():
        loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

    # force log the metrics
    self.store_metrics(metrics, train_eval="eval")

    if prediction_loss_only:
        return (loss.detach(), None, None)

    # logits for the chosen and rejected samples from model
    logits_dict = {
        "eval_logits/chosen": metrics["eval_logits/chosen"],
        "eval_logits/rejected": metrics["eval_logits/rejected"],
    }
    logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
    logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
    labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

    return (loss.detach(), logits, labels)
sppo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, chosen_probs=None, chosen_probs_win=None, chosen_probs_lose=None, reference_free=False)

Compute the SPPO loss for a batch of policy and reference model log probabilities.

Parameters:

Name Type Description Default
policy_chosen_logps FloatTensor

Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)

required
policy_rejected_logps FloatTensor

Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)

required
reference_chosen_logps FloatTensor

Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)

required
reference_rejected_logps FloatTensor

Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

required
reference_free bool

If True, we ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses.

False

Returns:

Type Description
FloatTensor

A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).

FloatTensor

The losses tensor contains the SPPO loss for each example in the batch.

FloatTensor

The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
def sppo_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
    chosen_probs: Union[torch.FloatTensor, None] = None,
    chosen_probs_win: Union[torch.FloatTensor, None] = None,
    chosen_probs_lose: Union[torch.FloatTensor, None] = None,
    reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Compute the SPPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

    Returns:
        A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
        The losses tensor contains the SPPO loss for each example in the batch.
        The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
    """
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    if reference_free:
        ref_logratios = 0
    else:
        ref_logratios = reference_chosen_logps - reference_rejected_logps

    pi_logratios = pi_logratios.to(self.accelerator.device)
    ref_logratios = ref_logratios.to(self.accelerator.device)
    logits = pi_logratios - ref_logratios

    # For sppo
    logits_w = policy_chosen_logps - reference_chosen_logps
    logits_l = policy_rejected_logps - reference_rejected_logps

    # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. In SPPO, beta=1/eta has a different meaning, and is usually chosen around 1e-3.
    # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
    # calculates a conservative SPPO loss.
    if self.loss_type == "sigmoid":
        losses = (
            -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
            - F.logsigmoid(-self.beta * logits) * self.label_smoothing
        )
    elif self.loss_type == "hinge":
        losses = torch.relu(1 - self.beta * logits)
    elif self.loss_type == "ipo":
        # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
        losses = (logits - 1 / (2 * self.beta)) ** 2
    elif self.loss_type == "sppo":
        loss_w = (logits_w - (1 / self.beta)*(chosen_probs_win - 0.5)) ** 2
        loss_l = (logits_l - (1 / self.beta)*(chosen_probs_lose - 0.5)) ** 2
        losses = (loss_w + loss_l)/2
    elif self.loss_type == "sppo_single":
        loss_w = (logits_w - (1 / self.beta)*(chosen_probs - 0.5)) ** 2
        loss_l = (logits_l + (1 / self.beta)*(chosen_probs - 0.5)) ** 2
        losses = (loss_w + loss_l)/2
    elif self.loss_type == "kto_pair":
        # eqn (7) of the HALOs paper
        chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
        rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

        chosen_logratios = policy_chosen_logps - reference_chosen_logps
        rejected_logratios = policy_rejected_logps - reference_rejected_logps
        # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
        losses = torch.cat(
            (
                1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
                1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
            ),
            0,
        )
    else:
        raise ValueError(
            f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
        )

    chosen_rewards = (
        self.beta
        * (
            policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
        ).detach()
    )
    rejected_rewards = (
        self.beta
        * (
            policy_rejected_logps.to(self.accelerator.device)
            - reference_rejected_logps.to(self.accelerator.device)
        ).detach()
    )

    return losses, chosen_rewards, rejected_rewards
store_metrics(metrics, train_eval='train')
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
1181
1182
1183
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
    for key, value in metrics.items():
        self._stored_metrics[train_eval][key].append(value)
tokenize_row(feature, model=None)

Tokenize a single row from a SPPO specific dataset.

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/trainer.py
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict:
    """Tokenize a single row from a SPPO specific dataset.

    At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
    in case the prompt + chosen or prompt + rejected responses is/are too long. First
        we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

    We also create the labels for the chosen/rejected responses, which are of length equal to
        the sum of the length of the prompt and the chosen/rejected response, with
        label_pad_token_id  for the prompt tokens.
    """
    batch = {}
    prompt = feature["prompt"]
    chosen = feature["chosen"]
    rejected = feature["rejected"]
    if not self.is_encoder_decoder:
        # Check issues below for more details
        #  1. https://github.com/huggingface/trl/issues/907
        #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        #  3. https://github.com/LianjiaTech/BELLE/issues/337

        if not isinstance(prompt, str):
            raise ValueError(f"prompt should be an str but got {type(prompt)}")
        prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
        prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

        if not isinstance(chosen, str):
            raise ValueError(f"chosen should be an str but got {type(chosen)}")
        chosen_tokens = self.build_tokenized_answer(prompt, chosen)

        if not isinstance(rejected, str):
            raise ValueError(f"rejected should be an str but got {type(rejected)}")
        rejected_tokens = self.build_tokenized_answer(prompt, rejected)

        # Last prompt token might get merged by tokenizer and
        # it should not be included for generation if that happens
        prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

        chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
        rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
        prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)

        for k, v in prompt_tokens.items():
            prompt_tokens[k] = v[:prompt_len_input_ids]

        # Make sure prompts only have one different token at most an
        # and length only differs by 1 at most
        num_diff_tokens = sum(
            [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
        )
        num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
        if num_diff_tokens > 1 or num_diff_len > 1:
            raise ValueError(
                "Chosen and rejected prompt_input_ids might only differ on the "
                "last token due to tokenizer merge ops."
            )

        # add BOS token to head of prompt
        prompt_tokens["prompt_input_ids"] = [self.processing_class.bos_token_id] + prompt_tokens["prompt_input_ids"]
        chosen_tokens["prompt_input_ids"] = [self.processing_class.bos_token_id] + chosen_tokens["prompt_input_ids"]
        rejected_tokens["prompt_input_ids"] = [self.processing_class.bos_token_id] + rejected_tokens["prompt_input_ids"]

        prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
        chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
        rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

        # add EOS token to end of answer
        chosen_tokens["input_ids"].append(self.processing_class.eos_token_id)
        chosen_tokens["attention_mask"].append(1)

        rejected_tokens["input_ids"].append(self.processing_class.eos_token_id)
        rejected_tokens["attention_mask"].append(1)

        longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

        # if combined sequence is too long, truncate the prompt
        for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
            if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                if self.truncation_mode == "keep_start":
                    for k in ["prompt_input_ids", "prompt_attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
                elif self.truncation_mode == "keep_end":
                    for k in ["prompt_input_ids", "prompt_attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
                else:
                    raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

        # if that's still too long, truncate the response
        for answer_tokens in [chosen_tokens, rejected_tokens]:
            if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                for k in ["input_ids", "attention_mask"]:
                    answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]

        # Create labels
        chosen_sequence_tokens = {
            k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
        }
        rejected_sequence_tokens = {
            k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
        }
        chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
        chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
            self.label_pad_token_id
        ] * len(chosen_tokens["prompt_input_ids"])
        rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
        rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
            self.label_pad_token_id
        ] * len(rejected_tokens["prompt_input_ids"])

        for k, toks in {
            "chosen_": chosen_sequence_tokens,
            "rejected_": rejected_sequence_tokens,
            "": prompt_tokens,
        }.items():
            for type_key, tokens in toks.items():
                if type_key == "token_type_ids":
                    continue
                batch[f"{k}{type_key}"] = tokens


    else:
        chosen_tokens = self.processing_class(
            chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
        )
        rejected_tokens = self.processing_class(
            rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
        )
        prompt_tokens = self.processing_class(
            prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
        )

        batch["chosen_labels"] = chosen_tokens["input_ids"]
        batch["rejected_labels"] = rejected_tokens["input_ids"]
        batch["prompt_input_ids"] = prompt_tokens["input_ids"]
        batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

        if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
            batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                labels=batch["rejected_labels"]
            )
            batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                labels=batch["chosen_labels"]
            )
    #print('batch=======', batch.keys())
    return batch
utils
apply_chat_template(example, tokenizer, skip_system_message)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def apply_chat_template(
    example,
    tokenizer,
    skip_system_message,
):
    if all(k in example.keys() for k in ("chosen", "rejected")):
        prompt_messages = example["chosen"][:-1]
        # Prepend a system message if the first message is not a system message
        if not skip_system_message:
            if example["chosen"][0]["role"] != "system":
                prompt_messages.insert(0, {"role": "system", "content": ""})
            # Now we extract the final turn to define chosen/rejected responses
            chosen_messages = example["chosen"][-1:]
            rejected_messages = example["rejected"][-1:]
            example["text_chosen"] = tokenizer.apply_chat_template(
                chosen_messages, tokenize=False, add_generate_prompt=True
            )
            example["text_rejected"] = tokenizer.apply_chat_template(
                rejected_messages, tokenize=False, add_generate_prompt=True
            )
            example["text_prompt"] = tokenizer.apply_chat_template(
                prompt_messages, tokenize=False, add_generate_prompt=True
            )
        else:
            prompt_messages = example["chosen"][:-1]
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]
            example["text_prompt"] = tokenizer.apply_chat_template(
                prompt_messages, tokenize=False, add_generate_prompt=True
            )
            example["text_chosen"] = tokenizer.apply_chat_template(
                chosen_messages, tokenize=False, add_generate_prompt=True
            )[len(example["text_prompt"]) :]
            example["text_rejected"] = tokenizer.apply_chat_template(
                rejected_messages, tokenize=False, add_generate_prompt=True
            )[len(example["text_prompt"]) :]
    else:
        raise ValueError(
            f"Could not format example as dialogue for `sppo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
        )
    return example
apply_template(text, tokenizer)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
19
20
21
22
23
def apply_template(text, tokenizer):
    return tokenizer.apply_chat_template(
        [{"role": "user", "content": text}, {"role": "assistant", "content": "None"}],
        tokenize=False, add_generate_prompt=True
    ).split("None")[0]
from_ranks(data, pairs, sppo_temp_dir, iter_num)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def from_ranks(data, pairs, sppo_temp_dir, iter_num):
    scores = np.load(f"{sppo_temp_dir}/ranking/SPPO-Iter{iter_num}/ranking.npy")  # ranking_fp == "./ranking.npy")
    scores = list(scores)

    probs = []
    rm_scores = []
    for idx, score in enumerate(scores):
        prb = np.zeros((pairs, pairs))
        for i in range(pairs):
            for j in range(pairs):
                prb[i][j] = 1 / (1 + np.exp(score[j] - score[i]))
        prb = prb.tolist()
        probs.append(prb)
        rm_scores.append(score)

    with open(f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}/probabilities.json", "w") as f:
        json.dump(probs, f)

    df = data.to_pandas()
    for i in range(pairs):
        with open(f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}/responses_{i}.json") as f:
            responses = json.load(f)
        fmt = [
            [
                {"content": data[j]["prompt"], "role": "user"},
                {"content": responses[j], "role": "assistant"},
            ]
            for j in range(len(data))
        ]
        df[f"generate_{i}"] = fmt
    if pairs < 5: #original implementation assumes pairs == 5
        #remove extra generate_x columns if they exist
        cols_to_delete = []
        for ind in range(pairs, 5):
            if f"generate_{ind}" in df.columns:
                cols_to_delete.append(f"generate_{ind}")
        if (len(cols_to_delete) > 0):
            df.drop(cols_to_delete, axis=1, inplace=True)
    df["probability"] = probs
    df["rm_scores"] = rm_scores
    df.to_parquet(f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}/train.parquet")
prepare_dataset_from_prompts(llm, tokenizer, data, sppo_temp_dir, iter_num=1, maxlen=2048, num_prompts=5)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def prepare_dataset_from_prompts(llm, tokenizer, data, sppo_temp_dir, iter_num=1, maxlen = 2048, num_prompts=5):

    # Step 1: https://github.com/uclaml/SPPO/blob/main/scripts/generate.sh
    # Step 1(a): Generate data - https://github.com/uclaml/SPPO/blob/main/scripts/generate.py
    Path(f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}").mkdir(parents=True, exist_ok=True)
    prompts = [apply_template(data[idx]["prompt"], tokenizer) for idx in range(len(data))]

    for p in range(num_prompts):
        set_seed(p * 50)
        enc = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")
        generated_ids = llm.generate(**enc,  do_sample=True, temperature=0.7,top_p=0.9,max_new_tokens=maxlen)

        generated_text = tokenizer.batch_decode(generated_ids[:, enc.input_ids.shape[1]:], skip_special_tokens=True) #[0]  #

        with open(f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}/responses_{p}.json", "w") as f:
            json.dump(generated_text, f)


    # Step 1(b): Rank data - https://github.com/uclaml/SPPO/blob/main/scripts/rank.py
    all_generated = []

    for i in range(num_prompts):
        file_path = f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}/responses_{i}.json"
        with open(file_path) as f:
            gen = json.load(f)
            all_generated.append(gen)

    candidates_texts = list(zip(*all_generated))
    assert len(data) == len(candidates_texts)

    os.makedirs(f"{sppo_temp_dir}/ranking/SPPO-Iter{iter_num}", exist_ok=True)

    ranking(sppo_temp_dir, iter_num, prompts, candidates_texts)

    # Step 1(c): Compute probs - https://github.com/uclaml/SPPO/blob/main/scripts/compute_prob.py
    from_ranks(data, num_prompts, sppo_temp_dir, iter_num)
    out_path = prepare_score(num_prompts, sppo_temp_dir, iter_num)
    print('probs calculated')

    # Step 2: https://github.com/uclaml/SPPO/blob/main/scripts/pipeline.sh
    train = Dataset.from_parquet(f"{out_path}/train.parquet")
    processed_train = process_dataset(train, tokenizer)
    return processed_train
prepare_score(pairs, sppo_temp_dir, iter_num)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def prepare_score(pairs, sppo_temp_dir, iter_num):
    # Load dataset and convert to DataFrame
    train = Dataset.from_parquet(f"{sppo_temp_dir}/generated/SPPO-Iter{iter_num}/train.parquet")
    train = pd.DataFrame(train)

    # Calculate metrics and probabilities
    metrics = train['rm_scores'].apply(lambda x: np.array(x[-pairs:]))
    metrics_prob = train['probability'].apply(lambda x: np.stack(x).sum(axis=1))
    maxmin = metrics.apply(lambda x: [x.argmax(), x.argmin()])

    # Reorganize the DataFrame for easy access
    cols = []
    for ind in range(pairs):
        cols.append(f"generate_{ind}")
    cols.append('probability')
    train_ordered = train[cols]

    # Determine chosen and rejected items based on maxmin indices
    chosen = [train_ordered.iloc[i, maxmin[i][0]] for i in range(len(train_ordered))]
    rejected = [train_ordered.iloc[i, maxmin[i][1]] for i in range(len(train_ordered))]

    # Calculate probabilities for chosen and rejected items
    chosen_probs = [train_ordered['probability'].iloc[i][maxmin[i][0]][maxmin[i][1]] for i in range(len(train_ordered))]
    chosen_probs_win = [metrics_prob[i][maxmin[i][0]] / len(metrics_prob.iloc[0]) for i in range(len(metrics_prob))]
    chosen_probs_lose = [metrics_prob[i][maxmin[i][1]] / len(metrics_prob.iloc[0]) for i in range(len(metrics_prob))]

    # Create a new DataFrame with the results
    train_new = pd.DataFrame({
        'chosen': chosen,
        'rejected': rejected,
        'chosen_probs': chosen_probs,
        'chosen_probs_win': chosen_probs_win,
        'chosen_probs_lose': chosen_probs_lose
    })

    # Determine output directory
    #output_dir = '-'.join(output_dir.split('-')[1:])
    OUTPATH = f'{sppo_temp_dir}/synthetic_data_SPPO-Iter{iter_num}_score'
    os.makedirs(OUTPATH, exist_ok=True)

    # Save train and test datasets to parquet files
    train_new.to_parquet(f'{OUTPATH}/train.parquet', index=False)
    print(f"Saved file to {OUTPATH}/train.parquet")

    # Temporary solution to make the code run, cannot use for test/evaluation purpose
    test = train_new.sample(n=int(0.1*len(train_new)))
    test.to_parquet(f'{OUTPATH}/test.parquet', index=False)
    print(f"Saved file to {OUTPATH}/test.parquet")

    return OUTPATH
process_dataset(raw_dataset, tokenizer)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def process_dataset(raw_dataset, tokenizer):
    column_names = list(raw_dataset.features)
    column_names = [x for x in column_names if x not in ['chosen_probs', 'chosen_probs_win', 'chosen_probs_lose']]

    raw_dataset = raw_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer, "skip_system_message": True},
        remove_columns=column_names,
        desc="Formatting comparisons with prompt template",
    )


    raw_dataset = raw_dataset.rename_columns(
        {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
    )

    return raw_dataset
ranking(sppo_temp_dir, iter_num, prompts, candidates)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
30
31
32
33
34
def ranking(sppo_temp_dir, iter_num, prompts, candidates):
    blender = llm_blender.Blender()
    blender.loadranker("llm-blender/PairRM")
    ranks = blender.rank(prompts, candidates, return_scores=True, batch_size=1)
    np.save(f"{sppo_temp_dir}/ranking/SPPO-Iter{iter_num}/ranking.npy", ranks)
set_seed(seed=5775709)
Source code in aisteer360/algorithms/structural_control/wrappers/trl/sppotrainer/utils.py
13
14
15
16
17
def set_seed(seed=5775709):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

evaluation

benchmark

Benchmark

Benchmark framework for comparing steering pipelines on specific use cases.

Provides a standardized way to compare different steering control configurations against a baseline model on a given evaluation task. Handles the complete benchmark workflow: model loading, generation, and evaluation.

The benchmark runs each control pipeline configuration independently, allowing for fair comparison of controls on a common task.

Parameters:

Name Type Description Default
use_case UseCase

The evaluation task defining prompts, generation logic, and metrics. Must implement generate() and evaluate() methods.

required
base_model_name_or_path str | Path

HuggingFace model identifier or local path to the base model. Used for all pipeline configurations and baseline.

required
steering_pipelines dict[str, list[Any]]

Named configurations of steering pipelines. Keys are configuration names (e.g., "baseline", "with_activation_steering"). Values are pipelines, e.g., lists of controls (StructuralControl, StateControl, etc.). Empty list or None creates a baseline configuration without steering.

required
runtime_overrides dict[str, dict[str, Any]]

Runtime parameters for specific pipeline configurations. Outer keys match control_pipelines keys, inner dicts contain runtime kwargs passed to controls during generation. Defaults to None.

None
hf_model_kwargs dict

Additional arguments passed to AutoModelForCausalLM.from_pretrained(). Defaults to {}.

None
gen_kwargs dict

Generation parameters passed to model.generate(). Defaults to {}.

None
device_map str

Device placement strategy for model loading. Defaults to "auto".

'auto'
Source code in aisteer360/evaluation/benchmark.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class Benchmark:
    """Benchmark framework for comparing steering pipelines on specific use cases.

    Provides a standardized way to compare different steering control configurations against a baseline model on a given
    evaluation task. Handles the complete benchmark workflow: model loading, generation, and evaluation.

    The benchmark runs each control pipeline configuration independently, allowing for fair comparison of controls on a
    common task.

    Args:
        use_case (UseCase): The evaluation task defining prompts, generation logic, and metrics.
            Must implement `generate()` and `evaluate()` methods.
        base_model_name_or_path (str | Path): HuggingFace model identifier or local path to the base model.
            Used for all pipeline configurations and baseline.
        steering_pipelines (dict[str, list[Any]]): Named configurations of steering pipelines.
            Keys are configuration names (e.g., "baseline", "with_activation_steering").
            Values are pipelines, e.g., lists of controls (StructuralControl, StateControl, etc.).
            Empty list or None creates a baseline configuration without steering.
        runtime_overrides (dict[str, dict[str, Any]], optional): Runtime parameters for specific pipeline
            configurations. Outer keys match `control_pipelines` keys,
            inner dicts contain runtime kwargs passed to controls during generation.
            Defaults to None.
        hf_model_kwargs (dict, optional): Additional arguments passed to `AutoModelForCausalLM.from_pretrained()`.
            Defaults to {}.
        gen_kwargs (dict, optional): Generation parameters passed to model.generate().
            Defaults to {}.
        device_map (str, optional): Device placement strategy for model loading.
            Defaults to "auto".
        """
    def __init__(
            self,
            use_case: UseCase,
            base_model_name_or_path: str | Path,
            steering_pipelines: dict[str, list[Any]],
            runtime_overrides: dict[str, dict[str, Any]] | None = None,
            hf_model_kwargs: dict | None = None,
            gen_kwargs: dict | None = None,
            device_map: str = "auto"
    ) -> None:
        self.use_case = use_case
        self.base_model_name_or_path = base_model_name_or_path
        self.steering_pipelines = steering_pipelines
        self.runtime_overrides = runtime_overrides
        self.hf_model_kwargs = hf_model_kwargs or {}
        self.gen_kwargs = gen_kwargs or {}
        self.device_map = device_map

    def run(self) -> dict[str, Any]:
        """Run benchmark on all configured steering pipelines.

        Executes the benchmark by iterating through each pipeline configuration defined in `control_pipelines`. For each
        configuration, calls `_run_pipeline()` to handle model setup, generation, and evaluation. Results from all
        pipelines are collected for comparison.

        Returns:
            Benchmark profiles for all pipeline configurations. Keys are pipeline names from `control_pipelines`. Values are dicts containing:

                - "generations": Generated outputs from the model
                - "evaluations": Evaluation scores from the use case metrics
        """
        profiles = {}

        for steering_pipeline_name, steering_pipeline in self.steering_pipelines.items():

            print(f"Running pipeline: {steering_pipeline_name}...", flush=True)

            profile = self._run_pipeline(steering_pipeline)
            profiles[steering_pipeline_name] = profile

            print("done.")

        return profiles

    def _run_pipeline(self, steering_pipeline: list[Any]) -> dict[str, Any]:
        """Run steering pipeline."""

        model = None
        pipeline = None
        tokenizer = None

        try:

            if steering_pipeline:

                # todo: determine if lazy_init needed; raise warnings/errors according

                # build pipeline and steer
                pipeline = SteeringPipeline(
                    model_name_or_path=self.base_model_name_or_path,
                    controls=steering_pipeline,
                    device_map=self.device_map,
                    hf_model_kwargs=self.hf_model_kwargs,
                )

                # todo: check if steer_kwargs are necessary
                # steerer = steerer.steer(**steer_kwargs)
                pipeline.steer()

                tokenizer = pipeline.tokenizer
                model_or_pipeline = pipeline

            else:  # baseline

                model = AutoModelForCausalLM.from_pretrained(
                    self.base_model_name_or_path,
                    device_map=self.device_map,
                    **self.hf_model_kwargs
                )
                tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
                tokenizer = ensure_pad_token(tokenizer)
                model_or_pipeline = model

            # generate
            generations = self.use_case.generate(
                model_or_pipeline=model_or_pipeline,
                tokenizer=tokenizer,
                gen_kwargs=self.gen_kwargs,
                runtime_overrides=self.runtime_overrides
            )

            # evaluate
            scores = self.use_case.evaluate(generations)

            return {
                "generations": generations,
                "evaluations": scores
            }

        finally:  # cleanup

            if model is not None:
                del model

            if pipeline is not None:
                del pipeline

            if tokenizer is not None:
                del tokenizer

            gc.collect()

            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

    def export(self, profiles: dict[str, Any], save_dir: str):
        """Export benchmark results to disk.

        Saves the benchmark profiles to the specified directory. Creates the directory if it doesn't exist. Delegates
        the actual export logic to the use case's export method, which handles format-specific serialization.

        Args:
            profiles (dict[str, Any]): Benchmark results from `run()` method.
                Contains generations and evaluations for each pipeline configuration.
            save_dir (str): Directory path where results will be saved.
                Will be created if it doesn't exist.
        """
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)
        self.use_case.export(profiles, save_dir)
base_model_name_or_path = base_model_name_or_path instance-attribute
device_map = device_map instance-attribute
gen_kwargs = gen_kwargs or {} instance-attribute
hf_model_kwargs = hf_model_kwargs or {} instance-attribute
runtime_overrides = runtime_overrides instance-attribute
steering_pipelines = steering_pipelines instance-attribute
use_case = use_case instance-attribute
export(profiles, save_dir)

Export benchmark results to disk.

Saves the benchmark profiles to the specified directory. Creates the directory if it doesn't exist. Delegates the actual export logic to the use case's export method, which handles format-specific serialization.

Parameters:

Name Type Description Default
profiles dict[str, Any]

Benchmark results from run() method. Contains generations and evaluations for each pipeline configuration.

required
save_dir str

Directory path where results will be saved. Will be created if it doesn't exist.

required
Source code in aisteer360/evaluation/benchmark.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def export(self, profiles: dict[str, Any], save_dir: str):
    """Export benchmark results to disk.

    Saves the benchmark profiles to the specified directory. Creates the directory if it doesn't exist. Delegates
    the actual export logic to the use case's export method, which handles format-specific serialization.

    Args:
        profiles (dict[str, Any]): Benchmark results from `run()` method.
            Contains generations and evaluations for each pipeline configuration.
        save_dir (str): Directory path where results will be saved.
            Will be created if it doesn't exist.
    """
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)
    self.use_case.export(profiles, save_dir)
run()

Run benchmark on all configured steering pipelines.

Executes the benchmark by iterating through each pipeline configuration defined in control_pipelines. For each configuration, calls _run_pipeline() to handle model setup, generation, and evaluation. Results from all pipelines are collected for comparison.

Returns:

Type Description
dict[str, Any]

Benchmark profiles for all pipeline configurations. Keys are pipeline names from control_pipelines. Values are dicts containing:

  • "generations": Generated outputs from the model
  • "evaluations": Evaluation scores from the use case metrics
Source code in aisteer360/evaluation/benchmark.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def run(self) -> dict[str, Any]:
    """Run benchmark on all configured steering pipelines.

    Executes the benchmark by iterating through each pipeline configuration defined in `control_pipelines`. For each
    configuration, calls `_run_pipeline()` to handle model setup, generation, and evaluation. Results from all
    pipelines are collected for comparison.

    Returns:
        Benchmark profiles for all pipeline configurations. Keys are pipeline names from `control_pipelines`. Values are dicts containing:

            - "generations": Generated outputs from the model
            - "evaluations": Evaluation scores from the use case metrics
    """
    profiles = {}

    for steering_pipeline_name, steering_pipeline in self.steering_pipelines.items():

        print(f"Running pipeline: {steering_pipeline_name}...", flush=True)

        profile = self._run_pipeline(steering_pipeline)
        profiles[steering_pipeline_name] = profile

        print("done.")

    return profiles

metrics

Base classes for evaluation metrics.

Contains two classes:

  • Metric: Base class for all evaluation metrics.
  • LLMJudgeMetric: Base class for LLM-as-a-judge metrics (subclasses Metric)
Factuality

Bases: LLMJudgeMetric

Judge factual correctness of a response to a prompt.

Source code in aisteer360/evaluation/metrics/generic/factuality.py
19
20
21
22
23
24
25
26
27
28
29
30
class Factuality(LLMJudgeMetric):
    """
    Judge factual correctness of a response to a prompt.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            prompt_template=_PROMPT,
            scale=(1, 5),
            **kwargs,
        )
base_prompt_template = prompt_template.strip() instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
format_instructions = self.output_parser.get_format_instructions() instance-attribute
max_retries = max_retries instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
num_return_sequences = int(gen_kwargs.pop('num_return_sequences', 1)) instance-attribute
pipeline = TextGenerationPipeline(model=(self.model), tokenizer=(self.tokenizer)) instance-attribute
scale = scale instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
use_chat = hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None instance-attribute
compute(responses, prompts=None, **kwargs)

Compute LLM judge scores for a list of responses.

Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple samples are generated per response (via num_return_sequences).

Parameters:

Name Type Description Default
responses list[str]

List of text responses to evaluate.

required
prompts list[str] | None

Optional list of prompts corresponding to each response. If provided, must be the same length as responses. These prompts can be referenced in the prompt_template using the {prompt} placeholder.

None
**kwargs Any

Additional keyword arguments (currently unused).

{}

Returns:

Type Description
dict[str, float | list[float]]

Score statistics containing:

  • "mean_score": Overall average score across all responses
  • "scores": List of mean scores for each response (averaged across samples)
  • "raw_scores": List of lists containing all individual scores for each response

Raises:

Type Description
AssertionError

If prompts is provided but has different length than responses.

Source code in aisteer360/evaluation/metrics/base_judge.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@torch.inference_mode()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, float | list[float]]:
    """Compute LLM judge scores for a list of responses.

    Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
    samples are generated per response (via `num_return_sequences`).

    Args:
        responses (list[str]): List of text responses to evaluate.
        prompts (list[str] | None): Optional list of prompts corresponding to each response.
            If provided, must be the same length as responses. These prompts can be
            referenced in the prompt_template using the {prompt} placeholder.
        **kwargs: Additional keyword arguments (currently unused).

    Returns:
        Score statistics containing:

            - "mean_score": Overall average score across all responses
            - "scores": List of mean scores for each response (averaged across samples)
            - "raw_scores": List of lists containing all individual scores for each response

    Raises:
        AssertionError: If prompts is provided but has different length than responses.
    """

    if prompts is not None and len(prompts) != len(responses):
        raise AssertionError("`responses` and `prompts` must be the same length")

    # build prompts
    prompts_list: list[str] = []
    for i in range(len(responses)):
        fields: dict[str, str | float] = {
            "response": responses[i],
            "lower_bound": self.scale[0],
            "upper_bound": self.scale[1],
        }
        if prompts is not None:
            fields["prompt"] = prompts[i]

        prompt_core = self.base_prompt_template.format(**fields)
        prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
        prompts_list.append(prompt_formatted)

    # generate
    prompt_scores: list[list[float]] = []
    for batch in self._batch_chunks(prompts_list, self.batch_size):
        outputs = self.pipeline(
            batch,
            num_return_sequences=self.num_return_sequences,
            return_full_text=False,
            clean_up_tokenization_spaces=True,
        )

        for prompt, generations in zip(batch, outputs):
            generations = generations if isinstance(generations, list) else [generations]
            assert len(generations) == self.num_return_sequences

            scores = []
            for generation in generations:
                reply_text = generation["generated_text"]
                try:
                    score = self.parse_fn(reply_text, self.scale)
                except Exception:
                    score = self._score_with_retries(prompt)
                scores.append(score)

            prompt_scores.append(scores)

    mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
    corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

    return {
        "mean_score": corpus_mean,  # overall average
        "scores": mean_per_prompt,  # one number per original prompt
        "raw_scores": prompt_scores  # n_samples scores per prompt
    }
LLMJudgeMetric

Bases: Metric

Base class for LLM-as-a-judge evaluation metrics.

Leverages a language model to evaluate the quality of generated text responses according to customized (natural language) criteria. The judge model evaluates each response (optionally with respect to an associated prompt and context) and returns numerical scores within a specified range. When multiple samples are generated per prompt (via num_return_sequences), scores are averaged to improve reliability.

Subclasses should define their specific evaluation criteria by providing a prompt_template that instructs the judge model how to score responses. The template should use placeholders {response}, {lower_bound}, and {upper_bound} (and optionally {prompt} and {context}). Subclasses typically override __init__() to set their specific prompt template and scoring scale (e.g., see metrics.generic.relevance).

Parameters:

Name Type Description Default
model_or_id str | PreTrainedModel

HuggingFace model ID or loaded model instance to use as the judge. If string, the model will be loaded automatically.

required
prompt_template str

Template string for evaluation prompts. Should contain placeholders for {response}, {lower_bound}, {upper_bound}, and optionally {prompt}, {context}. The formatted prompt will be passed to the judge model.

required
tokenizer Any | None

Tokenizer for the judge model. If None, will be loaded from the model ID. Required if passing a PreTrainedModel instance.

None
device str | None

Device for model inference ('cuda', 'mps', 'cpu'). Defaults to GPU if available, otherwise CPU.

None
scale tuple[float, float]

Score range as (min, max) tuple. Scores outside this range will be clamped. Defaults to (1, 5).

(1, 5)
batch_size int

Number of prompts to process simultaneously. Defaults to 8.

8
max_retries int

Maximum retry attempts when score parsing fails. Defaults to 5.

5
gen_kwargs dict[str, Any] | None

Generation parameters passed to the model.

None
Source code in aisteer360/evaluation/metrics/base_judge.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class LLMJudgeMetric(Metric):
    """Base class for LLM-as-a-judge evaluation metrics.

    Leverages a language model to evaluate the quality of generated text responses according to customized (natural
    language) criteria. The judge model evaluates each response (optionally with respect to an associated prompt and
    context) and returns numerical scores within a specified range. When multiple samples are generated per prompt (via
    num_return_sequences), scores are averaged to improve reliability.

    Subclasses should define their specific evaluation criteria by providing a `prompt_template` that instructs the
    judge model how to score responses. The template should use placeholders {response}, {lower_bound}, and
    {upper_bound} (and optionally {prompt} and {context}). Subclasses typically override `__init__()` to set their
    specific prompt template and scoring scale (e.g., see `metrics.generic.relevance`).

    Args:
        model_or_id (str | PreTrainedModel): HuggingFace model ID or loaded model instance to use as the judge.
            If string, the model will be loaded automatically.
        prompt_template (str): Template string for evaluation prompts. Should contain placeholders for {response},
            {lower_bound}, {upper_bound}, and optionally {prompt}, {context}.
            The formatted prompt will be passed to the judge model.
        tokenizer (Any | None): Tokenizer for the judge model. If None, will be loaded from the model ID.
            Required if passing a PreTrainedModel instance.
        device (str | None): Device for model inference ('cuda', 'mps', 'cpu').
            Defaults to GPU if available, otherwise CPU.
        scale (tuple[float, float]): Score range as (min, max) tuple. Scores outside this range will be clamped.
            Defaults to (1, 5).
        batch_size (int): Number of prompts to process simultaneously. Defaults to 8.
        max_retries (int): Maximum retry attempts when score parsing fails. Defaults to 5.
        gen_kwargs (dict[str, Any] | None): Generation parameters passed to the model.
    """

    def __init__(
        self,
        model_or_id: str | PreTrainedModel,
        prompt_template: str,
        tokenizer: Any | None = None,
        device: str | None = None,
        scale: tuple[float, float] = (1, 5),
        batch_size: int = 8,
        max_retries: int = 5,
        gen_kwargs: dict[str, Any] | None = None,
    ):
        super().__init__()

        if isinstance(model_or_id, str):
            self.model = AutoModelForCausalLM.from_pretrained(model_or_id)
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id)
        else:  # model
            self.model = model_or_id
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id.config._name_or_path)

        self.use_chat = hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template is not None
        self.device = device or (
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
        self.model.to(self.device).eval()

        gen_kwargs = dict(gen_kwargs or {})
        gen_kwargs.setdefault("temperature", 0.0)
        gen_kwargs.setdefault("max_new_tokens", 30)
        gen_kwargs.setdefault("pad_token_id", self.tokenizer.eos_token_id)

        self.num_return_sequences: int = int(gen_kwargs.pop("num_return_sequences", 1))
        self.model.generation_config = GenerationConfig(**gen_kwargs)

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        self.pipeline = TextGenerationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
        )

        self.scale = scale
        self.output_parser, self.parse_fn = build_structured_parser(scale)
        self.base_prompt_template = prompt_template.strip()
        self.format_instructions = self.output_parser.get_format_instructions()
        self.batch_size = batch_size
        self.max_retries = max_retries

    def _wrap(self, prompt: str) -> str:
        """Wrap prompt with appropriate formatting for the model.

        Applies the chat template (if the model supports it) with the prompt as a user message.
        Otherwise, returns the prompt unchanged.

        Args:
            prompt (str): The user prompt.

        Returns:
            str: The formatted prompt.
        """
        if self.use_chat:
            messages = [{"role": "user", "content": prompt}]
            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        return prompt

    @staticmethod
    def _batch_chunks(seq: Sequence[Any], chunk_size: int) -> Iterable[Sequence[Any]]:
        """Split a sequence into chunks of specified size.

        Args:
            seq (Sequence[Any]): The sequence to split into chunks.
            chunk_size (int): Maximum size of each chunk.

        Yields:
            Sequence[Any]: Chunks of the input sequence, each with at most chunk_size elements.
        """
        for i in range(0, len(seq), chunk_size):
            yield seq[i: i + chunk_size]

    def _score_with_retries(self, prompt: str) -> float:
        """Generate replies until parsing succeeds or maximum retries reached.

        Attempts to generate a response and parse it (using `parse_fn`) as a score.
        If parsing fails, retries up to `max_retries` times.
        If all attempts fail, raises a warning and returns `float('nan')`.

        Args:
            prompt (str): The formatted prompt to send to the model.

        Returns:
            float: The parsed score from the model's response, or `float('nan')` if parsing fails.
        """
        for attempt in range(self.max_retries + 1):
            reply_text = self.pipeline(
                prompt,
                clean_up_tokenization_spaces=True,
                return_full_text=False
            )[0]["generated_text"]

            try:
                return self.parse_fn(reply_text, self.scale)
            except Exception:
                if attempt == self.max_retries:
                    warnings.warn(
                        f"Failed to parse score after {self.max_retries + 1} attempts. "
                        "Returning float('nan') instead."
                    )
                    return float('nan')

    @torch.inference_mode()
    def compute(
        self,
        responses: list[str],
        prompts: list[str] | None = None,
        **kwargs: Any,
    ) -> dict[str, float | list[float]]:
        """Compute LLM judge scores for a list of responses.

        Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
        samples are generated per response (via `num_return_sequences`).

        Args:
            responses (list[str]): List of text responses to evaluate.
            prompts (list[str] | None): Optional list of prompts corresponding to each response.
                If provided, must be the same length as responses. These prompts can be
                referenced in the prompt_template using the {prompt} placeholder.
            **kwargs: Additional keyword arguments (currently unused).

        Returns:
            Score statistics containing:

                - "mean_score": Overall average score across all responses
                - "scores": List of mean scores for each response (averaged across samples)
                - "raw_scores": List of lists containing all individual scores for each response

        Raises:
            AssertionError: If prompts is provided but has different length than responses.
        """

        if prompts is not None and len(prompts) != len(responses):
            raise AssertionError("`responses` and `prompts` must be the same length")

        # build prompts
        prompts_list: list[str] = []
        for i in range(len(responses)):
            fields: dict[str, str | float] = {
                "response": responses[i],
                "lower_bound": self.scale[0],
                "upper_bound": self.scale[1],
            }
            if prompts is not None:
                fields["prompt"] = prompts[i]

            prompt_core = self.base_prompt_template.format(**fields)
            prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
            prompts_list.append(prompt_formatted)

        # generate
        prompt_scores: list[list[float]] = []
        for batch in self._batch_chunks(prompts_list, self.batch_size):
            outputs = self.pipeline(
                batch,
                num_return_sequences=self.num_return_sequences,
                return_full_text=False,
                clean_up_tokenization_spaces=True,
            )

            for prompt, generations in zip(batch, outputs):
                generations = generations if isinstance(generations, list) else [generations]
                assert len(generations) == self.num_return_sequences

                scores = []
                for generation in generations:
                    reply_text = generation["generated_text"]
                    try:
                        score = self.parse_fn(reply_text, self.scale)
                    except Exception:
                        score = self._score_with_retries(prompt)
                    scores.append(score)

                prompt_scores.append(scores)

        mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
        corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

        return {
            "mean_score": corpus_mean,  # overall average
            "scores": mean_per_prompt,  # one number per original prompt
            "raw_scores": prompt_scores  # n_samples scores per prompt
        }
base_prompt_template = prompt_template.strip() instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
format_instructions = self.output_parser.get_format_instructions() instance-attribute
max_retries = max_retries instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
num_return_sequences = int(gen_kwargs.pop('num_return_sequences', 1)) instance-attribute
pipeline = TextGenerationPipeline(model=(self.model), tokenizer=(self.tokenizer)) instance-attribute
scale = scale instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
use_chat = hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None instance-attribute
compute(responses, prompts=None, **kwargs)

Compute LLM judge scores for a list of responses.

Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple samples are generated per response (via num_return_sequences).

Parameters:

Name Type Description Default
responses list[str]

List of text responses to evaluate.

required
prompts list[str] | None

Optional list of prompts corresponding to each response. If provided, must be the same length as responses. These prompts can be referenced in the prompt_template using the {prompt} placeholder.

None
**kwargs Any

Additional keyword arguments (currently unused).

{}

Returns:

Type Description
dict[str, float | list[float]]

Score statistics containing:

  • "mean_score": Overall average score across all responses
  • "scores": List of mean scores for each response (averaged across samples)
  • "raw_scores": List of lists containing all individual scores for each response

Raises:

Type Description
AssertionError

If prompts is provided but has different length than responses.

Source code in aisteer360/evaluation/metrics/base_judge.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@torch.inference_mode()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, float | list[float]]:
    """Compute LLM judge scores for a list of responses.

    Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
    samples are generated per response (via `num_return_sequences`).

    Args:
        responses (list[str]): List of text responses to evaluate.
        prompts (list[str] | None): Optional list of prompts corresponding to each response.
            If provided, must be the same length as responses. These prompts can be
            referenced in the prompt_template using the {prompt} placeholder.
        **kwargs: Additional keyword arguments (currently unused).

    Returns:
        Score statistics containing:

            - "mean_score": Overall average score across all responses
            - "scores": List of mean scores for each response (averaged across samples)
            - "raw_scores": List of lists containing all individual scores for each response

    Raises:
        AssertionError: If prompts is provided but has different length than responses.
    """

    if prompts is not None and len(prompts) != len(responses):
        raise AssertionError("`responses` and `prompts` must be the same length")

    # build prompts
    prompts_list: list[str] = []
    for i in range(len(responses)):
        fields: dict[str, str | float] = {
            "response": responses[i],
            "lower_bound": self.scale[0],
            "upper_bound": self.scale[1],
        }
        if prompts is not None:
            fields["prompt"] = prompts[i]

        prompt_core = self.base_prompt_template.format(**fields)
        prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
        prompts_list.append(prompt_formatted)

    # generate
    prompt_scores: list[list[float]] = []
    for batch in self._batch_chunks(prompts_list, self.batch_size):
        outputs = self.pipeline(
            batch,
            num_return_sequences=self.num_return_sequences,
            return_full_text=False,
            clean_up_tokenization_spaces=True,
        )

        for prompt, generations in zip(batch, outputs):
            generations = generations if isinstance(generations, list) else [generations]
            assert len(generations) == self.num_return_sequences

            scores = []
            for generation in generations:
                reply_text = generation["generated_text"]
                try:
                    score = self.parse_fn(reply_text, self.scale)
                except Exception:
                    score = self._score_with_retries(prompt)
                scores.append(score)

            prompt_scores.append(scores)

    mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
    corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

    return {
        "mean_score": corpus_mean,  # overall average
        "scores": mean_per_prompt,  # one number per original prompt
        "raw_scores": prompt_scores  # n_samples scores per prompt
    }
Metric

Bases: ABC

Base-class for evaluation metrics.

Provides a standardized interface for computing evaluation scores on model-generated responses. Subclasses should define their specific scoring logic in compute() and can accept additional configuration through constructor arguments stored in extras.

Source code in aisteer360/evaluation/metrics/base.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Metric(ABC):
    """
    Base-class for evaluation metrics.

    Provides a standardized interface for computing evaluation scores on model-generated responses. Subclasses should
    define their specific scoring logic in `compute()` and can accept additional configuration through constructor
    arguments stored in `extras`.

    Args:
        **extras
            Required extras for the metric (e.g., LLM, tokenizer, etc.)
    """
    def __init__(self, **extras: Any) -> None:
        self.name: str = self.__class__.__name__
        self.extras: dict[str, Any] = extras

    @abstractmethod
    def compute(
        self,
        responses: list[Any],
        prompts: list[str] | None = None,
        **kwargs: Any,
    ) -> dict[str, Any]:
        """Base compute method."""
        raise NotImplementedError

    def __call__(self, *args, **kwargs):
        return self.compute(*args, **kwargs)
extras = extras instance-attribute
name = self.__class__.__name__ instance-attribute
compute(responses, prompts=None, **kwargs) abstractmethod

Base compute method.

Source code in aisteer360/evaluation/metrics/base.py
21
22
23
24
25
26
27
28
29
@abstractmethod
def compute(
    self,
    responses: list[Any],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, Any]:
    """Base compute method."""
    raise NotImplementedError
Perplexity

Bases: Metric

Compute token-level perplexity for a batch of sentences.

Perplexity is the exponentiated mean cross-entropy between the language model’s predicted distribution and the true next token. Lower is better.

Parameters:

Name Type Description Default
model_or_id str | Module

Hugging Face model ID or an already-instantiated causal language model.

required
tokenizer PreTrainedTokenizer | None

Tokenizer to use. Leave None when passing a model ID to automatically load the matching tokenizer. Defaults to None.

None
batch_size int

Number of sentences per forward pass. Higher is faster until GPU memory becomes the bottleneck. Defaults to 16.

16
add_bos bool

Whether to prepend the tokenizer’s BOS token so the first word in each sentence is also scored. Ignored if the tokenizer has no BOS token. Defaults to True.

True
max_length int | None

If set, truncate inputs to this length so they fit the model’s context window. None disables truncation. Defaults to None.

None
device str | None

"cuda" or "cpu". When None, automatically selects GPU if available. Defaults to None.

None

Attributes:

Name Type Description
add_bos bool

Whether a BOS token is prepended before scoring.

batch_size int

Number of sentences processed per forward pass.

device str

The device actually selected for computation ("cuda" or "cpu").

max_length int | None

Truncation length for inputs, or None for no truncation.

model PreTrainedModel

The loaded causal language model used to score tokens.

tokenizer PreTrainedTokenizer

Tokenizer used for encoding, padding, and BOS handling.

Source code in aisteer360/evaluation/metrics/generic/perplexity.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Perplexity(Metric):
    """Compute token-level perplexity for a batch of sentences.

    Perplexity is the exponentiated mean cross-entropy between the language model’s predicted distribution and the true
    next token. Lower is better.

    Args:
        model_or_id (str | torch.nn.Module): Hugging Face model ID or an already-instantiated causal language model.
        tokenizer (transformers.PreTrainedTokenizer | None, optional):
            Tokenizer to use.  Leave ``None`` when passing a model ID to automatically load the matching tokenizer.
            Defaults to ``None``.
        batch_size (int, optional): Number of sentences per forward pass. Higher is faster until GPU memory becomes the
            bottleneck. Defaults to ``16``.
        add_bos (bool, optional): Whether to prepend the tokenizer’s BOS token so the first word in each sentence is
            also scored. Ignored if the tokenizer has no BOS token. Defaults to ``True``.
        max_length (int | None, optional): If set, truncate inputs to this length so they fit the model’s context
            window. ``None`` disables truncation. Defaults to ``None``.
        device (str | None, optional): ``"cuda"`` or ``"cpu"``. When ``None``, automatically selects GPU if available.
            Defaults to ``None``.

    Attributes:
        add_bos (bool): Whether a BOS token is prepended before scoring.
        batch_size (int): Number of sentences processed per forward pass.
        device (str): The device actually selected for computation (``"cuda"`` or ``"cpu"``).
        max_length (int | None): Truncation length for inputs, or ``None`` for no truncation.
        model (transformers.PreTrainedModel): The loaded causal language model used to score tokens.
        tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for encoding, padding, and BOS handling.
    """

    def __init__(
        self,
        model_or_id: str | torch.nn.Module,
        tokenizer: Any | None = None,
        batch_size: int = 16,
        add_bos: bool = True,
        max_length: int | None = None,
        device: str | None = None,
    ):
        super().__init__()

        if isinstance(model_or_id, str):
            self.model = AutoModelForCausalLM.from_pretrained(model_or_id)
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id)
        else:  # model object
            self.model = model_or_id
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id.config._name_or_path)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device).eval()
        self.batch_size = batch_size
        self.add_bos = add_bos and (self.tokenizer.bos_token_id is not None)
        self.max_length = max_length

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = (
                self.tokenizer.eos_token
                or self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            )

    @torch.no_grad()
    def compute(
        self,
        responses: list[str],
        prompts: list[str] | None = None,
    ) -> dict[str, float]:
        """Compute perplexity for each response (and the mean across the batch).

        Args:
            responses (list[str]): Text sequences to score.
            prompts (list[str] | None, optional): Unused here; present for a uniform metric API.

        Returns:
            dict[str, float]: A dict with keys:

                - ``"mean_perplexity"``: mean perplexity over all inputs.
                - ``"perplexities"``: list of per-sample perplexities in input order.
        """
        perplexities: list[float] = []
        local_batch_size = self.batch_size

        for i in range(0, len(responses), local_batch_size):
            batch = responses[i : i + local_batch_size]

            encoding = self.tokenizer(
                batch,
                padding=True,
                truncation=self.max_length is not None,
                max_length=self.max_length,
                add_special_tokens=False,
                return_tensors="pt",
            ).to(self.device)
            input_ids = encoding["input_ids"]

            if self.add_bos:
                bos_tokens = torch.full(
                    (input_ids.size(0), 1),
                    self.tokenizer.bos_token_id,
                    device=self.device,
                )
                input_ids = torch.cat([bos_tokens, input_ids], dim=1)

            logits = self.model(input_ids).logits[:, :-1]
            labels = input_ids[:, 1:]

            loss_per_token = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                labels.reshape(-1),
                reduction="none",
            ).view(labels.size())

            mask = labels.ne(self.tokenizer.pad_token_id)
            seq_loss = (loss_per_token * mask).sum(1) / mask.sum(1)

            perplexities.extend(torch.exp(seq_loss).cpu().tolist())

        return {
            "mean_perplexity": sum(perplexities) / len(perplexities),
            "perplexities": perplexities,
        }
add_bos = add_bos and self.tokenizer.bos_token_id is not None instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
max_length = max_length instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
compute(responses, prompts=None)

Compute perplexity for each response (and the mean across the batch).

Parameters:

Name Type Description Default
responses list[str]

Text sequences to score.

required
prompts list[str] | None

Unused here; present for a uniform metric API.

None

Returns:

Type Description
dict[str, float]

dict[str, float]: A dict with keys:

  • "mean_perplexity": mean perplexity over all inputs.
  • "perplexities": list of per-sample perplexities in input order.
Source code in aisteer360/evaluation/metrics/generic/perplexity.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@torch.no_grad()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
) -> dict[str, float]:
    """Compute perplexity for each response (and the mean across the batch).

    Args:
        responses (list[str]): Text sequences to score.
        prompts (list[str] | None, optional): Unused here; present for a uniform metric API.

    Returns:
        dict[str, float]: A dict with keys:

            - ``"mean_perplexity"``: mean perplexity over all inputs.
            - ``"perplexities"``: list of per-sample perplexities in input order.
    """
    perplexities: list[float] = []
    local_batch_size = self.batch_size

    for i in range(0, len(responses), local_batch_size):
        batch = responses[i : i + local_batch_size]

        encoding = self.tokenizer(
            batch,
            padding=True,
            truncation=self.max_length is not None,
            max_length=self.max_length,
            add_special_tokens=False,
            return_tensors="pt",
        ).to(self.device)
        input_ids = encoding["input_ids"]

        if self.add_bos:
            bos_tokens = torch.full(
                (input_ids.size(0), 1),
                self.tokenizer.bos_token_id,
                device=self.device,
            )
            input_ids = torch.cat([bos_tokens, input_ids], dim=1)

        logits = self.model(input_ids).logits[:, :-1]
        labels = input_ids[:, 1:]

        loss_per_token = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1),
            reduction="none",
        ).view(labels.size())

        mask = labels.ne(self.tokenizer.pad_token_id)
        seq_loss = (loss_per_token * mask).sum(1) / mask.sum(1)

        perplexities.extend(torch.exp(seq_loss).cpu().tolist())

    return {
        "mean_perplexity": sum(perplexities) / len(perplexities),
        "perplexities": perplexities,
    }
Relevance

Bases: LLMJudgeMetric

Judge relevance of a response to a prompt.

Source code in aisteer360/evaluation/metrics/generic/relevance.py
19
20
21
22
23
24
25
26
27
28
29
30
class Relevance(LLMJudgeMetric):
    """
    Judge relevance of a response to a prompt.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            prompt_template=_PROMPT,
            scale=(1, 5),
            **kwargs,
        )
base_prompt_template = prompt_template.strip() instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
format_instructions = self.output_parser.get_format_instructions() instance-attribute
max_retries = max_retries instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
num_return_sequences = int(gen_kwargs.pop('num_return_sequences', 1)) instance-attribute
pipeline = TextGenerationPipeline(model=(self.model), tokenizer=(self.tokenizer)) instance-attribute
scale = scale instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
use_chat = hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None instance-attribute
compute(responses, prompts=None, **kwargs)

Compute LLM judge scores for a list of responses.

Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple samples are generated per response (via num_return_sequences).

Parameters:

Name Type Description Default
responses list[str]

List of text responses to evaluate.

required
prompts list[str] | None

Optional list of prompts corresponding to each response. If provided, must be the same length as responses. These prompts can be referenced in the prompt_template using the {prompt} placeholder.

None
**kwargs Any

Additional keyword arguments (currently unused).

{}

Returns:

Type Description
dict[str, float | list[float]]

Score statistics containing:

  • "mean_score": Overall average score across all responses
  • "scores": List of mean scores for each response (averaged across samples)
  • "raw_scores": List of lists containing all individual scores for each response

Raises:

Type Description
AssertionError

If prompts is provided but has different length than responses.

Source code in aisteer360/evaluation/metrics/base_judge.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@torch.inference_mode()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, float | list[float]]:
    """Compute LLM judge scores for a list of responses.

    Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
    samples are generated per response (via `num_return_sequences`).

    Args:
        responses (list[str]): List of text responses to evaluate.
        prompts (list[str] | None): Optional list of prompts corresponding to each response.
            If provided, must be the same length as responses. These prompts can be
            referenced in the prompt_template using the {prompt} placeholder.
        **kwargs: Additional keyword arguments (currently unused).

    Returns:
        Score statistics containing:

            - "mean_score": Overall average score across all responses
            - "scores": List of mean scores for each response (averaged across samples)
            - "raw_scores": List of lists containing all individual scores for each response

    Raises:
        AssertionError: If prompts is provided but has different length than responses.
    """

    if prompts is not None and len(prompts) != len(responses):
        raise AssertionError("`responses` and `prompts` must be the same length")

    # build prompts
    prompts_list: list[str] = []
    for i in range(len(responses)):
        fields: dict[str, str | float] = {
            "response": responses[i],
            "lower_bound": self.scale[0],
            "upper_bound": self.scale[1],
        }
        if prompts is not None:
            fields["prompt"] = prompts[i]

        prompt_core = self.base_prompt_template.format(**fields)
        prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
        prompts_list.append(prompt_formatted)

    # generate
    prompt_scores: list[list[float]] = []
    for batch in self._batch_chunks(prompts_list, self.batch_size):
        outputs = self.pipeline(
            batch,
            num_return_sequences=self.num_return_sequences,
            return_full_text=False,
            clean_up_tokenization_spaces=True,
        )

        for prompt, generations in zip(batch, outputs):
            generations = generations if isinstance(generations, list) else [generations]
            assert len(generations) == self.num_return_sequences

            scores = []
            for generation in generations:
                reply_text = generation["generated_text"]
                try:
                    score = self.parse_fn(reply_text, self.scale)
                except Exception:
                    score = self._score_with_retries(prompt)
                scores.append(score)

            prompt_scores.append(scores)

    mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
    corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

    return {
        "mean_score": corpus_mean,  # overall average
        "scores": mean_per_prompt,  # one number per original prompt
        "raw_scores": prompt_scores  # n_samples scores per prompt
    }
base
Metric

Bases: ABC

Base-class for evaluation metrics.

Provides a standardized interface for computing evaluation scores on model-generated responses. Subclasses should define their specific scoring logic in compute() and can accept additional configuration through constructor arguments stored in extras.

Source code in aisteer360/evaluation/metrics/base.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Metric(ABC):
    """
    Base-class for evaluation metrics.

    Provides a standardized interface for computing evaluation scores on model-generated responses. Subclasses should
    define their specific scoring logic in `compute()` and can accept additional configuration through constructor
    arguments stored in `extras`.

    Args:
        **extras
            Required extras for the metric (e.g., LLM, tokenizer, etc.)
    """
    def __init__(self, **extras: Any) -> None:
        self.name: str = self.__class__.__name__
        self.extras: dict[str, Any] = extras

    @abstractmethod
    def compute(
        self,
        responses: list[Any],
        prompts: list[str] | None = None,
        **kwargs: Any,
    ) -> dict[str, Any]:
        """Base compute method."""
        raise NotImplementedError

    def __call__(self, *args, **kwargs):
        return self.compute(*args, **kwargs)
extras = extras instance-attribute
name = self.__class__.__name__ instance-attribute
compute(responses, prompts=None, **kwargs) abstractmethod

Base compute method.

Source code in aisteer360/evaluation/metrics/base.py
21
22
23
24
25
26
27
28
29
@abstractmethod
def compute(
    self,
    responses: list[Any],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, Any]:
    """Base compute method."""
    raise NotImplementedError
base_judge
LLMJudgeMetric

Bases: Metric

Base class for LLM-as-a-judge evaluation metrics.

Leverages a language model to evaluate the quality of generated text responses according to customized (natural language) criteria. The judge model evaluates each response (optionally with respect to an associated prompt and context) and returns numerical scores within a specified range. When multiple samples are generated per prompt (via num_return_sequences), scores are averaged to improve reliability.

Subclasses should define their specific evaluation criteria by providing a prompt_template that instructs the judge model how to score responses. The template should use placeholders {response}, {lower_bound}, and {upper_bound} (and optionally {prompt} and {context}). Subclasses typically override __init__() to set their specific prompt template and scoring scale (e.g., see metrics.generic.relevance).

Parameters:

Name Type Description Default
model_or_id str | PreTrainedModel

HuggingFace model ID or loaded model instance to use as the judge. If string, the model will be loaded automatically.

required
prompt_template str

Template string for evaluation prompts. Should contain placeholders for {response}, {lower_bound}, {upper_bound}, and optionally {prompt}, {context}. The formatted prompt will be passed to the judge model.

required
tokenizer Any | None

Tokenizer for the judge model. If None, will be loaded from the model ID. Required if passing a PreTrainedModel instance.

None
device str | None

Device for model inference ('cuda', 'mps', 'cpu'). Defaults to GPU if available, otherwise CPU.

None
scale tuple[float, float]

Score range as (min, max) tuple. Scores outside this range will be clamped. Defaults to (1, 5).

(1, 5)
batch_size int

Number of prompts to process simultaneously. Defaults to 8.

8
max_retries int

Maximum retry attempts when score parsing fails. Defaults to 5.

5
gen_kwargs dict[str, Any] | None

Generation parameters passed to the model.

None
Source code in aisteer360/evaluation/metrics/base_judge.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class LLMJudgeMetric(Metric):
    """Base class for LLM-as-a-judge evaluation metrics.

    Leverages a language model to evaluate the quality of generated text responses according to customized (natural
    language) criteria. The judge model evaluates each response (optionally with respect to an associated prompt and
    context) and returns numerical scores within a specified range. When multiple samples are generated per prompt (via
    num_return_sequences), scores are averaged to improve reliability.

    Subclasses should define their specific evaluation criteria by providing a `prompt_template` that instructs the
    judge model how to score responses. The template should use placeholders {response}, {lower_bound}, and
    {upper_bound} (and optionally {prompt} and {context}). Subclasses typically override `__init__()` to set their
    specific prompt template and scoring scale (e.g., see `metrics.generic.relevance`).

    Args:
        model_or_id (str | PreTrainedModel): HuggingFace model ID or loaded model instance to use as the judge.
            If string, the model will be loaded automatically.
        prompt_template (str): Template string for evaluation prompts. Should contain placeholders for {response},
            {lower_bound}, {upper_bound}, and optionally {prompt}, {context}.
            The formatted prompt will be passed to the judge model.
        tokenizer (Any | None): Tokenizer for the judge model. If None, will be loaded from the model ID.
            Required if passing a PreTrainedModel instance.
        device (str | None): Device for model inference ('cuda', 'mps', 'cpu').
            Defaults to GPU if available, otherwise CPU.
        scale (tuple[float, float]): Score range as (min, max) tuple. Scores outside this range will be clamped.
            Defaults to (1, 5).
        batch_size (int): Number of prompts to process simultaneously. Defaults to 8.
        max_retries (int): Maximum retry attempts when score parsing fails. Defaults to 5.
        gen_kwargs (dict[str, Any] | None): Generation parameters passed to the model.
    """

    def __init__(
        self,
        model_or_id: str | PreTrainedModel,
        prompt_template: str,
        tokenizer: Any | None = None,
        device: str | None = None,
        scale: tuple[float, float] = (1, 5),
        batch_size: int = 8,
        max_retries: int = 5,
        gen_kwargs: dict[str, Any] | None = None,
    ):
        super().__init__()

        if isinstance(model_or_id, str):
            self.model = AutoModelForCausalLM.from_pretrained(model_or_id)
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id)
        else:  # model
            self.model = model_or_id
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id.config._name_or_path)

        self.use_chat = hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template is not None
        self.device = device or (
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
        self.model.to(self.device).eval()

        gen_kwargs = dict(gen_kwargs or {})
        gen_kwargs.setdefault("temperature", 0.0)
        gen_kwargs.setdefault("max_new_tokens", 30)
        gen_kwargs.setdefault("pad_token_id", self.tokenizer.eos_token_id)

        self.num_return_sequences: int = int(gen_kwargs.pop("num_return_sequences", 1))
        self.model.generation_config = GenerationConfig(**gen_kwargs)

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        self.pipeline = TextGenerationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
        )

        self.scale = scale
        self.output_parser, self.parse_fn = build_structured_parser(scale)
        self.base_prompt_template = prompt_template.strip()
        self.format_instructions = self.output_parser.get_format_instructions()
        self.batch_size = batch_size
        self.max_retries = max_retries

    def _wrap(self, prompt: str) -> str:
        """Wrap prompt with appropriate formatting for the model.

        Applies the chat template (if the model supports it) with the prompt as a user message.
        Otherwise, returns the prompt unchanged.

        Args:
            prompt (str): The user prompt.

        Returns:
            str: The formatted prompt.
        """
        if self.use_chat:
            messages = [{"role": "user", "content": prompt}]
            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        return prompt

    @staticmethod
    def _batch_chunks(seq: Sequence[Any], chunk_size: int) -> Iterable[Sequence[Any]]:
        """Split a sequence into chunks of specified size.

        Args:
            seq (Sequence[Any]): The sequence to split into chunks.
            chunk_size (int): Maximum size of each chunk.

        Yields:
            Sequence[Any]: Chunks of the input sequence, each with at most chunk_size elements.
        """
        for i in range(0, len(seq), chunk_size):
            yield seq[i: i + chunk_size]

    def _score_with_retries(self, prompt: str) -> float:
        """Generate replies until parsing succeeds or maximum retries reached.

        Attempts to generate a response and parse it (using `parse_fn`) as a score.
        If parsing fails, retries up to `max_retries` times.
        If all attempts fail, raises a warning and returns `float('nan')`.

        Args:
            prompt (str): The formatted prompt to send to the model.

        Returns:
            float: The parsed score from the model's response, or `float('nan')` if parsing fails.
        """
        for attempt in range(self.max_retries + 1):
            reply_text = self.pipeline(
                prompt,
                clean_up_tokenization_spaces=True,
                return_full_text=False
            )[0]["generated_text"]

            try:
                return self.parse_fn(reply_text, self.scale)
            except Exception:
                if attempt == self.max_retries:
                    warnings.warn(
                        f"Failed to parse score after {self.max_retries + 1} attempts. "
                        "Returning float('nan') instead."
                    )
                    return float('nan')

    @torch.inference_mode()
    def compute(
        self,
        responses: list[str],
        prompts: list[str] | None = None,
        **kwargs: Any,
    ) -> dict[str, float | list[float]]:
        """Compute LLM judge scores for a list of responses.

        Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
        samples are generated per response (via `num_return_sequences`).

        Args:
            responses (list[str]): List of text responses to evaluate.
            prompts (list[str] | None): Optional list of prompts corresponding to each response.
                If provided, must be the same length as responses. These prompts can be
                referenced in the prompt_template using the {prompt} placeholder.
            **kwargs: Additional keyword arguments (currently unused).

        Returns:
            Score statistics containing:

                - "mean_score": Overall average score across all responses
                - "scores": List of mean scores for each response (averaged across samples)
                - "raw_scores": List of lists containing all individual scores for each response

        Raises:
            AssertionError: If prompts is provided but has different length than responses.
        """

        if prompts is not None and len(prompts) != len(responses):
            raise AssertionError("`responses` and `prompts` must be the same length")

        # build prompts
        prompts_list: list[str] = []
        for i in range(len(responses)):
            fields: dict[str, str | float] = {
                "response": responses[i],
                "lower_bound": self.scale[0],
                "upper_bound": self.scale[1],
            }
            if prompts is not None:
                fields["prompt"] = prompts[i]

            prompt_core = self.base_prompt_template.format(**fields)
            prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
            prompts_list.append(prompt_formatted)

        # generate
        prompt_scores: list[list[float]] = []
        for batch in self._batch_chunks(prompts_list, self.batch_size):
            outputs = self.pipeline(
                batch,
                num_return_sequences=self.num_return_sequences,
                return_full_text=False,
                clean_up_tokenization_spaces=True,
            )

            for prompt, generations in zip(batch, outputs):
                generations = generations if isinstance(generations, list) else [generations]
                assert len(generations) == self.num_return_sequences

                scores = []
                for generation in generations:
                    reply_text = generation["generated_text"]
                    try:
                        score = self.parse_fn(reply_text, self.scale)
                    except Exception:
                        score = self._score_with_retries(prompt)
                    scores.append(score)

                prompt_scores.append(scores)

        mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
        corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

        return {
            "mean_score": corpus_mean,  # overall average
            "scores": mean_per_prompt,  # one number per original prompt
            "raw_scores": prompt_scores  # n_samples scores per prompt
        }
base_prompt_template = prompt_template.strip() instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
format_instructions = self.output_parser.get_format_instructions() instance-attribute
max_retries = max_retries instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
num_return_sequences = int(gen_kwargs.pop('num_return_sequences', 1)) instance-attribute
pipeline = TextGenerationPipeline(model=(self.model), tokenizer=(self.tokenizer)) instance-attribute
scale = scale instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
use_chat = hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None instance-attribute
compute(responses, prompts=None, **kwargs)

Compute LLM judge scores for a list of responses.

Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple samples are generated per response (via num_return_sequences).

Parameters:

Name Type Description Default
responses list[str]

List of text responses to evaluate.

required
prompts list[str] | None

Optional list of prompts corresponding to each response. If provided, must be the same length as responses. These prompts can be referenced in the prompt_template using the {prompt} placeholder.

None
**kwargs Any

Additional keyword arguments (currently unused).

{}

Returns:

Type Description
dict[str, float | list[float]]

Score statistics containing:

  • "mean_score": Overall average score across all responses
  • "scores": List of mean scores for each response (averaged across samples)
  • "raw_scores": List of lists containing all individual scores for each response

Raises:

Type Description
AssertionError

If prompts is provided but has different length than responses.

Source code in aisteer360/evaluation/metrics/base_judge.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@torch.inference_mode()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, float | list[float]]:
    """Compute LLM judge scores for a list of responses.

    Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
    samples are generated per response (via `num_return_sequences`).

    Args:
        responses (list[str]): List of text responses to evaluate.
        prompts (list[str] | None): Optional list of prompts corresponding to each response.
            If provided, must be the same length as responses. These prompts can be
            referenced in the prompt_template using the {prompt} placeholder.
        **kwargs: Additional keyword arguments (currently unused).

    Returns:
        Score statistics containing:

            - "mean_score": Overall average score across all responses
            - "scores": List of mean scores for each response (averaged across samples)
            - "raw_scores": List of lists containing all individual scores for each response

    Raises:
        AssertionError: If prompts is provided but has different length than responses.
    """

    if prompts is not None and len(prompts) != len(responses):
        raise AssertionError("`responses` and `prompts` must be the same length")

    # build prompts
    prompts_list: list[str] = []
    for i in range(len(responses)):
        fields: dict[str, str | float] = {
            "response": responses[i],
            "lower_bound": self.scale[0],
            "upper_bound": self.scale[1],
        }
        if prompts is not None:
            fields["prompt"] = prompts[i]

        prompt_core = self.base_prompt_template.format(**fields)
        prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
        prompts_list.append(prompt_formatted)

    # generate
    prompt_scores: list[list[float]] = []
    for batch in self._batch_chunks(prompts_list, self.batch_size):
        outputs = self.pipeline(
            batch,
            num_return_sequences=self.num_return_sequences,
            return_full_text=False,
            clean_up_tokenization_spaces=True,
        )

        for prompt, generations in zip(batch, outputs):
            generations = generations if isinstance(generations, list) else [generations]
            assert len(generations) == self.num_return_sequences

            scores = []
            for generation in generations:
                reply_text = generation["generated_text"]
                try:
                    score = self.parse_fn(reply_text, self.scale)
                except Exception:
                    score = self._score_with_retries(prompt)
                scores.append(score)

            prompt_scores.append(scores)

    mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
    corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

    return {
        "mean_score": corpus_mean,  # overall average
        "scores": mean_per_prompt,  # one number per original prompt
        "raw_scores": prompt_scores  # n_samples scores per prompt
    }
build_structured_parser(scale)

Build a StructuredOutputParser and parsing function for rating predictions.

Constructs a StructuredOutputParser configured with a single ResponseSchema that expects a float score within the specified scale range. It also returns a parsing function that extracts and validates the score from text, ensuring the result is clamped between the provided bounds.

Parameters:

Name Type Description Default
scale tuple[float, float]

A (low, high) tuple specifying the valid inclusive range for the score.

required
Source code in aisteer360/evaluation/metrics/base_judge.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def build_structured_parser(scale):
    """
    Build a StructuredOutputParser and parsing function for rating predictions.

    Constructs a `StructuredOutputParser` configured with a single `ResponseSchema` that expects a float score within
    the specified scale range. It also returns a parsing function that extracts and validates the score from text,
    ensuring the result is clamped between the provided bounds.

    Args:
        scale (tuple[float, float]): A `(low, high)` tuple specifying the valid inclusive range for the score.
    """
    low, high = scale
    score_schema = ResponseSchema(
        name="score",
        description=f"A single float between {low} and {high} (inclusive) that rates the prediction."
    )
    output_parser = StructuredOutputParser.from_response_schemas([score_schema])

    def parse_fn(text: str, _: tuple[float, float]) -> float:
        """
        Parse and validate a score from text using the structured output parser.

        Returns:
            A tuple with elements:

                - StructuredOutputParser: The parser configured with the score schema.
                - Callable[[str, tuple[float, float]], float]: A function that takes a raw text response and the
                  `(low, high)` scale, extracts the score, converts it to a float, and clamps it within the valid range.

        Raises:
            ValueError: If the score cannot be parsed from the text.
        """
        try:
            score = float(output_parser.parse(text)["score"])
        except OutputParserException as e:
            raise ValueError(f"Could not parse score: {e}")
        return max(low, min(high, score))

    return output_parser, parse_fn
custom

Custom metrics for specific evaluation use cases.

This module contains metrics tailored to particular use cases, organized by subdirectory. Unlike generic metrics that work across any use case, custom metrics are designed with specific evaluation contexts in mind (e.g., question answering, instruction following, etc.).

commonsense_mcqa

Evaluation metrics for the CommonsenseMCQA use case.

mcqa_accuracy
MCQAAccuracy

Bases: Metric

Exact-match accuracy for multiple-choice QA.

Source code in aisteer360/evaluation/metrics/custom/commonsense_mcqa/mcqa_accuracy.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class MCQAAccuracy(Metric):
    """
    Exact-match accuracy for multiple-choice QA.
    """

    def compute(
        self,
        responses: list[str],
        prompts: list[str] | None = None,
        reference_answers: list[str] | None = None,
        question_ids: list[str] | None = None,
        **kwargs
    ) -> dict[str, float]:
        """Computes trial-level and question-level accuracy metrics.

        Args:
            responses: List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').
            prompts: List of question prompts (unused, for interface compatibility).
            reference_answers: List of correct answer choices.
            question_ids: Optional question IDs for grouping responses by question.
            **kwargs: Additional arguments (unused).

        Returns:
            Dictionary of accuracy score statistics with values:

                - "trial_mean": micro (attempt-level accuracy)
                - "trial_std": sample std-dev over trials
                - "question_mean": macro (majority-vote accuracy)
                - "question_std": sample std-dev over questions

        Raises:
            ValueError: If reference_answers is None or length mismatches occur.
        """

        if reference_answers is None:
            raise ValueError("MCQAAccuracy needs `reference_answers`.")
        if len(responses) != len(reference_answers):
            raise ValueError("`responses` and `reference_answers` must be the same length.")
        if question_ids is not None and len(responses) != len(question_ids):
            raise ValueError("`question_ids` must match length of `responses`.")

        # micro
        attempt_correct = [
            choice.strip().upper() == answer.strip().upper()
            for choice, answer in zip(responses, reference_answers) if choice is not None
        ]
        attempt_accuracy = sum(attempt_correct) / len(attempt_correct) if attempt_correct else 0.0
        attempt_accuracy_std = self._sample_std(attempt_correct, attempt_accuracy)

        # macro
        if question_ids is None:
            question_accuracy = attempt_accuracy
        else:
            votes = defaultdict(list)
            for qid, is_correct in zip(question_ids, attempt_correct):
                votes[qid].append(is_correct)

            majority_outcomes = [int(sum(vote) > len(vote) / 2) for vote in votes.values()]
            question_accuracy = sum(majority_outcomes) / len(votes) if votes else 0.0
            question_accuracy_std = self._sample_std(majority_outcomes, question_accuracy)

        return {
            "trial_mean": attempt_accuracy,
            "trial_std": attempt_accuracy_std,
            "question_mean": question_accuracy,
            "question_std": question_accuracy_std,
        }

    @staticmethod
    def _sample_std(binary, mean):
        """Computes sample standard deviation for binary outcomes.

        Args:
            binary: List of binary values (0 or 1).
            mean: Pre-computed mean of the binary values.

        Returns:
            Sample standard deviation using Bessel's correction (n-1).
        """
        n = len(binary)
        if n < 2:
            return 0.0
        var = sum((x - mean) ** 2 for x in binary) / (n - 1)
        return sqrt(var)
extras = extras instance-attribute
name = self.__class__.__name__ instance-attribute
compute(responses, prompts=None, reference_answers=None, question_ids=None, **kwargs)

Computes trial-level and question-level accuracy metrics.

Parameters:

Name Type Description Default
responses list[str]

List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').

required
prompts list[str] | None

List of question prompts (unused, for interface compatibility).

None
reference_answers list[str] | None

List of correct answer choices.

None
question_ids list[str] | None

Optional question IDs for grouping responses by question.

None
**kwargs

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, float]

Dictionary of accuracy score statistics with values:

  • "trial_mean": micro (attempt-level accuracy)
  • "trial_std": sample std-dev over trials
  • "question_mean": macro (majority-vote accuracy)
  • "question_std": sample std-dev over questions

Raises:

Type Description
ValueError

If reference_answers is None or length mismatches occur.

Source code in aisteer360/evaluation/metrics/custom/commonsense_mcqa/mcqa_accuracy.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    reference_answers: list[str] | None = None,
    question_ids: list[str] | None = None,
    **kwargs
) -> dict[str, float]:
    """Computes trial-level and question-level accuracy metrics.

    Args:
        responses: List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').
        prompts: List of question prompts (unused, for interface compatibility).
        reference_answers: List of correct answer choices.
        question_ids: Optional question IDs for grouping responses by question.
        **kwargs: Additional arguments (unused).

    Returns:
        Dictionary of accuracy score statistics with values:

            - "trial_mean": micro (attempt-level accuracy)
            - "trial_std": sample std-dev over trials
            - "question_mean": macro (majority-vote accuracy)
            - "question_std": sample std-dev over questions

    Raises:
        ValueError: If reference_answers is None or length mismatches occur.
    """

    if reference_answers is None:
        raise ValueError("MCQAAccuracy needs `reference_answers`.")
    if len(responses) != len(reference_answers):
        raise ValueError("`responses` and `reference_answers` must be the same length.")
    if question_ids is not None and len(responses) != len(question_ids):
        raise ValueError("`question_ids` must match length of `responses`.")

    # micro
    attempt_correct = [
        choice.strip().upper() == answer.strip().upper()
        for choice, answer in zip(responses, reference_answers) if choice is not None
    ]
    attempt_accuracy = sum(attempt_correct) / len(attempt_correct) if attempt_correct else 0.0
    attempt_accuracy_std = self._sample_std(attempt_correct, attempt_accuracy)

    # macro
    if question_ids is None:
        question_accuracy = attempt_accuracy
    else:
        votes = defaultdict(list)
        for qid, is_correct in zip(question_ids, attempt_correct):
            votes[qid].append(is_correct)

        majority_outcomes = [int(sum(vote) > len(vote) / 2) for vote in votes.values()]
        question_accuracy = sum(majority_outcomes) / len(votes) if votes else 0.0
        question_accuracy_std = self._sample_std(majority_outcomes, question_accuracy)

    return {
        "trial_mean": attempt_accuracy,
        "trial_std": attempt_accuracy_std,
        "question_mean": question_accuracy,
        "question_std": question_accuracy_std,
    }
mcqa_calibration
MCQACalibration

Bases: Metric

Calibration metrics for multiple-choice QA.

Measures how well model confidence scores align with actual performance using Expected Calibration Error (ECE) and related metrics.

Source code in aisteer360/evaluation/metrics/custom/commonsense_mcqa/mcqa_calibration.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class MCQACalibration(Metric):
    """
    Calibration metrics for multiple-choice QA.

    Measures how well model confidence scores align with actual performance using Expected Calibration Error (ECE)
    and related metrics.
    """

    def __init__(self, n_bins: int = 10):
        super().__init__()
        self.n_bins = n_bins

    def compute(
        self,
        responses: list[str],
        reference_answers: list[str] = None,
        confidence_scores: list[float] = None,
        question_ids: list[str] | None = None,
        **kwargs
    ) -> dict[str, float]:
        """Computes calibration metrics for model predictions.

        Args:
            responses: List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').
            reference_answers: List of correct answer choices.
            confidence_scores: List of model confidence scores (0.0 to 1.0).
            question_ids: Optional question IDs (unused, for interface compatibility).
            **kwargs: Additional arguments (unused).

        Returns:
            Dictionary of calibration metrics with values:

                - "ece": Expected Calibration Error (lower is better, 0.0 is perfect)
                - "avg_confidence": Model's average confidence across all predictions
                - "overconfidence": avg_confidence - accuracy (positive means overconfident)

        Raises:
            ValueError: If reference_answers or confidence_scores is None.
        """

        if reference_answers is None:
            raise ValueError("MCQACalibration needs `reference_answers`.")
        if confidence_scores is None:
            raise ValueError("MCQACalibration needs `confidence_scores`.")

        # calculate ece
        valid_data = [
            (resp, ref, conf)
            for resp, ref, conf in zip(responses, reference_answers, confidence_scores)
            if conf is not None
        ]
        responses, answers, confidences = zip(*valid_data)
        confidences = np.array(confidences)
        accuracies = np.array([response == answer for response, answer in zip(responses, answers)], dtype=float)
        avg_confidence = float(np.mean(confidences))
        avg_accuracy = float(np.mean(accuracies))
        ece = self._calculate_ece(confidences, accuracies)

        return {
            "ece": ece,
            "avg_confidence": avg_confidence,
            "overconfidence": avg_confidence - avg_accuracy,
        }

    def _calculate_ece(self, confidences: np.ndarray, accuracies: np.ndarray) -> float:
        """Calculates Expected Calibration Error using binned confidence scores.

        ECE measures the difference between confidence and accuracy across confidence bins. For each bin, it computes
        the absolute difference between average confidence and average accuracy, weighted by the proportion of samples
        in that bin.

        Args:
            confidences: Array of confidence scores (0.0 to 1.0).
            accuracies: Array of binary accuracy values (0.0 or 1.0).

        Returns:
            Expected Calibration Error as a float between 0.0 and 1.0.
        """
        bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
        ece = 0

        for i in range(self.n_bins):
            if i == self.n_bins - 1:
                in_bin = (confidences >= bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            else:
                in_bin = (confidences >= bin_boundaries[i]) & (confidences < bin_boundaries[i + 1])

            prop_in_bin = np.mean(in_bin)

            if prop_in_bin > 0:
                bin_accuracy = np.mean(accuracies[in_bin])
                bin_confidence = np.mean(confidences[in_bin])
                ece += prop_in_bin * abs(bin_confidence - bin_accuracy)

        return float(ece)
extras = extras instance-attribute
n_bins = n_bins instance-attribute
name = self.__class__.__name__ instance-attribute
compute(responses, reference_answers=None, confidence_scores=None, question_ids=None, **kwargs)

Computes calibration metrics for model predictions.

Parameters:

Name Type Description Default
responses list[str]

List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').

required
reference_answers list[str]

List of correct answer choices.

None
confidence_scores list[float]

List of model confidence scores (0.0 to 1.0).

None
question_ids list[str] | None

Optional question IDs (unused, for interface compatibility).

None
**kwargs

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, float]

Dictionary of calibration metrics with values:

  • "ece": Expected Calibration Error (lower is better, 0.0 is perfect)
  • "avg_confidence": Model's average confidence across all predictions
  • "overconfidence": avg_confidence - accuracy (positive means overconfident)

Raises:

Type Description
ValueError

If reference_answers or confidence_scores is None.

Source code in aisteer360/evaluation/metrics/custom/commonsense_mcqa/mcqa_calibration.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def compute(
    self,
    responses: list[str],
    reference_answers: list[str] = None,
    confidence_scores: list[float] = None,
    question_ids: list[str] | None = None,
    **kwargs
) -> dict[str, float]:
    """Computes calibration metrics for model predictions.

    Args:
        responses: List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').
        reference_answers: List of correct answer choices.
        confidence_scores: List of model confidence scores (0.0 to 1.0).
        question_ids: Optional question IDs (unused, for interface compatibility).
        **kwargs: Additional arguments (unused).

    Returns:
        Dictionary of calibration metrics with values:

            - "ece": Expected Calibration Error (lower is better, 0.0 is perfect)
            - "avg_confidence": Model's average confidence across all predictions
            - "overconfidence": avg_confidence - accuracy (positive means overconfident)

    Raises:
        ValueError: If reference_answers or confidence_scores is None.
    """

    if reference_answers is None:
        raise ValueError("MCQACalibration needs `reference_answers`.")
    if confidence_scores is None:
        raise ValueError("MCQACalibration needs `confidence_scores`.")

    # calculate ece
    valid_data = [
        (resp, ref, conf)
        for resp, ref, conf in zip(responses, reference_answers, confidence_scores)
        if conf is not None
    ]
    responses, answers, confidences = zip(*valid_data)
    confidences = np.array(confidences)
    accuracies = np.array([response == answer for response, answer in zip(responses, answers)], dtype=float)
    avg_confidence = float(np.mean(confidences))
    avg_accuracy = float(np.mean(accuracies))
    ece = self._calculate_ece(confidences, accuracies)

    return {
        "ece": ece,
        "avg_confidence": avg_confidence,
        "overconfidence": avg_confidence - avg_accuracy,
    }
mcqa_positional_bias
MCQAPositionalBias

Bases: Metric

Positional bias metrics for multiple-choice QA.

Measures whether the model exhibits bias toward selecting certain answer positions.

Source code in aisteer360/evaluation/metrics/custom/commonsense_mcqa/mcqa_positional_bias.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class MCQAPositionalBias(Metric):
    """
    Positional bias metrics for multiple-choice QA.

    Measures whether the model exhibits bias toward selecting certain answer positions.
    """

    def compute(
        self,
        responses: list[str],
        prompts: list[str] | None = None,
        question_ids: list[str] | None = None,
        **kwargs
    ) -> dict[str, float]:
        """Computes positional bias metrics for model predictions.

        Calculates how much the model's choice frequencies deviate from uniform distribution across answer positions.
        For K answer choices, each position should ideally be selected 1/K of the time.

        Args:
            responses: List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').
            prompts: List of question prompts (unused, for interface compatibility).
            question_ids: Optional question IDs for computing per-question bias variance.
            **kwargs: Additional arguments (unused).

        Returns:
            Dictionary of positional bias metrics with values:

                - "mean": Overall positional bias (mean |f_i - 1/K| across positions)
                - "std": Sample standard deviation of bias computed per question

        Note:

        - If question_ids is None, per-question analysis is skipped and std will be 0.0.
        """

        valid_responses = [r for r in responses if r is not None]

        position_counts = Counter(valid_responses)
        total_responses = len(valid_responses)
        positions = sorted(position_counts.keys())
        position_frequencies = [position_counts.get(pos, 0) / total_responses for pos in positions]
        expected_frequency = 1 / len(positions)

        # positional bias per question
        bias_per_question = []
        responses_by_question = defaultdict(list)

        for response, question_id in zip(responses, question_ids):
            if response is not None:
                responses_by_question[question_id].append(response)

        for question_id, question_responses in responses_by_question.items():
            if not question_responses:
                continue
            counts_for_question = Counter(question_responses)
            total_for_question = len(question_responses)
            frequencies_for_question = [counts_for_question.get(pos, 0) / total_for_question for pos in positions]
            bias_for_question = np.mean([abs(freq - expected_frequency) for freq in frequencies_for_question])
            bias_per_question.append(bias_for_question)

        return {
            "mean": np.mean([abs(freq - expected_frequency) for freq in position_frequencies]),
            "std": np.std(bias_per_question, ddof=1) if len(bias_per_question) > 1 else 0.0
        }
extras = extras instance-attribute
name = self.__class__.__name__ instance-attribute
compute(responses, prompts=None, question_ids=None, **kwargs)

Computes positional bias metrics for model predictions.

Calculates how much the model's choice frequencies deviate from uniform distribution across answer positions. For K answer choices, each position should ideally be selected 1/K of the time.

Parameters:

Name Type Description Default
responses list[str]

List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').

required
prompts list[str] | None

List of question prompts (unused, for interface compatibility).

None
question_ids list[str] | None

Optional question IDs for computing per-question bias variance.

None
**kwargs

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, float]

Dictionary of positional bias metrics with values:

  • "mean": Overall positional bias (mean |f_i - 1/K| across positions)
  • "std": Sample standard deviation of bias computed per question

Note:

  • If question_ids is None, per-question analysis is skipped and std will be 0.0.
Source code in aisteer360/evaluation/metrics/custom/commonsense_mcqa/mcqa_positional_bias.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    question_ids: list[str] | None = None,
    **kwargs
) -> dict[str, float]:
    """Computes positional bias metrics for model predictions.

    Calculates how much the model's choice frequencies deviate from uniform distribution across answer positions.
    For K answer choices, each position should ideally be selected 1/K of the time.

    Args:
        responses: List of predicted answer choices (e.g., 'A', 'B', 'C', 'D').
        prompts: List of question prompts (unused, for interface compatibility).
        question_ids: Optional question IDs for computing per-question bias variance.
        **kwargs: Additional arguments (unused).

    Returns:
        Dictionary of positional bias metrics with values:

            - "mean": Overall positional bias (mean |f_i - 1/K| across positions)
            - "std": Sample standard deviation of bias computed per question

    Note:

    - If question_ids is None, per-question analysis is skipped and std will be 0.0.
    """

    valid_responses = [r for r in responses if r is not None]

    position_counts = Counter(valid_responses)
    total_responses = len(valid_responses)
    positions = sorted(position_counts.keys())
    position_frequencies = [position_counts.get(pos, 0) / total_responses for pos in positions]
    expected_frequency = 1 / len(positions)

    # positional bias per question
    bias_per_question = []
    responses_by_question = defaultdict(list)

    for response, question_id in zip(responses, question_ids):
        if response is not None:
            responses_by_question[question_id].append(response)

    for question_id, question_responses in responses_by_question.items():
        if not question_responses:
            continue
        counts_for_question = Counter(question_responses)
        total_for_question = len(question_responses)
        frequencies_for_question = [counts_for_question.get(pos, 0) / total_for_question for pos in positions]
        bias_for_question = np.mean([abs(freq - expected_frequency) for freq in frequencies_for_question])
        bias_per_question.append(bias_for_question)

    return {
        "mean": np.mean([abs(freq - expected_frequency) for freq in position_frequencies]),
        "std": np.std(bias_per_question, ddof=1) if len(bias_per_question) > 1 else 0.0
    }
instruction_following

Evaluation metrics for the InstructionFollowing use case.

helpers

We have omitted the documentation details on the IFEval functions (located in helpers/) from our API reference. For details please see the IFEval repo: https://github.com/google-research/google-research/tree/master/instruction_following_eval.

evaluation_main

Binary of evaluating instruction following. See README.md.

InputExample dataclass
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
46
47
48
49
50
51
@dataclasses.dataclass
class InputExample:
    key: int
    instruction_id_list: list[str]
    prompt: str
    kwargs: list[Dict[str, Optional[Union[str, int]]]]
instruction_id_list instance-attribute
key instance-attribute
kwargs instance-attribute
prompt instance-attribute
OutputExample dataclass
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
54
55
56
57
58
59
60
@dataclasses.dataclass
class OutputExample:
    instruction_id_list: list[str]
    prompt: str
    response: str
    follow_all_instructions: bool
    follow_instruction_list: list[bool]
follow_all_instructions instance-attribute
follow_instruction_list instance-attribute
instruction_id_list instance-attribute
prompt instance-attribute
response instance-attribute
main(argv)
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    inputs = read_prompt_list(_INPUT_DATA.value)
    prompt_to_response = read_prompt_to_response_dict(_INPUT_RESPONSE_DATA.value)

    # get instruction following results
    for func, output_file_name in [
        (test_instruction_following_strict, "eval_results_strict"),
        (test_instruction_following_loose, "eval_results_loose"),
    ]:
        logging.info("Generating %s...", output_file_name)
        outputs = []
        for inp in inputs:
            outputs.append(func(inp, prompt_to_response))
        follow_all_instructions = [o.follow_all_instructions for o in outputs]
        accuracy = sum(follow_all_instructions) / len(outputs)
        logging.info("Accuracy: %f", accuracy)

        output_file_name = os.path.join(_OUTPUT_DIR.value, output_file_name + ".jsonl")
        write_outputs(output_file_name, outputs)
        logging.info("Generated: %s", output_file_name)

        # Prints instruction following accuracy report.
        print("=" * 64)
        print(f"{output_file_name} Accuracy Scores:")
        print_report(outputs)
print_report(outputs)

Prints a report on accuracy scores.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def print_report(outputs):
    """Prints a report on accuracy scores."""

    prompt_total = 0
    prompt_correct = 0
    instruction_total = 0
    instruction_correct = 0

    tier0_total = collections.defaultdict(int)
    tier0_correct = collections.defaultdict(int)

    tier1_total = collections.defaultdict(int)
    tier1_correct = collections.defaultdict(int)

    for example in outputs:
        follow_instruction_list = example.follow_instruction_list
        instruction_id_list = example.instruction_id_list

        prompt_total += 1
        if all(follow_instruction_list):
            prompt_correct += 1

        instruction_total += len(instruction_id_list)
        instruction_correct += sum(follow_instruction_list)

        for instruction_id, followed_or_not in zip(
            instruction_id_list, follow_instruction_list
        ):
            instruction_id = instruction_id.split(":")[0]
            tier0_total[instruction_id] += 1
            if followed_or_not:
                tier0_correct[instruction_id] += 1

        for instruction_id, followed_or_not in zip(
            instruction_id_list, follow_instruction_list
        ):
            tier1_total[instruction_id] += 1
            if followed_or_not:
                tier1_correct[instruction_id] += 1

    # print(f"prompt-level: {prompt_correct / prompt_total}")
    # print(f"instruction-level: {instruction_correct / instruction_total}")
    # print()
    for instruction_id in sorted(tier0_total.keys()):
        accuracy = tier0_correct[instruction_id] / tier0_total[instruction_id]
    #   print(f"{instruction_id} {accuracy}")
    # print()
    for instruction_id in sorted(tier1_total.keys()):
        accuracy = tier1_correct[instruction_id] / tier1_total[instruction_id]
        # print(f"{instruction_id} {accuracy}")

    return prompt_correct / prompt_total, instruction_correct / instruction_total
read_prompt_list(input_jsonl)

Read inputs from jsonl.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def read_prompt_list(input_jsonl):
    """Read inputs from jsonl."""
    inputs = []
    if isinstance(input_jsonl, str):
        with open(input_jsonl, "r") as f:
            for l in f:
                example = json.loads(l)
                inputs.append(
                    InputExample(
                        key=example["key"],
                        instruction_id_list=example["instruction_id_list"],
                        prompt=example["prompt"],
                        kwargs=example["kwargs"],
                    )
                )
    else:
        for example in input_jsonl:
            inputs.append(
                InputExample(
                    key=example["key"],
                    instruction_id_list=example["instruction_id_list"],
                    prompt=example["prompt"],
                    kwargs=example["kwargs"],
                )
            )
    return inputs
read_prompt_to_response_dict(input_jsonl)

Creates dictionary matching prompt and response.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def read_prompt_to_response_dict(input_jsonl):
    """Creates dictionary matching prompt and response."""
    return_dict = {}
    if isinstance(input_jsonl, str):
        with open(input_jsonl, "r") as f:
            for l in f:
                example = json.loads(l)
                return_dict[example["prompt"]] = example["response"]
    else:
        for example in input_jsonl:
            # print("here")
            # print(example)
            return_dict[example["prompt"]] = example["response"]
    return return_dict
test_instruction_following_loose(inp, prompt_to_response)

Tests response for an upper bound for following instructions.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def test_instruction_following_loose(
    inp,
    prompt_to_response,
):
    """Tests response for an upper bound for following instructions."""
    response = prompt_to_response[inp.prompt]
    r = response.split("\n")
    response_remove_first = "\n".join(r[1:]).strip()
    response_remove_last = "\n".join(r[:-1]).strip()
    response_remove_both = "\n".join(r[1:-1]).strip()
    revised_response = response.replace("*", "")
    revised_response_remove_first = response_remove_first.replace("*", "")
    revised_response_remove_last = response_remove_last.replace("*", "")
    revised_response_remove_both = response_remove_both.replace("*", "")
    all_responses = [
        response,
        revised_response,
        response_remove_first,
        response_remove_last,
        response_remove_both,
        revised_response_remove_first,
        revised_response_remove_last,
        revised_response_remove_both,
    ]
    instruction_list = inp.instruction_id_list
    is_following_list = []

    for index, instruction_id in enumerate(instruction_list):
        instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
        instruction = instruction_cls(instruction_id)

        instruction.build_description(**inp.kwargs[index])
        args = instruction.get_instruction_args()
        if args and "prompt" in args:
            instruction.build_description(prompt=inp.prompt)

        is_following = False
        for r in all_responses:
            if r.strip() and instruction.check_following(r):
                is_following = True
                break

        is_following_list.append(is_following)

    return OutputExample(
        instruction_id_list=inp.instruction_id_list,
        prompt=inp.prompt,
        response=response,
        follow_all_instructions=all(is_following_list),
        follow_instruction_list=is_following_list,
    )
test_instruction_following_strict(inp, prompt_to_response)

Tests response to see if instrutions are followed.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def test_instruction_following_strict(
    inp,
    prompt_to_response,
):
    """Tests response to see if instrutions are followed."""
    response = prompt_to_response[inp["prompt"]]
    instruction_list = inp["instruction_id_list"]
    is_following_list = []

    for index, instruction_id in enumerate(instruction_list):
        instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
        instruction = instruction_cls(instruction_id)

        instruction.build_description(**inp["kwargs"][index])
        args = instruction.get_instruction_args()
        if args and "prompt" in args:
            instruction.build_description(prompt=inp["prompt"])

        if response.strip() and instruction.check_following(response):
            is_following_list.append(True)
        else:
            is_following_list.append(False)

    return OutputExample(
        instruction_id_list=inp["instruction_id_list"],
        prompt=inp["prompt"],
        response=response,
        follow_all_instructions=all(is_following_list),
        follow_instruction_list=is_following_list,
    )
write_outputs(output_jsonl_filename, outputs)

Writes outputs to jsonl.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/evaluation_main.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def write_outputs(output_jsonl_filename, outputs):
    """Writes outputs to jsonl."""
    assert outputs
    with open(output_jsonl_filename, "w") as f:
        for o in outputs:
            f.write(
                json.dumps(
                    {
                        attr_name: o.__getattribute__(attr_name)
                        for attr_name in [
                            name for name in dir(o) if not name.startswith("_")
                        ]
                    }
                )
            )
            f.write("\n")
instructions

Library of instructions.

BulletListChecker

Bases: Instruction

Checks the bullet list in the prompt.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
class BulletListChecker(Instruction):
  """Checks the bullet list in the prompt."""

  def build_description(self, *, num_bullets = None):
    """Build the instruction description.

    Args:
      num_bullets: An integer specifying the exact number of bullet lists
        that is required to appear in the response.

    Returns:
      A string representing the instruction description.
    """
    self._num_bullets = num_bullets
    if self._num_bullets is None or self._num_bullets < 0:
      self._num_bullets = random.randint(1, _NUM_BULLETS)
    self._description_pattern = (
        "Your answer must contain exactly {num_bullets} bullet points. " +
        "Use the markdown bullet points such as:\n" +
        "* This is point 1. \n" +
        "* This is point 2")
    return self._description_pattern.format(
        num_bullets=self._num_bullets)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_bullets": self._num_bullets}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_bullets"]

  def check_following(self, value):
    r"""Check if the number of bullet lists meets the requirement.

    Args:
      value: A string representing the response. The response is expected to
        contain some bullet lists that start with `\*`.

    Returns:
      True if the actual number of bullet lists in the response meets the
      requirement.
    """
    bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE)
    bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE)
    num_bullet_lists = len(bullet_lists) + len(bullet_lists_2)
    return num_bullet_lists == self._num_bullets
id = instruction_id instance-attribute
build_description(*, num_bullets=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_bullets

An integer specifying the exact number of bullet lists that is required to appear in the response.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def build_description(self, *, num_bullets = None):
  """Build the instruction description.

  Args:
    num_bullets: An integer specifying the exact number of bullet lists
      that is required to appear in the response.

  Returns:
    A string representing the instruction description.
  """
  self._num_bullets = num_bullets
  if self._num_bullets is None or self._num_bullets < 0:
    self._num_bullets = random.randint(1, _NUM_BULLETS)
  self._description_pattern = (
      "Your answer must contain exactly {num_bullets} bullet points. " +
      "Use the markdown bullet points such as:\n" +
      "* This is point 1. \n" +
      "* This is point 2")
  return self._description_pattern.format(
      num_bullets=self._num_bullets)
check_following(value)

Check if the number of bullet lists meets the requirement.

Parameters:

Name Type Description Default
value

A string representing the response. The response is expected to contain some bullet lists that start with \*.

required

Returns:

Type Description

True if the actual number of bullet lists in the response meets the

requirement.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def check_following(self, value):
  r"""Check if the number of bullet lists meets the requirement.

  Args:
    value: A string representing the response. The response is expected to
      contain some bullet lists that start with `\*`.

  Returns:
    True if the actual number of bullet lists in the response meets the
    requirement.
  """
  bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE)
  bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE)
  num_bullet_lists = len(bullet_lists) + len(bullet_lists_2)
  return num_bullet_lists == self._num_bullets
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
305
306
307
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_bullets": self._num_bullets}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
309
310
311
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_bullets"]
CapitalLettersEnglishChecker

Bases: Instruction

Checks that the response is in english and is in all capital letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
class CapitalLettersEnglishChecker(Instruction):
  """Checks that the response is in english and is in all capital letters."""

  def build_description(self):
    """Build the instruction description."""
    self._description_pattern = (
        "Your entire response should be in English, and in all capital letters."
    )
    return self._description_pattern

  def get_instruction_args(self):
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks that the response is in English and in all capital letters."""
    assert isinstance(value, str)

    try:
      return value.isupper() and langdetect.detect(value) == "en"
    except langdetect.LangDetectException as e:
      # Count as instruction is followed.
      logging.error(
          "Unable to detect language for text %s due to %s", value, e
      )  # refex: disable=pytotw.037
      return True
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1398
1399
1400
1401
1402
1403
def build_description(self):
  """Build the instruction description."""
  self._description_pattern = (
      "Your entire response should be in English, and in all capital letters."
  )
  return self._description_pattern
check_following(value)

Checks that the response is in English and in all capital letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
def check_following(self, value):
  """Checks that the response is in English and in all capital letters."""
  assert isinstance(value, str)

  try:
    return value.isupper() and langdetect.detect(value) == "en"
  except langdetect.LangDetectException as e:
    # Count as instruction is followed.
    logging.error(
        "Unable to detect language for text %s due to %s", value, e
    )  # refex: disable=pytotw.037
    return True
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1405
1406
def get_instruction_args(self):
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1408
1409
1410
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
CapitalWordFrequencyChecker

Bases: Instruction

Checks frequency of words with all capital letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
class CapitalWordFrequencyChecker(Instruction):
  """Checks frequency of words with all capital letters."""

  def build_description(
      self,
      capital_frequency = None,
      capital_relation = None,
  ):
    """Build the instruction description.

    Args:
      capital_frequency: An integer that represents the number of words that
        should be in all capital letters.
      capital_relation: A string that is 'at least' or 'at most' that refers to
        the frequency.

    Returns:
      A string representing the instruction description.
    """
    self._frequency = capital_frequency
    if self._frequency is None:
      self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY)

    self._comparison_relation = capital_relation
    if capital_relation is None:
      self._comparison_relation = random.choice(_COMPARISON_RELATION)
    elif capital_relation not in _COMPARISON_RELATION:
      raise ValueError(
          "The supported relation for comparison must be in "
          f"{_COMPARISON_RELATION}, but {capital_relation} is given."
      )

    self._description_pattern = (
        "In your response, words with all capital letters should appear"
        " {relation} {frequency} times."
    )

    return self._description_pattern.format(
        frequency=self._frequency, relation=self._comparison_relation
    )

  def get_instruction_args(self):
    """Returns the keyword args of build description."""
    return {
        "capital_frequency": self._frequency,
        "capital_relation": self._comparison_relation,
    }

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["capital_frequency", "capital_relation"]

  def check_following(self, value):
    """Checks the frequency of words with all capital letters."""
    # Hyphenated words will count as one word
    words = instructions_util.nltk.word_tokenize(value)
    capital_words = [word for word in words if word.isupper()]

    capital_words = len(capital_words)

    if self._comparison_relation == _COMPARISON_RELATION[0]:
      return capital_words < self._frequency
    else:
      return capital_words >= self._frequency
id = instruction_id instance-attribute
build_description(capital_frequency=None, capital_relation=None)

Build the instruction description.

Parameters:

Name Type Description Default
capital_frequency

An integer that represents the number of words that should be in all capital letters.

None
capital_relation

A string that is 'at least' or 'at most' that refers to the frequency.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
def build_description(
    self,
    capital_frequency = None,
    capital_relation = None,
):
  """Build the instruction description.

  Args:
    capital_frequency: An integer that represents the number of words that
      should be in all capital letters.
    capital_relation: A string that is 'at least' or 'at most' that refers to
      the frequency.

  Returns:
    A string representing the instruction description.
  """
  self._frequency = capital_frequency
  if self._frequency is None:
    self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY)

  self._comparison_relation = capital_relation
  if capital_relation is None:
    self._comparison_relation = random.choice(_COMPARISON_RELATION)
  elif capital_relation not in _COMPARISON_RELATION:
    raise ValueError(
        "The supported relation for comparison must be in "
        f"{_COMPARISON_RELATION}, but {capital_relation} is given."
    )

  self._description_pattern = (
      "In your response, words with all capital letters should appear"
      " {relation} {frequency} times."
  )

  return self._description_pattern.format(
      frequency=self._frequency, relation=self._comparison_relation
  )
check_following(value)

Checks the frequency of words with all capital letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
def check_following(self, value):
  """Checks the frequency of words with all capital letters."""
  # Hyphenated words will count as one word
  words = instructions_util.nltk.word_tokenize(value)
  capital_words = [word for word in words if word.isupper()]

  capital_words = len(capital_words)

  if self._comparison_relation == _COMPARISON_RELATION[0]:
    return capital_words < self._frequency
  else:
    return capital_words >= self._frequency
get_instruction_args()

Returns the keyword args of build description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1521
1522
1523
1524
1525
1526
def get_instruction_args(self):
  """Returns the keyword args of build description."""
  return {
      "capital_frequency": self._frequency,
      "capital_relation": self._comparison_relation,
  }
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1528
1529
1530
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["capital_frequency", "capital_relation"]
CommaChecker

Bases: Instruction

Checks the response for no commas.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
class CommaChecker(Instruction):
  """Checks the response for no commas."""

  def build_description(self):
    """Build the instruction description."""
    self._description_pattern = (
        "In your entire response, refrain from the use of any commas."
    )
    return self._description_pattern

  def get_instruction_args(self):
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks that the response does not contain commas."""
    return not re.search(r"\,", value)
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1461
1462
1463
1464
1465
1466
def build_description(self):
  """Build the instruction description."""
  self._description_pattern = (
      "In your entire response, refrain from the use of any commas."
  )
  return self._description_pattern
check_following(value)

Checks that the response does not contain commas.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1475
1476
1477
def check_following(self, value):
  """Checks that the response does not contain commas."""
  return not re.search(r"\,", value)
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1468
1469
def get_instruction_args(self):
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1471
1472
1473
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
ConstrainedResponseChecker

Bases: Instruction

Checks the constrained response.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
class ConstrainedResponseChecker(Instruction):
  """Checks the constrained response."""

  def build_description(self):
    """Build the instruction description."""
    # A sequence of string(s) representing the options of the expected response.
    self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS
    self._description_pattern = (
        "Answer with one of the following options: {response_options}")
    return self._description_pattern.format(
        response_options=self._constrained_responses)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks if the response matches the constrained options.

    Args:
      value: A string representing the response.

    Returns:
      True if the actual response contains one of the options in the constrained
      responses; otherwise False.
    """
    value = value.strip()
    for constrained_response in self._constrained_responses:
      if constrained_response in value:
        return True
    return False
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
333
334
335
336
337
338
339
340
def build_description(self):
  """Build the instruction description."""
  # A sequence of string(s) representing the options of the expected response.
  self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS
  self._description_pattern = (
      "Answer with one of the following options: {response_options}")
  return self._description_pattern.format(
      response_options=self._constrained_responses)
check_following(value)

Checks if the response matches the constrained options.

Parameters:

Name Type Description Default
value

A string representing the response.

required

Returns:

Type Description

True if the actual response contains one of the options in the constrained

responses; otherwise False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def check_following(self, value):
  """Checks if the response matches the constrained options.

  Args:
    value: A string representing the response.

  Returns:
    True if the actual response contains one of the options in the constrained
    responses; otherwise False.
  """
  value = value.strip()
  for constrained_response in self._constrained_responses:
    if constrained_response in value:
      return True
  return False
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
342
343
344
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
346
347
348
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
ConstrainedStartChecker

Bases: Instruction

Checks the response start.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
class ConstrainedStartChecker(Instruction):
  """Checks the response start."""

  def build_description(self, *, starter = None):
    """Build the instruction description.

    Args:
      starter: A string representing the keyward that the response should start
        with.

    Returns:
      A string representing the instruction description.
    """
    self._starter = starter.strip() if isinstance(starter, str) else starter
    if self._starter is None:
      self._starter = random.choice(_STARTER_OPTIONS)
    self._description_pattern = (
        "During the conversation, when it is your turn, " +
        "please always start with {starter}")
    return self._description_pattern.format(starter=self._starter)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"starter": self._starter}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["starter"]

  def check_following(self, value):
    """Checks if the response starts with the constrained keyword or phrase.

    Args:
      value: A string representing the response.

    Returns:
      True if the response starts with the given phrase or keyword that is
      contained in `instruction_args`; otherwise, False.
    """
    response_pattern = r"^\s*" + self._starter + r".*$"
    response_with_constrained_start = re.search(response_pattern, value,
                                                flags=re.MULTILINE)
    return True if response_with_constrained_start else False
id = instruction_id instance-attribute
build_description(*, starter=None)

Build the instruction description.

Parameters:

Name Type Description Default
starter

A string representing the keyward that the response should start with.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def build_description(self, *, starter = None):
  """Build the instruction description.

  Args:
    starter: A string representing the keyward that the response should start
      with.

  Returns:
    A string representing the instruction description.
  """
  self._starter = starter.strip() if isinstance(starter, str) else starter
  if self._starter is None:
    self._starter = random.choice(_STARTER_OPTIONS)
  self._description_pattern = (
      "During the conversation, when it is your turn, " +
      "please always start with {starter}")
  return self._description_pattern.format(starter=self._starter)
check_following(value)

Checks if the response starts with the constrained keyword or phrase.

Parameters:

Name Type Description Default
value

A string representing the response.

required

Returns:

Type Description

True if the response starts with the given phrase or keyword that is

contained in instruction_args; otherwise, False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def check_following(self, value):
  """Checks if the response starts with the constrained keyword or phrase.

  Args:
    value: A string representing the response.

  Returns:
    True if the response starts with the given phrase or keyword that is
    contained in `instruction_args`; otherwise, False.
  """
  response_pattern = r"^\s*" + self._starter + r".*$"
  response_with_constrained_start = re.search(response_pattern, value,
                                              flags=re.MULTILINE)
  return True if response_with_constrained_start else False
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
388
389
390
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"starter": self._starter}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
392
393
394
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["starter"]
EndChecker

Bases: Instruction

Checks that the prompt ends with a given phrase.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
class EndChecker(Instruction):
  """Checks that the prompt ends with a given phrase."""

  def build_description(self, *, end_phrase = None):
    """Build the instruction description.

    Args:
      end_phrase: A string representing the phrase the response should end with.

    Returns:
      A string representing the instruction description.
    """
    self._end_phrase = (
        end_phrase.strip() if isinstance(end_phrase, str) else end_phrase
    )
    if self._end_phrase is None:
      self._end_phrase = random.choice(_ENDING_OPTIONS)
    self._description_pattern = (
        "Finish your response with this exact phrase {ender}. "
        "No other words should follow this phrase.")
    return self._description_pattern.format(ender=self._end_phrase)

  def get_instruction_args(self):
    return {"end_phrase": self._end_phrase}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["end_phrase"]

  def check_following(self, value):
    """Checks if the response ends with the expected phrase."""
    value = value.strip().strip("\"").lower()
    self._end_phrase = self._end_phrase.strip().lower()
    return value.endswith(self._end_phrase)
id = instruction_id instance-attribute
build_description(*, end_phrase=None)

Build the instruction description.

Parameters:

Name Type Description Default
end_phrase

A string representing the phrase the response should end with.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
def build_description(self, *, end_phrase = None):
  """Build the instruction description.

  Args:
    end_phrase: A string representing the phrase the response should end with.

  Returns:
    A string representing the instruction description.
  """
  self._end_phrase = (
      end_phrase.strip() if isinstance(end_phrase, str) else end_phrase
  )
  if self._end_phrase is None:
    self._end_phrase = random.choice(_ENDING_OPTIONS)
  self._description_pattern = (
      "Finish your response with this exact phrase {ender}. "
      "No other words should follow this phrase.")
  return self._description_pattern.format(ender=self._end_phrase)
check_following(value)

Checks if the response ends with the expected phrase.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1280
1281
1282
1283
1284
def check_following(self, value):
  """Checks if the response ends with the expected phrase."""
  value = value.strip().strip("\"").lower()
  self._end_phrase = self._end_phrase.strip().lower()
  return value.endswith(self._end_phrase)
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1273
1274
def get_instruction_args(self):
  return {"end_phrase": self._end_phrase}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1276
1277
1278
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["end_phrase"]
ForbiddenWords

Bases: Instruction

Checks that specified words are not used in response.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
class ForbiddenWords(Instruction):
  """Checks that specified words are not used in response."""

  def build_description(self, forbidden_words = None
                        ):
    """Build the instruction description.

    Args:
      forbidden_words: A sequences of strings respresenting words that are not
        allowed in the response.

    Returns:
      A string representing the instruction description.
    """

    if not forbidden_words:
      self._forbidden_words = instructions_util.generate_keywords(
          num_keywords=_NUM_KEYWORDS)
    else:
      self._forbidden_words = list(set(forbidden_words))
    self._forbidden_words = sorted(self._forbidden_words)
    self._description_pattern = (
        "Do not include keywords {forbidden_words} in the response."
    )

    return self._description_pattern.format(
        forbidden_words=self._forbidden_words
    )

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"forbidden_words": self._forbidden_words}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["forbidden_words"]

  def check_following(self, value):
    """Check if the response does not contain the expected keywords."""
    for word in self._forbidden_words:
      if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE):
        return False
    return True
id = instruction_id instance-attribute
build_description(forbidden_words=None)

Build the instruction description.

Parameters:

Name Type Description Default
forbidden_words

A sequences of strings respresenting words that are not allowed in the response.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def build_description(self, forbidden_words = None
                      ):
  """Build the instruction description.

  Args:
    forbidden_words: A sequences of strings respresenting words that are not
      allowed in the response.

  Returns:
    A string representing the instruction description.
  """

  if not forbidden_words:
    self._forbidden_words = instructions_util.generate_keywords(
        num_keywords=_NUM_KEYWORDS)
  else:
    self._forbidden_words = list(set(forbidden_words))
  self._forbidden_words = sorted(self._forbidden_words)
  self._description_pattern = (
      "Do not include keywords {forbidden_words} in the response."
  )

  return self._description_pattern.format(
      forbidden_words=self._forbidden_words
  )
check_following(value)

Check if the response does not contain the expected keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1108
1109
1110
1111
1112
1113
def check_following(self, value):
  """Check if the response does not contain the expected keywords."""
  for word in self._forbidden_words:
    if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE):
      return False
  return True
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1100
1101
1102
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"forbidden_words": self._forbidden_words}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1104
1105
1106
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["forbidden_words"]
HighlightSectionChecker

Bases: Instruction

Checks the highlighted section.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
class HighlightSectionChecker(Instruction):
  """Checks the highlighted section."""

  def build_description(self, *, num_highlights = None):
    """Build the instruction description.

    Args:
      num_highlights: An integer specifying the minimum number of highlighted
        sections.

    Returns:
      A string representing the instruction description.
    """
    self._num_highlights = num_highlights
    if self._num_highlights is None or self._num_highlights < 0:
      self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS)

    self._description_pattern = (
        "Highlight at least {num_highlights} sections in your answer with " +
        "markdown, i.e. *highlighted section*.")

    return self._description_pattern.format(num_highlights=self._num_highlights)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_highlights": self._num_highlights}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_highlights"]

  def check_following(self, value):
    """Checks if the number of highlighted sections meets the requirement.

    Args:
      value: a string repesenting the response. The response is expected to
        contain highlighted sections in the format of *highlighted*.

    Returns:
      True if the actual number of highlighted sections in the format of
      *highlighed sections* meets the minimum requirement; otherwise False.
    """
    num_highlights = 0
    highlights = re.findall(r"\*[^\n\*]*\*", value)
    double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value)
    for highlight in highlights:
      if highlight.strip("*").strip():
        num_highlights += 1
    for highlight in double_highlights:
      if highlight.removeprefix("**").removesuffix("**").strip():
        num_highlights += 1

    return num_highlights >= self._num_highlights
id = instruction_id instance-attribute
build_description(*, num_highlights=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_highlights

An integer specifying the minimum number of highlighted sections.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def build_description(self, *, num_highlights = None):
  """Build the instruction description.

  Args:
    num_highlights: An integer specifying the minimum number of highlighted
      sections.

  Returns:
    A string representing the instruction description.
  """
  self._num_highlights = num_highlights
  if self._num_highlights is None or self._num_highlights < 0:
    self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS)

  self._description_pattern = (
      "Highlight at least {num_highlights} sections in your answer with " +
      "markdown, i.e. *highlighted section*.")

  return self._description_pattern.format(num_highlights=self._num_highlights)
check_following(value)

Checks if the number of highlighted sections meets the requirement.

Parameters:

Name Type Description Default
value

a string repesenting the response. The response is expected to contain highlighted sections in the format of highlighted.

required

Returns:

Type Description

True if the actual number of highlighted sections in the format of

highlighed sections meets the minimum requirement; otherwise False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def check_following(self, value):
  """Checks if the number of highlighted sections meets the requirement.

  Args:
    value: a string repesenting the response. The response is expected to
      contain highlighted sections in the format of *highlighted*.

  Returns:
    True if the actual number of highlighted sections in the format of
    *highlighed sections* meets the minimum requirement; otherwise False.
  """
  num_highlights = 0
  highlights = re.findall(r"\*[^\n\*]*\*", value)
  double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value)
  for highlight in highlights:
    if highlight.strip("*").strip():
      num_highlights += 1
  for highlight in double_highlights:
    if highlight.removeprefix("**").removesuffix("**").strip():
      num_highlights += 1

  return num_highlights >= self._num_highlights
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
435
436
437
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_highlights": self._num_highlights}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
439
440
441
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_highlights"]
Instruction

An instruction template.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class Instruction:
  """An instruction template."""

  def __init__(self, instruction_id):
    self.id = instruction_id

  def build_description(self, **kwargs):
    raise NotImplementedError("`build_description` not implemented.")

  def get_instruction_args(self):
    raise NotImplementedError("`get_instruction_args` not implemented.")

  def get_instruction_args_keys(self):
    raise NotImplementedError("`get_instruction_args_keys` not implemented.")

  def check_following(self, value):
    raise NotImplementedError("`check_following` not implemented.")
id = instruction_id instance-attribute
build_description(**kwargs)
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
101
102
def build_description(self, **kwargs):
  raise NotImplementedError("`build_description` not implemented.")
check_following(value)
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
110
111
def check_following(self, value):
  raise NotImplementedError("`check_following` not implemented.")
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
104
105
def get_instruction_args(self):
  raise NotImplementedError("`get_instruction_args` not implemented.")
get_instruction_args_keys()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
107
108
def get_instruction_args_keys(self):
  raise NotImplementedError("`get_instruction_args_keys` not implemented.")
JsonFormat

Bases: Instruction

Check the Json format.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
class JsonFormat(Instruction):
  """Check the Json format."""

  def build_description(self):
    self._description_pattern = (
        "Entire output should be wrapped in JSON format. You can use markdown"
        " ticks such as ```."
    )
    return self._description_pattern

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    value = (
        value.strip()
        .removeprefix("```json")
        .removeprefix("```Json")
        .removeprefix("```JSON")
        .removeprefix("```")
        .removesuffix("```")
        .strip()
    )
    try:
      json.loads(value)
    except ValueError as _:
      return False
    return True
id = instruction_id instance-attribute
build_description()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
877
878
879
880
881
882
def build_description(self):
  self._description_pattern = (
      "Entire output should be wrapped in JSON format. You can use markdown"
      " ticks such as ```."
  )
  return self._description_pattern
check_following(value)
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
def check_following(self, value):
  value = (
      value.strip()
      .removeprefix("```json")
      .removeprefix("```Json")
      .removeprefix("```JSON")
      .removeprefix("```")
      .removesuffix("```")
      .strip()
  )
  try:
    json.loads(value)
  except ValueError as _:
    return False
  return True
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
884
885
886
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
888
889
890
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
KeySentenceChecker

Bases: Instruction

Check the existence of certain key sentences.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
class KeySentenceChecker(Instruction):
  """Check the existence of certain key sentences."""

  def build_description(self, key_sentences = None,
                        num_sentences = None):
    """Build the instruction description.

    Args:
      key_sentences: A sequences of strings representing the key sentences that
        are expected in the response.
      num_sentences: The number of key sentences that are expected to be seen in
        the response.

    Returns:
      A string representing the instruction description.
    """

    if not key_sentences:
      # TODO(jeffrey) make a generate sentences function? wonderwords package
      self._key_sentences = set(["For now, this is fine."])
    else:
      self._key_sentences = key_sentences

    if not num_sentences:
      self._num_sentences = random.randint(1, len(self._key_sentences))
    else:
      self._num_sentences = num_sentences

    self._description_pattern = (
        "Include {num_sentences} of the following sentences {key_sentences}"
    )

    return self._description_pattern.format(
        num_sentences=self._num_sentences, key_sentences=self._key_sentences
    )

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_sentences": self._num_sentences,
            "key_sentences": list(self._key_sentences)}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_sentences", "key_sentences"]

  def check_following(self, value):
    """Checks if the response contains the expected key sentences."""
    count = 0
    sentences = instructions_util.split_into_sentences(value)
    for sentence in self._key_sentences:
      if sentence in sentences:
        count += 1

    return count == self._num_sentences
id = instruction_id instance-attribute
build_description(key_sentences=None, num_sentences=None)

Build the instruction description.

Parameters:

Name Type Description Default
key_sentences

A sequences of strings representing the key sentences that are expected in the response.

None
num_sentences

The number of key sentences that are expected to be seen in the response.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
def build_description(self, key_sentences = None,
                      num_sentences = None):
  """Build the instruction description.

  Args:
    key_sentences: A sequences of strings representing the key sentences that
      are expected in the response.
    num_sentences: The number of key sentences that are expected to be seen in
      the response.

  Returns:
    A string representing the instruction description.
  """

  if not key_sentences:
    # TODO(jeffrey) make a generate sentences function? wonderwords package
    self._key_sentences = set(["For now, this is fine."])
  else:
    self._key_sentences = key_sentences

  if not num_sentences:
    self._num_sentences = random.randint(1, len(self._key_sentences))
  else:
    self._num_sentences = num_sentences

  self._description_pattern = (
      "Include {num_sentences} of the following sentences {key_sentences}"
  )

  return self._description_pattern.format(
      num_sentences=self._num_sentences, key_sentences=self._key_sentences
  )
check_following(value)

Checks if the response contains the expected key sentences.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1060
1061
1062
1063
1064
1065
1066
1067
1068
def check_following(self, value):
  """Checks if the response contains the expected key sentences."""
  count = 0
  sentences = instructions_util.split_into_sentences(value)
  for sentence in self._key_sentences:
    if sentence in sentences:
      count += 1

  return count == self._num_sentences
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1051
1052
1053
1054
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_sentences": self._num_sentences,
          "key_sentences": list(self._key_sentences)}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1056
1057
1058
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_sentences", "key_sentences"]
KeywordChecker

Bases: Instruction

Check the exisitence of certain keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
class KeywordChecker(Instruction):
  """Check the exisitence of certain keywords."""

  def build_description(self, *, keywords = None
                        ):
    """Build the instruction description.

    Args:
      keywords: A sequence of strings representing the keywords that are
        expected in the response.

    Returns:
      A string representing the instruction description.
    """

    if not keywords:
      self._keywords = instructions_util.generate_keywords(
          num_keywords=_NUM_KEYWORDS)
    else:
      self._keywords = keywords
    self._keywords = sorted(self._keywords)

    self._description_pattern = ("Include keywords {keywords} in the response.")

    return self._description_pattern.format(keywords=self._keywords)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"keywords": self._keywords}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["keywords"]

  def check_following(self, value):
    """Check if the response contain the expected keywords."""
    for keyword in self._keywords:
      if not re.search(keyword, value, flags=re.IGNORECASE):
        return False
    return True
id = instruction_id instance-attribute
build_description(*, keywords=None)

Build the instruction description.

Parameters:

Name Type Description Default
keywords

A sequence of strings representing the keywords that are expected in the response.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
def build_description(self, *, keywords = None
                      ):
  """Build the instruction description.

  Args:
    keywords: A sequence of strings representing the keywords that are
      expected in the response.

  Returns:
    A string representing the instruction description.
  """

  if not keywords:
    self._keywords = instructions_util.generate_keywords(
        num_keywords=_NUM_KEYWORDS)
  else:
    self._keywords = keywords
  self._keywords = sorted(self._keywords)

  self._description_pattern = ("Include keywords {keywords} in the response.")

  return self._description_pattern.format(keywords=self._keywords)
check_following(value)

Check if the response contain the expected keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
738
739
740
741
742
743
def check_following(self, value):
  """Check if the response contain the expected keywords."""
  for keyword in self._keywords:
    if not re.search(keyword, value, flags=re.IGNORECASE):
      return False
  return True
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
730
731
732
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"keywords": self._keywords}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
734
735
736
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["keywords"]
KeywordFrequencyChecker

Bases: Instruction

Check the keyword frequency.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
class KeywordFrequencyChecker(Instruction):
  """Check the keyword frequency."""

  def build_description(self, *, keyword = None,
                        frequency = None,
                        relation = None):
    """Build the instruction description.

    Args:
      keyword: A string representing a keyword that is expected in the response.
      frequency: An integer specifying the number of times `keyword` is expected
        to appear in the response.
      relation: A string in (`less than`, `at least`), defining the relational
        operator for comparison.
        Two relational comparisons are supported for now:
        if 'less than', the actual number of occurrences < frequency;
        if 'at least', the actual number of occurrences >= frequency.

    Returns:
      A string representing the instruction description.
    """
    if not keyword:
      self._keyword = instructions_util.generate_keywords(num_keywords=1)[0]
    else:
      self._keyword = keyword.strip()

    self._frequency = frequency
    if self._frequency is None or self._frequency < 0:
      self._frequency = random.randint(1, _KEYWORD_FREQUENCY)

    if relation is None:
      self._comparison_relation = random.choice(_COMPARISON_RELATION)
    elif relation not in _COMPARISON_RELATION:
      raise ValueError("The supported relation for comparison must be in "
                       f"{_COMPARISON_RELATION}, but {relation} is given.")
    else:
      self._comparison_relation = relation

    self._description_pattern = (
        "In your response, the word {keyword} should appear {relation} " +
        "{frequency} times.")

    return self._description_pattern.format(
        keyword=self._keyword,
        relation=self._comparison_relation,
        frequency=self._frequency)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"keyword": self._keyword,
            "frequency": self._frequency,
            "relation": self._comparison_relation}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["keyword", "frequency", "relation"]

  def check_following(self, value):
    """Checks if the response contain the keyword with required frequency."""
    actual_occurrences = len(re.findall(
        self._keyword, value, flags=re.IGNORECASE))

    if self._comparison_relation == _COMPARISON_RELATION[0]:
      return actual_occurrences < self._frequency
    elif self._comparison_relation == _COMPARISON_RELATION[1]:
      return actual_occurrences >= self._frequency
id = instruction_id instance-attribute
build_description(*, keyword=None, frequency=None, relation=None)

Build the instruction description.

Parameters:

Name Type Description Default
keyword

A string representing a keyword that is expected in the response.

None
frequency

An integer specifying the number of times keyword is expected to appear in the response.

None
relation

A string in (less than, at least), defining the relational operator for comparison. Two relational comparisons are supported for now: if 'less than', the actual number of occurrences < frequency; if 'at least', the actual number of occurrences >= frequency.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
def build_description(self, *, keyword = None,
                      frequency = None,
                      relation = None):
  """Build the instruction description.

  Args:
    keyword: A string representing a keyword that is expected in the response.
    frequency: An integer specifying the number of times `keyword` is expected
      to appear in the response.
    relation: A string in (`less than`, `at least`), defining the relational
      operator for comparison.
      Two relational comparisons are supported for now:
      if 'less than', the actual number of occurrences < frequency;
      if 'at least', the actual number of occurrences >= frequency.

  Returns:
    A string representing the instruction description.
  """
  if not keyword:
    self._keyword = instructions_util.generate_keywords(num_keywords=1)[0]
  else:
    self._keyword = keyword.strip()

  self._frequency = frequency
  if self._frequency is None or self._frequency < 0:
    self._frequency = random.randint(1, _KEYWORD_FREQUENCY)

  if relation is None:
    self._comparison_relation = random.choice(_COMPARISON_RELATION)
  elif relation not in _COMPARISON_RELATION:
    raise ValueError("The supported relation for comparison must be in "
                     f"{_COMPARISON_RELATION}, but {relation} is given.")
  else:
    self._comparison_relation = relation

  self._description_pattern = (
      "In your response, the word {keyword} should appear {relation} " +
      "{frequency} times.")

  return self._description_pattern.format(
      keyword=self._keyword,
      relation=self._comparison_relation,
      frequency=self._frequency)
check_following(value)

Checks if the response contain the keyword with required frequency.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
803
804
805
806
807
808
809
810
811
def check_following(self, value):
  """Checks if the response contain the keyword with required frequency."""
  actual_occurrences = len(re.findall(
      self._keyword, value, flags=re.IGNORECASE))

  if self._comparison_relation == _COMPARISON_RELATION[0]:
    return actual_occurrences < self._frequency
  elif self._comparison_relation == _COMPARISON_RELATION[1]:
    return actual_occurrences >= self._frequency
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
793
794
795
796
797
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"keyword": self._keyword,
          "frequency": self._frequency,
          "relation": self._comparison_relation}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
799
800
801
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["keyword", "frequency", "relation"]
LetterFrequencyChecker

Bases: Instruction

Checks letter frequency.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
class LetterFrequencyChecker(Instruction):
  """Checks letter frequency."""

  def build_description(self, *, letter = None,
                        let_frequency = None,
                        let_relation = None):
    """Build the instruction description.

    Args:
      letter: A string representing a letter that is expected in the response.
      let_frequency: An integer specifying the number of times `keyword` is
        expected to appear in the response.
      let_relation: A string in (`less than`, `at least`), defining the
        relational operator for comparison. Two relational comparisons are
        supported for now; if 'less than', the actual number of
        occurrences < frequency; if 'at least', the actual number of
        occurrences >= frequency.

    Returns:
      A string representing the instruction description.
    """
    if (
        not letter
        or len(letter) > 1
        or ord(letter.lower()) < 97
        or ord(letter.lower()) > 122
    ):
      self._letter = random.choice(list(string.ascii_letters))
    else:
      self._letter = letter.strip()
    self._letter = self._letter.lower()

    self._frequency = let_frequency
    if self._frequency is None or self._frequency < 0:
      self._frequency = random.randint(1, _LETTER_FREQUENCY)

    if let_relation is None:
      self._comparison_relation = random.choice(_COMPARISON_RELATION)
    elif let_relation not in _COMPARISON_RELATION:
      raise ValueError(
          "The supported relation for comparison must be in "
          f"{_COMPARISON_RELATION}, but {let_relation} is given."
      )
    else:
      self._comparison_relation = let_relation

    self._description_pattern = (
        "In your response, the letter {letter} should appear {let_relation}"
        " {let_frequency} times."
    )

    return self._description_pattern.format(
        letter=self._letter,
        let_frequency=self._frequency,
        let_relation=self._comparison_relation,
    )

  def get_instruction_args(self):
    """Returns the keyword args of build description."""
    return {"letter": self._letter,
            "let_frequency": self._frequency,
            "let_relation": self._comparison_relation}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["letter", "let_frequency", "let_relation"]

  def check_following(self, value):
    """Checks that the response contains the letter at the right frequency."""
    value = value.lower()
    letters = collections.Counter(value)

    if self._comparison_relation == _COMPARISON_RELATION[0]:
      return letters[self._letter] < self._frequency
    else:
      return letters[self._letter] >= self._frequency
id = instruction_id instance-attribute
build_description(*, letter=None, let_frequency=None, let_relation=None)

Build the instruction description.

Parameters:

Name Type Description Default
letter

A string representing a letter that is expected in the response.

None
let_frequency

An integer specifying the number of times keyword is expected to appear in the response.

None
let_relation

A string in (less than, at least), defining the relational operator for comparison. Two relational comparisons are supported for now; if 'less than', the actual number of occurrences < frequency; if 'at least', the actual number of occurrences >= frequency.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
def build_description(self, *, letter = None,
                      let_frequency = None,
                      let_relation = None):
  """Build the instruction description.

  Args:
    letter: A string representing a letter that is expected in the response.
    let_frequency: An integer specifying the number of times `keyword` is
      expected to appear in the response.
    let_relation: A string in (`less than`, `at least`), defining the
      relational operator for comparison. Two relational comparisons are
      supported for now; if 'less than', the actual number of
      occurrences < frequency; if 'at least', the actual number of
      occurrences >= frequency.

  Returns:
    A string representing the instruction description.
  """
  if (
      not letter
      or len(letter) > 1
      or ord(letter.lower()) < 97
      or ord(letter.lower()) > 122
  ):
    self._letter = random.choice(list(string.ascii_letters))
  else:
    self._letter = letter.strip()
  self._letter = self._letter.lower()

  self._frequency = let_frequency
  if self._frequency is None or self._frequency < 0:
    self._frequency = random.randint(1, _LETTER_FREQUENCY)

  if let_relation is None:
    self._comparison_relation = random.choice(_COMPARISON_RELATION)
  elif let_relation not in _COMPARISON_RELATION:
    raise ValueError(
        "The supported relation for comparison must be in "
        f"{_COMPARISON_RELATION}, but {let_relation} is given."
    )
  else:
    self._comparison_relation = let_relation

  self._description_pattern = (
      "In your response, the letter {letter} should appear {let_relation}"
      " {let_frequency} times."
  )

  return self._description_pattern.format(
      letter=self._letter,
      let_frequency=self._frequency,
      let_relation=self._comparison_relation,
  )
check_following(value)

Checks that the response contains the letter at the right frequency.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1384
1385
1386
1387
1388
1389
1390
1391
1392
def check_following(self, value):
  """Checks that the response contains the letter at the right frequency."""
  value = value.lower()
  letters = collections.Counter(value)

  if self._comparison_relation == _COMPARISON_RELATION[0]:
    return letters[self._letter] < self._frequency
  else:
    return letters[self._letter] >= self._frequency
get_instruction_args()

Returns the keyword args of build description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1374
1375
1376
1377
1378
def get_instruction_args(self):
  """Returns the keyword args of build description."""
  return {"letter": self._letter,
          "let_frequency": self._frequency,
          "let_relation": self._comparison_relation}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1380
1381
1382
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["letter", "let_frequency", "let_relation"]
LowercaseLettersEnglishChecker

Bases: Instruction

Checks that the response is in english and is in all lowercase letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
class LowercaseLettersEnglishChecker(Instruction):
  """Checks that the response is in english and is in all lowercase letters."""

  def build_description(self):
    """Build the instruction description."""
    self._description_pattern = (
        "Your entire response should be in English, and in all lowercase"
        " letters. No capital letters are allowed."
    )
    return self._description_pattern

  def get_instruction_args(self):
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks that the response is in English and in all lowercase letters."""
    assert isinstance(value, str)

    try:
      return value.islower() and langdetect.detect(value) == "en"
    except langdetect.LangDetectException as e:
      # Count as instruction is followed.
      logging.error(
          "Unable to detect language for text %s due to %s", value, e
      )  # refex: disable=pytotw.037
      return True
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1429
1430
1431
1432
1433
1434
1435
def build_description(self):
  """Build the instruction description."""
  self._description_pattern = (
      "Your entire response should be in English, and in all lowercase"
      " letters. No capital letters are allowed."
  )
  return self._description_pattern
check_following(value)

Checks that the response is in English and in all lowercase letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
def check_following(self, value):
  """Checks that the response is in English and in all lowercase letters."""
  assert isinstance(value, str)

  try:
    return value.islower() and langdetect.detect(value) == "en"
  except langdetect.LangDetectException as e:
    # Count as instruction is followed.
    logging.error(
        "Unable to detect language for text %s due to %s", value, e
    )  # refex: disable=pytotw.037
    return True
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1437
1438
def get_instruction_args(self):
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1440
1441
1442
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
NumberOfSentences

Bases: Instruction

Check the number of sentences.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class NumberOfSentences(Instruction):
  """Check the number of sentences."""

  def build_description(self, *, num_sentences = None,
                        relation = None):
    """Build the instruction description.

    Args:
      num_sentences: An integer specifying the number of sentences as a
        threshold.
      relation: A string in (`less than`, `at least`), defining the relational
        operator for comparison.
        Two relational comparisons are supported for now:
        if 'less than', the actual number of sentences < the threshold;
        if 'at least', the actual number of sentences >= the threshold.

    Returns:
      A string representing the instruction description.
    """
    # The number of sentences as a threshold for comparison.
    self._num_sentences_threshold = num_sentences
    if (self._num_sentences_threshold is None or
        self._num_sentences_threshold < 0):
      self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES)

    if relation is None:
      self._comparison_relation = random.choice(_COMPARISON_RELATION)
    elif relation not in _COMPARISON_RELATION:
      raise ValueError("The supported relation for comparison must be in "
                       f"{_COMPARISON_RELATION}, but {relation} is given.")
    else:
      self._comparison_relation = relation

    self._description_pattern = (
        "Your response should contain {relation} {num_sentences} sentences.")
    return self._description_pattern.format(
        relation=self._comparison_relation,
        num_sentences=self._num_sentences_threshold)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_sentences": self._num_sentences_threshold,
            "relation": self._comparison_relation}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_sentences", "relation"]

  def check_following(self, value):
    """Check if the number of sentences follows the instruction.

    Args:
      value: A string representing the response.

    Returns:
      True if the response follows the instruction.

    Raise:
        ValueError if the string in `instruction_args` is not in
        [`less_than`, `at_least`].
    """
    num_sentences = instructions_util.count_sentences(value)
    if self._comparison_relation == _COMPARISON_RELATION[0]:
      return num_sentences < self._num_sentences_threshold
    elif self._comparison_relation == _COMPARISON_RELATION[1]:
      return num_sentences >= self._num_sentences_threshold
id = instruction_id instance-attribute
build_description(*, num_sentences=None, relation=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_sentences

An integer specifying the number of sentences as a threshold.

None
relation

A string in (less than, at least), defining the relational operator for comparison. Two relational comparisons are supported for now: if 'less than', the actual number of sentences < the threshold; if 'at least', the actual number of sentences >= the threshold.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def build_description(self, *, num_sentences = None,
                      relation = None):
  """Build the instruction description.

  Args:
    num_sentences: An integer specifying the number of sentences as a
      threshold.
    relation: A string in (`less than`, `at least`), defining the relational
      operator for comparison.
      Two relational comparisons are supported for now:
      if 'less than', the actual number of sentences < the threshold;
      if 'at least', the actual number of sentences >= the threshold.

  Returns:
    A string representing the instruction description.
  """
  # The number of sentences as a threshold for comparison.
  self._num_sentences_threshold = num_sentences
  if (self._num_sentences_threshold is None or
      self._num_sentences_threshold < 0):
    self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES)

  if relation is None:
    self._comparison_relation = random.choice(_COMPARISON_RELATION)
  elif relation not in _COMPARISON_RELATION:
    raise ValueError("The supported relation for comparison must be in "
                     f"{_COMPARISON_RELATION}, but {relation} is given.")
  else:
    self._comparison_relation = relation

  self._description_pattern = (
      "Your response should contain {relation} {num_sentences} sentences.")
  return self._description_pattern.format(
      relation=self._comparison_relation,
      num_sentences=self._num_sentences_threshold)
check_following(value)

Check if the number of sentences follows the instruction.

Parameters:

Name Type Description Default
value

A string representing the response.

required

Returns:

Type Description

True if the response follows the instruction.

Raise

ValueError if the string in instruction_args is not in [less_than, at_least].

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def check_following(self, value):
  """Check if the number of sentences follows the instruction.

  Args:
    value: A string representing the response.

  Returns:
    True if the response follows the instruction.

  Raise:
      ValueError if the string in `instruction_args` is not in
      [`less_than`, `at_least`].
  """
  num_sentences = instructions_util.count_sentences(value)
  if self._comparison_relation == _COMPARISON_RELATION[0]:
    return num_sentences < self._num_sentences_threshold
  elif self._comparison_relation == _COMPARISON_RELATION[1]:
    return num_sentences >= self._num_sentences_threshold
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
207
208
209
210
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_sentences": self._num_sentences_threshold,
          "relation": self._comparison_relation}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
212
213
214
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_sentences", "relation"]
NumberOfWords

Bases: Instruction

Checks the number of words.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
class NumberOfWords(Instruction):
  """Checks the number of words."""

  def build_description(self, *, num_words = None,
                        relation = None):
    """Build the instruction description.

    Args:
      num_words: An integer specifying the number of words contained in the
        response.
      relation: A string in (`less than`, `at least`), defining the relational
        operator for comparison.
        Two relational comparisons are supported for now:
        if 'less than', the actual number of words < num_words;
        if 'at least', the actual number of words >= num_words.

    Returns:
      A string representing the instruction description.
    """

    self._num_words = num_words
    if self._num_words is None or self._num_words < 0:
      self._num_words = random.randint(
          _NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT
      )

    if relation is None:
      self._comparison_relation = random.choice(_COMPARISON_RELATION)
    elif relation not in _COMPARISON_RELATION:
      raise ValueError("The supported relation for comparison must be in "
                       f"{_COMPARISON_RELATION}, but {relation} is given.")
    else:
      self._comparison_relation = relation

    self._description_pattern = (
        "Answer with {relation} {num_words} words.")

    return self._description_pattern.format(
        relation=self._comparison_relation,
        num_words=self._num_words)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_words": self._num_words,
            "relation": self._comparison_relation}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_words", "relation"]

  def check_following(self, value):
    """Checks if the response contains the expected number of words."""
    num_words = instructions_util.count_words(value)

    if self._comparison_relation == _COMPARISON_RELATION[0]:
      return num_words < self._num_words
    elif self._comparison_relation == _COMPARISON_RELATION[1]:
      return num_words >= self._num_words
id = instruction_id instance-attribute
build_description(*, num_words=None, relation=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_words

An integer specifying the number of words contained in the response.

None
relation

A string in (less than, at least), defining the relational operator for comparison. Two relational comparisons are supported for now: if 'less than', the actual number of words < num_words; if 'at least', the actual number of words >= num_words.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def build_description(self, *, num_words = None,
                      relation = None):
  """Build the instruction description.

  Args:
    num_words: An integer specifying the number of words contained in the
      response.
    relation: A string in (`less than`, `at least`), defining the relational
      operator for comparison.
      Two relational comparisons are supported for now:
      if 'less than', the actual number of words < num_words;
      if 'at least', the actual number of words >= num_words.

  Returns:
    A string representing the instruction description.
  """

  self._num_words = num_words
  if self._num_words is None or self._num_words < 0:
    self._num_words = random.randint(
        _NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT
    )

  if relation is None:
    self._comparison_relation = random.choice(_COMPARISON_RELATION)
  elif relation not in _COMPARISON_RELATION:
    raise ValueError("The supported relation for comparison must be in "
                     f"{_COMPARISON_RELATION}, but {relation} is given.")
  else:
    self._comparison_relation = relation

  self._description_pattern = (
      "Answer with {relation} {num_words} words.")

  return self._description_pattern.format(
      relation=self._comparison_relation,
      num_words=self._num_words)
check_following(value)

Checks if the response contains the expected number of words.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
864
865
866
867
868
869
870
871
def check_following(self, value):
  """Checks if the response contains the expected number of words."""
  num_words = instructions_util.count_words(value)

  if self._comparison_relation == _COMPARISON_RELATION[0]:
    return num_words < self._num_words
  elif self._comparison_relation == _COMPARISON_RELATION[1]:
    return num_words >= self._num_words
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
855
856
857
858
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_words": self._num_words,
          "relation": self._comparison_relation}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
860
861
862
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_words", "relation"]
ParagraphChecker

Bases: Instruction

Checks the paragraphs.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
class ParagraphChecker(Instruction):
  """Checks the paragraphs."""

  def build_description(self, *, num_paragraphs = None):
    """Build the instruction description.

    Args:
      num_paragraphs: An integer specifying the number of paragraphs.

    Returns:
      A string representing the instruction description.
    """
    self._num_paragraphs = num_paragraphs
    if self._num_paragraphs is None or self._num_paragraphs < 0:
      self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS)

    self._description_pattern = (
        "There should be {num_paragraphs} paragraphs. " +
        "Paragraphs are separated with the markdown divider: ***")

    return self._description_pattern.format(num_paragraphs=self._num_paragraphs)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_paragraphs": self._num_paragraphs}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_paragraphs"]

  def check_following(self, value):
    """Checks the response contains required number of paragraphs.

    Args:
      value: A string representing the response. The response may contain
        paragraphs that are separated by the markdown divider: `***`.

    Returns:
      True if the actual number of paragraphs is the same as required;
      otherwise, False.
    """
    paragraphs = re.split(r"\s?\*\*\*\s?", value)
    num_paragraphs = len(paragraphs)

    for index, paragraph in enumerate(paragraphs):
      if not paragraph.strip():
        if index == 0 or index == len(paragraphs) - 1:
          num_paragraphs -= 1
        else:
          return False

    return num_paragraphs == self._num_paragraphs
id = instruction_id instance-attribute
build_description(*, num_paragraphs=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_paragraphs

An integer specifying the number of paragraphs.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def build_description(self, *, num_paragraphs = None):
  """Build the instruction description.

  Args:
    num_paragraphs: An integer specifying the number of paragraphs.

  Returns:
    A string representing the instruction description.
  """
  self._num_paragraphs = num_paragraphs
  if self._num_paragraphs is None or self._num_paragraphs < 0:
    self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS)

  self._description_pattern = (
      "There should be {num_paragraphs} paragraphs. " +
      "Paragraphs are separated with the markdown divider: ***")

  return self._description_pattern.format(num_paragraphs=self._num_paragraphs)
check_following(value)

Checks the response contains required number of paragraphs.

Parameters:

Name Type Description Default
value

A string representing the response. The response may contain paragraphs that are separated by the markdown divider: ***.

required

Returns:

Type Description

True if the actual number of paragraphs is the same as required;

otherwise, False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def check_following(self, value):
  """Checks the response contains required number of paragraphs.

  Args:
    value: A string representing the response. The response may contain
      paragraphs that are separated by the markdown divider: `***`.

  Returns:
    True if the actual number of paragraphs is the same as required;
    otherwise, False.
  """
  paragraphs = re.split(r"\s?\*\*\*\s?", value)
  num_paragraphs = len(paragraphs)

  for index, paragraph in enumerate(paragraphs):
    if not paragraph.strip():
      if index == 0 or index == len(paragraphs) - 1:
        num_paragraphs -= 1
      else:
        return False

  return num_paragraphs == self._num_paragraphs
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
553
554
555
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_paragraphs": self._num_paragraphs}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
557
558
559
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_paragraphs"]
ParagraphFirstWordCheck

Bases: Instruction

Check the paragraph and the first word of the nth paragraph.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
class ParagraphFirstWordCheck(Instruction):
  """Check the paragraph and the first word of the nth paragraph."""

  def build_description(self, num_paragraphs = None,
                        nth_paragraph = None,
                        first_word = None):
    r"""Build the instruction description.

    Args:
      num_paragraphs: An integer indicating the number of paragraphs expected
        in the response. A paragraph is a subset of the string that is
        expected to be separated by '\n\n'.
      nth_paragraph: An integer indicating the paragraph number that we look at.
        Note that n starts from 1.
      first_word: A string that represent the first word of the bth paragraph.

    Returns:
      A string representing the instruction description.
    """
    self._num_paragraphs = num_paragraphs
    if self._num_paragraphs is None or self._num_paragraphs < 0:
      self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS)

    self._nth_paragraph = nth_paragraph
    if (
        self._nth_paragraph is None
        or self._nth_paragraph <= 0
        or self._nth_paragraph > self._num_paragraphs
    ):
      self._nth_paragraph = random.randint(1, self._num_paragraphs + 1)

    self._first_word = first_word
    if self._first_word is None:
      self._first_word = instructions_util.generate_keywords(num_keywords=1)[0]
    self._first_word = self._first_word.lower()

    self._description_pattern = (
        "There should be {num_paragraphs} paragraphs. " +
        "Paragraphs and only paragraphs are separated with each other by two " +
        "new lines as if it was '\\n\\n' in python. " +
        "Paragraph {nth_paragraph} must start with word {first_word}.")

    return self._description_pattern.format(
        num_paragraphs=self._num_paragraphs,
        nth_paragraph=self._nth_paragraph,
        first_word=self._first_word)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_paragraphs": self._num_paragraphs,
            "nth_paragraph": self._nth_paragraph,
            "first_word": self._first_word}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_paragraphs", "nth_paragraph", "first_word"]

  def check_following(self, value):
    """Checks for required number of paragraphs and correct first word.

    Args:
      value: a string representing the response. The response may contain
        paragraphs that are separated by two new lines and the first word of
        the nth paragraph will have to match a specified word.

    Returns:
      True if the number of paragraphs is the same as required and the first
      word of the specified paragraph is the same as required. Otherwise, false.
    """

    paragraphs = re.split(r"\n\n", value)
    num_paragraphs = len(paragraphs)

    for paragraph in paragraphs:
      if not paragraph.strip():
        num_paragraphs -= 1

    # check that index doesn't go out of bounds
    if self._nth_paragraph <= num_paragraphs:
      paragraph = paragraphs[self._nth_paragraph - 1].strip()
      if not paragraph:
        return False
    else:
      return False

    first_word = ""
    punctuation = {".", ",", "?", "!", "'", '"'}

    # get first word and remove punctuation
    word = paragraph.split()[0].strip()
    # TODO(jeffrey): make more complex?
    word = word.lstrip("'")
    word = word.lstrip('"')

    for letter in word:
      if letter in punctuation:
        break
      first_word += letter.lower()

    return (
        num_paragraphs == self._num_paragraphs
        and first_word == self._first_word
    )
id = instruction_id instance-attribute
build_description(num_paragraphs=None, nth_paragraph=None, first_word=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_paragraphs

An integer indicating the number of paragraphs expected in the response. A paragraph is a subset of the string that is expected to be separated by '\n\n'.

None
nth_paragraph

An integer indicating the paragraph number that we look at. Note that n starts from 1.

None
first_word

A string that represent the first word of the bth paragraph.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def build_description(self, num_paragraphs = None,
                      nth_paragraph = None,
                      first_word = None):
  r"""Build the instruction description.

  Args:
    num_paragraphs: An integer indicating the number of paragraphs expected
      in the response. A paragraph is a subset of the string that is
      expected to be separated by '\n\n'.
    nth_paragraph: An integer indicating the paragraph number that we look at.
      Note that n starts from 1.
    first_word: A string that represent the first word of the bth paragraph.

  Returns:
    A string representing the instruction description.
  """
  self._num_paragraphs = num_paragraphs
  if self._num_paragraphs is None or self._num_paragraphs < 0:
    self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS)

  self._nth_paragraph = nth_paragraph
  if (
      self._nth_paragraph is None
      or self._nth_paragraph <= 0
      or self._nth_paragraph > self._num_paragraphs
  ):
    self._nth_paragraph = random.randint(1, self._num_paragraphs + 1)

  self._first_word = first_word
  if self._first_word is None:
    self._first_word = instructions_util.generate_keywords(num_keywords=1)[0]
  self._first_word = self._first_word.lower()

  self._description_pattern = (
      "There should be {num_paragraphs} paragraphs. " +
      "Paragraphs and only paragraphs are separated with each other by two " +
      "new lines as if it was '\\n\\n' in python. " +
      "Paragraph {nth_paragraph} must start with word {first_word}.")

  return self._description_pattern.format(
      num_paragraphs=self._num_paragraphs,
      nth_paragraph=self._nth_paragraph,
      first_word=self._first_word)
check_following(value)

Checks for required number of paragraphs and correct first word.

Parameters:

Name Type Description Default
value

a string representing the response. The response may contain paragraphs that are separated by two new lines and the first word of the nth paragraph will have to match a specified word.

required

Returns:

Type Description

True if the number of paragraphs is the same as required and the first

word of the specified paragraph is the same as required. Otherwise, false.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
def check_following(self, value):
  """Checks for required number of paragraphs and correct first word.

  Args:
    value: a string representing the response. The response may contain
      paragraphs that are separated by two new lines and the first word of
      the nth paragraph will have to match a specified word.

  Returns:
    True if the number of paragraphs is the same as required and the first
    word of the specified paragraph is the same as required. Otherwise, false.
  """

  paragraphs = re.split(r"\n\n", value)
  num_paragraphs = len(paragraphs)

  for paragraph in paragraphs:
    if not paragraph.strip():
      num_paragraphs -= 1

  # check that index doesn't go out of bounds
  if self._nth_paragraph <= num_paragraphs:
    paragraph = paragraphs[self._nth_paragraph - 1].strip()
    if not paragraph:
      return False
  else:
    return False

  first_word = ""
  punctuation = {".", ",", "?", "!", "'", '"'}

  # get first word and remove punctuation
  word = paragraph.split()[0].strip()
  # TODO(jeffrey): make more complex?
  word = word.lstrip("'")
  word = word.lstrip('"')

  for letter in word:
    if letter in punctuation:
      break
    first_word += letter.lower()

  return (
      num_paragraphs == self._num_paragraphs
      and first_word == self._first_word
  )
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
956
957
958
959
960
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_paragraphs": self._num_paragraphs,
          "nth_paragraph": self._nth_paragraph,
          "first_word": self._first_word}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
962
963
964
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_paragraphs", "nth_paragraph", "first_word"]
PlaceholderChecker

Bases: Instruction

Check the placeholders in template writing.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class PlaceholderChecker(Instruction):
  """Check the placeholders in template writing."""

  def build_description(self, *, num_placeholders = None):
    """Build the instruction description.

    Args:
      num_placeholders: An integer denoting the minimum number of
        placeholders required in the response.

    Returns:
      A string representing the instruction description.
    """
    self._num_placeholders = num_placeholders
    if self._num_placeholders is None or self._num_placeholders < 0:
      self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS)
    self._description_pattern = (
        "The response must contain at least {num_placeholders} placeholders " +
        "represented by square brackets, such as [address].")
    return self._description_pattern.format(
        num_placeholders=self._num_placeholders)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"num_placeholders": self._num_placeholders}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["num_placeholders"]

  def check_following(self, value):
    """Check if the number of placeholders follows the instruction.

    Args:
      value: A string representing the response.

    Returns:
      True if the actual number of placeholders in the response is greater than
      or equal to `num_placeholders`; otherwise, False.
    """
    placeholders = re.findall(r"\[.*?\]", value)
    num_placeholders = len(placeholders)
    return num_placeholders >= self._num_placeholders
id = instruction_id instance-attribute
build_description(*, num_placeholders=None)

Build the instruction description.

Parameters:

Name Type Description Default
num_placeholders

An integer denoting the minimum number of placeholders required in the response.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def build_description(self, *, num_placeholders = None):
  """Build the instruction description.

  Args:
    num_placeholders: An integer denoting the minimum number of
      placeholders required in the response.

  Returns:
    A string representing the instruction description.
  """
  self._num_placeholders = num_placeholders
  if self._num_placeholders is None or self._num_placeholders < 0:
    self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS)
  self._description_pattern = (
      "The response must contain at least {num_placeholders} placeholders " +
      "represented by square brackets, such as [address].")
  return self._description_pattern.format(
      num_placeholders=self._num_placeholders)
check_following(value)

Check if the number of placeholders follows the instruction.

Parameters:

Name Type Description Default
value

A string representing the response.

required

Returns:

Type Description

True if the actual number of placeholders in the response is greater than

or equal to num_placeholders; otherwise, False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
266
267
268
269
270
271
272
273
274
275
276
277
278
def check_following(self, value):
  """Check if the number of placeholders follows the instruction.

  Args:
    value: A string representing the response.

  Returns:
    True if the actual number of placeholders in the response is greater than
    or equal to `num_placeholders`; otherwise, False.
  """
  placeholders = re.findall(r"\[.*?\]", value)
  num_placeholders = len(placeholders)
  return num_placeholders >= self._num_placeholders
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
258
259
260
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"num_placeholders": self._num_placeholders}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
262
263
264
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["num_placeholders"]
PostscriptChecker

Bases: Instruction

Checks the postscript.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
class PostscriptChecker(Instruction):
  """Checks the postscript."""

  def build_description(self, *, postscript_marker = None
                        ):
    """Build the instruction description.

    Args:
      postscript_marker: A string containing the keyword that marks the start
        of the postscript section.

    Returns:
      A string representing the instruction description.
    """
    self._postscript_marker = postscript_marker.strip() if isinstance(
        postscript_marker, str) else postscript_marker
    if self._postscript_marker is None:
      self._postscript_marker = random.choice(_POSTSCRIPT_MARKER)

    self._description_pattern = (
        "At the end of your response, please explicitly add a postscript " +
        "starting with {postscript}")

    return self._description_pattern.format(postscript=self._postscript_marker)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"postscript_marker": self._postscript_marker}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["postscript_marker"]

  def check_following(self, value):
    """Checks if the response follows the postscript format.

    Args:
      value: a string representing the response. The response is expected to
        contain a postscript section.

    Returns:
      True if the response contains a postscript section starting with
      the keyword containing in the `instruction_args`; otherwise False.
    """
    value = value.lower()
    if self._postscript_marker == "P.P.S":
      postscript_pattern = r"\s*p\.\s?p\.\s?s.*$"
    elif self._postscript_marker == "P.S.":
      postscript_pattern = r"\s*p\.\s?s\..*$"
    else:
      postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$"
    postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE)
    return True if postscript else False
id = instruction_id instance-attribute
build_description(*, postscript_marker=None)

Build the instruction description.

Parameters:

Name Type Description Default
postscript_marker

A string containing the keyword that marks the start of the postscript section.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def build_description(self, *, postscript_marker = None
                      ):
  """Build the instruction description.

  Args:
    postscript_marker: A string containing the keyword that marks the start
      of the postscript section.

  Returns:
    A string representing the instruction description.
  """
  self._postscript_marker = postscript_marker.strip() if isinstance(
      postscript_marker, str) else postscript_marker
  if self._postscript_marker is None:
    self._postscript_marker = random.choice(_POSTSCRIPT_MARKER)

  self._description_pattern = (
      "At the end of your response, please explicitly add a postscript " +
      "starting with {postscript}")

  return self._description_pattern.format(postscript=self._postscript_marker)
check_following(value)

Checks if the response follows the postscript format.

Parameters:

Name Type Description Default
value

a string representing the response. The response is expected to contain a postscript section.

required

Returns:

Type Description

True if the response contains a postscript section starting with

the keyword containing in the instruction_args; otherwise False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def check_following(self, value):
  """Checks if the response follows the postscript format.

  Args:
    value: a string representing the response. The response is expected to
      contain a postscript section.

  Returns:
    True if the response contains a postscript section starting with
    the keyword containing in the `instruction_args`; otherwise False.
  """
  value = value.lower()
  if self._postscript_marker == "P.P.S":
    postscript_pattern = r"\s*p\.\s?p\.\s?s.*$"
  elif self._postscript_marker == "P.S.":
    postscript_pattern = r"\s*p\.\s?s\..*$"
  else:
    postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$"
  postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE)
  return True if postscript else False
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
610
611
612
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"postscript_marker": self._postscript_marker}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
614
615
616
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["postscript_marker"]
QuotationChecker

Bases: Instruction

Checks response is wrapped with double quotation marks.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
class QuotationChecker(Instruction):
  """Checks response is wrapped with double quotation marks."""

  def build_description(self):
    """Build the instruction description."""
    self._description_pattern = (
        "Wrap your entire response with double quotation marks."
    )
    return self._description_pattern

  def get_instruction_args(self):
    """Returns the keyword args of build description."""
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks if the response is wrapped with double quotation marks."""
    value = value.strip()
    return len(value) > 1 and value[0] == '"' and value[-1] == '"'
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1549
1550
1551
1552
1553
1554
def build_description(self):
  """Build the instruction description."""
  self._description_pattern = (
      "Wrap your entire response with double quotation marks."
  )
  return self._description_pattern
check_following(value)

Checks if the response is wrapped with double quotation marks.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1564
1565
1566
1567
def check_following(self, value):
  """Checks if the response is wrapped with double quotation marks."""
  value = value.strip()
  return len(value) > 1 and value[0] == '"' and value[-1] == '"'
get_instruction_args()

Returns the keyword args of build description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1556
1557
1558
def get_instruction_args(self):
  """Returns the keyword args of build description."""
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1560
1561
1562
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
RepeatPromptThenAnswer

Bases: Instruction

Checks that Prompt is first repeated then answered.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
class RepeatPromptThenAnswer(Instruction):
  """Checks that Prompt is first repeated then answered."""

  def build_description(self, *, prompt_to_repeat = None):
    """Build the instruction description.

    Args:
      prompt_to_repeat: The prompt that is meant to be repeated.

    Returns:
      A string representing the instruction description.
    """
    if not prompt_to_repeat:
      raise ValueError("prompt_to_repeat must be set.")
    else:
      self._prompt_to_repeat = prompt_to_repeat
    self._description_pattern = (
        "First repeat the request word for word without change,"
        " then give your answer (1. do not say any words or characters"
        " before repeating the request; 2. the request you need to repeat"
        " does not include this sentence)"
    )
    return self._description_pattern

  def get_instruction_args(self):
    return {"prompt_to_repeat": self._prompt_to_repeat}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["prompt_to_repeat"]

  def check_following(self, value):
    if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()):
      return True
    return False
id = instruction_id instance-attribute
build_description(*, prompt_to_repeat=None)

Build the instruction description.

Parameters:

Name Type Description Default
prompt_to_repeat

The prompt that is meant to be repeated.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
def build_description(self, *, prompt_to_repeat = None):
  """Build the instruction description.

  Args:
    prompt_to_repeat: The prompt that is meant to be repeated.

  Returns:
    A string representing the instruction description.
  """
  if not prompt_to_repeat:
    raise ValueError("prompt_to_repeat must be set.")
  else:
    self._prompt_to_repeat = prompt_to_repeat
  self._description_pattern = (
      "First repeat the request word for word without change,"
      " then give your answer (1. do not say any words or characters"
      " before repeating the request; 2. the request you need to repeat"
      " does not include this sentence)"
  )
  return self._description_pattern
check_following(value)
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1245
1246
1247
1248
def check_following(self, value):
  if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()):
    return True
  return False
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1238
1239
def get_instruction_args(self):
  return {"prompt_to_repeat": self._prompt_to_repeat}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1241
1242
1243
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["prompt_to_repeat"]
RephraseChecker

Bases: Instruction

Checks the repharse.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
class RephraseChecker(Instruction):
  """Checks the repharse."""

  def build_description(self, *, original_message):
    """Build the instruction description.

    Args:
      original_message: A string representing the original message. The
        rephrased response should only change its words/sentences in between
        its two asterisks, for example, *change me*. Both original and rephrased
        messages should contain the changes in the form of *change me*.

    Returns:
      A string representing the instruction description.
    """
    if not self.is_change(original_message):
      raise ValueError(f"Message {original_message} does not contain changes "
                       "in the form of *change me*.")

    self._reference_without_change = original_message
    self._description = ("Rephrasing: Your rephrased response should only" +
                         "change the words/sentences in between two asterisks" +
                         "such as *change me*.")
    return self._description

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"original_message": self._reference_without_change}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["original_message"]

  def check_following(self, value):
    r"""Checks if the rephrasing follows the instruction.

    Args:
      value: A string representing the response, which is expected to rephras
        the string of `instruction_args`.

    Returns:
      True if `value` and `instruction_args` only differ by the words/sentences
      in between two asterisks such as *change me*; otherwise, False.
    """

    if not self.is_change(value):
      raise ValueError(f"value {value} does not contain "
                       "changes in the form of *change me*.")

    response_without_changes = self.strip_changes(value)
    reference_without_changes = self.strip_changes(
        self._reference_without_change)

    return response_without_changes == reference_without_changes

  def is_change(self, response):
    """Check if there is change in the response in the form of *change me*."""
    return re.search(r"\*.*\*", response)

  def strip_changes(self, response):
    """Strips off the changes."""
    return re.sub(r"\*.*\*", "", response)
id = instruction_id instance-attribute
build_description(*, original_message)

Build the instruction description.

Parameters:

Name Type Description Default
original_message

A string representing the original message. The rephrased response should only change its words/sentences in between its two asterisks, for example, change me. Both original and rephrased messages should contain the changes in the form of change me.

required

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
def build_description(self, *, original_message):
  """Build the instruction description.

  Args:
    original_message: A string representing the original message. The
      rephrased response should only change its words/sentences in between
      its two asterisks, for example, *change me*. Both original and rephrased
      messages should contain the changes in the form of *change me*.

  Returns:
    A string representing the instruction description.
  """
  if not self.is_change(original_message):
    raise ValueError(f"Message {original_message} does not contain changes "
                     "in the form of *change me*.")

  self._reference_without_change = original_message
  self._description = ("Rephrasing: Your rephrased response should only" +
                       "change the words/sentences in between two asterisks" +
                       "such as *change me*.")
  return self._description
check_following(value)

Checks if the rephrasing follows the instruction.

Parameters:

Name Type Description Default
value

A string representing the response, which is expected to rephras the string of instruction_args.

required

Returns:

Type Description

True if value and instruction_args only differ by the words/sentences

in between two asterisks such as change me; otherwise, False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def check_following(self, value):
  r"""Checks if the rephrasing follows the instruction.

  Args:
    value: A string representing the response, which is expected to rephras
      the string of `instruction_args`.

  Returns:
    True if `value` and `instruction_args` only differ by the words/sentences
    in between two asterisks such as *change me*; otherwise, False.
  """

  if not self.is_change(value):
    raise ValueError(f"value {value} does not contain "
                     "changes in the form of *change me*.")

  response_without_changes = self.strip_changes(value)
  reference_without_changes = self.strip_changes(
      self._reference_without_change)

  return response_without_changes == reference_without_changes
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
665
666
667
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"original_message": self._reference_without_change}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
669
670
671
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["original_message"]
is_change(response)

Check if there is change in the response in the form of change me.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
695
696
697
def is_change(self, response):
  """Check if there is change in the response in the form of *change me*."""
  return re.search(r"\*.*\*", response)
strip_changes(response)

Strips off the changes.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
699
700
701
def strip_changes(self, response):
  """Strips off the changes."""
  return re.sub(r"\*.*\*", "", response)
RephraseParagraph

Bases: Instruction

Checks that the paragraph is rephrased.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
class RephraseParagraph(Instruction):
  """Checks that the paragraph is rephrased."""

  def build_description(self, *, original_paragraph, low, high
                        ):
    """Builds the instruction description.

    Args:
      original_paragraph: A string presenting the original paragraph. The
        rephrases response should have betweeb low-high words in generic.
      low: An integer presenting the lower bound of similar words.
      high: An integer representing the upper bound of similar words.

    Returns:
      A string representing the instruction description.
    """
    # TODO(jeffrey) make more encompassing
    self._original_paragraph = original_paragraph
    self._low = low
    self._high = high

    self._description = ("Rephrase the following paragraph: " +
                         "{original_paragraph}\nYour response should have " +
                         "between {low} and {high} of the same words. " +
                         "Words are the same if and only if all of the " +
                         "letters, ignoring cases, are the same. For " +
                         "example, 'run' is the same as 'Run' but different " +
                         "to 'ran'.")

    return self._description.format(original_paragraph=original_paragraph,
                                    low=self._low, high=self._high)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"original_paragraph": self._original_paragraph,
            "low": self._low,
            "high": self._high}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["original_paragraph", "low", "high"]

  def check_following(self, value):
    val_words = re.findall(r"\w+", value.lower())
    original_words = re.findall(r"\w+", self._original_paragraph.lower())
    similar_words = 0

    dict_val = collections.Counter(val_words)
    dict_original = collections.Counter(original_words)

    for word in dict_original:
      similar_words += min(dict_original[word], dict_val[word])

    return similar_words >= self._low and similar_words <= self._high
id = instruction_id instance-attribute
build_description(*, original_paragraph, low, high)

Builds the instruction description.

Parameters:

Name Type Description Default
original_paragraph

A string presenting the original paragraph. The rephrases response should have betweeb low-high words in generic.

required
low

An integer presenting the lower bound of similar words.

required
high

An integer representing the upper bound of similar words.

required

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
def build_description(self, *, original_paragraph, low, high
                      ):
  """Builds the instruction description.

  Args:
    original_paragraph: A string presenting the original paragraph. The
      rephrases response should have betweeb low-high words in generic.
    low: An integer presenting the lower bound of similar words.
    high: An integer representing the upper bound of similar words.

  Returns:
    A string representing the instruction description.
  """
  # TODO(jeffrey) make more encompassing
  self._original_paragraph = original_paragraph
  self._low = low
  self._high = high

  self._description = ("Rephrase the following paragraph: " +
                       "{original_paragraph}\nYour response should have " +
                       "between {low} and {high} of the same words. " +
                       "Words are the same if and only if all of the " +
                       "letters, ignoring cases, are the same. For " +
                       "example, 'run' is the same as 'Run' but different " +
                       "to 'ran'.")

  return self._description.format(original_paragraph=original_paragraph,
                                  low=self._low, high=self._high)
check_following(value)
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
def check_following(self, value):
  val_words = re.findall(r"\w+", value.lower())
  original_words = re.findall(r"\w+", self._original_paragraph.lower())
  similar_words = 0

  dict_val = collections.Counter(val_words)
  dict_original = collections.Counter(original_words)

  for word in dict_original:
    similar_words += min(dict_original[word], dict_val[word])

  return similar_words >= self._low and similar_words <= self._high
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1148
1149
1150
1151
1152
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"original_paragraph": self._original_paragraph,
          "low": self._low,
          "high": self._high}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1154
1155
1156
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["original_paragraph", "low", "high"]
ResponseLanguageChecker

Bases: Instruction

Check the language of the entire response.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class ResponseLanguageChecker(Instruction):
  """Check the language of the entire response."""

  def build_description(self, *, language = None):
    """Build the instruction description.

    Args:
      language: A string representing the expected language of the response. The
        language has to comply to the 97 types defined in
        `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows
        ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes);
        for example, `en` for English, `zh` for Chinese, `fr` for French.

    Returns:
      A string representing the instruction description.
    """
    self._language = language
    if self._language is None:
      self._language = random.choice(list(_LANGUAGES.keys()))
    # TODO(tianjianlu): opens the description generation to more choices.
    self._description_pattern = (
        "Your ENTIRE response should be in {language} language, no other " +
        "language is allowed.")
    return self._description_pattern.format(language=_LANGUAGES[self._language])

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"language": self._language}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["language"]

  def check_following(self, value):
    """Check if the language of the entire response follows the instruction.

    Args:
      value: A string representing the response.

    Returns:
      True if the language of `value` follows instruction; otherwise False.
    """
    assert isinstance(value, str)

    try:
      return langdetect.detect(value) == self._language
    except langdetect.LangDetectException as e:
      # Count as instruction is followed.
      logging.error(
          "Unable to detect language for text %s due to %s", value, e
      )  # refex: disable=pytotw.037
      return True
id = instruction_id instance-attribute
build_description(*, language=None)

Build the instruction description.

Parameters:

Name Type Description Default
language

A string representing the expected language of the response. The language has to comply to the 97 types defined in langid.py (https://pypi.org/project/langid/1.1.5/), which follows ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); for example, en for English, zh for Chinese, fr for French.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def build_description(self, *, language = None):
  """Build the instruction description.

  Args:
    language: A string representing the expected language of the response. The
      language has to comply to the 97 types defined in
      `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows
      ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes);
      for example, `en` for English, `zh` for Chinese, `fr` for French.

  Returns:
    A string representing the instruction description.
  """
  self._language = language
  if self._language is None:
    self._language = random.choice(list(_LANGUAGES.keys()))
  # TODO(tianjianlu): opens the description generation to more choices.
  self._description_pattern = (
      "Your ENTIRE response should be in {language} language, no other " +
      "language is allowed.")
  return self._description_pattern.format(language=_LANGUAGES[self._language])
check_following(value)

Check if the language of the entire response follows the instruction.

Parameters:

Name Type Description Default
value

A string representing the response.

required

Returns:

Type Description

True if the language of value follows instruction; otherwise False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def check_following(self, value):
  """Check if the language of the entire response follows the instruction.

  Args:
    value: A string representing the response.

  Returns:
    True if the language of `value` follows instruction; otherwise False.
  """
  assert isinstance(value, str)

  try:
    return langdetect.detect(value) == self._language
  except langdetect.LangDetectException as e:
    # Count as instruction is followed.
    logging.error(
        "Unable to detect language for text %s due to %s", value, e
    )  # refex: disable=pytotw.037
    return True
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
139
140
141
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"language": self._language}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
143
144
145
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["language"]
SectionChecker

Bases: Instruction

Checks the sections.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
class SectionChecker(Instruction):
  """Checks the sections."""

  def build_description(self, *, section_spliter = None,
                        num_sections = None):
    """Build the instruction description.

    Args:
      section_spliter: A string represents the section spliter keyword that
        marks a new section, i.e., `Section` or `SECTION`.
      num_sections: An integer specifying the number of sections.

    Returns:
      A string representing the instruction description.
    """
    self._section_spliter = section_spliter.strip() if isinstance(
        section_spliter, str) else section_spliter
    if self._section_spliter is None:
      self._section_spliter = random.choice(_SECTION_SPLITER)

    self._num_sections = num_sections
    if self._num_sections is None or self._num_sections < 0:
      self._num_sections = random.randint(1, _NUM_SECTIONS)

    self._description_pattern = (
        "Your response must have {num_sections} sections. Mark the beginning " +
        "of each section with {section_spliter} X, such as:\n" +
        "{section_spliter} 1\n" +
        "[content of section 1]\n" +
        "{section_spliter} 2\n" +
        "[content of section 2]")

    return self._description_pattern.format(
        num_sections=self._num_sections,
        section_spliter=self._section_spliter)

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return {"section_spliter": self._section_spliter,
            "num_sections": self._num_sections}

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return ["section_spliter", "num_sections"]

  def check_following(self, value):
    """Checks the response contains multiple sections.

    Args:
      value: A string representing the response. The response is expected
        to contain multiple sections (number of sections is greater than 1).
        A new section starts with `Section 1`, where the number denotes the
        section index.

    Returns:
      True if the number of sections in the response is greater than or equal to
      the minimum number of sections; otherwise, False.
    """
    section_splitter_patten = r"\s?" + self._section_spliter  + r"\s?\d+\s?"
    sections = re.split(section_splitter_patten, value)
    num_sections = len(sections) - 1
    return num_sections >= self._num_sections
id = instruction_id instance-attribute
build_description(*, section_spliter=None, num_sections=None)

Build the instruction description.

Parameters:

Name Type Description Default
section_spliter

A string represents the section spliter keyword that marks a new section, i.e., Section or SECTION.

None
num_sections

An integer specifying the number of sections.

None

Returns:

Type Description

A string representing the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def build_description(self, *, section_spliter = None,
                      num_sections = None):
  """Build the instruction description.

  Args:
    section_spliter: A string represents the section spliter keyword that
      marks a new section, i.e., `Section` or `SECTION`.
    num_sections: An integer specifying the number of sections.

  Returns:
    A string representing the instruction description.
  """
  self._section_spliter = section_spliter.strip() if isinstance(
      section_spliter, str) else section_spliter
  if self._section_spliter is None:
    self._section_spliter = random.choice(_SECTION_SPLITER)

  self._num_sections = num_sections
  if self._num_sections is None or self._num_sections < 0:
    self._num_sections = random.randint(1, _NUM_SECTIONS)

  self._description_pattern = (
      "Your response must have {num_sections} sections. Mark the beginning " +
      "of each section with {section_spliter} X, such as:\n" +
      "{section_spliter} 1\n" +
      "[content of section 1]\n" +
      "{section_spliter} 2\n" +
      "[content of section 2]")

  return self._description_pattern.format(
      num_sections=self._num_sections,
      section_spliter=self._section_spliter)
check_following(value)

Checks the response contains multiple sections.

Parameters:

Name Type Description Default
value

A string representing the response. The response is expected to contain multiple sections (number of sections is greater than 1). A new section starts with Section 1, where the number denotes the section index.

required

Returns:

Type Description

True if the number of sections in the response is greater than or equal to

the minimum number of sections; otherwise, False.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
def check_following(self, value):
  """Checks the response contains multiple sections.

  Args:
    value: A string representing the response. The response is expected
      to contain multiple sections (number of sections is greater than 1).
      A new section starts with `Section 1`, where the number denotes the
      section index.

  Returns:
    True if the number of sections in the response is greater than or equal to
    the minimum number of sections; otherwise, False.
  """
  section_splitter_patten = r"\s?" + self._section_spliter  + r"\s?\d+\s?"
  sections = re.split(section_splitter_patten, value)
  num_sections = len(sections) - 1
  return num_sections >= self._num_sections
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
503
504
505
506
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return {"section_spliter": self._section_spliter,
          "num_sections": self._num_sections}
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
508
509
510
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return ["section_spliter", "num_sections"]
TitleChecker

Bases: Instruction

Checks the response for a title.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
class TitleChecker(Instruction):
  """Checks the response for a title."""

  def build_description(self):
    """Build the instruction description."""
    self._description_pattern = (
        "Your answer must contain a title, wrapped in double angular brackets,"
        " such as <<poem of joy>>."
    )
    return self._description_pattern

  def get_instruction_args(self):
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks if the response contains a title."""
    pattern = r"<<[^\n]+>>"
    re_pattern = re.compile(pattern)
    titles = re.findall(re_pattern, value)

    for title in titles:
      if title.lstrip("<").rstrip(">").strip():
        return True
    return False
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1290
1291
1292
1293
1294
1295
1296
def build_description(self):
  """Build the instruction description."""
  self._description_pattern = (
      "Your answer must contain a title, wrapped in double angular brackets,"
      " such as <<poem of joy>>."
  )
  return self._description_pattern
check_following(value)

Checks if the response contains a title.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
def check_following(self, value):
  """Checks if the response contains a title."""
  pattern = r"<<[^\n]+>>"
  re_pattern = re.compile(pattern)
  titles = re.findall(re_pattern, value)

  for title in titles:
    if title.lstrip("<").rstrip(">").strip():
      return True
  return False
get_instruction_args()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1298
1299
def get_instruction_args(self):
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1301
1302
1303
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
TwoResponsesChecker

Bases: Instruction

Check that two responses were given.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
class TwoResponsesChecker(Instruction):
  """Check that two responses were given."""

  def build_description(self):
    """Build the instruction description."""
    self._description_pattern = (
        "Give two different responses. Responses and only responses should"
        " be separated by 6 asterisk symbols: ******."
    )
    return self._description_pattern

  def get_instruction_args(self):
    """Returns the keyward args of `build_description`."""
    return None

  def get_instruction_args_keys(self):
    """Returns the args keys of `build_description`."""
    return []

  def check_following(self, value):
    """Checks if the response has two different answers.

    Args:
      value: A string representing the response.

    Returns:
      True if two responses are detected and false otherwise.
    """
    valid_responses = list()
    responses = value.split("******")
    for index, response in enumerate(responses):
      if not response.strip():
        if index != 0 and index != len(responses) - 1:
          return False
      else:
        valid_responses.append(response)
    return (
        len(valid_responses) == 2
        and valid_responses[0].strip() != valid_responses[1].strip()
    )
id = instruction_id instance-attribute
build_description()

Build the instruction description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1175
1176
1177
1178
1179
1180
1181
def build_description(self):
  """Build the instruction description."""
  self._description_pattern = (
      "Give two different responses. Responses and only responses should"
      " be separated by 6 asterisk symbols: ******."
  )
  return self._description_pattern
check_following(value)

Checks if the response has two different answers.

Parameters:

Name Type Description Default
value

A string representing the response.

required

Returns:

Type Description

True if two responses are detected and false otherwise.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
def check_following(self, value):
  """Checks if the response has two different answers.

  Args:
    value: A string representing the response.

  Returns:
    True if two responses are detected and false otherwise.
  """
  valid_responses = list()
  responses = value.split("******")
  for index, response in enumerate(responses):
    if not response.strip():
      if index != 0 and index != len(responses) - 1:
        return False
    else:
      valid_responses.append(response)
  return (
      len(valid_responses) == 2
      and valid_responses[0].strip() != valid_responses[1].strip()
  )
get_instruction_args()

Returns the keyward args of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1183
1184
1185
def get_instruction_args(self):
  """Returns the keyward args of `build_description`."""
  return None
get_instruction_args_keys()

Returns the args keys of build_description.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions.py
1187
1188
1189
def get_instruction_args_keys(self):
  """Returns the args keys of `build_description`."""
  return []
instructions_registry

Registry of all instructions.

INSTRUCTION_CONFLICTS = {_KEYWORD + 'existence': {_KEYWORD + 'existence'}, _KEYWORD + 'frequency': {_KEYWORD + 'frequency'}, _KEYWORD + 'forbidden_words': {_KEYWORD + 'forbidden_words'}, _KEYWORD + 'letter_frequency': {_KEYWORD + 'letter_frequency'}, _LANGUAGE + 'response_language': {_LANGUAGE + 'response_language', _FORMAT + 'multiple_sections', _KEYWORD + 'existence', _KEYWORD + 'frequency', _KEYWORD + 'forbidden_words', _STARTEND + 'end_checker', _CHANGE_CASES + 'english_capital', _CHANGE_CASES + 'english_lowercase'}, _LENGTH + 'number_sentences': {_LENGTH + 'number_sentences'}, _LENGTH + 'number_paragraphs': {_LENGTH + 'number_paragraphs', _LENGTH + 'nth_paragraph_first_word', _LENGTH + 'number_sentences', _LENGTH + 'nth_paragraph_first_word'}, _LENGTH + 'number_words': {_LENGTH + 'number_words'}, _LENGTH + 'nth_paragraph_first_word': {_LENGTH + 'nth_paragraph_first_word', _LENGTH + 'number_paragraphs'}, _CONTENT + 'number_placeholders': {_CONTENT + 'number_placeholders'}, _CONTENT + 'postscript': {_CONTENT + 'postscript'}, _FORMAT + 'number_bullet_lists': {_FORMAT + 'number_bullet_lists'}, _FORMAT + 'constrained_response': set(INSTRUCTION_DICT.keys()), _FORMAT + 'number_highlighted_sections': {_FORMAT + 'number_highlighted_sections'}, _FORMAT + 'multiple_sections': {_FORMAT + 'multiple_sections', _LANGUAGE + 'response_language', _FORMAT + 'number_highlighted_sections'}, _FORMAT + 'json_format': set(INSTRUCTION_DICT.keys()).difference({_KEYWORD + 'forbidden_words', _KEYWORD + 'existence'}), _FORMAT + 'title': {_FORMAT + 'title'}, _COMBINATION + 'two_responses': set(INSTRUCTION_DICT.keys()).difference({_KEYWORD + 'forbidden_words', _KEYWORD + 'existence', _LANGUAGE + 'response_language', _FORMAT + 'title', _PUNCTUATION + 'no_comma'}), _COMBINATION + 'repeat_prompt': set(INSTRUCTION_DICT.keys()).difference({_KEYWORD + 'existence', _FORMAT + 'title', _PUNCTUATION + 'no_comma'}), _STARTEND + 'end_checker': {_STARTEND + 'end_checker'}, _CHANGE_CASES + 'capital_word_frequency': {_CHANGE_CASES + 'capital_word_frequency', _CHANGE_CASES + 'english_lowercase', _CHANGE_CASES + 'english_capital'}, _CHANGE_CASES + 'english_capital': {_CHANGE_CASES + 'english_capital'}, _CHANGE_CASES + 'english_lowercase': {_CHANGE_CASES + 'english_lowercase', _CHANGE_CASES + 'english_capital'}, _PUNCTUATION + 'no_comma': {_PUNCTUATION + 'no_comma'}, _STARTEND + 'quotation': {_STARTEND + 'quotation', _FORMAT + 'title'}} module-attribute
INSTRUCTION_DICT = {_KEYWORD + 'existence': instructions.KeywordChecker, _KEYWORD + 'frequency': instructions.KeywordFrequencyChecker, _KEYWORD + 'forbidden_words': instructions.ForbiddenWords, _KEYWORD + 'letter_frequency': instructions.LetterFrequencyChecker, _LANGUAGE + 'response_language': instructions.ResponseLanguageChecker, _LENGTH + 'number_sentences': instructions.NumberOfSentences, _LENGTH + 'number_paragraphs': instructions.ParagraphChecker, _LENGTH + 'number_words': instructions.NumberOfWords, _LENGTH + 'nth_paragraph_first_word': instructions.ParagraphFirstWordCheck, _CONTENT + 'number_placeholders': instructions.PlaceholderChecker, _CONTENT + 'postscript': instructions.PostscriptChecker, _FORMAT + 'number_bullet_lists': instructions.BulletListChecker, _FORMAT + 'constrained_response': instructions.ConstrainedResponseChecker, _FORMAT + 'number_highlighted_sections': instructions.HighlightSectionChecker, _FORMAT + 'multiple_sections': instructions.SectionChecker, _FORMAT + 'json_format': instructions.JsonFormat, _FORMAT + 'title': instructions.TitleChecker, _COMBINATION + 'two_responses': instructions.TwoResponsesChecker, _COMBINATION + 'repeat_prompt': instructions.RepeatPromptThenAnswer, _STARTEND + 'end_checker': instructions.EndChecker, _CHANGE_CASES + 'capital_word_frequency': instructions.CapitalWordFrequencyChecker, _CHANGE_CASES + 'english_capital': instructions.CapitalLettersEnglishChecker, _CHANGE_CASES + 'english_lowercase': instructions.LowercaseLettersEnglishChecker, _PUNCTUATION + 'no_comma': instructions.CommaChecker, _STARTEND + 'quotation': instructions.QuotationChecker} module-attribute
conflict_make(conflicts)

Makes sure if A conflicts with B, B will conflict with A.

Parameters:

Name Type Description Default
conflicts

Dictionary of potential conflicts where key is instruction id and value is set of instruction ids that it conflicts with.

required

Returns:

Type Description

Revised version of the dictionary. All instructions conflict with

themselves. If A conflicts with B, B will conflict with A.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_registry.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def conflict_make(conflicts):
  """Makes sure if A conflicts with B, B will conflict with A.

  Args:
    conflicts: Dictionary of potential conflicts where key is instruction id
      and value is set of instruction ids that it conflicts with.

  Returns:
    Revised version of the dictionary. All instructions conflict with
    themselves. If A conflicts with B, B will conflict with A.
  """
  for key in conflicts:
    for k in conflicts[key]:
      conflicts[k].add(key)
    conflicts[key].add(key)
  return conflicts
instructions_test

Tests for instructions.py.

InstructionsTest

Bases: TestCase

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
class InstructionsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      [
          {
              'testcase_name': (
                  f'_response={response}_language={language}'
              ),
              'response': response,
              'language': language,
          }
          for response, language in [('The response is English', 'en')]
      ]
  )
  def test_response_language(self, response, language):
    """Test on single language response."""
    instruction_id = 'language:response_language'
    instruction = instructions.ResponseLanguageChecker(instruction_id)
    instruction.build_description(language=language)
    self.assertTrue(instruction.check_following(response))

  @parameterized.named_parameters(
      [
          {
              'testcase_name': (
                  f'_response={response}_language={language}'
              ),
              'response': response,
              'language': language,
          }
          for response, language in [("Desayunamos en McDonald's hoy", 'es'),
                                     ('Today we visit the Louvre', 'en'),]
      ]
  )
  def test_response_multilanguage(self, response, language):
    """Test on responses that contain multi-language tokens."""
    instruction_id = 'language:response_language'
    instruction = instructions.ResponseLanguageChecker(instruction_id)
    instruction.build_description(language=language)
    self.assertTrue(instruction.check_following(response))

  @parameterized.named_parameters(
      [
          {
              'testcase_name': (
                  f'_response={response}_relation={relation}'
                  f'_num_sentences={num_sentences}_expected={expected}'
              ),
              'response': response,
              'relation': relation,
              'num_sentences': num_sentences,
              'expected': expected,
          }
          for response, relation, num_sentences, expected in [
              ('xx,x. xx,x! xx/x. x{x}x?', instructions._COMPARISON_RELATION[0],
               4, False),
              ('xxxx. xx,x! xxxx. x(x)x?', instructions._COMPARISON_RELATION[0],
               5, True),
              ('xxxx. xx,x! xx|x. x&x x?', instructions._COMPARISON_RELATION[1],
               4, True),
              ('xx-x. xx,x! xx}x. x,xx?', instructions._COMPARISON_RELATION[1],
               5, False),
          ]
      ]
  )
  def test_number_sentences(self, response, relation, num_sentences, expected):
    """Test the number of sentences."""
    instruction_id = 'length_constraints:number_sentences'
    instruction = instructions.NumberOfSentences(instruction_id)
    instruction.build_description(relation=relation,
                                  num_sentences=num_sentences)
    actual = instruction.check_following(response)
    self.assertEqual(actual, expected)

  @parameterized.named_parameters(
      [
          {
              'testcase_name': (
                  f'_templated={template}_num_placeholders={num_placeholders}'
                  f'_expected={expected}'
              ),
              'template': template,
              'num_placeholders': num_placeholders,
              'expected': expected,
          }
          for template, num_placeholders, expected in [
              (('Sure, here is a short template with 5 placeholders:\n' +
                '[Name]\n[Email]\n[Phone]\n[Address]\n[Website]\n' +
                'This template can be used for a variety of purposes, such ' +
                'ascreating a contact list, sending out surveys, or creating ' +
                'a sign-up form.'), 5, True),
              (('My [adjective] [noun] is [adjective] [noun]. I [verb] and ' +
                '[verb].'), 7, False),
              ]
      ]
  )
  def test_number_placeholders(self, template, num_placeholders, expected):
    """Test the number of placeholders."""
    instruction_id = 'detectable_content:number_placeholders'
    instruction = instructions.PlaceholderChecker(instruction_id)
    instruction.build_description(num_placeholders=num_placeholders)
    actual = instruction.check_following(template)
    self.assertEqual(actual, expected)

  BULLET_TEST_MESSAGE_1 = """
  A Markdown bullet point is a way of formatting text to create a list. To
  create a bullet point, start each line with an asterisk (*). For example:
  * This is a bullet point.
  *(no space required)Another bullet point.
  * (no newline ending required)Another bullet point.
  markdown bullet points are often used to create to-do lists or to list items
  in a step-by-step guide."""
  BULLET_TEST_MESSAGE_2 = """
  Check that inline asterisk (*), *, will not be counted. Only * that starts a
  bullet list will be counted:
    * This is a bullet point.
    * Another bullet point.
    . dot is not counted"""
  BULLET_TEST_MESSAGE_3 = """
  Here are three bullets starting with asterisk:
  * I am a large language model, also known as a conversational AI.
  * I am trained on a massive amount of text data, and I am able to communicate.
  * I am still under development, but I am learning new things every day."""

  BULLET_TEST_MESSAGE_4 = """
  Here are three markdown bullets:
  - I am a large language model, also known as a conversational AI.
  - I am trained on a massive amount of text data, and I am able to communicate.
  -I am still under development, but I am learning new things every day."""

  BULLET_TEST_MESSAGE_5 = """
  Paragraph 1
  ***
  Paragraph 2
  ***
  Paragraph 3
  * only one bullet point
  """

  @parameterized.named_parameters(
      [
          {
              'testcase_name': (
                  f'_templated={template}_num_bullets={num_bullets}'
                  f'_expected={expected}'
              ),
              'template': template,
              'num_bullets': num_bullets,
              'expected': expected,
          }
          for template, num_bullets, expected in [
              (BULLET_TEST_MESSAGE_1, 3, True),
              (BULLET_TEST_MESSAGE_2, 2, True),
              (BULLET_TEST_MESSAGE_3, 3, True),
              (BULLET_TEST_MESSAGE_4, 3, True),
              (BULLET_TEST_MESSAGE_5, 1, True)]
      ]
  )
  def test_number_bullet_lists(self, template, num_bullets, expected):
    """Test the number of bullets."""
    instruction_id = 'detectable_format:exact_number_bullet_points'
    instruction = instructions.BulletListChecker(instruction_id)
    instruction.build_description(num_bullets=num_bullets)
    actual = instruction.check_following(template)
    self.assertEqual(actual, expected)

  CONSTRAINED_RESPONSE_TEST_RESPONSE_1 = """\n My answer is no.\n"""
  CONSTRAINED_RESPONSE_TEST_RESPONSE_2 = """My answer is no.   """
  CONSTRAINED_RESPONSE_TEST_RESPONSE_3 = """
  My answer is no. I am still under development and I am always learning and
  improving. I am not the best chatbot in the world, but I am striving to be
  the best that I can be."""

  def test_constrained_response(self):
    """Test the constrained response checker."""
    instruction_id = 'detectable_format:constrained_response'
    instruction = instructions.ConstrainedResponseChecker(instruction_id)
    instruction.build_description()

    with self.subTest('test with CONSTRAINED_RESPONSE_TEST_RESPONSE_1'):
      self.assertTrue(instruction.check_following(
          self.CONSTRAINED_RESPONSE_TEST_RESPONSE_1))

    with self.subTest('test with CONSTRAINED_RESPONSE_TEST_RESPONSE_2'):
      self.assertTrue(instruction.check_following(
          self.CONSTRAINED_RESPONSE_TEST_RESPONSE_2))

    with self.subTest('test with CONSTRAINED_RESPONSE_TEST_RESPONSE_3'):
      self.assertTrue(instruction.check_following(
          self.CONSTRAINED_RESPONSE_TEST_RESPONSE_3))

  HIGHLIGHTED_TEST_MESSAGE_1 = """
  To highlight text with Markdown, you can use the * character before and after
  the text you want to highlight. For example, if you want to highlight the
  word `hello`, you would type:*hello*, You can also use the ** character to
  create bold text. For example, if you want to bold the word `hello`, you
  would type: **hello** """
  HIGHLIGHTED_TEST_MESSAGE_2 = """
  Sure, here are the numerical methods for solving partial differential
  equations highlighted with Markdown:
  *Finite difference methods
  *Finite element methods*
  *Boundary element methods
  *Monte Carlo methods
  I hope this helps!"""
  HIGHLIGHTED_TEST_MESSAGE_3 = """
  There is allowed to be *two different* highlighted *sections in the same*
  line. **This is also true** for **double markdown highlights.**
  """

  @parameterized.named_parameters(
      [
          {
              'testcase_name': (
                  f'_response={response}'
                  f'_min_num_highlights={min_num_highlights}'
                  f'_expected={expected}'
              ),
              'response': response,
              'min_num_highlights': min_num_highlights,
              'expected': expected,
          }
          for response, min_num_highlights, expected in [
              (HIGHLIGHTED_TEST_MESSAGE_1, 2, True),
              (HIGHLIGHTED_TEST_MESSAGE_2, 2, False),
              (HIGHLIGHTED_TEST_MESSAGE_3, 4, True)]
      ]
  )
  def test_number_highlights(self, response, min_num_highlights, expected):
    """Test the minimum number of highlighted sections."""
    instruction_id = 'detectable_format:minimum_number_highlighted_sections'
    instruction = instructions.HighlightSectionChecker(instruction_id)
    instruction.build_description(num_highlights=min_num_highlights)
    actual = instruction.check_following(response)
    self.assertEqual(actual, expected)

  SECTION_TEST_MESSAGE_1 = """
  Your response must have multiple sections. Mark the beginning of each section
  with "Section X", such as:
  Section 1
  [content of section 1]
  Section 2
  [content of section 2]"""

  SECTION_TEST_MESSAGE_2 = """SECTION 1
  [content of section 1]
  SECTION 2
  [content of section 2]"""

  def test_section_checker(self):
    """Test the number of sections."""
    instruction_id = 'detectable_format:multiple_sections'
    instruction = instructions.SectionChecker(instruction_id)
    section_keyword = 'Section'
    min_num_sections = 3
    instruction.build_description(section_spliter=section_keyword,
                                  num_sections=min_num_sections)
    with self.subTest(f'test {section_keyword} and {min_num_sections}'):
      self.assertFalse(
          instruction.check_following(self.SECTION_TEST_MESSAGE_1))

    section_keyword = 'SECTION'
    min_num_sections = 2
    instruction.build_description(section_spliter=section_keyword,
                                  num_sections=min_num_sections)
    with self.subTest(f'test {section_keyword} and {min_num_sections}'):
      self.assertTrue(
          instruction.check_following(self.SECTION_TEST_MESSAGE_2))

  PARAGRAPH_TEST_MESSAGE_1 = """
  paragraph 1
  ***
  paragraph 2
  ***
  paragraph 3"""

  PARAGRAPH_TEST_MESSAGE_2 = """
          ***
  paragraph 1
          ***
      paragraph 2
          ***
      paragraph 3"""

  PARAGRAPH_TEST_MESSAGE_3 = """
  paragraph 1
          ***
      paragraph 2
          ***
      paragraph 3
          ***"""

  PARAGRAPH_TEST_MESSAGE_4 = """
  paragraph 1
          ***
      paragraph 2
          ***
          ***"""

  def test_paragraph_checker(self):
    """Test the number of sections."""
    instruction_id = 'length_constraint:number_paragraphs'
    instruction = instructions.ParagraphChecker(instruction_id)
    num_paragraphs = 3
    instruction.build_description(num_paragraphs=num_paragraphs)
    with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_1} and '
                      f'{num_paragraphs} paragraphs'):
      self.assertTrue(instruction.check_following(
          self.PARAGRAPH_TEST_MESSAGE_1))

    num_paragraphs = 3
    instruction.build_description(num_paragraphs=num_paragraphs)
    with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_2} and '
                      f'{num_paragraphs} paragraphs'):
      self.assertTrue(instruction.check_following(
          self.PARAGRAPH_TEST_MESSAGE_2))

    num_paragraphs = 3
    instruction.build_description(num_paragraphs=num_paragraphs)
    with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_3} and '
                      f'{num_paragraphs} paragraphs'):
      self.assertTrue(instruction.check_following(
          self.PARAGRAPH_TEST_MESSAGE_3))

    num_paragraphs = 2
    instruction.build_description(num_paragraphs=num_paragraphs)
    with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_4} and '
                      f'{num_paragraphs} paragraphs'):
      self.assertFalse(instruction.check_following(
          self.PARAGRAPH_TEST_MESSAGE_4))

  POSTSCRIPT_TEST_MESSAGE_1 = """
  I will do my best to follow your instructions and always start my responses
  with "My response is:". I will try to be as consistent as possible, but
  please be patient with me if I make a mistake. I am still under development,
  and I am always learning new things.

  P.S. I hope this is what you were looking for."""

  POSTSCRIPT_TEST_MESSAGE_2 = """
  Sure, here is my response with a postscript starting with P.P.S.:

  My response is: I hope this answers your question.

  P.P.S. I am always happy to answer any other questions you may have.

  Do you have any other questions for me?"""

  # Postscript does not have to start as a new line.
  # Relaxed the constraint in cl/525253841.
  POSTSCRIPT_TEST_MESSAGE_3 = """
  The radius of a unit circle is 1. However, I can give you a funny and wrong
  answer: the radius of a unit circle is 0. This is because a unit circle is a
  circle with a radius of 1, and if the radius is 0, then the circle has no
  size and is just a point. (not starting a new line) P.S. I hope you enjoyed
  my joke!"""

  POSTSCRIPT_TEST_MESSAGE_4 = """
  If the length of a square is one, the area of the square will also be one.
  p.p.s what if the entire response was lower case letters?
  """

  POSTSCRIPT_TEST_MESSAGE_5 = """
  The mysteries of space and time are mysterious.
  P. S. Sometimes there are even spaces between P. and S..
  """

  def test_postscript_checker(self):
    """Test the postscript checker."""
    instruction_id = 'detectable_content:postscript'
    instruction = instructions.PostscriptChecker(instruction_id)
    postscript_start_keyword = instructions._POSTSCRIPT_MARKER[0]
    instruction.build_description(postscript_marker=postscript_start_keyword)
    with self.subTest(f'test {postscript_start_keyword}'):
      self.assertTrue(
          instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_1))

    postscript_start_keyword = 'PS:'
    instruction.build_description(postscript_marker=postscript_start_keyword)
    with self.subTest(f'test {postscript_start_keyword}'):
      self.assertFalse(
          instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_1))

    postscript_start_keyword = instructions._POSTSCRIPT_MARKER[1]
    instruction.build_description(postscript_marker=postscript_start_keyword)
    with self.subTest(f'test {postscript_start_keyword}'):
      self.assertTrue(
          instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_2))

    postscript_start_keyword = 'P.S.'
    instruction.build_description(postscript_marker=postscript_start_keyword)
    with self.subTest(f'test {postscript_start_keyword}'):
      self.assertTrue(
          instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_3))

    postscript_start_keyword = 'P.P.S'
    instruction.build_description(postscript_marker=postscript_start_keyword)
    with self.subTest(f'test {postscript_start_keyword}'):
      self.assertTrue(
          instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_4))

    postscript_start_keyword = 'P.S.'
    instruction.build_description(postscript_marker=postscript_start_keyword)
    with self.subTest(f'test {postscript_start_keyword}'):
      self.assertTrue(
          instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_5))

  CONSTRAINED_START_TEST_MESSAGE_1 = """
  My response is: ASIC is a specialized chip for specific tasks in electronic
  devices, offering advantages in efficiency and processing speed."""

  CONSTRAINED_START_TEST_MESSAGE_2 = """
        My response is: ASIC is a specialized chip for specific tasks in
  electronic
  devices, offering advantages in efficiency and processing speed."""

  CONSTRAINED_START_TEST_MESSAGE_3 = """
  An ASIC, or Application-Specific Integrated Circuit, is a type of specialized
  chip that, my response is, is designed to perform specific tasks in electronic
  devices."""

  def test_constrained_start_checker(self):
    """Test the constrained start checker."""
    instruction_id = 'multi-turn:constrained_start'
    instruction = instructions.ConstrainedStartChecker(instruction_id)
    start_keyword = 'My response is:'
    instruction.build_description(starter=start_keyword)
    with self.subTest(f'test {start_keyword}'):
      self.assertTrue(
          instruction.check_following(self.CONSTRAINED_START_TEST_MESSAGE_1))

    with self.subTest(f'test {start_keyword} with spaces in the beginning'):
      self.assertTrue(instruction.check_following(
          self.CONSTRAINED_START_TEST_MESSAGE_2))

    start_keyword = 'my response is'
    with self.subTest(f'test {start_keyword} embedded in the middle'):
      self.assertFalse(
          instruction.check_following(self.CONSTRAINED_START_TEST_MESSAGE_3))

  REPHRASE_TEST_REPHRASED_MESSAGE_1 = """
  I am *content*."""
  REPHRASE_TEST_ORIGINAL_MESSAGE_1 = """
  I am *happy*."""

  REPHRASE_TEST_REPHRASED_MESSAGE_1_NOCHANGE = """
  I am ."""

  REPHRASE_TEST_REPHRASED_MESSAGE_1_FORMAT = """
  I am [content]."""

  REPHRASE_TEST_REPHRASED_MESSAGE_2 = """
  It is raining heavily *at this moment*."""
  REPHRASE_TEST_ORIGINAL_MESSAGE_2 = """
  *At present,* there is heavy rainfall occurring."""

  def test_rephrase_checker(self):
    """Test the rephrase checker."""
    instruction_id = 'detectable_format:rephrasing'
    instruction = instructions.RephraseChecker(instruction_id)
    instruction.build_description(
        original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_1)
    with self.subTest(f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_1}'):
      self.assertTrue(
          instruction.check_following(self.REPHRASE_TEST_REPHRASED_MESSAGE_1))

    instruction.build_description(
        original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_1)
    with self.subTest(
        f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_1_NOCHANGE}'):
      with self.assertRaises(ValueError):
        instruction.check_following(
            self.REPHRASE_TEST_REPHRASED_MESSAGE_1_NOCHANGE)

    instruction.build_description(
        original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_1)
    with self.subTest(f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_1_FORMAT}'):
      with self.assertRaises(ValueError):
        instruction.check_following(
            self.REPHRASE_TEST_REPHRASED_MESSAGE_1_FORMAT)

    instruction.build_description(
        original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_2)
    with self.subTest(f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_2}'):
      self.assertFalse(
          instruction.check_following(self.REPHRASE_TEST_REPHRASED_MESSAGE_2))

  TEST_INCLUDE_KEYWORD_MESSAGE_1 = """
  Paris is a city of beauty and romance. The romantic river Seine winds its way
  through the city, past iconic landmarks like the Eiffel Tower and the Louvre
  Museum, where the Mona Lisa resides. Whether you're taking a boat cruise down
  the river or simply strolling along the banks, you're sure to be captivated
  by the city's charm."""

  TEST_INCLUDE_KEYWORD_MESSAGE_2 = """
  Paris is a city of beauty, romance, and history. It is home to some of the
  most iconic landmarks in the world, including the Eiffel Tower, the Louvre
  Museum, and the Notre Dame Cathedral. The city is also known for its romantic
  river cruises, its delicious food, and its stylish people.
  """

  KEYWORDS = ('romantic', 'river', 'Mona Lisa')

  def test_keyword_checker(self):
    """Test the inclusion of keywords."""
    instruction_id = 'keywords:include_keywords'
    instruction = instructions.KeywordChecker(instruction_id)

    instruction.build_description(keywords=self.KEYWORDS)
    with self.subTest(f'test {self.TEST_INCLUDE_KEYWORD_MESSAGE_1}'):
      self.assertTrue(
          instruction.check_following(self.TEST_INCLUDE_KEYWORD_MESSAGE_1))

    instruction.build_description(keywords=self.KEYWORDS)
    with self.subTest(f'test {self.TEST_INCLUDE_KEYWORD_MESSAGE_2}'):
      self.assertFalse(
          instruction.check_following(self.TEST_INCLUDE_KEYWORD_MESSAGE_2))

  TEST_KEYWORD_FREQUNECY_MESSAGE_1 = """
  keyword, Keyword, KEYWORD
  """
  TEST_KEYWORD_FREQUENCY_KEYWORD_1 = '  keyword '

  TEST_KEYWORD_FREQUNECY_MESSAGE_2 = """
    *keyword
    *Keyword
    *KEYWORD
  """
  TEST_KEYWORD_FREQUENCY_KEYWORD_2 = 'KEYWORD'

  def test_keyword_frequency_checker(self):
    """Test the frequency of keywords."""

    instruction_id = 'keywords:keyword_frequency'
    instruction = instructions.KeywordFrequencyChecker(instruction_id)

    frequency = 4
    instruction.build_description(keyword=self.TEST_KEYWORD_FREQUENCY_KEYWORD_1,
                                  frequency=frequency,
                                  relation=instructions._COMPARISON_RELATION[0])
    with self.subTest(
        f'test {self.TEST_KEYWORD_FREQUENCY_KEYWORD_1} {frequency}'):
      self.assertTrue(
          instruction.check_following(self.TEST_KEYWORD_FREQUNECY_MESSAGE_1))

    frequency = 3
    instruction.build_description(keyword=self.TEST_KEYWORD_FREQUENCY_KEYWORD_1,
                                  frequency=frequency,
                                  relation=instructions._COMPARISON_RELATION[1])
    with self.subTest(
        f'test {self.TEST_KEYWORD_FREQUENCY_KEYWORD_1} {frequency}'):
      self.assertTrue(
          instruction.check_following(self.TEST_KEYWORD_FREQUNECY_MESSAGE_1))

    frequency = 4
    instruction.build_description(keyword=self.TEST_KEYWORD_FREQUENCY_KEYWORD_2,
                                  frequency=frequency,
                                  relation=instructions._COMPARISON_RELATION[1])
    with self.subTest(
        f'test {self.TEST_KEYWORD_FREQUENCY_KEYWORD_2} {frequency}'):
      self.assertFalse(
          instruction.check_following(self.TEST_KEYWORD_FREQUNECY_MESSAGE_2))

  TEST_NUM_WORDS_MESSAGE_1 = """
  d3sCRi7 lArge lAnguagE M0del w1tH 20 w0RdS."""

  TEST_NUM_WORDS_MESSAGE_2 = """
  L4RGE L4NGU4GE M0DEL: AI syst3m th4t und3rstands, g3n3r4tes, or tr4nsforms
  l4ngu4g3 b4s3d on pr3vious l3arning & d4t4."""

  def test_num_words_checker(self):
    """Test the checker on the number of words."""
    instruction_id = 'length_constraint:number_words'
    instruction = instructions.NumberOfWords(instruction_id)

    word_counts = 8
    instruction.build_description(num_words=word_counts,
                                  relation=instructions._COMPARISON_RELATION[0])
    with self.subTest(
        f'test {self.TEST_NUM_WORDS_MESSAGE_1} {word_counts}'):
      self.assertTrue(
          instruction.check_following(self.TEST_NUM_WORDS_MESSAGE_1))

    word_counts = 16
    instruction.build_description(num_words=word_counts,
                                  relation=instructions._COMPARISON_RELATION[0])
    with self.subTest(
        f'test {self.TEST_NUM_WORDS_MESSAGE_2} less than {word_counts}'):
      self.assertFalse(
          instruction.check_following(self.TEST_NUM_WORDS_MESSAGE_2))

    word_counts = 16
    instruction.build_description(num_words=word_counts,
                                  relation=instructions._COMPARISON_RELATION[1])
    with self.subTest(
        f'test {self.TEST_NUM_WORDS_MESSAGE_2} at least {word_counts}'):
      self.assertTrue(
          instruction.check_following(self.TEST_NUM_WORDS_MESSAGE_2))

  PARAGRAPH_FIRST_WORD_TEST_1 = """
  paragraph 1

  I paragraph 2

  paragraph 3"""

  PARAGRAPH_FIRST_WORD_TEST_2 = """
  paragraph 1

  I paragraph 2"""

  PARAGRAPH_FIRST_WORD_TEST_3 = """
  paragraph 1

  fail paragraph 2

  paragraph 3"""

  PARAGRAPH_FIRST_WORD_TEST_4 = """
  Wow this is a very long response.

  I can't believe there is more than three paragraphs.

  Really more than three? No way!

  I can't believe it but I guess I am living proof.

  Haha, you go that right."""

  PARAGRAPH_FIRST_WORD_TEST_5 = """
  Wow this is a very long response.

  I can't believe there is more than three paragraphs.

  "Really?! more than three? No way!"

  I can't believe it but I guess I am living proof.

  Haha, you go that right."""

  PARAGRAPH_FIRST_WORD_TEST_6 = """
  Wow this is a very long response.

  I can't believe there is more than three paragraphs.

  Rea!lly more than three? No way!

  I can't believe it but I guess I am living proof.

  Haha, you go that right."""

  def test_paragraph_first_word(self):
    """Test number of paragraphs and first word of nth paragraph."""
    instruction_id = 'length_constraints:nth_paragraph_first_word'
    instruction = instructions.ParagraphFirstWordCheck(instruction_id)
    tests = [
        self.PARAGRAPH_FIRST_WORD_TEST_1,
        self.PARAGRAPH_FIRST_WORD_TEST_2,
        self.PARAGRAPH_FIRST_WORD_TEST_3,
        self.PARAGRAPH_FIRST_WORD_TEST_4,
        self.PARAGRAPH_FIRST_WORD_TEST_5,
        self.PARAGRAPH_FIRST_WORD_TEST_6,
    ]

    for test in tests:
      if (test == self.PARAGRAPH_FIRST_WORD_TEST_1
          or test == self.PARAGRAPH_FIRST_WORD_TEST_2
          or test == self.PARAGRAPH_FIRST_WORD_TEST_3):
        num_paragraphs = 3
        nth_paragraph = 2
        first_word = 'I'
      elif test == self.PARAGRAPH_FIRST_WORD_TEST_4:
        num_paragraphs = 5
        nth_paragraph = 5
        first_word = 'haha'
      else:
        num_paragraphs = 5
        nth_paragraph = 3
        first_word = 'Really'

      instruction.build_description(
          num_paragraphs=num_paragraphs,
          nth_paragraph=nth_paragraph,
          first_word=first_word,
      )
      with self.subTest(
          f'test {test} \n. Test for '
          f'{num_paragraphs} paragraphs and '
          f'for paragraph {nth_paragraph} '
          f'{first_word} is first word'
      ):
        if (test == self.PARAGRAPH_FIRST_WORD_TEST_1
            or test == self.PARAGRAPH_FIRST_WORD_TEST_4
            or test == self.PARAGRAPH_FIRST_WORD_TEST_5):
          self.assertTrue(instruction.check_following(test))
        else:
          self.assertFalse(instruction.check_following(test))

  TEST_KEY_SENTENCES_1 = """
  Puppies are fun. They are playful, energetic, and always up for a good time.
Puppies love to run, jump, and play fetch. They are also very good at
cuddling and giving kisses. If you are looking for a fun and loving pet,
a puppy is a great choice.
  """

  TEST_KEY_SENTENCES_2 = """
  I like to eat candy. When I'm feeling happy, sad, or even angry, candy
always makes me feel better. I like to share candy with my friends and
family. It's a great way to show them how much I care.
  """

  TEST_KEY_SENTENCES_3 = """
I know that candy isn't the healthiest thing to eat, but I don't care.
I love it too much. I'll just have to make sure to eat it in moderation.
  """

  key_sentences = {'Puppies love to run, jump, and play fetch.',
                   'I like to eat candy.', 'Puppies are fun.'}

  def test_key_sentences(self):
    """Test the inclusion of key sentences."""
    instruction_id = 'keywords:key_sentences'
    instruction = instructions.KeySentenceChecker(instruction_id)

    num_sentences = 2
    instruction.build_description(
        key_sentences=self.key_sentences, num_sentences=num_sentences)

    with self.subTest(f'test {self.TEST_KEY_SENTENCES_1}'):
      self.assertTrue(instruction.check_following(self.TEST_KEY_SENTENCES_1))

    num_sentences = 1
    instruction.build_description(
        key_sentences=self.key_sentences, num_sentences=num_sentences)

    with self.subTest(f'test {self.TEST_KEY_SENTENCES_2}'):
      self.assertTrue(instruction.check_following(self.TEST_KEY_SENTENCES_2))

    with self.subTest(f'test {self.TEST_KEY_SENTENCES_3}'):
      self.assertFalse(instruction.check_following(self.TEST_KEY_SENTENCES_3))

  TEST_FORBIDDEN_WORDS_MESSAGE_1 = """
  The Nazis came to power in 1933 through a combination of legal and illegal
  means. Hitler was appointed chancellor by President Paul von Hindenburg, and
  the Nazis quickly consolidated their power by passing a series of laws that
  restricted the rights of opposition parties and individuals. By 1934, Hitler
  had become dictator of Germany.
  """

  TEST_FORBIDDEN_WORDS_MESSAGE_2 = """
  Dinosaurs were a diverse group of reptiles that dominated the Earth for over
  160 million years. They came in all shapes and sizes, from the tiny
  Compsognathus to the massive Argentinosaurus. Dinosaurs were the most
  successful land animals on Earth until they went extinct about 66 million
  years ago. The exact cause of their extinction is still unknown, but it
  is thought to have been a combination of factors, including an asteroid
  impact and climate change.
  """

  TEST_FORBIDDEN_WORDS_MESSAGE_3 = """
  GPT, or Generative Pre-trained Transformer, is a family of neural network
  models that uses the transformer architecture. GPT models are trained on a
  massive dataset of text and code, and can be used for a variety of tasks,
  including text generation, translation, and question answering. GPT models
  have been shown to be very effective at these tasks, and are being used by
  a variety of companies and organizations like Google.
  """
  FORBIDDEN_WORDS_1 = ('HOUSE', 'POWER', 'BECOME')
  FORBIDDEN_WORDS_2 = ('GOOGLE', 'TEXT')
  FORBIDDEN_WORDS_3 = ('GENE', 'TRANSFORM')

  def test_forbidden_words(self):
    """Test the exclusion of key words."""
    instruction_id = 'keywords:forbidden_words'
    instruction = instructions.ForbiddenWords(instruction_id)

    instruction.build_description(forbidden_words=self.FORBIDDEN_WORDS_1)
    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_1}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_1}. '):
      self.assertFalse(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_1))

    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_2}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_1}. '):
      self.assertTrue(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_2))

    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_3}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_1}. '):
      self.assertTrue(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_3))

    instruction.build_description(forbidden_words=self.FORBIDDEN_WORDS_2)
    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_1}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
      self.assertTrue(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_1))

    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_2}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
      self.assertTrue(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_2))

    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_3}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
      self.assertFalse(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_3))

    instruction.build_description(forbidden_words=self.FORBIDDEN_WORDS_3)
    with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_3}\n ' +
                      f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
      self.assertTrue(
          instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_3))

  TEST_ORIGINAL_PARAGRAPH_1 = """
  The sun is shining brightly today, and the birds are singing in the trees.
  It's a beautiful day to be outside, so I decided to go for a walk.
  As I walked, I took in the fresh air and the warm sunshine.
  I felt happy and relaxed, and I was grateful for the beautiful day
  """

  TEST_ORIGINAL_PARAGRAPH_2 = """
  Google is a global technology company that specializes in Internet-related
  services and products. It is one of the most successful companies in the
  world, and its products are used by billions of people every day. Google's
  mission is to organize the world's information and make it universally
  accessible and useful.
  """

  TEST_REPHRASED_PARAGRAPH_1 = """
  On a beautiful day, I went for a walk. The sun shone and birds sang.
  I enjoyed the fresh air and warm sun.
  I felt happy and grateful for the lovely day.
  """

  TEST_REPHRASED_PARAGRAPH_2 = """
  The weather was lovely, so I went for a walk. I enjoyed the
  fresh air and warm sun. It was a beautiful day, and I felt happy and grateful.
  """

  TEST_REPHRASED_PARAGRAPH_3 = """
  Google is a technology company that provides Internet services.
  It aims to organize the world's information and make it universally
  accessible and useful.
  """

  TEST_REPHRASED_PARAGRAPH_4 = """
  I like candy.
  """

  def test_rephrase_paragraph(self):
    """Test the rephrasing of paragraph."""
    instruction_id = 'detectable_content:rephrase_paragraph'
    instruction = instructions.RephraseParagraph(instruction_id)
    low, high = 20, 30
    instruction.build_description(
        low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_1)

    with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_1} to ' +
                      f'have between {low} and {high} same words.'):
      self.assertTrue(
          instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_1))

    low, high = 20, 25
    instruction.build_description(
        low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_1)

    with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_1} to ' +
                      f'have between {low} and {high} same words.'):
      self.assertTrue(
          instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_2))

    low, high = 15, 20
    instruction.build_description(
        low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_2)

    with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_2} to ' +
                      f'have between {low} and {high} same words.'):
      self.assertFalse(
          instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_3))

    low, high = 0, 5
    instruction.build_description(
        low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_2)

    with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_2} to ' +
                      f'have between {low} and {high} same words.'):
      self.assertTrue(
          instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_4))

    low, high = 1, 5
    instruction.build_description(
        low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_2)

    with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_2} to ' +
                      f'have between {low} and {high} same words.'):
      self.assertFalse(
          instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_4))

  TEST_TWO_RESPONSES_1 = """
  This is response 1.
  ******
  This is response 2.
  """

  TEST_TWO_RESPONSES_2 = """
  This is response 1.
  ******
  This is response 1.
  """

  TEST_TWO_RESPONSES_3 = """
  This is response 1.
  ******
  This is response 2.
  ******
  This is response 3.
  """

  TEST_TWO_RESPONSES_4 = """
  ******
  Response 1.
  ******
  ******
  Response 2.
  ******
  """

  TEST_TWO_RESPONSES_5 = """
  ******
  Response 1
  ******
  Response 2
  ******
  """

  def test_two_responses(self):
    """Test that two responses are given."""
    instruction_id = 'combination:two_responses'
    instruction = instructions.TwoResponsesChecker(instruction_id)
    instruction.build_description()

    with self.subTest(f'test {self.TEST_TWO_RESPONSES_1}'):
      self.assertTrue(instruction.check_following(self.TEST_TWO_RESPONSES_1))

    with self.subTest(f'test {self.TEST_TWO_RESPONSES_2}'):
      self.assertFalse(instruction.check_following(self.TEST_TWO_RESPONSES_2))

    with self.subTest(f'test {self.TEST_TWO_RESPONSES_3}'):
      self.assertFalse(instruction.check_following(self.TEST_TWO_RESPONSES_3))

    with self.subTest(f'test {self.TEST_TWO_RESPONSES_4}'):
      self.assertFalse(instruction.check_following(self.TEST_TWO_RESPONSES_4))

    with self.subTest(f'test {self.TEST_TWO_RESPONSES_5}'):
      self.assertTrue(instruction.check_following(self.TEST_TWO_RESPONSES_5))

  PROMPT_TO_REPEAT = 'Write a CL description.'

  TEST_PROMPT_1 = """Write a CL description. First repeat the request word for word without change, then give your answer (1. do not say any words or characters before repeating the request; 2. the request you need to repeat does not include this sentence)"""

  TEST_PROMPT_ANSWER_1 = """Write a CL description. Hi, Le and TJ, please
  check this out. Thanks.
  """
  TEST_PROMPT_ANSWER_2 = """Hi, Le and TJ. Write a CL description. Thanks.
  """

  def test_prompt_repeat_answer(self):
    """Test that prompt is repeated then anwered."""
    instruction_id = 'combination:repeat_prompt'
    instruction = instructions.RepeatPromptThenAnswer(instruction_id)

    instruction.build_description(prompt_to_repeat=self.PROMPT_TO_REPEAT)
    with self.subTest(f'test {self.TEST_PROMPT_ANSWER_1}' +
                      f' with prompt: {self.TEST_PROMPT_1}'):
      self.assertTrue(instruction.check_following(self.TEST_PROMPT_ANSWER_1))

    with self.subTest(f'test {self.TEST_PROMPT_ANSWER_2}' +
                      f' with prompt: {self.TEST_PROMPT_1}'):
      self.assertFalse(instruction.check_following(self.TEST_PROMPT_ANSWER_2))

  TEST_END_CHECKER_1 = """
  The answer is 7. Any more questions?
  """

  TEST_END_CHECKER_2 = """
  At the end of this prompt I am required to say that this is the end.
  """

  TEST_END_CHECKER_3 = """
  This will fail. Paris is cool.
  """

  END_PHRASE_1 = """
  Any more questions?
  """

  END_PHRASE_2 = """
  This is the end.
  """

  END_PHRASE_3 = """
  This will fail.
  """

  def test_end_checker(self):
    """Check the end of the prompt."""
    instruction_id = 'startend:end_checker'
    instruction = instructions.EndChecker(instruction_id)
    instruction.build_description(end_phrase=self.END_PHRASE_1)
    with self.subTest(f'test {self.TEST_END_CHECKER_1}'):
      self.assertTrue(instruction.check_following(self.TEST_END_CHECKER_1))

    instruction.build_description(end_phrase=self.END_PHRASE_2)
    with self.subTest(f'test {self.TEST_END_CHECKER_2}'):
      self.assertTrue(instruction.check_following(self.TEST_END_CHECKER_2))

    instruction.build_description(end_phrase=self.END_PHRASE_3)
    with self.subTest(f'test {self.TEST_END_CHECKER_3}'):
      self.assertFalse(instruction.check_following(self.TEST_END_CHECKER_3))

  TEST_TITLE_MESSAGE_1 = """
  <<Song of Joy>>
  La la la. Happy song.
  """

  TEST_TITLE_MESSAGE_2 = """
  Is it fine for title to be at the end?
  <<This is the title>>
  """
  TEST_TITLE_MESSAGE_3 = """
  << >>
  There is no title.
  """

  TEST_TITLE_MESSAGE_4 = """
  <<This is not a title.
  This is a paragraph.>>
  """

  def test_title_checker(self):
    """Check the prompt for a title."""
    instruction_id = 'detectable_format:title'
    instruction = instructions.TitleChecker(instruction_id)
    instruction.build_description()
    with self.subTest(f'test {self.TEST_TITLE_MESSAGE_1}'):
      self.assertTrue(instruction.check_following(self.TEST_TITLE_MESSAGE_1))
    with self.subTest(f'test {self.TEST_TITLE_MESSAGE_2}'):
      self.assertTrue(instruction.check_following(self.TEST_TITLE_MESSAGE_2))

    with self.subTest(f'test {self.TEST_TITLE_MESSAGE_3}'):
      self.assertFalse(instruction.check_following(self.TEST_TITLE_MESSAGE_3))
    with self.subTest(f'test {self.TEST_TITLE_MESSAGE_4}'):
      self.assertFalse(instruction.check_following(self.TEST_TITLE_MESSAGE_4))

  TEST_LETTER_FREQUENCY_MESSAGE_1 = """
  There is the T. Four T's.
  """

  TEST_LETTER_FREQUENCY_MESSAGE_2 = """
  asdfghjkl!!aA
  """

  TEST_LETTER_FREQUENCY_MESSAGE_3 = """
  The letter P appears 3 times in this message.
    """

  def test_letter_frequency_checker(self):
    """Test the frequency of letters."""
    instruction_id = 'keywords:letter_frequency'
    instruction = instructions.LetterFrequencyChecker(instruction_id)

    letter = 'T'
    frequency = 4
    instruction.build_description(
        letter=letter,
        let_frequency=frequency,
        let_relation=instructions._COMPARISON_RELATION[1],
    )
    with self.subTest(f'test {self.TEST_LETTER_FREQUENCY_MESSAGE_1}'):
      self.assertTrue(
          instruction.check_following(self.TEST_LETTER_FREQUENCY_MESSAGE_1)
      )

    letter = 'a'
    frequency = 4
    instruction.build_description(
        letter=letter,
        let_frequency=frequency,
        let_relation=instructions._COMPARISON_RELATION[0],
    )
    with self.subTest(f'test {self.TEST_LETTER_FREQUENCY_MESSAGE_2}'):
      self.assertTrue(
          instruction.check_following(self.TEST_LETTER_FREQUENCY_MESSAGE_2)
      )

    letter = 'p'
    frequency = 4
    instruction.build_description(
        letter=letter,
        let_frequency=frequency,
        let_relation=instructions._COMPARISON_RELATION[1],
    )
    with self.subTest(f'test {self.TEST_LETTER_FREQUENCY_MESSAGE_2}'):
      self.assertFalse(
          instruction.check_following(self.TEST_LETTER_FREQUENCY_MESSAGE_2)
      )

  TEST_ENGLISH_CAPITAL_1 = """
  THIS IS AN ENGLISH SENTENCE. EVERY LETTER IS CAPITALIZED!!! AMAZING.
  """

  TEST_ENGLISH_CAPITAL_2 = """
  Every Word Is Capitalized.
  """

  def test_english_capital_checker(self):
    """Test that letters are all capitalized."""
    instruction_id = 'change_case:english_capital'
    instruction = instructions.CapitalLettersEnglishChecker(instruction_id)
    instruction.build_description()
    with self.subTest(f'test {self.TEST_ENGLISH_CAPITAL_1}'):
      self.assertTrue(instruction.check_following(self.TEST_ENGLISH_CAPITAL_1))

    with self.subTest(f'test {self.TEST_ENGLISH_CAPITAL_2}'):
      self.assertFalse(instruction.check_following(self.TEST_ENGLISH_CAPITAL_2))

  TEST_ENGLISH_LOWERCASE_1 = """
  every letter is lowercase.
  """

  TEST_ENGLISH_LOWERCASE_2 = """
  Almost every letter is lowercase.
  """

  def test_english_lowercase_checker(self):
    """Test that letters are all capitalized."""
    instruction_id = 'change_case:english_lowercase'
    instruction = instructions.LowercaseLettersEnglishChecker(instruction_id)
    instruction.build_description()
    with self.subTest(f'test {self.TEST_ENGLISH_LOWERCASE_1}'):
      self.assertTrue(
          instruction.check_following(self.TEST_ENGLISH_LOWERCASE_1)
      )

    with self.subTest(f'test {self.TEST_ENGLISH_LOWERCASE_2}'):
      self.assertFalse(
          instruction.check_following(self.TEST_ENGLISH_LOWERCASE_2)
      )

  TEST_COMMA_MESSAGE_1 = """
  Every sentence is short. There is no need for a comma.
  """

  TEST_COMMA_MESSAGE_2 = """
  Since the start of time, people have always found a way to punctuate.
  """

  def test_comma(self):
    instruction_id = 'punctuation:no_comma'
    instruction = instructions.CommaChecker(instruction_id)
    instruction.build_description()
    with self.subTest(f'test {self.TEST_COMMA_MESSAGE_1}'):
      self.assertTrue(instruction.check_following(self.TEST_COMMA_MESSAGE_1))
    with self.subTest(f'test {self.TEST_COMMA_MESSAGE_2}'):
      self.assertFalse(instruction.check_following(self.TEST_COMMA_MESSAGE_2))

  TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_1 = """
  HERE there are THREE FUlly CAPITAL words.
  """

  TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2 = """
  THERE are Four FULLY CAPITAL WORDS. Many Others Are Only Partially So.
  """

  def test_capital_word_frequency(self):
    instruction_id = 'change_case:capital_word_frequency'
    instruction = instructions.CapitalWordFrequencyChecker(instruction_id)

    capital_frequency = 3
    instruction.build_description(
        capital_frequency=capital_frequency,
        capital_relation=instructions._COMPARISON_RELATION[1],
    )
    with self.subTest(f'test {self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_1}'):
      self.assertTrue(
          instruction.check_following(
              self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_1
          )
      )

    capital_frequency = 5
    instruction.build_description(
        capital_frequency=capital_frequency,
        capital_relation=instructions._COMPARISON_RELATION[0],
    )
    with self.subTest(f'test {self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2}'):
      self.assertTrue(
          instruction.check_following(
              self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2
          )
      )

    capital_frequency = 4
    instruction.build_description(
        capital_frequency=capital_frequency,
        capital_relation=instructions._COMPARISON_RELATION[0],
    )
    with self.subTest(f'test {self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2}'):
      self.assertFalse(
          instruction.check_following(
              self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2
          )
      )

  TEST_QUOTATION_MESSAGE_1 = """
  "This entire message is wrapped in double quotation marks."
  """

  TEST_QUOTATION_MESSAGE_2 = """
  "This message is wrapped in double quotation marks." But not everything.
  """

  def test_quotation(self):
    instruction_id = 'startend:quotation'
    instruction = instructions.QuotationChecker(instruction_id)
    instruction.build_description()
    with self.subTest(f'test {self.TEST_QUOTATION_MESSAGE_1}'):
      self.assertTrue(
          instruction.check_following(self.TEST_QUOTATION_MESSAGE_1)
      )
    with self.subTest(f'test {self.TEST_QUOTATION_MESSAGE_2}'):
      self.assertFalse(
          instruction.check_following(self.TEST_QUOTATION_MESSAGE_2)
      )

  INSTRUCTION_DICT = {
      'language:response_language': instructions.ResponseLanguageChecker,
      'length_constraints:number_sentences': instructions.NumberOfSentences,
      'length_constraints:number_paragraphs': instructions.ParagraphChecker,
      'length_constraints:number_words': instructions.NumberOfWords,
      'detectable_content:number_placeholders': instructions.PlaceholderChecker,
      'detectable_content:postscript': instructions.PostscriptChecker,
      'detectable_format:number_bullet_lists': instructions.BulletListChecker,
      'detectable_format:constrained_response': (
          instructions.ConstrainedResponseChecker),
      'detectable_format:number_highlighted_sections': (
          instructions.HighlightSectionChecker),
      'detectable_format:multiple_sections': instructions.SectionChecker,
      'detectable_format:json_format': instructions.JsonFormat,
  }

  def test_get_instruction_args(self):
    """Test getting instruction args."""
    for inst_id, inst_cls in self.INSTRUCTION_DICT.items():
      instruction = inst_cls(inst_id)
      inst_description = instruction.build_description()
      kwargs = instruction.get_instruction_args()
      # The keyword args can be None.
      if kwargs:
        inst_description_closed_loop = instruction.build_description(**kwargs)
        with self.subTest(f'test {inst_id}'):
          self.assertEqual(inst_description, inst_description_closed_loop)
BULLET_TEST_MESSAGE_1 = '\n A Markdown bullet point is a way of formatting text to create a list. To\n create a bullet point, start each line with an asterisk (*). For example:\n * This is a bullet point.\n *(no space required)Another bullet point.\n * (no newline ending required)Another bullet point.\n markdown bullet points are often used to create to-do lists or to list items\n in a step-by-step guide.' class-attribute instance-attribute
BULLET_TEST_MESSAGE_2 = '\n Check that inline asterisk (*), *, will not be counted. Only * that starts a\n bullet list will be counted:\n * This is a bullet point.\n * Another bullet point.\n . dot is not counted' class-attribute instance-attribute
BULLET_TEST_MESSAGE_3 = '\n Here are three bullets starting with asterisk:\n * I am a large language model, also known as a conversational AI.\n * I am trained on a massive amount of text data, and I am able to communicate.\n * I am still under development, but I am learning new things every day.' class-attribute instance-attribute
BULLET_TEST_MESSAGE_4 = '\n Here are three markdown bullets:\n - I am a large language model, also known as a conversational AI.\n - I am trained on a massive amount of text data, and I am able to communicate.\n -I am still under development, but I am learning new things every day.' class-attribute instance-attribute
BULLET_TEST_MESSAGE_5 = '\n Paragraph 1\n ***\n Paragraph 2\n ***\n Paragraph 3\n * only one bullet point\n ' class-attribute instance-attribute
CONSTRAINED_RESPONSE_TEST_RESPONSE_1 = '\n My answer is no.\n' class-attribute instance-attribute
CONSTRAINED_RESPONSE_TEST_RESPONSE_2 = 'My answer is no. ' class-attribute instance-attribute
CONSTRAINED_RESPONSE_TEST_RESPONSE_3 = '\n My answer is no. I am still under development and I am always learning and\n improving. I am not the best chatbot in the world, but I am striving to be\n the best that I can be.' class-attribute instance-attribute
CONSTRAINED_START_TEST_MESSAGE_1 = '\n My response is: ASIC is a specialized chip for specific tasks in electronic\n devices, offering advantages in efficiency and processing speed.' class-attribute instance-attribute
CONSTRAINED_START_TEST_MESSAGE_2 = '\n My response is: ASIC is a specialized chip for specific tasks in\n electronic\n devices, offering advantages in efficiency and processing speed.' class-attribute instance-attribute
CONSTRAINED_START_TEST_MESSAGE_3 = '\n An ASIC, or Application-Specific Integrated Circuit, is a type of specialized\n chip that, my response is, is designed to perform specific tasks in electronic\n devices.' class-attribute instance-attribute
END_PHRASE_1 = '\n Any more questions?\n ' class-attribute instance-attribute
END_PHRASE_2 = '\n This is the end.\n ' class-attribute instance-attribute
END_PHRASE_3 = '\n This will fail.\n ' class-attribute instance-attribute
FORBIDDEN_WORDS_1 = ('HOUSE', 'POWER', 'BECOME') class-attribute instance-attribute
FORBIDDEN_WORDS_2 = ('GOOGLE', 'TEXT') class-attribute instance-attribute
FORBIDDEN_WORDS_3 = ('GENE', 'TRANSFORM') class-attribute instance-attribute
HIGHLIGHTED_TEST_MESSAGE_1 = '\n To highlight text with Markdown, you can use the * character before and after\n the text you want to highlight. For example, if you want to highlight the\n word `hello`, you would type:*hello*, You can also use the ** character to\n create bold text. For example, if you want to bold the word `hello`, you\n would type: **hello** ' class-attribute instance-attribute
HIGHLIGHTED_TEST_MESSAGE_2 = '\n Sure, here are the numerical methods for solving partial differential\n equations highlighted with Markdown:\n *Finite difference methods\n *Finite element methods*\n *Boundary element methods\n *Monte Carlo methods\n I hope this helps!' class-attribute instance-attribute
HIGHLIGHTED_TEST_MESSAGE_3 = '\n There is allowed to be *two different* highlighted *sections in the same*\n line. **This is also true** for **double markdown highlights.**\n ' class-attribute instance-attribute
INSTRUCTION_DICT = {'language:response_language': instructions.ResponseLanguageChecker, 'length_constraints:number_sentences': instructions.NumberOfSentences, 'length_constraints:number_paragraphs': instructions.ParagraphChecker, 'length_constraints:number_words': instructions.NumberOfWords, 'detectable_content:number_placeholders': instructions.PlaceholderChecker, 'detectable_content:postscript': instructions.PostscriptChecker, 'detectable_format:number_bullet_lists': instructions.BulletListChecker, 'detectable_format:constrained_response': instructions.ConstrainedResponseChecker, 'detectable_format:number_highlighted_sections': instructions.HighlightSectionChecker, 'detectable_format:multiple_sections': instructions.SectionChecker, 'detectable_format:json_format': instructions.JsonFormat} class-attribute instance-attribute
KEYWORDS = ('romantic', 'river', 'Mona Lisa') class-attribute instance-attribute
PARAGRAPH_FIRST_WORD_TEST_1 = '\n paragraph 1\n\n I paragraph 2\n\n paragraph 3' class-attribute instance-attribute
PARAGRAPH_FIRST_WORD_TEST_2 = '\n paragraph 1\n\n I paragraph 2' class-attribute instance-attribute
PARAGRAPH_FIRST_WORD_TEST_3 = '\n paragraph 1\n\n fail paragraph 2\n\n paragraph 3' class-attribute instance-attribute
PARAGRAPH_FIRST_WORD_TEST_4 = "\n Wow this is a very long response.\n\n I can't believe there is more than three paragraphs.\n\n Really more than three? No way!\n\n I can't believe it but I guess I am living proof.\n\n Haha, you go that right." class-attribute instance-attribute
PARAGRAPH_FIRST_WORD_TEST_5 = '\n Wow this is a very long response.\n\n I can\'t believe there is more than three paragraphs.\n\n "Really?! more than three? No way!"\n\n I can\'t believe it but I guess I am living proof.\n\n Haha, you go that right.' class-attribute instance-attribute
PARAGRAPH_FIRST_WORD_TEST_6 = "\n Wow this is a very long response.\n\n I can't believe there is more than three paragraphs.\n\n Rea!lly more than three? No way!\n\n I can't believe it but I guess I am living proof.\n\n Haha, you go that right." class-attribute instance-attribute
PARAGRAPH_TEST_MESSAGE_1 = '\n paragraph 1\n ***\n paragraph 2\n ***\n paragraph 3' class-attribute instance-attribute
PARAGRAPH_TEST_MESSAGE_2 = '\n ***\n paragraph 1\n ***\n paragraph 2\n ***\n paragraph 3' class-attribute instance-attribute
PARAGRAPH_TEST_MESSAGE_3 = '\n paragraph 1\n ***\n paragraph 2\n ***\n paragraph 3\n ***' class-attribute instance-attribute
PARAGRAPH_TEST_MESSAGE_4 = '\n paragraph 1\n ***\n paragraph 2\n ***\n ***' class-attribute instance-attribute
POSTSCRIPT_TEST_MESSAGE_1 = '\n I will do my best to follow your instructions and always start my responses\n with "My response is:". I will try to be as consistent as possible, but\n please be patient with me if I make a mistake. I am still under development,\n and I am always learning new things.\n\n P.S. I hope this is what you were looking for.' class-attribute instance-attribute
POSTSCRIPT_TEST_MESSAGE_2 = '\n Sure, here is my response with a postscript starting with P.P.S.:\n\n My response is: I hope this answers your question.\n\n P.P.S. I am always happy to answer any other questions you may have.\n\n Do you have any other questions for me?' class-attribute instance-attribute
POSTSCRIPT_TEST_MESSAGE_3 = '\n The radius of a unit circle is 1. However, I can give you a funny and wrong\n answer: the radius of a unit circle is 0. This is because a unit circle is a\n circle with a radius of 1, and if the radius is 0, then the circle has no\n size and is just a point. (not starting a new line) P.S. I hope you enjoyed\n my joke!' class-attribute instance-attribute
POSTSCRIPT_TEST_MESSAGE_4 = '\n If the length of a square is one, the area of the square will also be one.\n p.p.s what if the entire response was lower case letters?\n ' class-attribute instance-attribute
POSTSCRIPT_TEST_MESSAGE_5 = '\n The mysteries of space and time are mysterious.\n P. S. Sometimes there are even spaces between P. and S..\n ' class-attribute instance-attribute
PROMPT_TO_REPEAT = 'Write a CL description.' class-attribute instance-attribute
REPHRASE_TEST_ORIGINAL_MESSAGE_1 = '\n I am *happy*.' class-attribute instance-attribute
REPHRASE_TEST_ORIGINAL_MESSAGE_2 = '\n *At present,* there is heavy rainfall occurring.' class-attribute instance-attribute
REPHRASE_TEST_REPHRASED_MESSAGE_1 = '\n I am *content*.' class-attribute instance-attribute
REPHRASE_TEST_REPHRASED_MESSAGE_1_FORMAT = '\n I am [content].' class-attribute instance-attribute
REPHRASE_TEST_REPHRASED_MESSAGE_1_NOCHANGE = '\n I am .' class-attribute instance-attribute
REPHRASE_TEST_REPHRASED_MESSAGE_2 = '\n It is raining heavily *at this moment*.' class-attribute instance-attribute
SECTION_TEST_MESSAGE_1 = '\n Your response must have multiple sections. Mark the beginning of each section\n with "Section X", such as:\n Section 1\n [content of section 1]\n Section 2\n [content of section 2]' class-attribute instance-attribute
SECTION_TEST_MESSAGE_2 = 'SECTION 1\n [content of section 1]\n SECTION 2\n [content of section 2]' class-attribute instance-attribute
TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_1 = '\n HERE there are THREE FUlly CAPITAL words.\n ' class-attribute instance-attribute
TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2 = '\n THERE are Four FULLY CAPITAL WORDS. Many Others Are Only Partially So.\n ' class-attribute instance-attribute
TEST_COMMA_MESSAGE_1 = '\n Every sentence is short. There is no need for a comma.\n ' class-attribute instance-attribute
TEST_COMMA_MESSAGE_2 = '\n Since the start of time, people have always found a way to punctuate.\n ' class-attribute instance-attribute
TEST_END_CHECKER_1 = '\n The answer is 7. Any more questions?\n ' class-attribute instance-attribute
TEST_END_CHECKER_2 = '\n At the end of this prompt I am required to say that this is the end.\n ' class-attribute instance-attribute
TEST_END_CHECKER_3 = '\n This will fail. Paris is cool.\n ' class-attribute instance-attribute
TEST_ENGLISH_CAPITAL_1 = '\n THIS IS AN ENGLISH SENTENCE. EVERY LETTER IS CAPITALIZED!!! AMAZING.\n ' class-attribute instance-attribute
TEST_ENGLISH_CAPITAL_2 = '\n Every Word Is Capitalized.\n ' class-attribute instance-attribute
TEST_ENGLISH_LOWERCASE_1 = '\n every letter is lowercase.\n ' class-attribute instance-attribute
TEST_ENGLISH_LOWERCASE_2 = '\n Almost every letter is lowercase.\n ' class-attribute instance-attribute
TEST_FORBIDDEN_WORDS_MESSAGE_1 = '\n The Nazis came to power in 1933 through a combination of legal and illegal\n means. Hitler was appointed chancellor by President Paul von Hindenburg, and\n the Nazis quickly consolidated their power by passing a series of laws that\n restricted the rights of opposition parties and individuals. By 1934, Hitler\n had become dictator of Germany.\n ' class-attribute instance-attribute
TEST_FORBIDDEN_WORDS_MESSAGE_2 = '\n Dinosaurs were a diverse group of reptiles that dominated the Earth for over\n 160 million years. They came in all shapes and sizes, from the tiny\n Compsognathus to the massive Argentinosaurus. Dinosaurs were the most\n successful land animals on Earth until they went extinct about 66 million\n years ago. The exact cause of their extinction is still unknown, but it\n is thought to have been a combination of factors, including an asteroid\n impact and climate change.\n ' class-attribute instance-attribute
TEST_FORBIDDEN_WORDS_MESSAGE_3 = '\n GPT, or Generative Pre-trained Transformer, is a family of neural network\n models that uses the transformer architecture. GPT models are trained on a\n massive dataset of text and code, and can be used for a variety of tasks,\n including text generation, translation, and question answering. GPT models\n have been shown to be very effective at these tasks, and are being used by\n a variety of companies and organizations like Google.\n ' class-attribute instance-attribute
TEST_INCLUDE_KEYWORD_MESSAGE_1 = "\n Paris is a city of beauty and romance. The romantic river Seine winds its way\n through the city, past iconic landmarks like the Eiffel Tower and the Louvre\n Museum, where the Mona Lisa resides. Whether you're taking a boat cruise down\n the river or simply strolling along the banks, you're sure to be captivated\n by the city's charm." class-attribute instance-attribute
TEST_INCLUDE_KEYWORD_MESSAGE_2 = '\n Paris is a city of beauty, romance, and history. It is home to some of the\n most iconic landmarks in the world, including the Eiffel Tower, the Louvre\n Museum, and the Notre Dame Cathedral. The city is also known for its romantic\n river cruises, its delicious food, and its stylish people.\n ' class-attribute instance-attribute
TEST_KEYWORD_FREQUENCY_KEYWORD_1 = ' keyword ' class-attribute instance-attribute
TEST_KEYWORD_FREQUENCY_KEYWORD_2 = 'KEYWORD' class-attribute instance-attribute
TEST_KEYWORD_FREQUNECY_MESSAGE_1 = '\n keyword, Keyword, KEYWORD\n ' class-attribute instance-attribute
TEST_KEYWORD_FREQUNECY_MESSAGE_2 = '\n *keyword\n *Keyword\n *KEYWORD\n ' class-attribute instance-attribute
TEST_KEY_SENTENCES_1 = '\n Puppies are fun. They are playful, energetic, and always up for a good time.\nPuppies love to run, jump, and play fetch. They are also very good at\ncuddling and giving kisses. If you are looking for a fun and loving pet,\na puppy is a great choice.\n ' class-attribute instance-attribute
TEST_KEY_SENTENCES_2 = "\n I like to eat candy. When I'm feeling happy, sad, or even angry, candy\nalways makes me feel better. I like to share candy with my friends and\nfamily. It's a great way to show them how much I care.\n " class-attribute instance-attribute
TEST_KEY_SENTENCES_3 = "\nI know that candy isn't the healthiest thing to eat, but I don't care.\nI love it too much. I'll just have to make sure to eat it in moderation.\n " class-attribute instance-attribute
TEST_LETTER_FREQUENCY_MESSAGE_1 = "\n There is the T. Four T's.\n " class-attribute instance-attribute
TEST_LETTER_FREQUENCY_MESSAGE_2 = '\n asdfghjkl!!aA\n ' class-attribute instance-attribute
TEST_LETTER_FREQUENCY_MESSAGE_3 = '\n The letter P appears 3 times in this message.\n ' class-attribute instance-attribute
TEST_NUM_WORDS_MESSAGE_1 = '\n d3sCRi7 lArge lAnguagE M0del w1tH 20 w0RdS.' class-attribute instance-attribute
TEST_NUM_WORDS_MESSAGE_2 = '\n L4RGE L4NGU4GE M0DEL: AI syst3m th4t und3rstands, g3n3r4tes, or tr4nsforms\n l4ngu4g3 b4s3d on pr3vious l3arning & d4t4.' class-attribute instance-attribute
TEST_ORIGINAL_PARAGRAPH_1 = "\n The sun is shining brightly today, and the birds are singing in the trees.\n It's a beautiful day to be outside, so I decided to go for a walk.\n As I walked, I took in the fresh air and the warm sunshine.\n I felt happy and relaxed, and I was grateful for the beautiful day\n " class-attribute instance-attribute
TEST_ORIGINAL_PARAGRAPH_2 = "\n Google is a global technology company that specializes in Internet-related\n services and products. It is one of the most successful companies in the\n world, and its products are used by billions of people every day. Google's\n mission is to organize the world's information and make it universally\n accessible and useful.\n " class-attribute instance-attribute
TEST_PROMPT_1 = 'Write a CL description. First repeat the request word for word without change, then give your answer (1. do not say any words or characters before repeating the request; 2. the request you need to repeat does not include this sentence)' class-attribute instance-attribute
TEST_PROMPT_ANSWER_1 = 'Write a CL description. Hi, Le and TJ, please\n check this out. Thanks.\n ' class-attribute instance-attribute
TEST_PROMPT_ANSWER_2 = 'Hi, Le and TJ. Write a CL description. Thanks.\n ' class-attribute instance-attribute
TEST_QUOTATION_MESSAGE_1 = '\n "This entire message is wrapped in double quotation marks."\n ' class-attribute instance-attribute
TEST_QUOTATION_MESSAGE_2 = '\n "This message is wrapped in double quotation marks." But not everything.\n ' class-attribute instance-attribute
TEST_REPHRASED_PARAGRAPH_1 = '\n On a beautiful day, I went for a walk. The sun shone and birds sang.\n I enjoyed the fresh air and warm sun.\n I felt happy and grateful for the lovely day.\n ' class-attribute instance-attribute
TEST_REPHRASED_PARAGRAPH_2 = '\n The weather was lovely, so I went for a walk. I enjoyed the\n fresh air and warm sun. It was a beautiful day, and I felt happy and grateful.\n ' class-attribute instance-attribute
TEST_REPHRASED_PARAGRAPH_3 = "\n Google is a technology company that provides Internet services.\n It aims to organize the world's information and make it universally\n accessible and useful.\n " class-attribute instance-attribute
TEST_REPHRASED_PARAGRAPH_4 = '\n I like candy.\n ' class-attribute instance-attribute
TEST_TITLE_MESSAGE_1 = '\n <<Song of Joy>>\n La la la. Happy song.\n ' class-attribute instance-attribute
TEST_TITLE_MESSAGE_2 = '\n Is it fine for title to be at the end?\n <<This is the title>>\n ' class-attribute instance-attribute
TEST_TITLE_MESSAGE_3 = '\n << >>\n There is no title.\n ' class-attribute instance-attribute
TEST_TITLE_MESSAGE_4 = '\n <<This is not a title.\n This is a paragraph.>>\n ' class-attribute instance-attribute
TEST_TWO_RESPONSES_1 = '\n This is response 1.\n ******\n This is response 2.\n ' class-attribute instance-attribute
TEST_TWO_RESPONSES_2 = '\n This is response 1.\n ******\n This is response 1.\n ' class-attribute instance-attribute
TEST_TWO_RESPONSES_3 = '\n This is response 1.\n ******\n This is response 2.\n ******\n This is response 3.\n ' class-attribute instance-attribute
TEST_TWO_RESPONSES_4 = '\n ******\n Response 1.\n ******\n ******\n Response 2.\n ******\n ' class-attribute instance-attribute
TEST_TWO_RESPONSES_5 = '\n ******\n Response 1\n ******\n Response 2\n ******\n ' class-attribute instance-attribute
key_sentences = {'Puppies love to run, jump, and play fetch.', 'I like to eat candy.', 'Puppies are fun.'} class-attribute instance-attribute
test_capital_word_frequency()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
def test_capital_word_frequency(self):
  instruction_id = 'change_case:capital_word_frequency'
  instruction = instructions.CapitalWordFrequencyChecker(instruction_id)

  capital_frequency = 3
  instruction.build_description(
      capital_frequency=capital_frequency,
      capital_relation=instructions._COMPARISON_RELATION[1],
  )
  with self.subTest(f'test {self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_1}'):
    self.assertTrue(
        instruction.check_following(
            self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_1
        )
    )

  capital_frequency = 5
  instruction.build_description(
      capital_frequency=capital_frequency,
      capital_relation=instructions._COMPARISON_RELATION[0],
  )
  with self.subTest(f'test {self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2}'):
    self.assertTrue(
        instruction.check_following(
            self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2
        )
    )

  capital_frequency = 4
  instruction.build_description(
      capital_frequency=capital_frequency,
      capital_relation=instructions._COMPARISON_RELATION[0],
  )
  with self.subTest(f'test {self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2}'):
    self.assertFalse(
        instruction.check_following(
            self.TEST_CAPITAL_WORD_FREQUENCY_MESSAGE_2
        )
    )
test_comma()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1181
1182
1183
1184
1185
1186
1187
1188
def test_comma(self):
  instruction_id = 'punctuation:no_comma'
  instruction = instructions.CommaChecker(instruction_id)
  instruction.build_description()
  with self.subTest(f'test {self.TEST_COMMA_MESSAGE_1}'):
    self.assertTrue(instruction.check_following(self.TEST_COMMA_MESSAGE_1))
  with self.subTest(f'test {self.TEST_COMMA_MESSAGE_2}'):
    self.assertFalse(instruction.check_following(self.TEST_COMMA_MESSAGE_2))
test_constrained_response()

Test the constrained response checker.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def test_constrained_response(self):
  """Test the constrained response checker."""
  instruction_id = 'detectable_format:constrained_response'
  instruction = instructions.ConstrainedResponseChecker(instruction_id)
  instruction.build_description()

  with self.subTest('test with CONSTRAINED_RESPONSE_TEST_RESPONSE_1'):
    self.assertTrue(instruction.check_following(
        self.CONSTRAINED_RESPONSE_TEST_RESPONSE_1))

  with self.subTest('test with CONSTRAINED_RESPONSE_TEST_RESPONSE_2'):
    self.assertTrue(instruction.check_following(
        self.CONSTRAINED_RESPONSE_TEST_RESPONSE_2))

  with self.subTest('test with CONSTRAINED_RESPONSE_TEST_RESPONSE_3'):
    self.assertTrue(instruction.check_following(
        self.CONSTRAINED_RESPONSE_TEST_RESPONSE_3))
test_constrained_start_checker()

Test the constrained start checker.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def test_constrained_start_checker(self):
  """Test the constrained start checker."""
  instruction_id = 'multi-turn:constrained_start'
  instruction = instructions.ConstrainedStartChecker(instruction_id)
  start_keyword = 'My response is:'
  instruction.build_description(starter=start_keyword)
  with self.subTest(f'test {start_keyword}'):
    self.assertTrue(
        instruction.check_following(self.CONSTRAINED_START_TEST_MESSAGE_1))

  with self.subTest(f'test {start_keyword} with spaces in the beginning'):
    self.assertTrue(instruction.check_following(
        self.CONSTRAINED_START_TEST_MESSAGE_2))

  start_keyword = 'my response is'
  with self.subTest(f'test {start_keyword} embedded in the middle'):
    self.assertFalse(
        instruction.check_following(self.CONSTRAINED_START_TEST_MESSAGE_3))
test_end_checker()

Check the end of the prompt.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
def test_end_checker(self):
  """Check the end of the prompt."""
  instruction_id = 'startend:end_checker'
  instruction = instructions.EndChecker(instruction_id)
  instruction.build_description(end_phrase=self.END_PHRASE_1)
  with self.subTest(f'test {self.TEST_END_CHECKER_1}'):
    self.assertTrue(instruction.check_following(self.TEST_END_CHECKER_1))

  instruction.build_description(end_phrase=self.END_PHRASE_2)
  with self.subTest(f'test {self.TEST_END_CHECKER_2}'):
    self.assertTrue(instruction.check_following(self.TEST_END_CHECKER_2))

  instruction.build_description(end_phrase=self.END_PHRASE_3)
  with self.subTest(f'test {self.TEST_END_CHECKER_3}'):
    self.assertFalse(instruction.check_following(self.TEST_END_CHECKER_3))
test_english_capital_checker()

Test that letters are all capitalized.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
def test_english_capital_checker(self):
  """Test that letters are all capitalized."""
  instruction_id = 'change_case:english_capital'
  instruction = instructions.CapitalLettersEnglishChecker(instruction_id)
  instruction.build_description()
  with self.subTest(f'test {self.TEST_ENGLISH_CAPITAL_1}'):
    self.assertTrue(instruction.check_following(self.TEST_ENGLISH_CAPITAL_1))

  with self.subTest(f'test {self.TEST_ENGLISH_CAPITAL_2}'):
    self.assertFalse(instruction.check_following(self.TEST_ENGLISH_CAPITAL_2))
test_english_lowercase_checker()

Test that letters are all capitalized.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
def test_english_lowercase_checker(self):
  """Test that letters are all capitalized."""
  instruction_id = 'change_case:english_lowercase'
  instruction = instructions.LowercaseLettersEnglishChecker(instruction_id)
  instruction.build_description()
  with self.subTest(f'test {self.TEST_ENGLISH_LOWERCASE_1}'):
    self.assertTrue(
        instruction.check_following(self.TEST_ENGLISH_LOWERCASE_1)
    )

  with self.subTest(f'test {self.TEST_ENGLISH_LOWERCASE_2}'):
    self.assertFalse(
        instruction.check_following(self.TEST_ENGLISH_LOWERCASE_2)
    )
test_forbidden_words()

Test the exclusion of key words.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
def test_forbidden_words(self):
  """Test the exclusion of key words."""
  instruction_id = 'keywords:forbidden_words'
  instruction = instructions.ForbiddenWords(instruction_id)

  instruction.build_description(forbidden_words=self.FORBIDDEN_WORDS_1)
  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_1}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_1}. '):
    self.assertFalse(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_1))

  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_2}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_1}. '):
    self.assertTrue(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_2))

  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_3}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_1}. '):
    self.assertTrue(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_3))

  instruction.build_description(forbidden_words=self.FORBIDDEN_WORDS_2)
  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_1}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
    self.assertTrue(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_1))

  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_2}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
    self.assertTrue(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_2))

  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_3}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
    self.assertFalse(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_3))

  instruction.build_description(forbidden_words=self.FORBIDDEN_WORDS_3)
  with self.subTest(f'test {self.TEST_FORBIDDEN_WORDS_MESSAGE_3}\n ' +
                    f'with forbidden words: {self.FORBIDDEN_WORDS_2}. '):
    self.assertTrue(
        instruction.check_following(self.TEST_FORBIDDEN_WORDS_MESSAGE_3))
test_get_instruction_args()

Test getting instruction args.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
def test_get_instruction_args(self):
  """Test getting instruction args."""
  for inst_id, inst_cls in self.INSTRUCTION_DICT.items():
    instruction = inst_cls(inst_id)
    inst_description = instruction.build_description()
    kwargs = instruction.get_instruction_args()
    # The keyword args can be None.
    if kwargs:
      inst_description_closed_loop = instruction.build_description(**kwargs)
      with self.subTest(f'test {inst_id}'):
        self.assertEqual(inst_description, inst_description_closed_loop)
test_key_sentences()

Test the inclusion of key sentences.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
def test_key_sentences(self):
  """Test the inclusion of key sentences."""
  instruction_id = 'keywords:key_sentences'
  instruction = instructions.KeySentenceChecker(instruction_id)

  num_sentences = 2
  instruction.build_description(
      key_sentences=self.key_sentences, num_sentences=num_sentences)

  with self.subTest(f'test {self.TEST_KEY_SENTENCES_1}'):
    self.assertTrue(instruction.check_following(self.TEST_KEY_SENTENCES_1))

  num_sentences = 1
  instruction.build_description(
      key_sentences=self.key_sentences, num_sentences=num_sentences)

  with self.subTest(f'test {self.TEST_KEY_SENTENCES_2}'):
    self.assertTrue(instruction.check_following(self.TEST_KEY_SENTENCES_2))

  with self.subTest(f'test {self.TEST_KEY_SENTENCES_3}'):
    self.assertFalse(instruction.check_following(self.TEST_KEY_SENTENCES_3))
test_keyword_checker()

Test the inclusion of keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def test_keyword_checker(self):
  """Test the inclusion of keywords."""
  instruction_id = 'keywords:include_keywords'
  instruction = instructions.KeywordChecker(instruction_id)

  instruction.build_description(keywords=self.KEYWORDS)
  with self.subTest(f'test {self.TEST_INCLUDE_KEYWORD_MESSAGE_1}'):
    self.assertTrue(
        instruction.check_following(self.TEST_INCLUDE_KEYWORD_MESSAGE_1))

  instruction.build_description(keywords=self.KEYWORDS)
  with self.subTest(f'test {self.TEST_INCLUDE_KEYWORD_MESSAGE_2}'):
    self.assertFalse(
        instruction.check_following(self.TEST_INCLUDE_KEYWORD_MESSAGE_2))
test_keyword_frequency_checker()

Test the frequency of keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def test_keyword_frequency_checker(self):
  """Test the frequency of keywords."""

  instruction_id = 'keywords:keyword_frequency'
  instruction = instructions.KeywordFrequencyChecker(instruction_id)

  frequency = 4
  instruction.build_description(keyword=self.TEST_KEYWORD_FREQUENCY_KEYWORD_1,
                                frequency=frequency,
                                relation=instructions._COMPARISON_RELATION[0])
  with self.subTest(
      f'test {self.TEST_KEYWORD_FREQUENCY_KEYWORD_1} {frequency}'):
    self.assertTrue(
        instruction.check_following(self.TEST_KEYWORD_FREQUNECY_MESSAGE_1))

  frequency = 3
  instruction.build_description(keyword=self.TEST_KEYWORD_FREQUENCY_KEYWORD_1,
                                frequency=frequency,
                                relation=instructions._COMPARISON_RELATION[1])
  with self.subTest(
      f'test {self.TEST_KEYWORD_FREQUENCY_KEYWORD_1} {frequency}'):
    self.assertTrue(
        instruction.check_following(self.TEST_KEYWORD_FREQUNECY_MESSAGE_1))

  frequency = 4
  instruction.build_description(keyword=self.TEST_KEYWORD_FREQUENCY_KEYWORD_2,
                                frequency=frequency,
                                relation=instructions._COMPARISON_RELATION[1])
  with self.subTest(
      f'test {self.TEST_KEYWORD_FREQUENCY_KEYWORD_2} {frequency}'):
    self.assertFalse(
        instruction.check_following(self.TEST_KEYWORD_FREQUNECY_MESSAGE_2))
test_letter_frequency_checker()

Test the frequency of letters.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
def test_letter_frequency_checker(self):
  """Test the frequency of letters."""
  instruction_id = 'keywords:letter_frequency'
  instruction = instructions.LetterFrequencyChecker(instruction_id)

  letter = 'T'
  frequency = 4
  instruction.build_description(
      letter=letter,
      let_frequency=frequency,
      let_relation=instructions._COMPARISON_RELATION[1],
  )
  with self.subTest(f'test {self.TEST_LETTER_FREQUENCY_MESSAGE_1}'):
    self.assertTrue(
        instruction.check_following(self.TEST_LETTER_FREQUENCY_MESSAGE_1)
    )

  letter = 'a'
  frequency = 4
  instruction.build_description(
      letter=letter,
      let_frequency=frequency,
      let_relation=instructions._COMPARISON_RELATION[0],
  )
  with self.subTest(f'test {self.TEST_LETTER_FREQUENCY_MESSAGE_2}'):
    self.assertTrue(
        instruction.check_following(self.TEST_LETTER_FREQUENCY_MESSAGE_2)
    )

  letter = 'p'
  frequency = 4
  instruction.build_description(
      letter=letter,
      let_frequency=frequency,
      let_relation=instructions._COMPARISON_RELATION[1],
  )
  with self.subTest(f'test {self.TEST_LETTER_FREQUENCY_MESSAGE_2}'):
    self.assertFalse(
        instruction.check_following(self.TEST_LETTER_FREQUENCY_MESSAGE_2)
    )
test_num_words_checker()

Test the checker on the number of words.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def test_num_words_checker(self):
  """Test the checker on the number of words."""
  instruction_id = 'length_constraint:number_words'
  instruction = instructions.NumberOfWords(instruction_id)

  word_counts = 8
  instruction.build_description(num_words=word_counts,
                                relation=instructions._COMPARISON_RELATION[0])
  with self.subTest(
      f'test {self.TEST_NUM_WORDS_MESSAGE_1} {word_counts}'):
    self.assertTrue(
        instruction.check_following(self.TEST_NUM_WORDS_MESSAGE_1))

  word_counts = 16
  instruction.build_description(num_words=word_counts,
                                relation=instructions._COMPARISON_RELATION[0])
  with self.subTest(
      f'test {self.TEST_NUM_WORDS_MESSAGE_2} less than {word_counts}'):
    self.assertFalse(
        instruction.check_following(self.TEST_NUM_WORDS_MESSAGE_2))

  word_counts = 16
  instruction.build_description(num_words=word_counts,
                                relation=instructions._COMPARISON_RELATION[1])
  with self.subTest(
      f'test {self.TEST_NUM_WORDS_MESSAGE_2} at least {word_counts}'):
    self.assertTrue(
        instruction.check_following(self.TEST_NUM_WORDS_MESSAGE_2))
test_number_bullet_lists(template, num_bullets, expected)

Test the number of bullets.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@parameterized.named_parameters(
    [
        {
            'testcase_name': (
                f'_templated={template}_num_bullets={num_bullets}'
                f'_expected={expected}'
            ),
            'template': template,
            'num_bullets': num_bullets,
            'expected': expected,
        }
        for template, num_bullets, expected in [
            (BULLET_TEST_MESSAGE_1, 3, True),
            (BULLET_TEST_MESSAGE_2, 2, True),
            (BULLET_TEST_MESSAGE_3, 3, True),
            (BULLET_TEST_MESSAGE_4, 3, True),
            (BULLET_TEST_MESSAGE_5, 1, True)]
    ]
)
def test_number_bullet_lists(self, template, num_bullets, expected):
  """Test the number of bullets."""
  instruction_id = 'detectable_format:exact_number_bullet_points'
  instruction = instructions.BulletListChecker(instruction_id)
  instruction.build_description(num_bullets=num_bullets)
  actual = instruction.check_following(template)
  self.assertEqual(actual, expected)
test_number_highlights(response, min_num_highlights, expected)

Test the minimum number of highlighted sections.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
@parameterized.named_parameters(
    [
        {
            'testcase_name': (
                f'_response={response}'
                f'_min_num_highlights={min_num_highlights}'
                f'_expected={expected}'
            ),
            'response': response,
            'min_num_highlights': min_num_highlights,
            'expected': expected,
        }
        for response, min_num_highlights, expected in [
            (HIGHLIGHTED_TEST_MESSAGE_1, 2, True),
            (HIGHLIGHTED_TEST_MESSAGE_2, 2, False),
            (HIGHLIGHTED_TEST_MESSAGE_3, 4, True)]
    ]
)
def test_number_highlights(self, response, min_num_highlights, expected):
  """Test the minimum number of highlighted sections."""
  instruction_id = 'detectable_format:minimum_number_highlighted_sections'
  instruction = instructions.HighlightSectionChecker(instruction_id)
  instruction.build_description(num_highlights=min_num_highlights)
  actual = instruction.check_following(response)
  self.assertEqual(actual, expected)
test_number_placeholders(template, num_placeholders, expected)

Test the number of placeholders.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@parameterized.named_parameters(
    [
        {
            'testcase_name': (
                f'_templated={template}_num_placeholders={num_placeholders}'
                f'_expected={expected}'
            ),
            'template': template,
            'num_placeholders': num_placeholders,
            'expected': expected,
        }
        for template, num_placeholders, expected in [
            (('Sure, here is a short template with 5 placeholders:\n' +
              '[Name]\n[Email]\n[Phone]\n[Address]\n[Website]\n' +
              'This template can be used for a variety of purposes, such ' +
              'ascreating a contact list, sending out surveys, or creating ' +
              'a sign-up form.'), 5, True),
            (('My [adjective] [noun] is [adjective] [noun]. I [verb] and ' +
              '[verb].'), 7, False),
            ]
    ]
)
def test_number_placeholders(self, template, num_placeholders, expected):
  """Test the number of placeholders."""
  instruction_id = 'detectable_content:number_placeholders'
  instruction = instructions.PlaceholderChecker(instruction_id)
  instruction.build_description(num_placeholders=num_placeholders)
  actual = instruction.check_following(template)
  self.assertEqual(actual, expected)
test_number_sentences(response, relation, num_sentences, expected)

Test the number of sentences.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@parameterized.named_parameters(
    [
        {
            'testcase_name': (
                f'_response={response}_relation={relation}'
                f'_num_sentences={num_sentences}_expected={expected}'
            ),
            'response': response,
            'relation': relation,
            'num_sentences': num_sentences,
            'expected': expected,
        }
        for response, relation, num_sentences, expected in [
            ('xx,x. xx,x! xx/x. x{x}x?', instructions._COMPARISON_RELATION[0],
             4, False),
            ('xxxx. xx,x! xxxx. x(x)x?', instructions._COMPARISON_RELATION[0],
             5, True),
            ('xxxx. xx,x! xx|x. x&x x?', instructions._COMPARISON_RELATION[1],
             4, True),
            ('xx-x. xx,x! xx}x. x,xx?', instructions._COMPARISON_RELATION[1],
             5, False),
        ]
    ]
)
def test_number_sentences(self, response, relation, num_sentences, expected):
  """Test the number of sentences."""
  instruction_id = 'length_constraints:number_sentences'
  instruction = instructions.NumberOfSentences(instruction_id)
  instruction.build_description(relation=relation,
                                num_sentences=num_sentences)
  actual = instruction.check_following(response)
  self.assertEqual(actual, expected)
test_paragraph_checker()

Test the number of sections.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def test_paragraph_checker(self):
  """Test the number of sections."""
  instruction_id = 'length_constraint:number_paragraphs'
  instruction = instructions.ParagraphChecker(instruction_id)
  num_paragraphs = 3
  instruction.build_description(num_paragraphs=num_paragraphs)
  with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_1} and '
                    f'{num_paragraphs} paragraphs'):
    self.assertTrue(instruction.check_following(
        self.PARAGRAPH_TEST_MESSAGE_1))

  num_paragraphs = 3
  instruction.build_description(num_paragraphs=num_paragraphs)
  with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_2} and '
                    f'{num_paragraphs} paragraphs'):
    self.assertTrue(instruction.check_following(
        self.PARAGRAPH_TEST_MESSAGE_2))

  num_paragraphs = 3
  instruction.build_description(num_paragraphs=num_paragraphs)
  with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_3} and '
                    f'{num_paragraphs} paragraphs'):
    self.assertTrue(instruction.check_following(
        self.PARAGRAPH_TEST_MESSAGE_3))

  num_paragraphs = 2
  instruction.build_description(num_paragraphs=num_paragraphs)
  with self.subTest(f'test {self.PARAGRAPH_TEST_MESSAGE_4} and '
                    f'{num_paragraphs} paragraphs'):
    self.assertFalse(instruction.check_following(
        self.PARAGRAPH_TEST_MESSAGE_4))
test_paragraph_first_word()

Test number of paragraphs and first word of nth paragraph.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
def test_paragraph_first_word(self):
  """Test number of paragraphs and first word of nth paragraph."""
  instruction_id = 'length_constraints:nth_paragraph_first_word'
  instruction = instructions.ParagraphFirstWordCheck(instruction_id)
  tests = [
      self.PARAGRAPH_FIRST_WORD_TEST_1,
      self.PARAGRAPH_FIRST_WORD_TEST_2,
      self.PARAGRAPH_FIRST_WORD_TEST_3,
      self.PARAGRAPH_FIRST_WORD_TEST_4,
      self.PARAGRAPH_FIRST_WORD_TEST_5,
      self.PARAGRAPH_FIRST_WORD_TEST_6,
  ]

  for test in tests:
    if (test == self.PARAGRAPH_FIRST_WORD_TEST_1
        or test == self.PARAGRAPH_FIRST_WORD_TEST_2
        or test == self.PARAGRAPH_FIRST_WORD_TEST_3):
      num_paragraphs = 3
      nth_paragraph = 2
      first_word = 'I'
    elif test == self.PARAGRAPH_FIRST_WORD_TEST_4:
      num_paragraphs = 5
      nth_paragraph = 5
      first_word = 'haha'
    else:
      num_paragraphs = 5
      nth_paragraph = 3
      first_word = 'Really'

    instruction.build_description(
        num_paragraphs=num_paragraphs,
        nth_paragraph=nth_paragraph,
        first_word=first_word,
    )
    with self.subTest(
        f'test {test} \n. Test for '
        f'{num_paragraphs} paragraphs and '
        f'for paragraph {nth_paragraph} '
        f'{first_word} is first word'
    ):
      if (test == self.PARAGRAPH_FIRST_WORD_TEST_1
          or test == self.PARAGRAPH_FIRST_WORD_TEST_4
          or test == self.PARAGRAPH_FIRST_WORD_TEST_5):
        self.assertTrue(instruction.check_following(test))
      else:
        self.assertFalse(instruction.check_following(test))
test_postscript_checker()

Test the postscript checker.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def test_postscript_checker(self):
  """Test the postscript checker."""
  instruction_id = 'detectable_content:postscript'
  instruction = instructions.PostscriptChecker(instruction_id)
  postscript_start_keyword = instructions._POSTSCRIPT_MARKER[0]
  instruction.build_description(postscript_marker=postscript_start_keyword)
  with self.subTest(f'test {postscript_start_keyword}'):
    self.assertTrue(
        instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_1))

  postscript_start_keyword = 'PS:'
  instruction.build_description(postscript_marker=postscript_start_keyword)
  with self.subTest(f'test {postscript_start_keyword}'):
    self.assertFalse(
        instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_1))

  postscript_start_keyword = instructions._POSTSCRIPT_MARKER[1]
  instruction.build_description(postscript_marker=postscript_start_keyword)
  with self.subTest(f'test {postscript_start_keyword}'):
    self.assertTrue(
        instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_2))

  postscript_start_keyword = 'P.S.'
  instruction.build_description(postscript_marker=postscript_start_keyword)
  with self.subTest(f'test {postscript_start_keyword}'):
    self.assertTrue(
        instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_3))

  postscript_start_keyword = 'P.P.S'
  instruction.build_description(postscript_marker=postscript_start_keyword)
  with self.subTest(f'test {postscript_start_keyword}'):
    self.assertTrue(
        instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_4))

  postscript_start_keyword = 'P.S.'
  instruction.build_description(postscript_marker=postscript_start_keyword)
  with self.subTest(f'test {postscript_start_keyword}'):
    self.assertTrue(
        instruction.check_following(self.POSTSCRIPT_TEST_MESSAGE_5))
test_prompt_repeat_answer()

Test that prompt is repeated then anwered.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
def test_prompt_repeat_answer(self):
  """Test that prompt is repeated then anwered."""
  instruction_id = 'combination:repeat_prompt'
  instruction = instructions.RepeatPromptThenAnswer(instruction_id)

  instruction.build_description(prompt_to_repeat=self.PROMPT_TO_REPEAT)
  with self.subTest(f'test {self.TEST_PROMPT_ANSWER_1}' +
                    f' with prompt: {self.TEST_PROMPT_1}'):
    self.assertTrue(instruction.check_following(self.TEST_PROMPT_ANSWER_1))

  with self.subTest(f'test {self.TEST_PROMPT_ANSWER_2}' +
                    f' with prompt: {self.TEST_PROMPT_1}'):
    self.assertFalse(instruction.check_following(self.TEST_PROMPT_ANSWER_2))
test_quotation()
Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
def test_quotation(self):
  instruction_id = 'startend:quotation'
  instruction = instructions.QuotationChecker(instruction_id)
  instruction.build_description()
  with self.subTest(f'test {self.TEST_QUOTATION_MESSAGE_1}'):
    self.assertTrue(
        instruction.check_following(self.TEST_QUOTATION_MESSAGE_1)
    )
  with self.subTest(f'test {self.TEST_QUOTATION_MESSAGE_2}'):
    self.assertFalse(
        instruction.check_following(self.TEST_QUOTATION_MESSAGE_2)
    )
test_rephrase_checker()

Test the rephrase checker.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def test_rephrase_checker(self):
  """Test the rephrase checker."""
  instruction_id = 'detectable_format:rephrasing'
  instruction = instructions.RephraseChecker(instruction_id)
  instruction.build_description(
      original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_1)
  with self.subTest(f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_1}'):
    self.assertTrue(
        instruction.check_following(self.REPHRASE_TEST_REPHRASED_MESSAGE_1))

  instruction.build_description(
      original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_1)
  with self.subTest(
      f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_1_NOCHANGE}'):
    with self.assertRaises(ValueError):
      instruction.check_following(
          self.REPHRASE_TEST_REPHRASED_MESSAGE_1_NOCHANGE)

  instruction.build_description(
      original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_1)
  with self.subTest(f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_1_FORMAT}'):
    with self.assertRaises(ValueError):
      instruction.check_following(
          self.REPHRASE_TEST_REPHRASED_MESSAGE_1_FORMAT)

  instruction.build_description(
      original_message=self.REPHRASE_TEST_ORIGINAL_MESSAGE_2)
  with self.subTest(f'test {self.REPHRASE_TEST_REPHRASED_MESSAGE_2}'):
    self.assertFalse(
        instruction.check_following(self.REPHRASE_TEST_REPHRASED_MESSAGE_2))
test_rephrase_paragraph()

Test the rephrasing of paragraph.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
def test_rephrase_paragraph(self):
  """Test the rephrasing of paragraph."""
  instruction_id = 'detectable_content:rephrase_paragraph'
  instruction = instructions.RephraseParagraph(instruction_id)
  low, high = 20, 30
  instruction.build_description(
      low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_1)

  with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_1} to ' +
                    f'have between {low} and {high} same words.'):
    self.assertTrue(
        instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_1))

  low, high = 20, 25
  instruction.build_description(
      low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_1)

  with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_1} to ' +
                    f'have between {low} and {high} same words.'):
    self.assertTrue(
        instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_2))

  low, high = 15, 20
  instruction.build_description(
      low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_2)

  with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_2} to ' +
                    f'have between {low} and {high} same words.'):
    self.assertFalse(
        instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_3))

  low, high = 0, 5
  instruction.build_description(
      low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_2)

  with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_2} to ' +
                    f'have between {low} and {high} same words.'):
    self.assertTrue(
        instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_4))

  low, high = 1, 5
  instruction.build_description(
      low=low, high=high, original_paragraph=self.TEST_ORIGINAL_PARAGRAPH_2)

  with self.subTest(f'test {self.TEST_ORIGINAL_PARAGRAPH_2} to ' +
                    f'have between {low} and {high} same words.'):
    self.assertFalse(
        instruction.check_following(self.TEST_REPHRASED_PARAGRAPH_4))
test_response_language(response, language)

Test on single language response.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@parameterized.named_parameters(
    [
        {
            'testcase_name': (
                f'_response={response}_language={language}'
            ),
            'response': response,
            'language': language,
        }
        for response, language in [('The response is English', 'en')]
    ]
)
def test_response_language(self, response, language):
  """Test on single language response."""
  instruction_id = 'language:response_language'
  instruction = instructions.ResponseLanguageChecker(instruction_id)
  instruction.build_description(language=language)
  self.assertTrue(instruction.check_following(response))
test_response_multilanguage(response, language)

Test on responses that contain multi-language tokens.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@parameterized.named_parameters(
    [
        {
            'testcase_name': (
                f'_response={response}_language={language}'
            ),
            'response': response,
            'language': language,
        }
        for response, language in [("Desayunamos en McDonald's hoy", 'es'),
                                   ('Today we visit the Louvre', 'en'),]
    ]
)
def test_response_multilanguage(self, response, language):
  """Test on responses that contain multi-language tokens."""
  instruction_id = 'language:response_language'
  instruction = instructions.ResponseLanguageChecker(instruction_id)
  instruction.build_description(language=language)
  self.assertTrue(instruction.check_following(response))
test_section_checker()

Test the number of sections.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def test_section_checker(self):
  """Test the number of sections."""
  instruction_id = 'detectable_format:multiple_sections'
  instruction = instructions.SectionChecker(instruction_id)
  section_keyword = 'Section'
  min_num_sections = 3
  instruction.build_description(section_spliter=section_keyword,
                                num_sections=min_num_sections)
  with self.subTest(f'test {section_keyword} and {min_num_sections}'):
    self.assertFalse(
        instruction.check_following(self.SECTION_TEST_MESSAGE_1))

  section_keyword = 'SECTION'
  min_num_sections = 2
  instruction.build_description(section_spliter=section_keyword,
                                num_sections=min_num_sections)
  with self.subTest(f'test {section_keyword} and {min_num_sections}'):
    self.assertTrue(
        instruction.check_following(self.SECTION_TEST_MESSAGE_2))
test_title_checker()

Check the prompt for a title.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
def test_title_checker(self):
  """Check the prompt for a title."""
  instruction_id = 'detectable_format:title'
  instruction = instructions.TitleChecker(instruction_id)
  instruction.build_description()
  with self.subTest(f'test {self.TEST_TITLE_MESSAGE_1}'):
    self.assertTrue(instruction.check_following(self.TEST_TITLE_MESSAGE_1))
  with self.subTest(f'test {self.TEST_TITLE_MESSAGE_2}'):
    self.assertTrue(instruction.check_following(self.TEST_TITLE_MESSAGE_2))

  with self.subTest(f'test {self.TEST_TITLE_MESSAGE_3}'):
    self.assertFalse(instruction.check_following(self.TEST_TITLE_MESSAGE_3))
  with self.subTest(f'test {self.TEST_TITLE_MESSAGE_4}'):
    self.assertFalse(instruction.check_following(self.TEST_TITLE_MESSAGE_4))
test_two_responses()

Test that two responses are given.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_test.py
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
def test_two_responses(self):
  """Test that two responses are given."""
  instruction_id = 'combination:two_responses'
  instruction = instructions.TwoResponsesChecker(instruction_id)
  instruction.build_description()

  with self.subTest(f'test {self.TEST_TWO_RESPONSES_1}'):
    self.assertTrue(instruction.check_following(self.TEST_TWO_RESPONSES_1))

  with self.subTest(f'test {self.TEST_TWO_RESPONSES_2}'):
    self.assertFalse(instruction.check_following(self.TEST_TWO_RESPONSES_2))

  with self.subTest(f'test {self.TEST_TWO_RESPONSES_3}'):
    self.assertFalse(instruction.check_following(self.TEST_TWO_RESPONSES_3))

  with self.subTest(f'test {self.TEST_TWO_RESPONSES_4}'):
    self.assertFalse(instruction.check_following(self.TEST_TWO_RESPONSES_4))

  with self.subTest(f'test {self.TEST_TWO_RESPONSES_5}'):
    self.assertTrue(instruction.check_following(self.TEST_TWO_RESPONSES_5))
instructions_util

Utility library of instructions.

LANGUAGE_CODES = immutabledict.immutabledict({'en': 'English', 'es': 'Spanish', 'pt': 'Portuguese', 'ar': 'Arabic', 'hi': 'Hindi', 'fr': 'French', 'ru': 'Russian', 'de': 'German', 'ja': 'Japanese', 'it': 'Italian', 'bn': 'Bengali', 'uk': 'Ukrainian', 'th': 'Thai', 'ur': 'Urdu', 'ta': 'Tamil', 'te': 'Telugu', 'bg': 'Bulgarian', 'ko': 'Korean', 'pl': 'Polish', 'he': 'Hebrew', 'fa': 'Persian', 'vi': 'Vietnamese', 'ne': 'Nepali', 'sw': 'Swahili', 'kn': 'Kannada', 'mr': 'Marathi', 'gu': 'Gujarati', 'pa': 'Punjabi', 'ml': 'Malayalam', 'fi': 'Finnish'}) module-attribute
WORD_LIST = ['western', 'sentence', 'signal', 'dump', 'spot', 'opposite', 'bottom', 'potato', 'administration', 'working', 'welcome', 'morning', 'good', 'agency', 'primary', 'wish', 'responsibility', 'press', 'problem', 'president', 'steal', 'brush', 'read', 'type', 'beat', 'trainer', 'growth', 'lock', 'bone', 'case', 'equal', 'comfortable', 'region', 'replacement', 'performance', 'mate', 'walk', 'medicine', 'film', 'thing', 'rock', 'tap', 'total', 'competition', 'ease', 'south', 'establishment', 'gather', 'parking', 'world', 'plenty', 'breath', 'claim', 'alcohol', 'trade', 'dear', 'highlight', 'street', 'matter', 'decision', 'mess', 'agreement', 'studio', 'coach', 'assist', 'brain', 'wing', 'style', 'private', 'top', 'brown', 'leg', 'buy', 'procedure', 'method', 'speed', 'high', 'company', 'valuable', 'pie', 'analyst', 'session', 'pattern', 'district', 'pleasure', 'dinner', 'swimming', 'joke', 'order', 'plate', 'department', 'motor', 'cell', 'spend', 'cabinet', 'difference', 'power', 'examination', 'engine', 'horse', 'dimension', 'pay', 'toe', 'curve', 'literature', 'bother', 'fire', 'possibility', 'debate', 'activity', 'passage', 'hello', 'cycle', 'background', 'quiet', 'author', 'effect', 'actor', 'page', 'bicycle', 'error', 'throat', 'attack', 'character', 'phone', 'tea', 'increase', 'outcome', 'file', 'specific', 'inspector', 'internal', 'potential', 'staff', 'building', 'employer', 'shoe', 'hand', 'direction', 'garden', 'purchase', 'interview', 'study', 'recognition', 'member', 'spiritual', 'oven', 'sandwich', 'weird', 'passenger', 'particular', 'response', 'reaction', 'size', 'variation', 'a', 'cancel', 'candy', 'exit', 'guest', 'condition', 'fly', 'price', 'weakness', 'convert', 'hotel', 'great', 'mouth', 'mind', 'song', 'sugar', 'suspect', 'telephone', 'ear', 'roof', 'paint', 'refrigerator', 'organization', 'jury', 'reward', 'engineering', 'day', 'possession', 'crew', 'bar', 'road', 'description', 'celebration', 'score', 'mark', 'letter', 'shower', 'suggestion', 'sir', 'luck', 'national', 'progress', 'hall', 'stroke', 'theory', 'offer', 'story', 'tax', 'definition', 'history', 'ride', 'medium', 'opening', 'glass', 'elevator', 'stomach', 'question', 'ability', 'leading', 'village', 'computer', 'city', 'grand', 'confidence', 'candle', 'priest', 'recommendation', 'point', 'necessary', 'body', 'desk', 'secret', 'horror', 'noise', 'culture', 'warning', 'water', 'round', 'diet', 'flower', 'bus', 'tough', 'permission', 'week', 'prompt', 'connection', 'abuse', 'height', 'save', 'corner', 'border', 'stress', 'drive', 'stop', 'rip', 'meal', 'listen', 'confusion', 'girlfriend', 'living', 'relation', 'significance', 'plan', 'creative', 'atmosphere', 'blame', 'invite', 'housing', 'paper', 'drink', 'roll', 'silver', 'drunk', 'age', 'damage', 'smoke', 'environment', 'pack', 'savings', 'influence', 'tourist', 'rain', 'post', 'sign', 'grandmother', 'run', 'profit', 'push', 'clerk', 'final', 'wine', 'swim', 'pause', 'stuff', 'singer', 'funeral', 'average', 'source', 'scene', 'tradition', 'personal', 'snow', 'nobody', 'distance', 'sort', 'sensitive', 'animal', 'major', 'negotiation', 'click', 'mood', 'period', 'arrival', 'expression', 'holiday', 'repeat', 'dust', 'closet', 'gold', 'bad', 'sail', 'combination', 'clothes', 'emphasis', 'duty', 'black', 'step', 'school', 'jump', 'document', 'professional', 'lip', 'chemical', 'front', 'wake', 'while', 'inside', 'watch', 'row', 'subject', 'penalty', 'balance', 'possible', 'adult', 'aside', 'sample', 'appeal', 'wedding', 'depth', 'king', 'award', 'wife', 'blow', 'site', 'camp', 'music', 'safe', 'gift', 'fault', 'guess', 'act', 'shame', 'drama', 'capital', 'exam', 'stupid', 'record', 'sound', 'swing', 'novel', 'minimum', 'ratio', 'machine', 'shape', 'lead', 'operation', 'salary', 'cloud', 'affair', 'hit', 'chapter', 'stage', 'quantity', 'access', 'army', 'chain', 'traffic', 'kick', 'analysis', 'airport', 'time', 'vacation', 'philosophy', 'ball', 'chest', 'thanks', 'place', 'mountain', 'advertising', 'red', 'past', 'rent', 'return', 'tour', 'house', 'construction', 'net', 'native', 'war', 'figure', 'fee', 'spray', 'user', 'dirt', 'shot', 'task', 'stick', 'friend', 'software', 'promotion', 'interaction', 'surround', 'block', 'purpose', 'practice', 'conflict', 'routine', 'requirement', 'bonus', 'hole', 'state', 'junior', 'sweet', 'catch', 'tear', 'fold', 'wall', 'editor', 'life', 'position', 'pound', 'respect', 'bathroom', 'coat', 'script', 'job', 'teach', 'birth', 'view', 'resolve', 'theme', 'employee', 'doubt', 'market', 'education', 'serve', 'recover', 'tone', 'harm', 'miss', 'union', 'understanding', 'cow', 'river', 'association', 'concept', 'training', 'recipe', 'relationship', 'reserve', 'depression', 'proof', 'hair', 'revenue', 'independent', 'lift', 'assignment', 'temporary', 'amount', 'loss', 'edge', 'track', 'check', 'rope', 'estimate', 'pollution', 'stable', 'message', 'delivery', 'perspective', 'mirror', 'assistant', 'representative', 'witness', 'nature', 'judge', 'fruit', 'tip', 'devil', 'town', 'emergency', 'upper', 'drop', 'stay', 'human', 'neck', 'speaker', 'network', 'sing', 'resist', 'league', 'trip', 'signature', 'lawyer', 'importance', 'gas', 'choice', 'engineer', 'success', 'part', 'external', 'worker', 'simple', 'quarter', 'student', 'heart', 'pass', 'spite', 'shift', 'rough', 'lady', 'grass', 'community', 'garage', 'youth', 'standard', 'skirt', 'promise', 'blind', 'television', 'disease', 'commission', 'positive', 'energy', 'calm', 'presence', 'tune', 'basis', 'preference', 'head', 'generic', 'cut', 'somewhere', 'presentation', 'current', 'thought', 'revolution', 'effort', 'master', 'implement', 'republic', 'floor', 'principle', 'stranger', 'shoulder', 'grade', 'button', 'tennis', 'police', 'collection', 'account', 'register', 'glove', 'divide', 'professor', 'chair', 'priority', 'combine', 'peace', 'extension', 'maybe', 'evening', 'frame', 'sister', 'wave', 'code', 'application', 'mouse', 'match', 'counter', 'bottle', 'half', 'cheek', 'resolution', 'back', 'knowledge', 'make', 'discussion', 'screw', 'length', 'accident', 'battle', 'dress', 'knee', 'log', 'package', 'it', 'turn', 'hearing', 'newspaper', 'layer', 'wealth', 'profile', 'imagination', 'answer', 'weekend', 'teacher', 'appearance', 'meet', 'bike', 'rise', 'belt', 'crash', 'bowl', 'equivalent', 'support', 'image', 'poem', 'risk', 'excitement', 'remote', 'secretary', 'public', 'produce', 'plane', 'display', 'money', 'sand', 'situation', 'punch', 'customer', 'title', 'shake', 'mortgage', 'option', 'number', 'pop', 'window', 'extent', 'nothing', 'experience', 'opinion', 'departure', 'dance', 'indication', 'boy', 'material', 'band', 'leader', 'sun', 'beautiful', 'muscle', 'farmer', 'variety', 'fat', 'handle', 'director', 'opportunity', 'calendar', 'outside', 'pace', 'bath', 'fish', 'consequence', 'put', 'owner', 'go', 'doctor', 'information', 'share', 'hurt', 'protection', 'career', 'finance', 'force', 'golf', 'garbage', 'aspect', 'kid', 'food', 'boot', 'milk', 'respond', 'objective', 'reality', 'raw', 'ring', 'mall', 'one', 'impact', 'area', 'news', 'international', 'series', 'impress', 'mother', 'shelter', 'strike', 'loan', 'month', 'seat', 'anything', 'entertainment', 'familiar', 'clue', 'year', 'glad', 'supermarket', 'natural', 'god', 'cost', 'conversation', 'tie', 'ruin', 'comfort', 'earth', 'storm', 'percentage', 'assistance', 'budget', 'strength', 'beginning', 'sleep', 'other', 'young', 'unit', 'fill', 'store', 'desire', 'hide', 'value', 'cup', 'maintenance', 'nurse', 'function', 'tower', 'role', 'class', 'camera', 'database', 'panic', 'nation', 'basket', 'ice', 'art', 'spirit', 'chart', 'exchange', 'feedback', 'statement', 'reputation', 'search', 'hunt', 'exercise', 'nasty', 'notice', 'male', 'yard', 'annual', 'collar', 'date', 'platform', 'plant', 'fortune', 'passion', 'friendship', 'spread', 'cancer', 'ticket', 'attitude', 'island', 'active', 'object', 'service', 'buyer', 'bite', 'card', 'face', 'steak', 'proposal', 'patient', 'heat', 'rule', 'resident', 'broad', 'politics', 'west', 'knife', 'expert', 'girl', 'design', 'salt', 'baseball', 'grab', 'inspection', 'cousin', 'couple', 'magazine', 'cook', 'dependent', 'security', 'chicken', 'version', 'currency', 'ladder', 'scheme', 'kitchen', 'employment', 'local', 'attention', 'manager', 'fact', 'cover', 'sad', 'guard', 'relative', 'county', 'rate', 'lunch', 'program', 'initiative', 'gear', 'bridge', 'breast', 'talk', 'dish', 'guarantee', 'beer', 'vehicle', 'reception', 'woman', 'substance', 'copy', 'lecture', 'advantage', 'park', 'cold', 'death', 'mix', 'hold', 'scale', 'tomorrow', 'blood', 'request', 'green', 'cookie', 'church', 'strip', 'forever', 'beyond', 'debt', 'tackle', 'wash', 'following', 'feel', 'maximum', 'sector', 'sea', 'property', 'economics', 'menu', 'bench', 'try', 'language', 'start', 'call', 'solid', 'address', 'income', 'foot', 'senior', 'honey', 'few', 'mixture', 'cash', 'grocery', 'link', 'map', 'form', 'factor', 'pot', 'model', 'writer', 'farm', 'winter', 'skill', 'anywhere', 'birthday', 'policy', 'release', 'husband', 'lab', 'hurry', 'mail', 'equipment', 'sink', 'pair', 'driver', 'consideration', 'leather', 'skin', 'blue', 'boat', 'sale', 'brick', 'two', 'feed', 'square', 'dot', 'rush', 'dream', 'location', 'afternoon', 'manufacturer', 'control', 'occasion', 'trouble', 'introduction', 'advice', 'bet', 'eat', 'kill', 'category', 'manner', 'office', 'estate', 'pride', 'awareness', 'slip', 'crack', 'client', 'nail', 'shoot', 'membership', 'soft', 'anybody', 'web', 'official', 'individual', 'pizza', 'interest', 'bag', 'spell', 'profession', 'queen', 'deal', 'resource', 'ship', 'guy', 'chocolate', 'joint', 'formal', 'upstairs', 'car', 'resort', 'abroad', 'dealer', 'associate', 'finger', 'surgery', 'comment', 'team', 'detail', 'crazy', 'path', 'tale', 'initial', 'arm', 'radio', 'demand', 'single', 'draw', 'yellow', 'contest', 'piece', 'quote', 'pull', 'commercial', 'shirt', 'contribution', 'cream', 'channel', 'suit', 'discipline', 'instruction', 'concert', 'speech', 'low', 'effective', 'hang', 'scratch', 'industry', 'breakfast', 'lay', 'join', 'metal', 'bedroom', 'minute', 'product', 'rest', 'temperature', 'many', 'give', 'argument', 'print', 'purple', 'laugh', 'health', 'credit', 'investment', 'sell', 'setting', 'lesson', 'egg', 'middle', 'marriage', 'level', 'evidence', 'phrase', 'love', 'self', 'benefit', 'guidance', 'affect', 'you', 'dad', 'anxiety', 'special', 'boyfriend', 'test', 'blank', 'payment', 'soup', 'obligation', 'reply', 'smile', 'deep', 'complaint', 'addition', 'review', 'box', 'towel', 'minor', 'fun', 'soil', 'issue', 'cigarette', 'internet', 'gain', 'tell', 'entry', 'spare', 'incident', 'family', 'refuse', 'branch', 'can', 'pen', 'grandfather', 'constant', 'tank', 'uncle', 'climate', 'ground', 'volume', 'communication', 'kind', 'poet', 'child', 'screen', 'mine', 'quit', 'gene', 'lack', 'charity', 'memory', 'tooth', 'fear', 'mention', 'marketing', 'reveal', 'reason', 'court', 'season', 'freedom', 'land', 'sport', 'audience', 'classroom', 'law', 'hook', 'win', 'carry', 'eye', 'smell', 'distribution', 'research', 'country', 'dare', 'hope', 'whereas', 'stretch', 'library', 'if', 'delay', 'college', 'plastic', 'book', 'present', 'use', 'worry', 'champion', 'goal', 'economy', 'march', 'election', 'reflection', 'midnight', 'slide', 'inflation', 'action', 'challenge', 'guitar', 'coast', 'apple', 'campaign', 'field', 'jacket', 'sense', 'way', 'visual', 'remove', 'weather', 'trash', 'cable', 'regret', 'buddy', 'beach', 'historian', 'courage', 'sympathy', 'truck', 'tension', 'permit', 'nose', 'bed', 'son', 'person', 'base', 'meat', 'usual', 'air', 'meeting', 'worth', 'game', 'independence', 'physical', 'brief', 'play', 'raise', 'board', 'she', 'key', 'writing', 'pick', 'command', 'party', 'yesterday', 'spring', 'candidate', 'physics', 'university', 'concern', 'development', 'change', 'string', 'target', 'instance', 'room', 'bitter', 'bird', 'football', 'normal', 'split', 'impression', 'wood', 'long', 'meaning', 'stock', 'cap', 'leadership', 'media', 'ambition', 'fishing', 'essay', 'salad', 'repair', 'today', 'designer', 'night', 'bank', 'drawing', 'inevitable', 'phase', 'vast', 'chip', 'anger', 'switch', 'cry', 'twist', 'personality', 'attempt', 'storage', 'being', 'preparation', 'bat', 'selection', 'white', 'technology', 'contract', 'side', 'section', 'station', 'till', 'structure', 'tongue', 'taste', 'truth', 'difficulty', 'group', 'limit', 'main', 'move', 'feeling', 'light', 'example', 'mission', 'might', 'wait', 'wheel', 'shop', 'host', 'classic', 'alternative', 'cause', 'agent', 'consist', 'table', 'airline', 'text', 'pool', 'craft', 'range', 'fuel', 'tool', 'partner', 'load', 'entrance', 'deposit', 'hate', 'article', 'video', 'summer', 'feature', 'extreme', 'mobile', 'hospital', 'flight', 'fall', 'pension', 'piano', 'fail', 'result', 'rub', 'gap', 'system', 'report', 'suck', 'ordinary', 'wind', 'nerve', 'ask', 'shine', 'note', 'line', 'mom', 'perception', 'brother', 'reference', 'bend', 'charge', 'treat', 'trick', 'term', 'homework', 'bake', 'bid', 'status', 'project', 'strategy', 'orange', 'let', 'enthusiasm', 'parent', 'concentrate', 'device', 'travel', 'poetry', 'business', 'society', 'kiss', 'end', 'vegetable', 'employ', 'schedule', 'hour', 'brave', 'focus', 'process', 'movie', 'illegal', 'general', 'coffee', 'ad', 'highway', 'chemistry', 'psychology', 'hire', 'bell', 'conference', 'relief', 'show', 'neat', 'funny', 'weight', 'quality', 'club', 'daughter', 'zone', 'touch', 'tonight', 'shock', 'burn', 'excuse', 'name', 'survey', 'landscape', 'advance', 'satisfaction', 'bread', 'disaster', 'item', 'hat', 'prior', 'shopping', 'visit', 'east', 'photo', 'home', 'idea', 'father', 'comparison', 'cat', 'pipe', 'winner', 'count', 'lake', 'fight', 'prize', 'foundation', 'dog', 'keep', 'ideal', 'fan', 'struggle', 'peak', 'safety', 'solution', 'hell', 'conclusion', 'population', 'strain', 'alarm', 'measurement', 'second', 'train', 'race', 'due', 'insurance', 'boss', 'tree', 'monitor', 'sick', 'course', 'drag', 'appointment', 'slice', 'still', 'care', 'patience', 'rich', 'escape', 'emotion', 'royal', 'female', 'childhood', 'government', 'picture', 'will', 'sock', 'big', 'gate', 'oil', 'cross', 'pin', 'improvement', 'championship', 'silly', 'help', 'sky', 'pitch', 'man', 'diamond', 'most', 'transition', 'work', 'science', 'committee', 'moment', 'fix', 'teaching', 'dig', 'specialist', 'complex', 'guide', 'people', 'dead', 'voice', 'original', 'break', 'topic', 'data', 'degree', 'reading', 'recording', 'bunch', 'reach', 'judgment', 'lie', 'regular', 'set', 'painting', 'mode', 'list', 'player', 'bear', 'north', 'wonder', 'carpet', 'heavy', 'officer', 'negative', 'clock', 'unique', 'baby', 'pain', 'assumption', 'disk', 'iron', 'bill', 'drawer', 'look', 'double', 'mistake', 'finish', 'future', 'brilliant', 'contact', 'math', 'rice', 'leave', 'restaurant', 'discount', 'sex', 'virus', 'bit', 'trust', 'event', 'wear', 'juice', 'failure', 'bug', 'context', 'mud', 'whole', 'wrap', 'intention', 'draft', 'pressure', 'cake', 'dark', 'explanation', 'space', 'angle', 'word', 'efficiency', 'management', 'habit', 'star', 'chance', 'finding', 'transportation', 'stand', 'criticism', 'flow', 'door', 'injury', 'insect', 'surprise', 'apartment'] module-attribute
count_sentences(text)

Count the number of sentences.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util.py
138
139
140
141
142
def count_sentences(text):
  """Count the number of sentences."""
  tokenizer = _get_sentence_tokenizer()
  tokenized_sentences = tokenizer.tokenize(text)
  return len(tokenized_sentences)
count_words(text)

Counts the number of words.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util.py
125
126
127
128
129
130
def count_words(text):
  """Counts the number of words."""
  tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+")
  tokens = tokenizer.tokenize(text)
  num_words = len(tokens)
  return num_words
generate_keywords(num_keywords)

Randomly generates a few keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util.py
145
146
147
def generate_keywords(num_keywords):
  """Randomly generates a few keywords."""
  return random.sample(WORD_LIST, k=num_keywords)
split_into_sentences(text)

Split the text into sentences.

Parameters:

Name Type Description Default
text

A string that consists of more than or equal to one sentences.

required

Returns:

Type Description

A list of strings where each string is a sentence.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def split_into_sentences(text):
  """Split the text into sentences.

  Args:
    text: A string that consists of more than or equal to one sentences.

  Returns:
    A list of strings where each string is a sentence.
  """
  text = " " + text + "  "
  text = text.replace("\n", " ")
  text = re.sub(_PREFIXES, "\\1<prd>", text)
  text = re.sub(_WEBSITES, "<prd>\\1", text)
  text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1<prd>\\2", text)
  text = re.sub(
      _MULTIPLE_DOTS,
      lambda match: "<prd>" * len(match.group(0)) + "<stop>",
      text,
  )
  if "Ph.D" in text:
    text = text.replace("Ph.D.", "Ph<prd>D<prd>")
  text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1<prd> ", text)
  text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1<stop> \\2", text)
  text = re.sub(
      _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]",
      "\\1<prd>\\2<prd>\\3<prd>",
      text,
  )
  text = re.sub(
      _ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1<prd>\\2<prd>", text
  )
  text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1<stop> \\2", text)
  text = re.sub(" " + _SUFFIXES + "[.]", " \\1<prd>", text)
  text = re.sub(" " + _ALPHABETS + "[.]", " \\1<prd>", text)
  if "”" in text:
    text = text.replace(".”", "”.")
  if '"' in text:
    text = text.replace('."', '".')
  if "!" in text:
    text = text.replace('!"', '"!')
  if "?" in text:
    text = text.replace('?"', '"?')
  text = text.replace(".", ".<stop>")
  text = text.replace("?", "?<stop>")
  text = text.replace("!", "!<stop>")
  text = text.replace("<prd>", ".")
  sentences = text.split("<stop>")
  sentences = [s.strip() for s in sentences]
  if sentences and not sentences[-1]:
    sentences = sentences[:-1]
  return sentences
instructions_util_test

Test for utility library of instructions.

InstructionsUtilTest

Bases: TestCase

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util_test.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class InstructionsUtilTest(parameterized.TestCase):

  TEST_WORD_COUNT_CASE_1 = ("word1, word2, word3, word4.", 4)

  TEST_WORD_COUNT_CASE_2 = (
      """
      Bard can you tell me which is the best optimization method for the
      transition from an hydro-thermal system to an hydro-renewables system""",
      24)

  TEST_WORD_COUNT_CASE_3 = (
      """
      Hyphenated-word has two word counts.
      """, 6)

  def test_word_count(self):
    """Tests word counter."""
    with self.subTest(f"{self.TEST_WORD_COUNT_CASE_1[0]}"):
      text, expected_num_words = self.TEST_WORD_COUNT_CASE_1
      actual_num_words = instructions_util.count_words(text)
      self.assertEqual(expected_num_words, actual_num_words)

    with self.subTest(f"{self.TEST_WORD_COUNT_CASE_2[0]}"):
      text, expected_num_words = self.TEST_WORD_COUNT_CASE_2
      actual_num_words = instructions_util.count_words(text)
      self.assertEqual(expected_num_words, actual_num_words)

    with self.subTest(f"{self.TEST_WORD_COUNT_CASE_3[0]}"):
      text, expected_num_words = self.TEST_WORD_COUNT_CASE_3
      actual_num_words = instructions_util.count_words(text)
      self.assertEqual(expected_num_words, actual_num_words)

  @parameterized.named_parameters(
      [
          {  # pylint: disable=g-complex-comprehension
              "testcase_name": (
                  f"_response={response}_num_sentences={num_sentences}"
              ),
              "response": response,
              "num_sentences": num_sentences,
          }
          for response, num_sentences in [
              ("xx,x. xx,x! xx/x. x{x}x? x.", 5),
              ("xx,x! xxxx. x(x)x?", 3),
              ("xxxx. xx,x! xx|x. x&x x?", 4),
              ("xx-x]xx,x! x{x}xx,x.", 2),
          ]
      ]
  )
  def test_count_sentences(self, response, num_sentences):
    """Tests sentence counter."""
    actual_num_sentences = instructions_util.count_sentences(response)
    self.assertEqual(num_sentences, actual_num_sentences)

  TEST_SENTENCE_SPLIT_1 = """
  Google is a technology company. It was founded in 1998 by Larry Page
and Sergey Brin. Google's mission is to organize the world's information
and make it universally accessible and useful.
  """

  TEST_SENTENCE_SPLIT_2 = """
  The U.S.A has many Ph.D. students. They will often haven a .com website
sharing the research that they have done.
  """

  EXPECTED_SENTENCE_SPLIT_1 = [
      "Google is a technology company.",
      "It was founded in 1998 by Larry Page and Sergey Brin.",
      (
          "Google's mission is to organize the world's information and make it"
          " universally accessible and useful."
      ),
  ]

  EXPECTED_SENTENCE_SPLIT_2 = [
      "The U.S.A has many Ph.D. students.",
      (
          "They will often haven a .com website sharing the research that they"
          " have done."
      ),
  ]

  def test_sentence_splitter(self):
    """Tests sentence splitter."""
    sentence_split_1 = instructions_util.split_into_sentences(
        self.TEST_SENTENCE_SPLIT_1
    )
    sentence_split_2 = instructions_util.split_into_sentences(
        self.TEST_SENTENCE_SPLIT_2
    )

    self.assertEqual(self.EXPECTED_SENTENCE_SPLIT_1, sentence_split_1)
    self.assertEqual(self.EXPECTED_SENTENCE_SPLIT_2, sentence_split_2)

  def test_generate_keywords(self):
    """Tests generate keywords."""
    self.assertLen(instructions_util.generate_keywords(10), 10)
EXPECTED_SENTENCE_SPLIT_1 = ['Google is a technology company.', 'It was founded in 1998 by Larry Page and Sergey Brin.', "Google's mission is to organize the world's information and make it universally accessible and useful."] class-attribute instance-attribute
EXPECTED_SENTENCE_SPLIT_2 = ['The U.S.A has many Ph.D. students.', 'They will often haven a .com website sharing the research that they have done.'] class-attribute instance-attribute
TEST_SENTENCE_SPLIT_1 = "\n Google is a technology company. It was founded in 1998 by Larry Page\nand Sergey Brin. Google's mission is to organize the world's information\nand make it universally accessible and useful.\n " class-attribute instance-attribute
TEST_SENTENCE_SPLIT_2 = '\n The U.S.A has many Ph.D. students. They will often haven a .com website\nsharing the research that they have done.\n ' class-attribute instance-attribute
TEST_WORD_COUNT_CASE_1 = ('word1, word2, word3, word4.', 4) class-attribute instance-attribute
TEST_WORD_COUNT_CASE_2 = ('\n Bard can you tell me which is the best optimization method for the\n transition from an hydro-thermal system to an hydro-renewables system', 24) class-attribute instance-attribute
TEST_WORD_COUNT_CASE_3 = ('\n Hyphenated-word has two word counts.\n ', 6) class-attribute instance-attribute
test_count_sentences(response, num_sentences)

Tests sentence counter.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util_test.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@parameterized.named_parameters(
    [
        {  # pylint: disable=g-complex-comprehension
            "testcase_name": (
                f"_response={response}_num_sentences={num_sentences}"
            ),
            "response": response,
            "num_sentences": num_sentences,
        }
        for response, num_sentences in [
            ("xx,x. xx,x! xx/x. x{x}x? x.", 5),
            ("xx,x! xxxx. x(x)x?", 3),
            ("xxxx. xx,x! xx|x. x&x x?", 4),
            ("xx-x]xx,x! x{x}xx,x.", 2),
        ]
    ]
)
def test_count_sentences(self, response, num_sentences):
  """Tests sentence counter."""
  actual_num_sentences = instructions_util.count_sentences(response)
  self.assertEqual(num_sentences, actual_num_sentences)
test_generate_keywords()

Tests generate keywords.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util_test.py
116
117
118
def test_generate_keywords(self):
  """Tests generate keywords."""
  self.assertLen(instructions_util.generate_keywords(10), 10)
test_sentence_splitter()

Tests sentence splitter.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util_test.py
104
105
106
107
108
109
110
111
112
113
114
def test_sentence_splitter(self):
  """Tests sentence splitter."""
  sentence_split_1 = instructions_util.split_into_sentences(
      self.TEST_SENTENCE_SPLIT_1
  )
  sentence_split_2 = instructions_util.split_into_sentences(
      self.TEST_SENTENCE_SPLIT_2
  )

  self.assertEqual(self.EXPECTED_SENTENCE_SPLIT_1, sentence_split_1)
  self.assertEqual(self.EXPECTED_SENTENCE_SPLIT_2, sentence_split_2)
test_word_count()

Tests word counter.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/helpers/instructions_util_test.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def test_word_count(self):
  """Tests word counter."""
  with self.subTest(f"{self.TEST_WORD_COUNT_CASE_1[0]}"):
    text, expected_num_words = self.TEST_WORD_COUNT_CASE_1
    actual_num_words = instructions_util.count_words(text)
    self.assertEqual(expected_num_words, actual_num_words)

  with self.subTest(f"{self.TEST_WORD_COUNT_CASE_2[0]}"):
    text, expected_num_words = self.TEST_WORD_COUNT_CASE_2
    actual_num_words = instructions_util.count_words(text)
    self.assertEqual(expected_num_words, actual_num_words)

  with self.subTest(f"{self.TEST_WORD_COUNT_CASE_3[0]}"):
    text, expected_num_words = self.TEST_WORD_COUNT_CASE_3
    actual_num_words = instructions_util.count_words(text)
    self.assertEqual(expected_num_words, actual_num_words)
strict_instruction
StrictInstruction

Bases: Metric

Evaluation wrapper around IFEval's official implementation from Google Research (https://github.com/google-research/google-research/tree/master/instruction_following_eval). Measures how well models follow explicit instructions embedded within prompts, using strict binary evaluation criteria.

Source code in aisteer360/evaluation/metrics/custom/instruction_following/strict_instruction.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class StrictInstruction(Metric):
    """
    Evaluation wrapper around IFEval's official implementation from Google Research ([https://github.com/google-research/google-research/tree/master/instruction_following_eval](https://github.com/google-research/google-research/tree/master/instruction_following_eval)).
    Measures how well models follow explicit instructions embedded within prompts, using strict binary evaluation criteria.
    """

    def _fix_kwargs(self, kwargs_list):
        """
        Fix kwargs list by removing None values and converting
        all-None dicts back to empty dicts
        """
        fixed_kwargs = []
        for kwarg_dict in kwargs_list:
            cleaned = {k: v for k, v in kwarg_dict.items() if v is not None}
            fixed_kwargs.append(cleaned)

        return fixed_kwargs

    def compute(
        self,
        responses: list[dict] | None = None,
        prompts: list[str] | None = None,
        **kwargs,
    ) -> dict[str, Any]:
        """Computes strict instruction-following metrics using IFEval evaluation.

        Evaluates model responses against structured instructions using the official IFEval framework. Each response is
        assessed both at the prompt level (whether ALL instructions were followed) and at the individual instruction
        level.

        Args:
            responses: List of response dictionaries, each containing:

                - "prompt": The input prompt with embedded instructions
                - "response": The model's generated response
                - "instruction_id_list": List of instruction IDs to evaluate
                - "kwargs": Additional parameters for instruction evaluation
            prompts: List of question prompts (unused, for interface compatibility).
            **kwargs: Additional arguments (unused).

        Returns:
            Dictionary of instruction-following metrics with values:

                - "strict_prompt_accuracy": Proportion of prompts where all instructions were followed correctly
                  (prompt-level accuracy)
                - "strict_instruction_accuracy": Proportion of individual instructions followed correctly across all
                  prompts (instruction-level accuracy)
                - "follow_all_instructions": List of boolean values indicating whether each prompt had all instructions
                  followed

        Note:

        - Returns zero accuracies and empty list if responses is None or empty.
        - The evaluation uses strict binary criteria (partial compliance counts as failure).
        """
        total_prompts = len(responses) if responses is not None else 0
        correct_prompts = 0
        total_instructions = 0
        correct_instructions = 0
        follow_all_instructions = []

        if responses is not None:
            for instance in responses:
                instance["instruction_id_list"] = instance["instruction_id_list"]
                instance["kwargs"] = self._fix_kwargs(instance["kwargs"])
                prompt = instance["prompt"]
                response = instance["response"]
                # test_instruction_following_strict expects an input with fields:
                # prompt, instruction_id_list, kwargs
                output_example = test_instruction_following_strict(
                    instance, {prompt: response}
                )

                # if all instructions followed
                if output_example.follow_all_instructions:
                    correct_prompts += 1
                    follow_all_instructions.append(True)
                else:
                    follow_all_instructions.append(False)

                num_instructions = len(output_example.follow_instruction_list)
                total_instructions += num_instructions
                correct_instructions += sum(output_example.follow_instruction_list)

        strict_prompt_accuracy = (
            correct_prompts / total_prompts if total_prompts > 0 else 0.0
        )
        strict_instruction_accuracy = (
            correct_instructions / total_instructions if total_instructions > 0 else 0.0
        )

        return {
            "strict_prompt_accuracy": strict_prompt_accuracy,
            "strict_instruction_accuracy": strict_instruction_accuracy,
            "follow_all_instructions": follow_all_instructions,
        }
extras = extras instance-attribute
name = self.__class__.__name__ instance-attribute
compute(responses=None, prompts=None, **kwargs)

Computes strict instruction-following metrics using IFEval evaluation.

Evaluates model responses against structured instructions using the official IFEval framework. Each response is assessed both at the prompt level (whether ALL instructions were followed) and at the individual instruction level.

Parameters:

Name Type Description Default
responses list[dict] | None

List of response dictionaries, each containing:

  • "prompt": The input prompt with embedded instructions
  • "response": The model's generated response
  • "instruction_id_list": List of instruction IDs to evaluate
  • "kwargs": Additional parameters for instruction evaluation
None
prompts list[str] | None

List of question prompts (unused, for interface compatibility).

None
**kwargs

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, Any]

Dictionary of instruction-following metrics with values:

  • "strict_prompt_accuracy": Proportion of prompts where all instructions were followed correctly (prompt-level accuracy)
  • "strict_instruction_accuracy": Proportion of individual instructions followed correctly across all prompts (instruction-level accuracy)
  • "follow_all_instructions": List of boolean values indicating whether each prompt had all instructions followed

Note:

  • Returns zero accuracies and empty list if responses is None or empty.
  • The evaluation uses strict binary criteria (partial compliance counts as failure).
Source code in aisteer360/evaluation/metrics/custom/instruction_following/strict_instruction.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def compute(
    self,
    responses: list[dict] | None = None,
    prompts: list[str] | None = None,
    **kwargs,
) -> dict[str, Any]:
    """Computes strict instruction-following metrics using IFEval evaluation.

    Evaluates model responses against structured instructions using the official IFEval framework. Each response is
    assessed both at the prompt level (whether ALL instructions were followed) and at the individual instruction
    level.

    Args:
        responses: List of response dictionaries, each containing:

            - "prompt": The input prompt with embedded instructions
            - "response": The model's generated response
            - "instruction_id_list": List of instruction IDs to evaluate
            - "kwargs": Additional parameters for instruction evaluation
        prompts: List of question prompts (unused, for interface compatibility).
        **kwargs: Additional arguments (unused).

    Returns:
        Dictionary of instruction-following metrics with values:

            - "strict_prompt_accuracy": Proportion of prompts where all instructions were followed correctly
              (prompt-level accuracy)
            - "strict_instruction_accuracy": Proportion of individual instructions followed correctly across all
              prompts (instruction-level accuracy)
            - "follow_all_instructions": List of boolean values indicating whether each prompt had all instructions
              followed

    Note:

    - Returns zero accuracies and empty list if responses is None or empty.
    - The evaluation uses strict binary criteria (partial compliance counts as failure).
    """
    total_prompts = len(responses) if responses is not None else 0
    correct_prompts = 0
    total_instructions = 0
    correct_instructions = 0
    follow_all_instructions = []

    if responses is not None:
        for instance in responses:
            instance["instruction_id_list"] = instance["instruction_id_list"]
            instance["kwargs"] = self._fix_kwargs(instance["kwargs"])
            prompt = instance["prompt"]
            response = instance["response"]
            # test_instruction_following_strict expects an input with fields:
            # prompt, instruction_id_list, kwargs
            output_example = test_instruction_following_strict(
                instance, {prompt: response}
            )

            # if all instructions followed
            if output_example.follow_all_instructions:
                correct_prompts += 1
                follow_all_instructions.append(True)
            else:
                follow_all_instructions.append(False)

            num_instructions = len(output_example.follow_instruction_list)
            total_instructions += num_instructions
            correct_instructions += sum(output_example.follow_instruction_list)

    strict_prompt_accuracy = (
        correct_prompts / total_prompts if total_prompts > 0 else 0.0
    )
    strict_instruction_accuracy = (
        correct_instructions / total_instructions if total_instructions > 0 else 0.0
    )

    return {
        "strict_prompt_accuracy": strict_prompt_accuracy,
        "strict_instruction_accuracy": strict_instruction_accuracy,
        "follow_all_instructions": follow_all_instructions,
    }
generic

Generic evaluation metrics.

This module contains metrics that can be used for evaluating model outputs regardless of the specific task or domain (e.g., relevance, factuality, etc.).

factuality
Factuality

Bases: LLMJudgeMetric

Judge factual correctness of a response to a prompt.

Source code in aisteer360/evaluation/metrics/generic/factuality.py
19
20
21
22
23
24
25
26
27
28
29
30
class Factuality(LLMJudgeMetric):
    """
    Judge factual correctness of a response to a prompt.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            prompt_template=_PROMPT,
            scale=(1, 5),
            **kwargs,
        )
base_prompt_template = prompt_template.strip() instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
format_instructions = self.output_parser.get_format_instructions() instance-attribute
max_retries = max_retries instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
num_return_sequences = int(gen_kwargs.pop('num_return_sequences', 1)) instance-attribute
pipeline = TextGenerationPipeline(model=(self.model), tokenizer=(self.tokenizer)) instance-attribute
scale = scale instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
use_chat = hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None instance-attribute
compute(responses, prompts=None, **kwargs)

Compute LLM judge scores for a list of responses.

Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple samples are generated per response (via num_return_sequences).

Parameters:

Name Type Description Default
responses list[str]

List of text responses to evaluate.

required
prompts list[str] | None

Optional list of prompts corresponding to each response. If provided, must be the same length as responses. These prompts can be referenced in the prompt_template using the {prompt} placeholder.

None
**kwargs Any

Additional keyword arguments (currently unused).

{}

Returns:

Type Description
dict[str, float | list[float]]

Score statistics containing:

  • "mean_score": Overall average score across all responses
  • "scores": List of mean scores for each response (averaged across samples)
  • "raw_scores": List of lists containing all individual scores for each response

Raises:

Type Description
AssertionError

If prompts is provided but has different length than responses.

Source code in aisteer360/evaluation/metrics/base_judge.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@torch.inference_mode()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, float | list[float]]:
    """Compute LLM judge scores for a list of responses.

    Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
    samples are generated per response (via `num_return_sequences`).

    Args:
        responses (list[str]): List of text responses to evaluate.
        prompts (list[str] | None): Optional list of prompts corresponding to each response.
            If provided, must be the same length as responses. These prompts can be
            referenced in the prompt_template using the {prompt} placeholder.
        **kwargs: Additional keyword arguments (currently unused).

    Returns:
        Score statistics containing:

            - "mean_score": Overall average score across all responses
            - "scores": List of mean scores for each response (averaged across samples)
            - "raw_scores": List of lists containing all individual scores for each response

    Raises:
        AssertionError: If prompts is provided but has different length than responses.
    """

    if prompts is not None and len(prompts) != len(responses):
        raise AssertionError("`responses` and `prompts` must be the same length")

    # build prompts
    prompts_list: list[str] = []
    for i in range(len(responses)):
        fields: dict[str, str | float] = {
            "response": responses[i],
            "lower_bound": self.scale[0],
            "upper_bound": self.scale[1],
        }
        if prompts is not None:
            fields["prompt"] = prompts[i]

        prompt_core = self.base_prompt_template.format(**fields)
        prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
        prompts_list.append(prompt_formatted)

    # generate
    prompt_scores: list[list[float]] = []
    for batch in self._batch_chunks(prompts_list, self.batch_size):
        outputs = self.pipeline(
            batch,
            num_return_sequences=self.num_return_sequences,
            return_full_text=False,
            clean_up_tokenization_spaces=True,
        )

        for prompt, generations in zip(batch, outputs):
            generations = generations if isinstance(generations, list) else [generations]
            assert len(generations) == self.num_return_sequences

            scores = []
            for generation in generations:
                reply_text = generation["generated_text"]
                try:
                    score = self.parse_fn(reply_text, self.scale)
                except Exception:
                    score = self._score_with_retries(prompt)
                scores.append(score)

            prompt_scores.append(scores)

    mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
    corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

    return {
        "mean_score": corpus_mean,  # overall average
        "scores": mean_per_prompt,  # one number per original prompt
        "raw_scores": prompt_scores  # n_samples scores per prompt
    }
perplexity
Perplexity

Bases: Metric

Compute token-level perplexity for a batch of sentences.

Perplexity is the exponentiated mean cross-entropy between the language model’s predicted distribution and the true next token. Lower is better.

Parameters:

Name Type Description Default
model_or_id str | Module

Hugging Face model ID or an already-instantiated causal language model.

required
tokenizer PreTrainedTokenizer | None

Tokenizer to use. Leave None when passing a model ID to automatically load the matching tokenizer. Defaults to None.

None
batch_size int

Number of sentences per forward pass. Higher is faster until GPU memory becomes the bottleneck. Defaults to 16.

16
add_bos bool

Whether to prepend the tokenizer’s BOS token so the first word in each sentence is also scored. Ignored if the tokenizer has no BOS token. Defaults to True.

True
max_length int | None

If set, truncate inputs to this length so they fit the model’s context window. None disables truncation. Defaults to None.

None
device str | None

"cuda" or "cpu". When None, automatically selects GPU if available. Defaults to None.

None

Attributes:

Name Type Description
add_bos bool

Whether a BOS token is prepended before scoring.

batch_size int

Number of sentences processed per forward pass.

device str

The device actually selected for computation ("cuda" or "cpu").

max_length int | None

Truncation length for inputs, or None for no truncation.

model PreTrainedModel

The loaded causal language model used to score tokens.

tokenizer PreTrainedTokenizer

Tokenizer used for encoding, padding, and BOS handling.

Source code in aisteer360/evaluation/metrics/generic/perplexity.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Perplexity(Metric):
    """Compute token-level perplexity for a batch of sentences.

    Perplexity is the exponentiated mean cross-entropy between the language model’s predicted distribution and the true
    next token. Lower is better.

    Args:
        model_or_id (str | torch.nn.Module): Hugging Face model ID or an already-instantiated causal language model.
        tokenizer (transformers.PreTrainedTokenizer | None, optional):
            Tokenizer to use.  Leave ``None`` when passing a model ID to automatically load the matching tokenizer.
            Defaults to ``None``.
        batch_size (int, optional): Number of sentences per forward pass. Higher is faster until GPU memory becomes the
            bottleneck. Defaults to ``16``.
        add_bos (bool, optional): Whether to prepend the tokenizer’s BOS token so the first word in each sentence is
            also scored. Ignored if the tokenizer has no BOS token. Defaults to ``True``.
        max_length (int | None, optional): If set, truncate inputs to this length so they fit the model’s context
            window. ``None`` disables truncation. Defaults to ``None``.
        device (str | None, optional): ``"cuda"`` or ``"cpu"``. When ``None``, automatically selects GPU if available.
            Defaults to ``None``.

    Attributes:
        add_bos (bool): Whether a BOS token is prepended before scoring.
        batch_size (int): Number of sentences processed per forward pass.
        device (str): The device actually selected for computation (``"cuda"`` or ``"cpu"``).
        max_length (int | None): Truncation length for inputs, or ``None`` for no truncation.
        model (transformers.PreTrainedModel): The loaded causal language model used to score tokens.
        tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for encoding, padding, and BOS handling.
    """

    def __init__(
        self,
        model_or_id: str | torch.nn.Module,
        tokenizer: Any | None = None,
        batch_size: int = 16,
        add_bos: bool = True,
        max_length: int | None = None,
        device: str | None = None,
    ):
        super().__init__()

        if isinstance(model_or_id, str):
            self.model = AutoModelForCausalLM.from_pretrained(model_or_id)
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id)
        else:  # model object
            self.model = model_or_id
            self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id.config._name_or_path)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device).eval()
        self.batch_size = batch_size
        self.add_bos = add_bos and (self.tokenizer.bos_token_id is not None)
        self.max_length = max_length

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = (
                self.tokenizer.eos_token
                or self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            )

    @torch.no_grad()
    def compute(
        self,
        responses: list[str],
        prompts: list[str] | None = None,
    ) -> dict[str, float]:
        """Compute perplexity for each response (and the mean across the batch).

        Args:
            responses (list[str]): Text sequences to score.
            prompts (list[str] | None, optional): Unused here; present for a uniform metric API.

        Returns:
            dict[str, float]: A dict with keys:

                - ``"mean_perplexity"``: mean perplexity over all inputs.
                - ``"perplexities"``: list of per-sample perplexities in input order.
        """
        perplexities: list[float] = []
        local_batch_size = self.batch_size

        for i in range(0, len(responses), local_batch_size):
            batch = responses[i : i + local_batch_size]

            encoding = self.tokenizer(
                batch,
                padding=True,
                truncation=self.max_length is not None,
                max_length=self.max_length,
                add_special_tokens=False,
                return_tensors="pt",
            ).to(self.device)
            input_ids = encoding["input_ids"]

            if self.add_bos:
                bos_tokens = torch.full(
                    (input_ids.size(0), 1),
                    self.tokenizer.bos_token_id,
                    device=self.device,
                )
                input_ids = torch.cat([bos_tokens, input_ids], dim=1)

            logits = self.model(input_ids).logits[:, :-1]
            labels = input_ids[:, 1:]

            loss_per_token = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                labels.reshape(-1),
                reduction="none",
            ).view(labels.size())

            mask = labels.ne(self.tokenizer.pad_token_id)
            seq_loss = (loss_per_token * mask).sum(1) / mask.sum(1)

            perplexities.extend(torch.exp(seq_loss).cpu().tolist())

        return {
            "mean_perplexity": sum(perplexities) / len(perplexities),
            "perplexities": perplexities,
        }
add_bos = add_bos and self.tokenizer.bos_token_id is not None instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
max_length = max_length instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
compute(responses, prompts=None)

Compute perplexity for each response (and the mean across the batch).

Parameters:

Name Type Description Default
responses list[str]

Text sequences to score.

required
prompts list[str] | None

Unused here; present for a uniform metric API.

None

Returns:

Type Description
dict[str, float]

dict[str, float]: A dict with keys:

  • "mean_perplexity": mean perplexity over all inputs.
  • "perplexities": list of per-sample perplexities in input order.
Source code in aisteer360/evaluation/metrics/generic/perplexity.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@torch.no_grad()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
) -> dict[str, float]:
    """Compute perplexity for each response (and the mean across the batch).

    Args:
        responses (list[str]): Text sequences to score.
        prompts (list[str] | None, optional): Unused here; present for a uniform metric API.

    Returns:
        dict[str, float]: A dict with keys:

            - ``"mean_perplexity"``: mean perplexity over all inputs.
            - ``"perplexities"``: list of per-sample perplexities in input order.
    """
    perplexities: list[float] = []
    local_batch_size = self.batch_size

    for i in range(0, len(responses), local_batch_size):
        batch = responses[i : i + local_batch_size]

        encoding = self.tokenizer(
            batch,
            padding=True,
            truncation=self.max_length is not None,
            max_length=self.max_length,
            add_special_tokens=False,
            return_tensors="pt",
        ).to(self.device)
        input_ids = encoding["input_ids"]

        if self.add_bos:
            bos_tokens = torch.full(
                (input_ids.size(0), 1),
                self.tokenizer.bos_token_id,
                device=self.device,
            )
            input_ids = torch.cat([bos_tokens, input_ids], dim=1)

        logits = self.model(input_ids).logits[:, :-1]
        labels = input_ids[:, 1:]

        loss_per_token = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1),
            reduction="none",
        ).view(labels.size())

        mask = labels.ne(self.tokenizer.pad_token_id)
        seq_loss = (loss_per_token * mask).sum(1) / mask.sum(1)

        perplexities.extend(torch.exp(seq_loss).cpu().tolist())

    return {
        "mean_perplexity": sum(perplexities) / len(perplexities),
        "perplexities": perplexities,
    }
relevance
Relevance

Bases: LLMJudgeMetric

Judge relevance of a response to a prompt.

Source code in aisteer360/evaluation/metrics/generic/relevance.py
19
20
21
22
23
24
25
26
27
28
29
30
class Relevance(LLMJudgeMetric):
    """
    Judge relevance of a response to a prompt.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            prompt_template=_PROMPT,
            scale=(1, 5),
            **kwargs,
        )
base_prompt_template = prompt_template.strip() instance-attribute
batch_size = batch_size instance-attribute
device = device or ('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') instance-attribute
extras = extras instance-attribute
format_instructions = self.output_parser.get_format_instructions() instance-attribute
max_retries = max_retries instance-attribute
model = AutoModelForCausalLM.from_pretrained(model_or_id) instance-attribute
name = self.__class__.__name__ instance-attribute
num_return_sequences = int(gen_kwargs.pop('num_return_sequences', 1)) instance-attribute
pipeline = TextGenerationPipeline(model=(self.model), tokenizer=(self.tokenizer)) instance-attribute
scale = scale instance-attribute
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_or_id) instance-attribute
use_chat = hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None instance-attribute
compute(responses, prompts=None, **kwargs)

Compute LLM judge scores for a list of responses.

Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple samples are generated per response (via num_return_sequences).

Parameters:

Name Type Description Default
responses list[str]

List of text responses to evaluate.

required
prompts list[str] | None

Optional list of prompts corresponding to each response. If provided, must be the same length as responses. These prompts can be referenced in the prompt_template using the {prompt} placeholder.

None
**kwargs Any

Additional keyword arguments (currently unused).

{}

Returns:

Type Description
dict[str, float | list[float]]

Score statistics containing:

  • "mean_score": Overall average score across all responses
  • "scores": List of mean scores for each response (averaged across samples)
  • "raw_scores": List of lists containing all individual scores for each response

Raises:

Type Description
AssertionError

If prompts is provided but has different length than responses.

Source code in aisteer360/evaluation/metrics/base_judge.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@torch.inference_mode()
def compute(
    self,
    responses: list[str],
    prompts: list[str] | None = None,
    **kwargs: Any,
) -> dict[str, float | list[float]]:
    """Compute LLM judge scores for a list of responses.

    Evaluates each response using the configured judge model and prompt template. Scores are averaged when multiple
    samples are generated per response (via `num_return_sequences`).

    Args:
        responses (list[str]): List of text responses to evaluate.
        prompts (list[str] | None): Optional list of prompts corresponding to each response.
            If provided, must be the same length as responses. These prompts can be
            referenced in the prompt_template using the {prompt} placeholder.
        **kwargs: Additional keyword arguments (currently unused).

    Returns:
        Score statistics containing:

            - "mean_score": Overall average score across all responses
            - "scores": List of mean scores for each response (averaged across samples)
            - "raw_scores": List of lists containing all individual scores for each response

    Raises:
        AssertionError: If prompts is provided but has different length than responses.
    """

    if prompts is not None and len(prompts) != len(responses):
        raise AssertionError("`responses` and `prompts` must be the same length")

    # build prompts
    prompts_list: list[str] = []
    for i in range(len(responses)):
        fields: dict[str, str | float] = {
            "response": responses[i],
            "lower_bound": self.scale[0],
            "upper_bound": self.scale[1],
        }
        if prompts is not None:
            fields["prompt"] = prompts[i]

        prompt_core = self.base_prompt_template.format(**fields)
        prompt_formatted = self._wrap(prompt_core + "\n\n" + self.format_instructions)
        prompts_list.append(prompt_formatted)

    # generate
    prompt_scores: list[list[float]] = []
    for batch in self._batch_chunks(prompts_list, self.batch_size):
        outputs = self.pipeline(
            batch,
            num_return_sequences=self.num_return_sequences,
            return_full_text=False,
            clean_up_tokenization_spaces=True,
        )

        for prompt, generations in zip(batch, outputs):
            generations = generations if isinstance(generations, list) else [generations]
            assert len(generations) == self.num_return_sequences

            scores = []
            for generation in generations:
                reply_text = generation["generated_text"]
                try:
                    score = self.parse_fn(reply_text, self.scale)
                except Exception:
                    score = self._score_with_retries(prompt)
                scores.append(score)

            prompt_scores.append(scores)

    mean_per_prompt = [sum(prompt_score) / len(prompt_score) for prompt_score in prompt_scores]
    corpus_mean = sum(mean_per_prompt) / len(mean_per_prompt)

    return {
        "mean_score": corpus_mean,  # overall average
        "scores": mean_per_prompt,  # one number per original prompt
        "raw_scores": prompt_scores  # n_samples scores per prompt
    }

use_cases

base

Base class for all use cases. Provides a framework for loading evaluation data, applying metrics, and running standardized evaluations across different types of tasks. Subclasses must implement the generate() and evaluate() methods to define task-specific evaluation logic.

UseCase

Bases: ABC

Base use case class.

Source code in aisteer360/evaluation/use_cases/base.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class UseCase(ABC):
    """
    Base use case class.
    """
    def __init__(
        self,
        evaluation_data: list[dict] | str | Path,
        evaluation_metrics: list[Metric],
        num_samples: int = -1,
        **kwargs
    ) -> None:

        self.evaluation_data = []
        if isinstance(evaluation_data, Sequence) and all(isinstance(item, Mapping) for item in evaluation_data):
            self.evaluation_data = list(evaluation_data)
        else:
            path = Path(evaluation_data) if isinstance(evaluation_data, str) else evaluation_data
            with open(path) as f:
                self.evaluation_data = [json.loads(line) for line in f] if path.suffix == '.jsonl' else json.load(f)
        if not self.evaluation_data:
            warnings.warn(
                "Either evaluation data was not provided, or was unable to be generated.",
                UserWarning
            )

        if num_samples > 0:
            self.evaluation_data = self.evaluation_data[:num_samples]

        self.evaluation_metrics = evaluation_metrics
        self._metrics_by_name = {metric.name: metric for metric in evaluation_metrics}

        # store kwargs as attributes
        for key, value in kwargs.items():
            setattr(self, key, value)

        # validation
        if not all(isinstance(metric, Metric) for metric in self.evaluation_metrics):
            raise TypeError("All items in `evaluation_metrics` must be of type `Metric`.")

    @abstractmethod
    def generate(
            self,
            model_or_pipeline,
            tokenizer,
            gen_kwargs=None,
            runtime_overrides: dict[tuple[str, str], str] | None = None
    ) -> list[dict[str, Any]]:
        """
        Required generation logic for the current use case.
        """
        raise NotImplementedError

    @abstractmethod
    def evaluate(
            self,
            generations: list[dict[str, Any]]
    ) -> dict[str, dict[str, Any]]:
        """
        Required evaluation logic for model's generations via `evaluation_metrics`.
        """
        raise NotImplementedError

    def export(self,
               profiles: dict[str, dict[str, Any]],
               save_dir: str
    ) -> None:
        """
        Optional formatting and export of evaluation profiles.
        """
        raise NotImplementedError

    # def validate_steering_data(self, steering_data):
    #     pass

    def validate_evaluation_data(self, evaluation_data) -> None:
        """
        Optional validation of the evaluation dataset.
        """
        raise NotImplementedError
evaluation_data = [(json.loads(line)) for line in f] if path.suffix == '.jsonl' else json.load(f) instance-attribute
evaluation_metrics = evaluation_metrics instance-attribute
evaluate(generations) abstractmethod

Required evaluation logic for model's generations via evaluation_metrics.

Source code in aisteer360/evaluation/use_cases/base.py
68
69
70
71
72
73
74
75
76
@abstractmethod
def evaluate(
        self,
        generations: list[dict[str, Any]]
) -> dict[str, dict[str, Any]]:
    """
    Required evaluation logic for model's generations via `evaluation_metrics`.
    """
    raise NotImplementedError
export(profiles, save_dir)

Optional formatting and export of evaluation profiles.

Source code in aisteer360/evaluation/use_cases/base.py
78
79
80
81
82
83
84
85
def export(self,
           profiles: dict[str, dict[str, Any]],
           save_dir: str
) -> None:
    """
    Optional formatting and export of evaluation profiles.
    """
    raise NotImplementedError
generate(model_or_pipeline, tokenizer, gen_kwargs=None, runtime_overrides=None) abstractmethod

Required generation logic for the current use case.

Source code in aisteer360/evaluation/use_cases/base.py
55
56
57
58
59
60
61
62
63
64
65
66
@abstractmethod
def generate(
        self,
        model_or_pipeline,
        tokenizer,
        gen_kwargs=None,
        runtime_overrides: dict[tuple[str, str], str] | None = None
) -> list[dict[str, Any]]:
    """
    Required generation logic for the current use case.
    """
    raise NotImplementedError
validate_evaluation_data(evaluation_data)

Optional validation of the evaluation dataset.

Source code in aisteer360/evaluation/use_cases/base.py
90
91
92
93
94
def validate_evaluation_data(self, evaluation_data) -> None:
    """
    Optional validation of the evaluation dataset.
    """
    raise NotImplementedError
commonsense_mcqa

Use case class for the commonsense multiple-choice question answering (MCQA) task.

use_case
CommonsenseMCQA

Bases: UseCase

Commonsense MCQA evaluation use case.

Evaluates model's ability to answer commonsense questions via accuracy on the CommonsenseMCQA dataset (https://huggingface.co/datasets/tau/commonsense_qa). Supports answer choice shuffling across multiple runs to reduce position bias and improve evaluation robustness.

The evaluation data should contain questions with multiple choice options where models are asked to respond with only the letter (A, B, C, etc.) corresponding to their chosen answer.

Attributes:

Name Type Description
num_shuffling_runs int

Number of times to shuffle answer choices for each question to mitigate position bias effects.

Source code in aisteer360/evaluation/use_cases/commonsense_mcqa/use_case.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
class CommonsenseMCQA(UseCase):
    """Commonsense MCQA evaluation use case.

    Evaluates model's ability to answer commonsense questions via accuracy on the CommonsenseMCQA dataset
    ([https://huggingface.co/datasets/tau/commonsense_qa](https://huggingface.co/datasets/tau/commonsense_qa)). Supports
    answer choice shuffling across multiple runs to reduce position bias and improve evaluation robustness.

    The evaluation data should contain questions with multiple choice options where models are asked to respond with
    only the letter (A, B, C, etc.) corresponding to their chosen answer.

    Attributes:
        num_shuffling_runs: Number of times to shuffle answer choices for each question to mitigate position bias effects.
    """
    num_shuffling_runs: int

    def validate_evaluation_data(self, evaluation_data: dict[str, Any]):
        """Validates that evaluation data contains required fields for MCQA evaluation.

        Ensures each data instance has the necessary keys and non-null values for the evaluation.

        Args:
            evaluation_data: Dictionary containing a single evaluation instance with question, answer choices, and correct answer information.

        Raises:
            ValueError: If required keys ('id', 'question', 'answer', 'choices') are missing or if any required fields contain null/NaN values.
        """
        if "id" not in evaluation_data.keys():
            raise ValueError("The evaluation data must include an 'id' key")

        missing_keys = [col for col in _EVALUATION_REQ_KEYS if col not in evaluation_data.keys()]
        if missing_keys:
            raise ValueError(f"Missing required keys: {missing_keys}")

        if any(
            key not in evaluation_data or evaluation_data[key] is None or
            (isinstance(evaluation_data[key], float) and math.isnan(evaluation_data[key]))
            for key in _EVALUATION_REQ_KEYS
        ):
            raise ValueError("Some required fields are missing or null.")

    def generate(
        self,
        model_or_pipeline,
        tokenizer,
        gen_kwargs: dict | None = None,
        runtime_overrides: dict[tuple[str, str], str] | None = None
    ) -> list[dict[str, Any]]:
        """Generates model responses for multiple-choice questions with shuffled answer orders.

        Creates prompts for each question with shuffled answer choices, generates model responses, and parses the
        outputs to extract letter choices. Repeats the process multiple times with different answer orderings to reduce
        positional bias.

        Args:
            model_or_pipeline: Either a HuggingFace model or SteeringPipeline instance to use for generation.
            tokenizer: Tokenizer for encoding/decoding text.
            gen_kwargs: Optional generation parameters.
            runtime_overrides: Optional runtime parameter overrides for steering controls, structured as {(pipeline_name, param_name): value}.

        Returns:
            List of generation dictionaries, each containing:

                - "response": Parsed letter choice (A, B, C, etc.) or None if not parseable
                - "prompt": Full prompt text sent to the model
                - "question_id": Identifier from the original evaluation data
                - "reference_answer": Correct letter choice for this shuffled ordering

        Note:

        - The number of returned generations will be `len(evaluation_data)` * `num_shuffling_runs` due to answer choice shuffling.
        """

        if not self.evaluation_data:
            print('No evaluation data provided.')
            return []
        gen_kwargs = dict(gen_kwargs or {})

        # form prompt data
        prompt_data = []
        for instance in self.evaluation_data:
            data_id = instance['id']
            question = instance['question']
            answer = instance['answer']
            choices = instance['choices']
            # shuffle order of choices for each shuffling run
            for _ in range(self.num_shuffling_runs):

                lines = ["You will be given a multiple-choice question and asked to select from a set of choices."]
                lines += [f"\nQuestion: {question}\n"]

                # shuffle
                choice_order = list(range(len(choices)))
                random.shuffle(choice_order)
                for i, old_idx in enumerate(choice_order):
                    lines.append(f"{_LETTERS[i]}. {choices[old_idx]}")

                lines += ["\nPlease only print the letter corresponding to your choice."]
                lines += ["\nAnswer:"]

                prompt_data.append(
                    {
                        "id": data_id,
                        "prompt": "\n".join(lines),
                        "reference_answer": _LETTERS[choice_order.index(choices.index(answer))]
                    }
                )

        # batch template/generate/decode
        choices = batch_retry_generate(
            prompt_data=prompt_data,
            model_or_pipeline=model_or_pipeline,
            tokenizer=tokenizer,
            parse_fn=self._parse_letter,
            gen_kwargs=gen_kwargs,
            runtime_overrides=runtime_overrides,
            evaluation_data=self.evaluation_data
        )

        # store
        generations = [
            {
                "response": choice,
                "prompt": prompt_dict["prompt"],
                "question_id": prompt_dict["id"],
                "reference_answer": prompt_dict["reference_answer"],
            }
            for prompt_dict, choice in zip(prompt_data, choices)
        ]

        return generations

    def evaluate(self, generations: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
        """Evaluates generated responses against reference answers using configured metrics.

        Extracts responses and reference answers from generations and computes scores using all evaluation metrics
        specified during initialization.

        Args:
            generations: List of generation dictionaries returned by the `generate()` method, each containing response,
                reference_answer, and question_id fields.

        Returns:
            Dictionary of scores keyed by `metric_name`
        """
        eval_data = {
            "responses": [generation["response"] for generation in generations],
            "reference_answers": [generation["reference_answer"] for generation in generations],
            "question_ids": [generation["question_id"] for generation in generations],
        }

        scores = {}
        for metric in self.evaluation_metrics:
            scores[metric.name] = metric(**eval_data)

        return scores

    def export(self, profiles: dict[str, Any], save_dir) -> None:
        """Exports evaluation profiles to (tabbed) JSON format."""

        with open(Path(save_dir) / "profiles.json", "w", encoding="utf-8") as f:
            json.dump(profiles, f, indent=4, ensure_ascii=False)

    @staticmethod
    def _parse_letter(response) -> str:
        """Extracts the letter choice from model's generation.

        Parses model output to find the first valid letter (A-Z) that represents the model's choice.

        Args:
            response: Raw text response from the model.

        Returns:
            Single uppercase letter (A, B, C, etc.) representing the model's choice, or None if no valid letter choice could be parsed.
        """
        valid = _LETTERS
        text = re.sub(r"^\s*(assistant|system|user)[:\n ]*", "", response, flags=re.I).strip()
        match = re.search(rf"\b([{valid}])\b", text, flags=re.I)
        return match.group(1).upper() if match else None
evaluation_data = [(json.loads(line)) for line in f] if path.suffix == '.jsonl' else json.load(f) instance-attribute
evaluation_metrics = evaluation_metrics instance-attribute
num_shuffling_runs instance-attribute
evaluate(generations)

Evaluates generated responses against reference answers using configured metrics.

Extracts responses and reference answers from generations and computes scores using all evaluation metrics specified during initialization.

Parameters:

Name Type Description Default
generations list[dict[str, Any]]

List of generation dictionaries returned by the generate() method, each containing response, reference_answer, and question_id fields.

required

Returns:

Type Description
dict[str, dict[str, Any]]

Dictionary of scores keyed by metric_name

Source code in aisteer360/evaluation/use_cases/commonsense_mcqa/use_case.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def evaluate(self, generations: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
    """Evaluates generated responses against reference answers using configured metrics.

    Extracts responses and reference answers from generations and computes scores using all evaluation metrics
    specified during initialization.

    Args:
        generations: List of generation dictionaries returned by the `generate()` method, each containing response,
            reference_answer, and question_id fields.

    Returns:
        Dictionary of scores keyed by `metric_name`
    """
    eval_data = {
        "responses": [generation["response"] for generation in generations],
        "reference_answers": [generation["reference_answer"] for generation in generations],
        "question_ids": [generation["question_id"] for generation in generations],
    }

    scores = {}
    for metric in self.evaluation_metrics:
        scores[metric.name] = metric(**eval_data)

    return scores
export(profiles, save_dir)

Exports evaluation profiles to (tabbed) JSON format.

Source code in aisteer360/evaluation/use_cases/commonsense_mcqa/use_case.py
176
177
178
179
180
def export(self, profiles: dict[str, Any], save_dir) -> None:
    """Exports evaluation profiles to (tabbed) JSON format."""

    with open(Path(save_dir) / "profiles.json", "w", encoding="utf-8") as f:
        json.dump(profiles, f, indent=4, ensure_ascii=False)
generate(model_or_pipeline, tokenizer, gen_kwargs=None, runtime_overrides=None)

Generates model responses for multiple-choice questions with shuffled answer orders.

Creates prompts for each question with shuffled answer choices, generates model responses, and parses the outputs to extract letter choices. Repeats the process multiple times with different answer orderings to reduce positional bias.

Parameters:

Name Type Description Default
model_or_pipeline

Either a HuggingFace model or SteeringPipeline instance to use for generation.

required
tokenizer

Tokenizer for encoding/decoding text.

required
gen_kwargs dict | None

Optional generation parameters.

None
runtime_overrides dict[tuple[str, str], str] | None

Optional runtime parameter overrides for steering controls, structured as {(pipeline_name, param_name): value}.

None

Returns:

Type Description
list[dict[str, Any]]

List of generation dictionaries, each containing:

  • "response": Parsed letter choice (A, B, C, etc.) or None if not parseable
  • "prompt": Full prompt text sent to the model
  • "question_id": Identifier from the original evaluation data
  • "reference_answer": Correct letter choice for this shuffled ordering

Note:

  • The number of returned generations will be len(evaluation_data) * num_shuffling_runs due to answer choice shuffling.
Source code in aisteer360/evaluation/use_cases/commonsense_mcqa/use_case.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def generate(
    self,
    model_or_pipeline,
    tokenizer,
    gen_kwargs: dict | None = None,
    runtime_overrides: dict[tuple[str, str], str] | None = None
) -> list[dict[str, Any]]:
    """Generates model responses for multiple-choice questions with shuffled answer orders.

    Creates prompts for each question with shuffled answer choices, generates model responses, and parses the
    outputs to extract letter choices. Repeats the process multiple times with different answer orderings to reduce
    positional bias.

    Args:
        model_or_pipeline: Either a HuggingFace model or SteeringPipeline instance to use for generation.
        tokenizer: Tokenizer for encoding/decoding text.
        gen_kwargs: Optional generation parameters.
        runtime_overrides: Optional runtime parameter overrides for steering controls, structured as {(pipeline_name, param_name): value}.

    Returns:
        List of generation dictionaries, each containing:

            - "response": Parsed letter choice (A, B, C, etc.) or None if not parseable
            - "prompt": Full prompt text sent to the model
            - "question_id": Identifier from the original evaluation data
            - "reference_answer": Correct letter choice for this shuffled ordering

    Note:

    - The number of returned generations will be `len(evaluation_data)` * `num_shuffling_runs` due to answer choice shuffling.
    """

    if not self.evaluation_data:
        print('No evaluation data provided.')
        return []
    gen_kwargs = dict(gen_kwargs or {})

    # form prompt data
    prompt_data = []
    for instance in self.evaluation_data:
        data_id = instance['id']
        question = instance['question']
        answer = instance['answer']
        choices = instance['choices']
        # shuffle order of choices for each shuffling run
        for _ in range(self.num_shuffling_runs):

            lines = ["You will be given a multiple-choice question and asked to select from a set of choices."]
            lines += [f"\nQuestion: {question}\n"]

            # shuffle
            choice_order = list(range(len(choices)))
            random.shuffle(choice_order)
            for i, old_idx in enumerate(choice_order):
                lines.append(f"{_LETTERS[i]}. {choices[old_idx]}")

            lines += ["\nPlease only print the letter corresponding to your choice."]
            lines += ["\nAnswer:"]

            prompt_data.append(
                {
                    "id": data_id,
                    "prompt": "\n".join(lines),
                    "reference_answer": _LETTERS[choice_order.index(choices.index(answer))]
                }
            )

    # batch template/generate/decode
    choices = batch_retry_generate(
        prompt_data=prompt_data,
        model_or_pipeline=model_or_pipeline,
        tokenizer=tokenizer,
        parse_fn=self._parse_letter,
        gen_kwargs=gen_kwargs,
        runtime_overrides=runtime_overrides,
        evaluation_data=self.evaluation_data
    )

    # store
    generations = [
        {
            "response": choice,
            "prompt": prompt_dict["prompt"],
            "question_id": prompt_dict["id"],
            "reference_answer": prompt_dict["reference_answer"],
        }
        for prompt_dict, choice in zip(prompt_data, choices)
    ]

    return generations
validate_evaluation_data(evaluation_data)

Validates that evaluation data contains required fields for MCQA evaluation.

Ensures each data instance has the necessary keys and non-null values for the evaluation.

Parameters:

Name Type Description Default
evaluation_data dict[str, Any]

Dictionary containing a single evaluation instance with question, answer choices, and correct answer information.

required

Raises:

Type Description
ValueError

If required keys ('id', 'question', 'answer', 'choices') are missing or if any required fields contain null/NaN values.

Source code in aisteer360/evaluation/use_cases/commonsense_mcqa/use_case.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def validate_evaluation_data(self, evaluation_data: dict[str, Any]):
    """Validates that evaluation data contains required fields for MCQA evaluation.

    Ensures each data instance has the necessary keys and non-null values for the evaluation.

    Args:
        evaluation_data: Dictionary containing a single evaluation instance with question, answer choices, and correct answer information.

    Raises:
        ValueError: If required keys ('id', 'question', 'answer', 'choices') are missing or if any required fields contain null/NaN values.
    """
    if "id" not in evaluation_data.keys():
        raise ValueError("The evaluation data must include an 'id' key")

    missing_keys = [col for col in _EVALUATION_REQ_KEYS if col not in evaluation_data.keys()]
    if missing_keys:
        raise ValueError(f"Missing required keys: {missing_keys}")

    if any(
        key not in evaluation_data or evaluation_data[key] is None or
        (isinstance(evaluation_data[key], float) and math.isnan(evaluation_data[key]))
        for key in _EVALUATION_REQ_KEYS
    ):
        raise ValueError("Some required fields are missing or null.")
instruction_following

Use case class for the instruction following task.

use_case
InstructionFollowing

Bases: UseCase

Instruction following use case using the IFEval dataset.

Evaluates model ability to follow specific instructions by testing adherence to various formatting, content, and structural constraints. Uses the IFEval dataset which contains prompts with explicit instructions that models must follow precisely.

The evaluation focuses on whether models can follow instructions like:

  • Formatting requirements (e.g., "respond in exactly 3 sentences")
  • Content constraints (e.g., "include the word 'fantastic' twice")
  • Structural requirements (e.g., "use bullet points", "write in JSON format")

Expected evaluation data format should include fields like 'prompt', 'instructions', 'instruction_id_list', and 'kwargs' for comprehensive instruction following assessment.

Source code in aisteer360/evaluation/use_cases/instruction_following/use_case.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class InstructionFollowing(UseCase):
    """
    Instruction following use case using the IFEval dataset.

    Evaluates model ability to follow specific instructions by testing adherence to
    various formatting, content, and structural constraints. Uses the IFEval dataset
    which contains prompts with explicit instructions that models must follow precisely.

    The evaluation focuses on whether models can follow instructions like:

    - Formatting requirements (e.g., "respond in exactly 3 sentences")
    - Content constraints (e.g., "include the word 'fantastic' twice")
    - Structural requirements (e.g., "use bullet points", "write in JSON format")

    Expected evaluation data format should include fields like 'prompt', 'instructions',
    'instruction_id_list', and 'kwargs' for comprehensive instruction following assessment.
    """

    def validate_evaluation_data(self, evaluation_data: dict[str, Any]) -> None:
        pass

    def generate(
        self,
        model_or_pipeline,
        tokenizer,
        gen_kwargs: dict | None = None,
        runtime_overrides: dict[tuple[str, str], str] | None = None,
    ) -> list[dict[str, Any]]:
        """Generates model responses for instruction following prompts.

        Processes evaluation data to create chat-formatted prompts and generates model responses.

        Args:
            model_or_pipeline: Either a HuggingFace model or SteeringPipeline instance to use for generation.
            tokenizer: Tokenizer for encoding/decoding text.
            gen_kwargs: Optional generation parameters passed to the model's generate method.
            runtime_overrides: Optional runtime parameter overrides for steering controls, structured as {(pipeline_name, param_name): value}.

        Returns:
            List of generation dictionaries, each containing:

                - "response": Generated text response from the model
                - "prompt": Original instruction following prompt
                - "instructions": List of specific instructions the model should follow
                - "instruction_id_list": Identifiers for each instruction type
                - "kwargs": Additional metadata for instruction evaluation
        """
        if not self.evaluation_data:
            print("No evaluation data provided.")
            return []

        gen_kwargs = dict(gen_kwargs or {})
        prompt_data = []

        for instance in self.evaluation_data:
            user_prompt = [{"role": "user", "content": instance["prompt"]}]
            prompt_data.append({"prompt": user_prompt})

        responses = batch_retry_generate(
            prompt_data=prompt_data,
            model_or_pipeline=model_or_pipeline,
            tokenizer=tokenizer,
            gen_kwargs=gen_kwargs,
            runtime_overrides=runtime_overrides,
            evaluation_data=self.evaluation_data,
        )

        generations = [
            {
                "response": response,
                "prompt": eval_data["prompt"],
                "instructions": eval_data["instructions"],
                "instruction_id_list": eval_data["instruction_id_list"],
                "kwargs": eval_data["kwargs"],
            }
            for eval_data, response in zip(self.evaluation_data, responses)
        ]

        return generations

    def evaluate(self, generations: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
        results = {}
        for metric in self.evaluation_metrics:
            results[metric.name] = metric(responses=generations)
        return results

    def export(
        self,
        profiles: dict[str, Any],
        save_dir: str,
    ) -> None:
        """Exports instruction following evaluation results to structured JSON files.

        Creates two output files:

        1. `responses.json`: Contains model responses for each steering method
        2. `scores.json`: Contains strict metric scores for each steering method

        Args:
            profiles: Dictionary containing evaluation results from all tested pipelines.
            save_dir: Directory path where results should be saved.
        """

        folder_path = Path(save_dir)
        folder_path.mkdir(parents=True, exist_ok=True)
        steering_methods, predictions, follow_instructions = [], {}, {}
        inputs = None

        for steering_method, results in profiles.items():
            generations = results.pop("generations")
            steering_methods.append(steering_method)
            predictions[steering_method] = [gen["response"] for gen in generations]

            # get instruction following details from the StrictInstruction metric
            if "StrictInstruction" in results["evaluations"]:
                follow_instructions[steering_method] = results["evaluations"][
                    "StrictInstruction"
                ].pop("follow_all_instructions")
            if not inputs:
                inputs = [gen["prompt"] for gen in generations]

        responses = []
        for idx, prompt in enumerate(inputs):
            response = {"prompt": prompt}
            for method in steering_methods:
                response[method] = predictions[method][idx]
                response[f"{method}_instr_follow"] = follow_instructions[method][idx]
            responses.append(response)

        with open(folder_path / "responses.json", "w") as f:
            json.dump(responses, f, indent=4)
        with open(folder_path / "scores.json", "w") as f:
            json.dump(profiles, f, indent=4)
evaluation_data = [(json.loads(line)) for line in f] if path.suffix == '.jsonl' else json.load(f) instance-attribute
evaluation_metrics = evaluation_metrics instance-attribute
evaluate(generations)

Required evaluation logic for model's generations via evaluation_metrics.

Source code in aisteer360/evaluation/use_cases/instruction_following/use_case.py
91
92
93
94
95
def evaluate(self, generations: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
    results = {}
    for metric in self.evaluation_metrics:
        results[metric.name] = metric(responses=generations)
    return results
export(profiles, save_dir)

Exports instruction following evaluation results to structured JSON files.

Creates two output files:

  1. responses.json: Contains model responses for each steering method
  2. scores.json: Contains strict metric scores for each steering method

Parameters:

Name Type Description Default
profiles dict[str, Any]

Dictionary containing evaluation results from all tested pipelines.

required
save_dir str

Directory path where results should be saved.

required
Source code in aisteer360/evaluation/use_cases/instruction_following/use_case.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def export(
    self,
    profiles: dict[str, Any],
    save_dir: str,
) -> None:
    """Exports instruction following evaluation results to structured JSON files.

    Creates two output files:

    1. `responses.json`: Contains model responses for each steering method
    2. `scores.json`: Contains strict metric scores for each steering method

    Args:
        profiles: Dictionary containing evaluation results from all tested pipelines.
        save_dir: Directory path where results should be saved.
    """

    folder_path = Path(save_dir)
    folder_path.mkdir(parents=True, exist_ok=True)
    steering_methods, predictions, follow_instructions = [], {}, {}
    inputs = None

    for steering_method, results in profiles.items():
        generations = results.pop("generations")
        steering_methods.append(steering_method)
        predictions[steering_method] = [gen["response"] for gen in generations]

        # get instruction following details from the StrictInstruction metric
        if "StrictInstruction" in results["evaluations"]:
            follow_instructions[steering_method] = results["evaluations"][
                "StrictInstruction"
            ].pop("follow_all_instructions")
        if not inputs:
            inputs = [gen["prompt"] for gen in generations]

    responses = []
    for idx, prompt in enumerate(inputs):
        response = {"prompt": prompt}
        for method in steering_methods:
            response[method] = predictions[method][idx]
            response[f"{method}_instr_follow"] = follow_instructions[method][idx]
        responses.append(response)

    with open(folder_path / "responses.json", "w") as f:
        json.dump(responses, f, indent=4)
    with open(folder_path / "scores.json", "w") as f:
        json.dump(profiles, f, indent=4)
generate(model_or_pipeline, tokenizer, gen_kwargs=None, runtime_overrides=None)

Generates model responses for instruction following prompts.

Processes evaluation data to create chat-formatted prompts and generates model responses.

Parameters:

Name Type Description Default
model_or_pipeline

Either a HuggingFace model or SteeringPipeline instance to use for generation.

required
tokenizer

Tokenizer for encoding/decoding text.

required
gen_kwargs dict | None

Optional generation parameters passed to the model's generate method.

None
runtime_overrides dict[tuple[str, str], str] | None

Optional runtime parameter overrides for steering controls, structured as {(pipeline_name, param_name): value}.

None

Returns:

Type Description
list[dict[str, Any]]

List of generation dictionaries, each containing:

  • "response": Generated text response from the model
  • "prompt": Original instruction following prompt
  • "instructions": List of specific instructions the model should follow
  • "instruction_id_list": Identifiers for each instruction type
  • "kwargs": Additional metadata for instruction evaluation
Source code in aisteer360/evaluation/use_cases/instruction_following/use_case.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def generate(
    self,
    model_or_pipeline,
    tokenizer,
    gen_kwargs: dict | None = None,
    runtime_overrides: dict[tuple[str, str], str] | None = None,
) -> list[dict[str, Any]]:
    """Generates model responses for instruction following prompts.

    Processes evaluation data to create chat-formatted prompts and generates model responses.

    Args:
        model_or_pipeline: Either a HuggingFace model or SteeringPipeline instance to use for generation.
        tokenizer: Tokenizer for encoding/decoding text.
        gen_kwargs: Optional generation parameters passed to the model's generate method.
        runtime_overrides: Optional runtime parameter overrides for steering controls, structured as {(pipeline_name, param_name): value}.

    Returns:
        List of generation dictionaries, each containing:

            - "response": Generated text response from the model
            - "prompt": Original instruction following prompt
            - "instructions": List of specific instructions the model should follow
            - "instruction_id_list": Identifiers for each instruction type
            - "kwargs": Additional metadata for instruction evaluation
    """
    if not self.evaluation_data:
        print("No evaluation data provided.")
        return []

    gen_kwargs = dict(gen_kwargs or {})
    prompt_data = []

    for instance in self.evaluation_data:
        user_prompt = [{"role": "user", "content": instance["prompt"]}]
        prompt_data.append({"prompt": user_prompt})

    responses = batch_retry_generate(
        prompt_data=prompt_data,
        model_or_pipeline=model_or_pipeline,
        tokenizer=tokenizer,
        gen_kwargs=gen_kwargs,
        runtime_overrides=runtime_overrides,
        evaluation_data=self.evaluation_data,
    )

    generations = [
        {
            "response": response,
            "prompt": eval_data["prompt"],
            "instructions": eval_data["instructions"],
            "instruction_id_list": eval_data["instruction_id_list"],
            "kwargs": eval_data["kwargs"],
        }
        for eval_data, response in zip(self.evaluation_data, responses)
    ]

    return generations
validate_evaluation_data(evaluation_data)

Optional validation of the evaluation dataset.

Source code in aisteer360/evaluation/use_cases/instruction_following/use_case.py
29
30
def validate_evaluation_data(self, evaluation_data: dict[str, Any]) -> None:
    pass

utils

generation_utils
BATCH_SIZE = 64 module-attribute
apply_chat_template(tokenizer, batch, **kwargs)

Constructs template prompts for each batch element based on following cases: 1. If the model's tokenizer does not support chat_template, return the string as is. 2. If it supports chat_template: Check each instance of the batch to construct chat messages if needed. Cases: - Plain string -> convert as 'content' of 'user' - List of dictionaries with 'role' and 'content'. Continue Then apply chat template and return

Source code in aisteer360/evaluation/utils/generation_utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def apply_chat_template(tokenizer, batch, **kwargs) -> list:
    """
    Constructs template prompts for each batch element based on following cases:
    1. If the model's tokenizer does not support chat_template, return the string as is.
    2. If it supports chat_template:
        Check each instance of the batch to construct chat messages if needed. Cases:
        - Plain string -> convert as 'content' of 'user'
        - List of dictionaries with 'role' and 'content'. Continue
        Then apply chat template and return
    """

    template_prompts = []
    for idx, item in enumerate(batch):
        prompt_obj = item["prompt"]
        if not hasattr(tokenizer, "apply_chat_template"):
            template_prompts.append(str(prompt_obj))
        else:
            if isinstance(prompt_obj, str):
                messages = [{"role": "user", "content": prompt_obj}]
            elif (
                isinstance(prompt_obj, list)
                and prompt_obj
                and isinstance(prompt_obj[0], dict)
            ):
                if not all("role" in m and "content" in m for m in prompt_obj):
                    raise ValueError(
                        f"Prompt {idx}: every chat message dict must have 'role' and 'content' keys."
                    )
                messages = prompt_obj
            else:
                raise TypeError(
                    f"Prompt {idx}: must be str or list of chat messages as list[dict[str, str]] "
                    f"(got {type(prompt_obj).__name__})."
                )

            chat_str = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=False,
            )
            template_prompts.append(chat_str)
    return template_prompts
batch_retry_generate(prompt_data, model_or_pipeline, tokenizer, gen_kwargs=None, runtime_overrides=None, evaluation_data=None, parse_fn=None, max_retries=2, return_raw=False)

Generate chat completions with optional parsing/retry logic.

Function keeps retrying only the prompts whose outputs fail parse_fn (up to max_retries); return value is a list of parsed objects (or None if parsing doesn't succeed).

If return_raw is True the function instead returns a tuple (parsed_list, raw_list).

Source code in aisteer360/evaluation/utils/generation_utils.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def batch_retry_generate(
    prompt_data: Sequence[dict[str, Any]],
    model_or_pipeline: PreTrainedModel | SteeringPipeline,
    tokenizer: PreTrainedTokenizerBase,
    gen_kwargs: dict[str, Any] | None = None,
    runtime_overrides: dict[tuple[str, str], str] | None = None,
    evaluation_data: dict | None = None,
    parse_fn: Callable[[str, dict[str, Any]], Any | None] | None = None,
    max_retries: int = 2,
    return_raw: bool = False,
) -> list[Any] | tuple[list[Any], list[str]]:
    """
    Generate chat completions with optional parsing/retry logic.

    Function keeps retrying only the prompts whose outputs fail parse_fn (up to max_retries); return value is a list
    of parsed objects (or None if parsing doesn't succeed).

    If return_raw is True the function instead returns a tuple (parsed_list, raw_list).
    """

    missing_prompt = [i for i, item in enumerate(prompt_data) if "prompt" not in item]
    if missing_prompt:
        raise ValueError(f"'prompt' key missing for {len(missing_prompt)} instances")

    gen_kwargs = dict(gen_kwargs or {})
    is_pipeline = isinstance(model_or_pipeline, SteeringPipeline)

    try:
        device_obj = model_or_pipeline.device
    except Exception as e:
        raise RuntimeError(f"Unable to identify model or pipeline device - {e}")

    if is_pipeline:
        responses = chat_generate_pipeline(
            batch=prompt_data,
            pipeline=model_or_pipeline,
            tokenizer=tokenizer,
            device=device_obj,
            gen_kwargs=gen_kwargs,
            runtime_overrides=runtime_overrides,
            evaluation_data=evaluation_data,
        )
    else:
        responses = chat_generate_model(
            batch=prompt_data,
            model=model_or_pipeline,
            tokenizer=tokenizer,
            device=device_obj,
            gen_kwargs=gen_kwargs,
        )

    if parse_fn is not None:
        # parse and retry
        parsed_responses = [parse_fn(response) for response in responses]
        retry_indices = [i for i, v in enumerate(parsed_responses) if v is None]
    else:
        parsed_responses = responses
        retry_indices = []

    tries = 0
    while retry_indices and tries < max_retries:
        retry_prompts = [prompt_data[i] for i in retry_indices]

        if is_pipeline:
            retry_raw = chat_generate_pipeline(
                batch=retry_prompts,
                pipeline=model_or_pipeline,
                tokenizer=tokenizer,
                device=device_obj,
                gen_kwargs=gen_kwargs,
                runtime_overrides=runtime_overrides,
                evaluation_data=evaluation_data,
            )
        else:
            retry_raw = chat_generate_model(
                batch=retry_prompts,
                model=model_or_pipeline,
                tokenizer=tokenizer,
                device=device_obj,
                gen_kwargs=gen_kwargs,
            )

        for local_i, global_i in enumerate(retry_indices):
            responses[global_i] = retry_raw[local_i]
            parsed_responses[global_i] = parse_fn(retry_raw[local_i])

        retry_indices = [i for i, v in enumerate(parsed_responses) if v is None]
        tries += 1

    return (parsed_responses, responses) if return_raw else parsed_responses
chat_generate_model(batch, model, tokenizer, device, gen_kwargs=None)

Batch generate on model with chunking to prevent OOM. Each instance of the batch must have a 'prompt' which could be: - A plain string , in which case we apply the chat template - Dict with the chat template already applied ('role' and 'content' keys)

Source code in aisteer360/evaluation/utils/generation_utils.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def chat_generate_model(
    batch: Sequence[dict[str, Any]],
    model,
    tokenizer,
    device: str | torch.device,
    gen_kwargs: dict[str, Any] | None = None,
) -> list[str]:
    """
    Batch generate on model with chunking to prevent OOM.
    Each instance of the batch must have a 'prompt' which could be:
    - A plain string , in which case we apply the chat template
    - Dict with the chat template already applied ('role' and 'content' keys)
    """

    prompts = apply_chat_template(tokenizer, batch)
    decoded_outputs = []

    for i in range(0, len(prompts), BATCH_SIZE):
        batch_prompts = prompts[i:i + BATCH_SIZE]

        try:
            inputs = tokenizer(
                batch_prompts, return_tensors="pt", padding=True, truncation=True
            ).to(device)
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    **(gen_kwargs or {}),
                )
            start = inputs["input_ids"].shape[1]

            batch_decoded = tokenizer.batch_decode(outputs[:, start:], skip_special_tokens=True)
            decoded_outputs.extend(batch_decoded)

        except Exception as e:
            print(f"Issue with model generation at batch {i//BATCH_SIZE}: {e}")
            print("Hint - Do not apply chat template to your prompts.")
            raise

    return decoded_outputs
chat_generate_pipeline(batch, pipeline, tokenizer, device, gen_kwargs=None, runtime_overrides=None, evaluation_data=None)

Generate on pipeline.

Source code in aisteer360/evaluation/utils/generation_utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def chat_generate_pipeline(
    batch: Sequence[dict[str, Any]],
    pipeline,
    tokenizer,
    device: str | torch.device,
    gen_kwargs: dict[str, Any] | None = None,
    runtime_overrides: dict[tuple[str, str], str] | None = None,
    evaluation_data: list[dict] | None = None,
) -> list[str]:
    """Generate on pipeline."""

    if runtime_overrides is not None and evaluation_data is None:
        raise ValueError(
            "evaluation_data must be provided when runtime_overrides are supplied."
        )

    # create runtime_kwargs from runtime_overrides and evaluation_data
    runtime_kwargs_flat = {}
    if runtime_overrides:
        runtime_kwargs = {}
        for control in pipeline.controls:
            control_name = control.__class__.__name__
            if control_name in runtime_overrides:
                runtime_kwargs[control_name] = _map_runtime_overrides(
                    overrides=runtime_overrides[control_name],
                    data=evaluation_data,
                )

        # flatten and convert to list of dicts; todo: maintain control names to avoid possible collisions
        for kwargs in runtime_kwargs.values():
            for var, arg in kwargs.items():
                if var in runtime_kwargs_flat:
                    raise ValueError(
                        f"Duplicate runtime_kwargs for: {var!r}; ensure controls have distinct variables."
                    )
                runtime_kwargs_flat[var] = arg
        runtime_kwargs_flat = _runtime_kwargs_to_list(runtime_kwargs_flat)

    # Need to check for empty runtime_kwargs_flat since we may define runtime_overrides
    # for a subset of steering methods, but the current method may not have any overrides.
    # This will result in the above if block being executed with runtime_kwargs_flat = []
    if not runtime_overrides or not runtime_kwargs_flat:
        runtime_kwargs_flat = [None] * len(batch)

    prompts = apply_chat_template(tokenizer, batch)
    decoded_outputs = []

    for i in range(0, len(prompts), BATCH_SIZE):
        batch_prompts = prompts[i:i + BATCH_SIZE]
        batch_runtime_kwargs = runtime_kwargs_flat[i:i + BATCH_SIZE]

        inputs = tokenizer(
            batch_prompts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        # generate
        # todo-future: run batch as dictated by availability of batch processing of controls in pipeline
        generations = []
        with torch.no_grad():
            for j in range(len(batch_prompts)):
                out = pipeline.generate(
                    input_ids=input_ids[j].unsqueeze(0),
                    attention_mask=attention_mask[j].unsqueeze(0),
                    runtime_kwargs=batch_runtime_kwargs[j],
                    **(gen_kwargs or {}),
                )
                generations.append(out)

        tokens = [generation.squeeze(0).tolist() for generation in generations]
        padded = tokenizer.pad({"input_ids": tokens}, padding=True, return_tensors="pt")
        outputs = padded["input_ids"]
        batch_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded_outputs.extend(batch_decoded)

    return decoded_outputs
metric_utils
to_1d_array(result, n_examples)

Normalize a metric's result into a 1d numpy array of length n_examples.

Source code in aisteer360/evaluation/utils/metric_utils.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def to_1d_array(result: Any, n_examples: int) -> np.ndarray:
    """
    Normalize a metric's result into a 1d numpy array of length n_examples.
    """

    if isinstance(result, dict):
        if len(result) != 1:
            raise ValueError(f"Metric returned multiple values {list(result.keys())}; UseCase.evaluate expects exactly one.")
        result = next(iter(result.values()))

    array = np.asarray(result, dtype=float)
    if array.ndim == 0:
        array = np.full(n_examples, array.item(), dtype=float)
    elif array.ndim == 1:
        if array.size != n_examples:
            raise ValueError(f"Metric produced {array.size} values, but {n_examples} examples were expected.")
    else:
        raise ValueError(f"Metric returned an array with shape {array.shape}; only scalars or 1‑D arrays are supported.")

    return array

utils

model_utils

find_project_root(current_path)

Finds root dir by looking for pyproject.toml

Source code in aisteer360/utils/model_utils.py
 4
 5
 6
 7
 8
 9
10
def find_project_root(current_path: Path) -> Path:
    """Finds root dir by looking for pyproject.toml"""
    while current_path.parent != current_path:
        if (current_path / 'pyproject.toml').exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("no pyproject.toml found")
is_valid_model(config, model_id, service)
Source code in aisteer360/utils/model_utils.py
13
14
15
16
17
18
def is_valid_model(config, model_id, service):
    model_config = config['model-config']
    return (
            model_id in model_config and
            service in model_config[model_id]['access']
    )