Skip to content

Adding a state control method

Required override: get_hooks

State control methods work by defining hooks that are then registered into the base model before inference. As part of this tutorial, we’ll implement an ActivationBias method that adds a fixed bias (alpha) to the hidden state output at a specified transformer layer.

First, create the registry file:

from .control import ActivationBias
from .args import ActivationBiasArgs

REGISTRY_ENTRY = {
    "category": "state_control",
    "name": "activation_bias",
    "control": ActivationBias,
    "args": ActivationBiasArgs,
}

Next, define the arguments class. This is where we define the required arguments; the transformer layer (via layer_idx) and the bias (via alpha):

from dataclasses import dataclass, field
from aisteer360.algorithms.core.base_args import BaseArgs


@dataclass
class ActivationBiasArgs(BaseArgs):
    layer_idx: int = field(
        default=0,
        metadata={"help": "Transformer block to patch."}
    )
    alpha: float = field(
        default=0.02,
        metadata={"help": "Bias magnitude."}
    )

    def __post_init__(self):
        if self.layer_idx < 0:
            raise ValueError("layer_idx must be non-negative")

Lastly, the control is implemented as follows:

import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

from aisteer360.algorithms.state_control.base import StateControl, HookSpec
from aisteer360.algorithms.state_control.activation_bias.args import ActivationBiasArgs


class ActivationBias(StateControl):
    """Adds alpha to hidden states at the selected layer."""

    Args = ActivationBiasArgs

    # class attributes (filled by steer)
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    device: torch.device | str | None = None

    def steer(
            self,
            model: PreTrainedModel = None,
            tokenizer: PreTrainedTokenizer = None,
            **kwargs) -> None:
        self.model = model
        self.device = next(model.parameters()).device

    def get_hooks(
            self,
            input_ids: torch.Tensor,
            runtime_kwargs,
            **__
    ) -> dict[str, list[HookSpec]]:
        """Returns a forward hook that adds alpha to a specific layer's output.

        Args:
            input_ids (torch.Tensor): Input tensor (unused).
            runtime_kwargs: Optional runtime parameters (unused).

        Returns:
            dict[str, list[HookSpec]]: A dictionary mapping hook phases ("pre", "forward", "backward") to lists of hook
            specifications. Each HookSpec contains:
              - "module": The name of the module to hook
              - "hook_func": The hook function to apply (pre, forward, or backward)
        """

        def fwd_hook(module, args, kwargs, output):

            # handle different output formats
            if isinstance(output, tuple):
                return (output[0] + self.alpha,) + output[1:]
            elif isinstance(output, dict):
                output = output.copy()
                output['hidden_states'] += self.alpha
                return output
            else:  # direct tensor
                return output + self.alpha

        return {
            "pre": [],
            "forward": [{
                "module": f"model.layers.{self.layer_idx}",
                "hook_func": fwd_hook,
            }],
            "backward": [],
        }

The hooks are then registered into the model via the register_hooks method in the state control base class (aisteer360/algorithms/state_control/base.py) such that they can be run on every generate call. The control can then be called via:

from aisteer360.algorithms.state_control.activation_bias.control import ActivationBias
from aisteer360.algorithms.core.steering_pipeline import SteeringPipeline

MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"

activation_bias_control = ActivationBias(layer_idx=2, alpha=0.03)

activation_bias_pipeline = SteeringPipeline(
    model_name_or_path=MODEL_NAME,
    controls=[activation_bias_control],
)
activation_bias_pipeline.steer()

prompt = "What should I do in Prague?"
chat = activation_bias_pipeline.tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize=False,
    add_generation_prompt=True
)
inputs = activation_bias_pipeline.tokenizer(chat, return_tensors="pt")

print(activation_bias_pipeline.generate_text(inputs.input_ids, max_new_tokens=50))