Skip to content

CAST

aisteer360.algorithms.state_control.cast

args

_check_layer_ids(layer_ids)

Checks validity of layer_ids list

Raises exception if elements are not int and <0, or elements are not unique.

Source code in aisteer360/algorithms/state_control/cast/args.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def _check_layer_ids(layer_ids):
    """
    Checks validity of layer_ids list

    Raises exception if elements are not int and <0, or elements are not unique.
    """
    for ii, vv in enumerate(layer_ids):
        if not isinstance(vv, int):
            raise ValueError(f"invalid layer_id[{ii}]={vv} is of type {type(vv)} instead of int.")
        if vv < 0:
            raise ValueError(f"invalid layer_id[{ii}]={vv} < 0, should be >=0.")

    if len(set(layer_ids)) != len(layer_ids):
        raise ValueError(f"{layer_ids=} has duplicate entries. layers ids should be unique")

control

CAST

Bases: StateControl

Implementation of CAST (Conditional Activation Steering) from Lee et al., 2024.

CAST enables selective control of LLM behavior by conditionally applying activation steering based on input context, allowing fine-grained control without affecting responses to non-targeted content.

The method operates in two phases:

  1. Condition Detection: Analyzes hidden state activation patterns at specified layers during inference to detect if the input matches target conditions. This is done by projecting hidden states onto a condition subspace and computing similarity scores against a threshold.

  2. Conditional Behavior Modification: When conditions are met, applies steering vectors to hidden states at designated behavior layers. This selectively modifies the model's internal representations to produce desired behavioral changes while preserving normal functionality for non-matching inputs.

Parameters:

Name Type Description Default
condition_vector SteeringVector

Steering vector defining the condition subspace for detecting target input patterns. Defaults to None.

required
behavior_vector SteeringVector

Steering vector applied to modify behavior when conditions are met. Defaults to None.

required
condition_layer_ids list[int]

Layer indices where condition detection occurs. Defaults to None.

required
behavior_layer_ids list[int]

Layer indices where behavior modification is applied. Defaults to None.

required
condition_vector_threshold float

Similarity threshold for condition detection. Higher values require stronger pattern matches. Defaults to 0.5.

required
behavior_vector_strength float

Scaling factor for the behavior steering vector. Controls the intensity of behavioral modification. Defaults to 1.0.

required
condition_comparator_threshold_is str

Comparison mode for threshold ('larger' or 'smaller'). Determines if condition is met when similarity is above or below threshold. Defaults to 'larger'.

required
condition_threshold_comparison_mode str

How to aggregate hidden states for comparison ('mean' or 'last'). Defaults to 'mean'.

required
apply_behavior_on_first_call bool

Whether to apply behavior steering on the first forward pass. Defaults to True.

required
use_ooi_preventive_normalization bool

Apply out-of-distribution preventive normalization to maintain hidden state magnitudes. Defaults to False.

required
use_explained_variance bool

Scale steering vectors by their explained variance for adaptive layer-wise control. Defaults to False.

required

Reference:

  • "Programming Refusal with Conditional Activation Steering" Bruce W. Lee, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Erik Miehling, Pierre Dognin, Manish Nagireddy, Amit Dhurandhar https://arxiv.org/abs/2409.05907
Source code in aisteer360/algorithms/state_control/cast/control.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
class CAST(StateControl):
    """
    Implementation of CAST (Conditional Activation Steering) from Lee et al., 2024.

    CAST enables selective control of LLM behavior by conditionally applying activation steering based on input context,
    allowing fine-grained control without affecting responses to non-targeted content.

    The method operates in two phases:

    1. **Condition Detection**: Analyzes hidden state activation patterns at specified layers during inference to detect
        if the input matches target conditions. This is done by projecting hidden states onto a condition subspace and
        computing similarity scores against a threshold.

    2. **Conditional Behavior Modification**: When conditions are met, applies steering vectors to hidden states at
        designated behavior layers. This selectively modifies the model's internal representations to produce desired
        behavioral changes while preserving normal functionality for non-matching inputs.

    Args:
        condition_vector (SteeringVector, optional): Steering vector defining the condition subspace for detecting
            target input patterns. Defaults to None.
        behavior_vector (SteeringVector, optional): Steering vector applied to modify behavior when conditions are met.
            Defaults to None.
        condition_layer_ids (list[int], optional): Layer indices where condition detection occurs. Defaults to None.
        behavior_layer_ids (list[int], optional): Layer indices where behavior modification is applied. Defaults to None.
        condition_vector_threshold (float, optional): Similarity threshold for condition detection. Higher values
            require stronger pattern matches. Defaults to 0.5.
        behavior_vector_strength (float, optional): Scaling factor for the behavior steering vector. Controls the
            intensity of behavioral modification. Defaults to 1.0.
        condition_comparator_threshold_is (str, optional): Comparison mode for threshold ('larger' or 'smaller').
            Determines if condition is met when similarity is above or below threshold. Defaults to 'larger'.
        condition_threshold_comparison_mode (str, optional): How to aggregate hidden states for comparison ('mean'
            or 'last'). Defaults to 'mean'.
        apply_behavior_on_first_call (bool, optional): Whether to apply behavior steering on the first forward pass.
            Defaults to True.
        use_ooi_preventive_normalization (bool, optional): Apply out-of-distribution preventive normalization to
            maintain hidden state magnitudes. Defaults to False.
        use_explained_variance (bool, optional): Scale steering vectors by their explained variance for adaptive
            layer-wise control. Defaults to False.

    Reference:

    - "Programming Refusal with Conditional Activation Steering"
    Bruce W. Lee, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Erik Miehling, Pierre Dognin, Manish Nagireddy, Amit Dhurandhar
    [https://arxiv.org/abs/2409.05907](https://arxiv.org/abs/2409.05907)
    """

    Args = CASTArgs

    # placeholders
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    device: torch.device | str | None = None

    # layers list reference
    _layers: list | None = None
    _layers_names: list | None = None
    _layers_states: dict[int, LayerArgs] | None = None

    # Boolean lists for condition and behavior layers
    _condition_layers: dict[int, bool] | None = None
    _behavior_layers: dict[int, bool] | None = None

    # Logic flags
    _condition_met: dict[int, bool] = defaultdict(bool)
    _forward_calls: dict[int, int] = defaultdict(int)

    # condition similarity record
    _condition_similarities: dict = defaultdict(lambda: defaultdict(float))

    def reset(self):
        """Reset internal state tracking between generation calls.

        Clears condition detection flags, forward call counters, and similarity scores.
        """
        self._condition_met = defaultdict(bool)
        self._forward_calls = defaultdict(int)
        self._condition_similarities = defaultdict(lambda: defaultdict(float))

    def steer(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer | None = None,
        **__
    ) -> PreTrainedModel:
        """Initialization by configuring condition detection and behavior modification layers.

        Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation
        steering. Pre-computes projection matrices and behavior vectors.

        Args:
            model (PreTrainedModel): The base language model to be steered.
            tokenizer (PreTrainedTokenizer | None): Tokenizer (currently unused but maintained
                for API consistency). If None, attempts to retrieve from model attributes.
            **__: Additional arguments (unused).

        Returns:
            PreTrainedModel: The input model, unchanged.
        """
        self.model = model
        self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
        self.device = next(model.parameters()).device

        self._setup(self.model)

        return model

    def get_hooks(
            self,
            input_ids: torch.Tensor,
            runtime_kwargs: dict | None,
            **__,
    ) -> dict[str, list]:
        """Create pre-forward hooks for conditional activation steering.

        Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior
        modifications during the forward pass.

        Args:
            input_ids (torch.Tensor): Input token IDs (unused but required by interface).
            runtime_kwargs (dict | None): Runtime parameters (currently unused).
            **__: Additional arguments (unused).

        Returns:
            dict[str, list]: Hook specifications with "pre", "forward", "backward" keys.
                Only "pre" hooks are populated with CAST steering logic.
        """

        hooks: dict[str, list] = {"pre": [], "forward": [], "backward": []}
        for layer_id, layer_name in enumerate(self._layers_names):
            hooks["pre"].append(
                {
                    "module": layer_name,  # "model.layers.0"
                    "hook_func": partial(
                        self._cast_pre_hook,
                        layer_id=layer_id,
                    ),
                }
            )

        return hooks

    def get_model_layer_list(self, model: PreTrainedModel) -> list:
        """Extract the list of transformer layers from the model.

        Args:
            model (PreTrainedModel): Model to extract layers from.

        Returns:

            List of layers for given model
            List of layers module name prefix for given model
        """
        layers = []
        layers_names = []

        model_layers = None
        model_layers_prefix = ''

        if hasattr(model, "model"):  # mistral-, llama-, gemma-like models
            model_layers = model.model.layers
            model_layers_prefix = "model.layers"
        elif hasattr(model, "transformer"):  # gpt2-like models
            model_layers = model.transformer.h
            model_layers_prefix = "transformer.h"
        else:
            raise ValueError(f"Don't know how to get layer list from model for {type(model)=}")

        for idx, layer in enumerate(model_layers):
            layers.append(layer)
            layers_names.append(f"{model_layers_prefix}.{idx}")

        return layers, layers_names

    def _setup(self, model: PreTrainedModel):
        """Configure all CAST internals for the given model.

        Pre-computes steering vectors, condition projectors, and layer configurations to minimize runtime overhead during generation.

        Process:

        1. Identifies condition and behavior layers from configuration
        2. Computes condition projection matrices for detection layers
        3. Prepares scaled behavior vectors for modification layers
        4. Stores layer-specific parameters in _layer_states

        Args:
            model (PreTrainedModel): Model to configure CAST for.
        """
        self._layers, self._layers_names = self.get_model_layer_list(model)
        num_layers = len(self._layers)

        # Creating dicts for condition and behavior layers
        condition_layers = [False] * num_layers
        behavior_layers = [False] * num_layers

        if self.condition_vector is not None and self.condition_layer_ids is not None:
            for layer_id in self.condition_layer_ids:
                condition_layers[layer_id] = True

        if self.behavior_vector is not None:
            for layer_id in self.behavior_layer_ids:
                behavior_layers[layer_id] = True

        self._condition_layers = {i: v for i, v in enumerate(condition_layers)}
        self._behavior_layers = {i: v for i, v in enumerate(behavior_layers)}

        # Precompute behavior vectors and condition projectors
        condition_layer_ids_set = set(self.condition_layer_ids) if self.condition_layer_ids is not None else set()
        behavior_layer_ids_set = set(self.behavior_layer_ids)

        self._layer_states = {}

        for layer_id in range(num_layers):
            # layer = self._layers[layer_id]
            behavior_tensor = None
            if self.behavior_vector is not None:
                if layer_id in behavior_layer_ids_set:
                    if self.use_explained_variance:
                        behavior_direction = self._use_explained_variance_func(self.behavior_vector)
                    else:
                        behavior_direction = self.behavior_vector.directions[layer_id]

                    behavior_tensor = torch.tensor(self.behavior_vector_strength * behavior_direction, dtype=self.model.dtype).to(self.model.device)

            condition_projector = None
            if self.condition_vector is not None and layer_id in condition_layer_ids_set:
                condition_direction = self.condition_vector.directions[layer_id]
                if self.use_explained_variance:
                    condition_direction = self._use_explained_variance_func(self.condition_vector)
                else:
                    condition_direction = self.condition_vector.directions[layer_id]

                condition_tensor = torch.tensor(condition_direction, dtype=self.model.dtype).to(self.model.device)
                condition_projector = torch.ger(condition_tensor, condition_tensor) / torch.dot(condition_tensor, condition_tensor)

            layer_control_params = LayerControlParams()

            layer_args = LayerArgs(
                behavior_vector=behavior_tensor,
                condition_projector=condition_projector,
                threshold=self.condition_vector_threshold,
                use_ooi_preventive_normalization=self.use_ooi_preventive_normalization,
                apply_behavior_on_first_call=self.apply_behavior_on_first_call,
                condition_comparator_threshold_is=self.condition_comparator_threshold_is,
                condition_threshold_comparison_mode=self.condition_threshold_comparison_mode,
                params=layer_control_params,
            )

            self._layer_states[layer_id] = layer_args

    def _use_explained_variance_func(self, vector: SteeringVector, layer_id: int) -> np.ndarray:
        """Scale steering vector by its explained variance for adaptive control.

        This method scales the steering vector based on its explained variance,
        potentially adjusting its impact on different layers of the model.

        Args:
            vector (SteeringVector): Steering vector containing directions and variances.
            layer_id (int): Layer index to retrieve variance scaling for.

        Returns:
            np.ndarray: Direction vector scaled by explained variance.
        """

        if hasattr(vector, 'explained_variances'):
            variance_scale = vector.explained_variances.get(layer_id, 1)
            direction = vector.directions.get(layer_id, 1)
            direction = direction * variance_scale

        return direction

    def _cast_pre_hook(
        self,
        module,
        input_args: Tuple,
        input_kwargs: dict,
        layer_id: int,
    ):
        """Apply conditional activation steering as a pre-forward hook.

        Detect conditions and applies behavior modifications during the model's forward pass. Processes each layer
        independently based on its configuration.

        Process:

        1. Extract hidden states from arguments
        2. If condition layer: detect if input matches target pattern
        3. If behavior layer and conditions met: apply steering vector
        4. Optionally apply OOI normalization to prevent distribution shift

        Args:
            module: The layer module being hooked.
            input_args: Positional arguments to the forward pass.
            input_kwargs: Keyword arguments to the forward pass.
            layer_id (int): Index of the current layer.

        Returns:
            Tuple of potentially modified (input_args, input_kwargs).

        Raises:
            RuntimeError: If hidden states cannot be located.
        """
        hidden_states = input_args[0] if input_args else input_kwargs.get("hidden_states")
        if hidden_states is None:
            raise RuntimeError("CAST: could not locate hidden states")

        self._forward_calls[layer_id] += 1
        batch_size, seq_length, hidden_dim = hidden_states.shape

        if self._condition_layers is None:
            # CASE 1 -> no steering
            is_condition_layer = False
            is_behavior_layer = False
        else:
            # CASE 2 -> steering
            is_condition_layer = self._condition_layers[layer_id]
            is_behavior_layer = self._behavior_layers[layer_id]

        original_norm = hidden_states.norm(dim=-1, keepdim=True)

        if is_condition_layer:
            self._process_single_condition(hidden_states[0], layer_id)

        if is_behavior_layer:
            self._apply_single_behavior(hidden_states, layer_id)

        if self.use_ooi_preventive_normalization and is_behavior_layer:
            hidden_states = self._apply_ooi_normalization(hidden_states, original_norm)

        if input_args:
            input_list = list(input_args)
            input_list[0] = hidden_states
            my_input_args = tuple(input_list)
        else:
            my_input_args = input_args
            input_kwargs["hidden_states"] = hidden_states

        return my_input_args, input_kwargs

    def _process_single_condition(self, hidden_state, layer_id: int):
        """Detect if input matches target condition pattern.

        Projects hidden states onto condition subspace and compares similarity against threshold to determine if
        steering should be activated.

        Process:

        1. Aggregate hidden states (mean or last token based on config)
        2. Project onto condition subspace using precomputed projector
        3. Compute cosine similarity between original and projected
        4. Compare against threshold with specified comparator

        Args:
            hidden_state: Hidden state tensor to analyze [seq_len, hidden_dim].
            layer_id (int): Current layer index.
        """
        layer_args = self._layer_states[layer_id]

        if not self._condition_met[0] and self._forward_calls[layer_id] == 1:
            if layer_args.condition_threshold_comparison_mode == "mean":
                hidden_state = hidden_state.mean(dim=0)
            elif layer_args.condition_threshold_comparison_mode == "last":
                hidden_state = hidden_state[-1, :]

            projected_hidden_state = torch.tanh(torch.matmul(layer_args.condition_projector, hidden_state))
            condition_similarity = self._compute_similarity(hidden_state, projected_hidden_state)
            self._condition_similarities[0][layer_id] = condition_similarity

            if layer_args.condition_comparator_threshold_is == "smaller":
                condition_met = (condition_similarity >= layer_args.threshold)
            elif layer_args.condition_comparator_threshold_is == "larger":
                condition_met = (condition_similarity < layer_args.threshold)
            else:
                raise ValueError(f"invalid {layer_args.condition_comparator_threshold_is}")

            self._condition_met[0] = condition_met

            print(f"layer {layer_id}:  similarity: {condition_similarity} "
                  f"threshold: {layer_args.threshold} "
                  f"condition comparator threshold '{layer_args.condition_comparator_threshold_is}' -- "
                  f"Condition Met: {condition_met}")

    def _apply_single_behavior(self, hidden_states, layer_id: int):
        """Apply behavior steering vector when conditions are met.

        Modifies hidden states by adding scaled steering vectors to shift model behavior toward desired outputs.

        Args:
            hidden_states: Hidden states to modify [batch, seq_len, hidden_dim].
            layer_id (int): Current layer index.
        """
        layer_args = self._layer_states[layer_id]

        should_apply = not any(self._condition_layers.values()) or self._condition_met[0]

        # print(f"Should Apply Behavior: {should_apply}")

        if should_apply:
            control = layer_args.behavior_vector.to(dtype=hidden_states.dtype)
            if self._forward_calls[layer_id] == 1:
                if layer_args.apply_behavior_on_first_call:
                    hidden_states[0] = layer_args.params.operator(hidden_states[0], control)
                else:
                    print("apply_behavior_on_first_call is False, skipping behavior vector application")
            else:
                hidden_states[0] = layer_args.params.operator(hidden_states[0], control)
                # print(f"{layer_id=}: Applying behavior vector to all tokens")

    def _compute_similarity(self, x: torch.Tensor, y: torch.Tensor) -> float:
        """
        Compute the cosine similarity between two tensors.

        Args:
            x: First tensor.
            y: Second tensor.

        Returns:
            The cosine similarity as a float.
        """
        cossim = torch.dot(x.flatten(), y.flatten()) / (torch.norm(x) * torch.norm(y))
        return float(cossim.item())

    def _apply_ooi_normalization(self, hidden_states, original_norm):
        """Apply out-of-distribution preventive normalization.

        Prevents hidden states from drifting too far from original distribution by rescaling to maintain norm magnitudes after steering.

        Args:
            hidden_states: Modified hidden states to normalize.
            original_norm: Original norm before modifications.

        Returns:
            torch.Tensor: Normalized hidden states.

        Raises:
            ValueError: If NaN or Inf detected in hidden states.
        """
        new_norm = hidden_states.norm(dim=-1, keepdim=True)
        max_ratio = (new_norm / original_norm).max().item()
        has_nan_inf = torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any()

        if has_nan_inf:
            # NaN propagates, decided to raise instead of just applying norm as in original code.
            raise ValueError(f"NaN: {torch.isnan(hidden_states).any()} or Inf: {torch.isinf(hidden_states).any()} dectected in hidden_states")

        if max_ratio > 1 or has_nan_inf:
            print(f"Applying OOI preventive normalization. Max_ratio was {max_ratio}")
            hidden_states = hidden_states * (original_norm / new_norm)
        else:
            print(f"No OOI preventive normalization. Max_ratio was {max_ratio}")

        return hidden_states
_behavior_layers = None class-attribute instance-attribute
_condition_layers = None class-attribute instance-attribute
_condition_met = defaultdict(bool) class-attribute instance-attribute
_condition_similarities = defaultdict(lambda: defaultdict(float)) class-attribute instance-attribute
_forward_calls = defaultdict(int) class-attribute instance-attribute
_layers = None class-attribute instance-attribute
_layers_names = None class-attribute instance-attribute
_layers_states = None class-attribute instance-attribute
_model_ref = None class-attribute instance-attribute
args = self.Args.validate(*args, **kwargs) instance-attribute
device = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute
model = None class-attribute instance-attribute
registered = [] instance-attribute
tokenizer = None class-attribute instance-attribute
__enter__()

Context manager entry: register hooks to model.

Raises:

Type Description
RuntimeError

If model reference not set by pipeline

Source code in aisteer360/algorithms/state_control/base.py
119
120
121
122
123
124
125
126
127
128
129
def __enter__(self):
    """Context manager entry: register hooks to model.

    Raises:
        RuntimeError: If model reference not set by pipeline
    """
    if self._model_ref is None:
        raise RuntimeError("Model reference not set before entering context.")
    self.register_hooks(self._model_ref)

    return self
__exit__(exc_type, exc, tb)

Context manager exit: clean up all hooks.

Source code in aisteer360/algorithms/state_control/base.py
131
132
133
def __exit__(self, exc_type, exc, tb):
    """Context manager exit: clean up all hooks."""
    self.remove_hooks()
__init__(*args, **kwargs)
Source code in aisteer360/algorithms/state_control/base.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(self, *args, **kwargs) -> None:
    if self.Args is None:  # null control
        if args or kwargs:
            raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
        return

    self.args: BaseArgs = self.Args.validate(*args, **kwargs)

    # move fields to attributes
    for field in fields(self.args):
        setattr(self, field.name, getattr(self.args, field.name))

    self.hooks: dict[str, list[HookSpec]] = {"pre": [], "forward": [], "backward": []}
    self.registered: list[torch.utils.hooks.RemovableHandle] = []
_apply_ooi_normalization(hidden_states, original_norm)

Apply out-of-distribution preventive normalization.

Prevents hidden states from drifting too far from original distribution by rescaling to maintain norm magnitudes after steering.

Parameters:

Name Type Description Default
hidden_states

Modified hidden states to normalize.

required
original_norm

Original norm before modifications.

required

Returns:

Type Description

torch.Tensor: Normalized hidden states.

Raises:

Type Description
ValueError

If NaN or Inf detected in hidden states.

Source code in aisteer360/algorithms/state_control/cast/control.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
def _apply_ooi_normalization(self, hidden_states, original_norm):
    """Apply out-of-distribution preventive normalization.

    Prevents hidden states from drifting too far from original distribution by rescaling to maintain norm magnitudes after steering.

    Args:
        hidden_states: Modified hidden states to normalize.
        original_norm: Original norm before modifications.

    Returns:
        torch.Tensor: Normalized hidden states.

    Raises:
        ValueError: If NaN or Inf detected in hidden states.
    """
    new_norm = hidden_states.norm(dim=-1, keepdim=True)
    max_ratio = (new_norm / original_norm).max().item()
    has_nan_inf = torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any()

    if has_nan_inf:
        # NaN propagates, decided to raise instead of just applying norm as in original code.
        raise ValueError(f"NaN: {torch.isnan(hidden_states).any()} or Inf: {torch.isinf(hidden_states).any()} dectected in hidden_states")

    if max_ratio > 1 or has_nan_inf:
        print(f"Applying OOI preventive normalization. Max_ratio was {max_ratio}")
        hidden_states = hidden_states * (original_norm / new_norm)
    else:
        print(f"No OOI preventive normalization. Max_ratio was {max_ratio}")

    return hidden_states
_apply_single_behavior(hidden_states, layer_id)

Apply behavior steering vector when conditions are met.

Modifies hidden states by adding scaled steering vectors to shift model behavior toward desired outputs.

Parameters:

Name Type Description Default
hidden_states

Hidden states to modify [batch, seq_len, hidden_dim].

required
layer_id int

Current layer index.

required
Source code in aisteer360/algorithms/state_control/cast/control.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def _apply_single_behavior(self, hidden_states, layer_id: int):
    """Apply behavior steering vector when conditions are met.

    Modifies hidden states by adding scaled steering vectors to shift model behavior toward desired outputs.

    Args:
        hidden_states: Hidden states to modify [batch, seq_len, hidden_dim].
        layer_id (int): Current layer index.
    """
    layer_args = self._layer_states[layer_id]

    should_apply = not any(self._condition_layers.values()) or self._condition_met[0]

    # print(f"Should Apply Behavior: {should_apply}")

    if should_apply:
        control = layer_args.behavior_vector.to(dtype=hidden_states.dtype)
        if self._forward_calls[layer_id] == 1:
            if layer_args.apply_behavior_on_first_call:
                hidden_states[0] = layer_args.params.operator(hidden_states[0], control)
            else:
                print("apply_behavior_on_first_call is False, skipping behavior vector application")
        else:
            hidden_states[0] = layer_args.params.operator(hidden_states[0], control)
_cast_pre_hook(module, input_args, input_kwargs, layer_id)

Apply conditional activation steering as a pre-forward hook.

Detect conditions and applies behavior modifications during the model's forward pass. Processes each layer independently based on its configuration.

Process:

  1. Extract hidden states from arguments
  2. If condition layer: detect if input matches target pattern
  3. If behavior layer and conditions met: apply steering vector
  4. Optionally apply OOI normalization to prevent distribution shift

Parameters:

Name Type Description Default
module

The layer module being hooked.

required
input_args Tuple

Positional arguments to the forward pass.

required
input_kwargs dict

Keyword arguments to the forward pass.

required
layer_id int

Index of the current layer.

required

Returns:

Type Description

Tuple of potentially modified (input_args, input_kwargs).

Raises:

Type Description
RuntimeError

If hidden states cannot be located.

Source code in aisteer360/algorithms/state_control/cast/control.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def _cast_pre_hook(
    self,
    module,
    input_args: Tuple,
    input_kwargs: dict,
    layer_id: int,
):
    """Apply conditional activation steering as a pre-forward hook.

    Detect conditions and applies behavior modifications during the model's forward pass. Processes each layer
    independently based on its configuration.

    Process:

    1. Extract hidden states from arguments
    2. If condition layer: detect if input matches target pattern
    3. If behavior layer and conditions met: apply steering vector
    4. Optionally apply OOI normalization to prevent distribution shift

    Args:
        module: The layer module being hooked.
        input_args: Positional arguments to the forward pass.
        input_kwargs: Keyword arguments to the forward pass.
        layer_id (int): Index of the current layer.

    Returns:
        Tuple of potentially modified (input_args, input_kwargs).

    Raises:
        RuntimeError: If hidden states cannot be located.
    """
    hidden_states = input_args[0] if input_args else input_kwargs.get("hidden_states")
    if hidden_states is None:
        raise RuntimeError("CAST: could not locate hidden states")

    self._forward_calls[layer_id] += 1
    batch_size, seq_length, hidden_dim = hidden_states.shape

    if self._condition_layers is None:
        # CASE 1 -> no steering
        is_condition_layer = False
        is_behavior_layer = False
    else:
        # CASE 2 -> steering
        is_condition_layer = self._condition_layers[layer_id]
        is_behavior_layer = self._behavior_layers[layer_id]

    original_norm = hidden_states.norm(dim=-1, keepdim=True)

    if is_condition_layer:
        self._process_single_condition(hidden_states[0], layer_id)

    if is_behavior_layer:
        self._apply_single_behavior(hidden_states, layer_id)

    if self.use_ooi_preventive_normalization and is_behavior_layer:
        hidden_states = self._apply_ooi_normalization(hidden_states, original_norm)

    if input_args:
        input_list = list(input_args)
        input_list[0] = hidden_states
        my_input_args = tuple(input_list)
    else:
        my_input_args = input_args
        input_kwargs["hidden_states"] = hidden_states

    return my_input_args, input_kwargs
_compute_similarity(x, y)

Compute the cosine similarity between two tensors.

Parameters:

Name Type Description Default
x Tensor

First tensor.

required
y Tensor

Second tensor.

required

Returns:

Type Description
float

The cosine similarity as a float.

Source code in aisteer360/algorithms/state_control/cast/control.py
430
431
432
433
434
435
436
437
438
439
440
441
442
def _compute_similarity(self, x: torch.Tensor, y: torch.Tensor) -> float:
    """
    Compute the cosine similarity between two tensors.

    Args:
        x: First tensor.
        y: Second tensor.

    Returns:
        The cosine similarity as a float.
    """
    cossim = torch.dot(x.flatten(), y.flatten()) / (torch.norm(x) * torch.norm(y))
    return float(cossim.item())
_process_single_condition(hidden_state, layer_id)

Detect if input matches target condition pattern.

Projects hidden states onto condition subspace and compares similarity against threshold to determine if steering should be activated.

Process:

  1. Aggregate hidden states (mean or last token based on config)
  2. Project onto condition subspace using precomputed projector
  3. Compute cosine similarity between original and projected
  4. Compare against threshold with specified comparator

Parameters:

Name Type Description Default
hidden_state

Hidden state tensor to analyze [seq_len, hidden_dim].

required
layer_id int

Current layer index.

required
Source code in aisteer360/algorithms/state_control/cast/control.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def _process_single_condition(self, hidden_state, layer_id: int):
    """Detect if input matches target condition pattern.

    Projects hidden states onto condition subspace and compares similarity against threshold to determine if
    steering should be activated.

    Process:

    1. Aggregate hidden states (mean or last token based on config)
    2. Project onto condition subspace using precomputed projector
    3. Compute cosine similarity between original and projected
    4. Compare against threshold with specified comparator

    Args:
        hidden_state: Hidden state tensor to analyze [seq_len, hidden_dim].
        layer_id (int): Current layer index.
    """
    layer_args = self._layer_states[layer_id]

    if not self._condition_met[0] and self._forward_calls[layer_id] == 1:
        if layer_args.condition_threshold_comparison_mode == "mean":
            hidden_state = hidden_state.mean(dim=0)
        elif layer_args.condition_threshold_comparison_mode == "last":
            hidden_state = hidden_state[-1, :]

        projected_hidden_state = torch.tanh(torch.matmul(layer_args.condition_projector, hidden_state))
        condition_similarity = self._compute_similarity(hidden_state, projected_hidden_state)
        self._condition_similarities[0][layer_id] = condition_similarity

        if layer_args.condition_comparator_threshold_is == "smaller":
            condition_met = (condition_similarity >= layer_args.threshold)
        elif layer_args.condition_comparator_threshold_is == "larger":
            condition_met = (condition_similarity < layer_args.threshold)
        else:
            raise ValueError(f"invalid {layer_args.condition_comparator_threshold_is}")

        self._condition_met[0] = condition_met

        print(f"layer {layer_id}:  similarity: {condition_similarity} "
              f"threshold: {layer_args.threshold} "
              f"condition comparator threshold '{layer_args.condition_comparator_threshold_is}' -- "
              f"Condition Met: {condition_met}")
_setup(model)

Configure all CAST internals for the given model.

Pre-computes steering vectors, condition projectors, and layer configurations to minimize runtime overhead during generation.

Process:

  1. Identifies condition and behavior layers from configuration
  2. Computes condition projection matrices for detection layers
  3. Prepares scaled behavior vectors for modification layers
  4. Stores layer-specific parameters in _layer_states

Parameters:

Name Type Description Default
model PreTrainedModel

Model to configure CAST for.

required
Source code in aisteer360/algorithms/state_control/cast/control.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def _setup(self, model: PreTrainedModel):
    """Configure all CAST internals for the given model.

    Pre-computes steering vectors, condition projectors, and layer configurations to minimize runtime overhead during generation.

    Process:

    1. Identifies condition and behavior layers from configuration
    2. Computes condition projection matrices for detection layers
    3. Prepares scaled behavior vectors for modification layers
    4. Stores layer-specific parameters in _layer_states

    Args:
        model (PreTrainedModel): Model to configure CAST for.
    """
    self._layers, self._layers_names = self.get_model_layer_list(model)
    num_layers = len(self._layers)

    # Creating dicts for condition and behavior layers
    condition_layers = [False] * num_layers
    behavior_layers = [False] * num_layers

    if self.condition_vector is not None and self.condition_layer_ids is not None:
        for layer_id in self.condition_layer_ids:
            condition_layers[layer_id] = True

    if self.behavior_vector is not None:
        for layer_id in self.behavior_layer_ids:
            behavior_layers[layer_id] = True

    self._condition_layers = {i: v for i, v in enumerate(condition_layers)}
    self._behavior_layers = {i: v for i, v in enumerate(behavior_layers)}

    # Precompute behavior vectors and condition projectors
    condition_layer_ids_set = set(self.condition_layer_ids) if self.condition_layer_ids is not None else set()
    behavior_layer_ids_set = set(self.behavior_layer_ids)

    self._layer_states = {}

    for layer_id in range(num_layers):
        # layer = self._layers[layer_id]
        behavior_tensor = None
        if self.behavior_vector is not None:
            if layer_id in behavior_layer_ids_set:
                if self.use_explained_variance:
                    behavior_direction = self._use_explained_variance_func(self.behavior_vector)
                else:
                    behavior_direction = self.behavior_vector.directions[layer_id]

                behavior_tensor = torch.tensor(self.behavior_vector_strength * behavior_direction, dtype=self.model.dtype).to(self.model.device)

        condition_projector = None
        if self.condition_vector is not None and layer_id in condition_layer_ids_set:
            condition_direction = self.condition_vector.directions[layer_id]
            if self.use_explained_variance:
                condition_direction = self._use_explained_variance_func(self.condition_vector)
            else:
                condition_direction = self.condition_vector.directions[layer_id]

            condition_tensor = torch.tensor(condition_direction, dtype=self.model.dtype).to(self.model.device)
            condition_projector = torch.ger(condition_tensor, condition_tensor) / torch.dot(condition_tensor, condition_tensor)

        layer_control_params = LayerControlParams()

        layer_args = LayerArgs(
            behavior_vector=behavior_tensor,
            condition_projector=condition_projector,
            threshold=self.condition_vector_threshold,
            use_ooi_preventive_normalization=self.use_ooi_preventive_normalization,
            apply_behavior_on_first_call=self.apply_behavior_on_first_call,
            condition_comparator_threshold_is=self.condition_comparator_threshold_is,
            condition_threshold_comparison_mode=self.condition_threshold_comparison_mode,
            params=layer_control_params,
        )

        self._layer_states[layer_id] = layer_args
_use_explained_variance_func(vector, layer_id)

Scale steering vector by its explained variance for adaptive control.

This method scales the steering vector based on its explained variance, potentially adjusting its impact on different layers of the model.

Parameters:

Name Type Description Default
vector SteeringVector

Steering vector containing directions and variances.

required
layer_id int

Layer index to retrieve variance scaling for.

required

Returns:

Type Description
ndarray

np.ndarray: Direction vector scaled by explained variance.

Source code in aisteer360/algorithms/state_control/cast/control.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def _use_explained_variance_func(self, vector: SteeringVector, layer_id: int) -> np.ndarray:
    """Scale steering vector by its explained variance for adaptive control.

    This method scales the steering vector based on its explained variance,
    potentially adjusting its impact on different layers of the model.

    Args:
        vector (SteeringVector): Steering vector containing directions and variances.
        layer_id (int): Layer index to retrieve variance scaling for.

    Returns:
        np.ndarray: Direction vector scaled by explained variance.
    """

    if hasattr(vector, 'explained_variances'):
        variance_scale = vector.explained_variances.get(layer_id, 1)
        direction = vector.directions.get(layer_id, 1)
        direction = direction * variance_scale

    return direction
get_hooks(input_ids, runtime_kwargs, **__)

Create pre-forward hooks for conditional activation steering.

Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior modifications during the forward pass.

Parameters:

Name Type Description Default
input_ids Tensor

Input token IDs (unused but required by interface).

required
runtime_kwargs dict | None

Runtime parameters (currently unused).

required
**__

Additional arguments (unused).

{}

Returns:

Type Description
dict[str, list]

dict[str, list]: Hook specifications with "pre", "forward", "backward" keys. Only "pre" hooks are populated with CAST steering logic.

Source code in aisteer360/algorithms/state_control/cast/control.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def get_hooks(
        self,
        input_ids: torch.Tensor,
        runtime_kwargs: dict | None,
        **__,
) -> dict[str, list]:
    """Create pre-forward hooks for conditional activation steering.

    Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior
    modifications during the forward pass.

    Args:
        input_ids (torch.Tensor): Input token IDs (unused but required by interface).
        runtime_kwargs (dict | None): Runtime parameters (currently unused).
        **__: Additional arguments (unused).

    Returns:
        dict[str, list]: Hook specifications with "pre", "forward", "backward" keys.
            Only "pre" hooks are populated with CAST steering logic.
    """

    hooks: dict[str, list] = {"pre": [], "forward": [], "backward": []}
    for layer_id, layer_name in enumerate(self._layers_names):
        hooks["pre"].append(
            {
                "module": layer_name,  # "model.layers.0"
                "hook_func": partial(
                    self._cast_pre_hook,
                    layer_id=layer_id,
                ),
            }
        )

    return hooks
get_model_layer_list(model)

Extract the list of transformer layers from the model.

Parameters:

Name Type Description Default
model PreTrainedModel

Model to extract layers from.

required

Returns:

List of layers for given model
List of layers module name prefix for given model
Source code in aisteer360/algorithms/state_control/cast/control.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def get_model_layer_list(self, model: PreTrainedModel) -> list:
    """Extract the list of transformer layers from the model.

    Args:
        model (PreTrainedModel): Model to extract layers from.

    Returns:

        List of layers for given model
        List of layers module name prefix for given model
    """
    layers = []
    layers_names = []

    model_layers = None
    model_layers_prefix = ''

    if hasattr(model, "model"):  # mistral-, llama-, gemma-like models
        model_layers = model.model.layers
        model_layers_prefix = "model.layers"
    elif hasattr(model, "transformer"):  # gpt2-like models
        model_layers = model.transformer.h
        model_layers_prefix = "transformer.h"
    else:
        raise ValueError(f"Don't know how to get layer list from model for {type(model)=}")

    for idx, layer in enumerate(model_layers):
        layers.append(layer)
        layers_names.append(f"{model_layers_prefix}.{idx}")

    return layers, layers_names
register_hooks(model)

Attach hooks to model.

Source code in aisteer360/algorithms/state_control/base.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def register_hooks(self, model: PreTrainedModel) -> None:
    """Attach hooks to model."""
    for phase in ("pre", "forward", "backward"):
        for spec in self.hooks[phase]:
            module = model.get_submodule(spec["module"])
            if phase == "pre":
                handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
            elif phase == "forward":
                handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
            else:
                handle = module.register_full_backward_hook(spec["hook_func"])
            self.registered.append(handle)
remove_hooks()

Remove all registered hooks from the model.

Source code in aisteer360/algorithms/state_control/base.py
109
110
111
112
113
def remove_hooks(self) -> None:
    """Remove all registered hooks from the model."""
    for handle in self.registered:
        handle.remove()
    self.registered.clear()
reset()

Reset internal state tracking between generation calls.

Clears condition detection flags, forward call counters, and similarity scores.

Source code in aisteer360/algorithms/state_control/cast/control.py
91
92
93
94
95
96
97
98
def reset(self):
    """Reset internal state tracking between generation calls.

    Clears condition detection flags, forward call counters, and similarity scores.
    """
    self._condition_met = defaultdict(bool)
    self._forward_calls = defaultdict(int)
    self._condition_similarities = defaultdict(lambda: defaultdict(float))
set_hooks(hooks)

Update the hook specifications to be registered.

Source code in aisteer360/algorithms/state_control/base.py
115
116
117
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Update the hook specifications to be registered."""
    self.hooks = hooks
steer(model, tokenizer=None, **__)

Initialization by configuring condition detection and behavior modification layers.

Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation steering. Pre-computes projection matrices and behavior vectors.

Parameters:

Name Type Description Default
model PreTrainedModel

The base language model to be steered.

required
tokenizer PreTrainedTokenizer | None

Tokenizer (currently unused but maintained for API consistency). If None, attempts to retrieve from model attributes.

None
**__

Additional arguments (unused).

{}

Returns:

Name Type Description
PreTrainedModel PreTrainedModel

The input model, unchanged.

Source code in aisteer360/algorithms/state_control/cast/control.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def steer(
    self,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer | None = None,
    **__
) -> PreTrainedModel:
    """Initialization by configuring condition detection and behavior modification layers.

    Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation
    steering. Pre-computes projection matrices and behavior vectors.

    Args:
        model (PreTrainedModel): The base language model to be steered.
        tokenizer (PreTrainedTokenizer | None): Tokenizer (currently unused but maintained
            for API consistency). If None, attempts to retrieve from model attributes.
        **__: Additional arguments (unused).

    Returns:
        PreTrainedModel: The input model, unchanged.
    """
    self.model = model
    self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
    self.device = next(model.parameters()).device

    self._setup(self.model)

    return model