CAST
aisteer360.algorithms.state_control.cast
args
_check_layer_ids(layer_ids)
Checks validity of layer_ids list
Raises exception if elements are not int and <0, or elements are not unique.
Source code in aisteer360/algorithms/state_control/cast/args.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
|
control
CAST
Bases: StateControl
Implementation of CAST (Conditional Activation Steering) from Lee et al., 2024.
CAST enables selective control of LLM behavior by conditionally applying activation steering based on input context, allowing fine-grained control without affecting responses to non-targeted content.
The method operates in two phases:
-
Condition Detection: Analyzes hidden state activation patterns at specified layers during inference to detect if the input matches target conditions. This is done by projecting hidden states onto a condition subspace and computing similarity scores against a threshold.
-
Conditional Behavior Modification: When conditions are met, applies steering vectors to hidden states at designated behavior layers. This selectively modifies the model's internal representations to produce desired behavioral changes while preserving normal functionality for non-matching inputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
condition_vector
|
SteeringVector
|
Steering vector defining the condition subspace for detecting target input patterns. Defaults to None. |
required |
behavior_vector
|
SteeringVector
|
Steering vector applied to modify behavior when conditions are met. Defaults to None. |
required |
condition_layer_ids
|
list[int]
|
Layer indices where condition detection occurs. Defaults to None. |
required |
behavior_layer_ids
|
list[int]
|
Layer indices where behavior modification is applied. Defaults to None. |
required |
condition_vector_threshold
|
float
|
Similarity threshold for condition detection. Higher values require stronger pattern matches. Defaults to 0.5. |
required |
behavior_vector_strength
|
float
|
Scaling factor for the behavior steering vector. Controls the intensity of behavioral modification. Defaults to 1.0. |
required |
condition_comparator_threshold_is
|
str
|
Comparison mode for threshold ('larger' or 'smaller'). Determines if condition is met when similarity is above or below threshold. Defaults to 'larger'. |
required |
condition_threshold_comparison_mode
|
str
|
How to aggregate hidden states for comparison ('mean' or 'last'). Defaults to 'mean'. |
required |
apply_behavior_on_first_call
|
bool
|
Whether to apply behavior steering on the first forward pass. Defaults to True. |
required |
use_ooi_preventive_normalization
|
bool
|
Apply out-of-distribution preventive normalization to maintain hidden state magnitudes. Defaults to False. |
required |
use_explained_variance
|
bool
|
Scale steering vectors by their explained variance for adaptive layer-wise control. Defaults to False. |
required |
Reference:
- "Programming Refusal with Conditional Activation Steering" Bruce W. Lee, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Erik Miehling, Pierre Dognin, Manish Nagireddy, Amit Dhurandhar https://arxiv.org/abs/2409.05907
Source code in aisteer360/algorithms/state_control/cast/control.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 |
|
_behavior_layers = None
class-attribute
instance-attribute
_condition_layers = None
class-attribute
instance-attribute
_condition_met = defaultdict(bool)
class-attribute
instance-attribute
_condition_similarities = defaultdict(lambda: defaultdict(float))
class-attribute
instance-attribute
_forward_calls = defaultdict(int)
class-attribute
instance-attribute
_layers = None
class-attribute
instance-attribute
_layers_names = None
class-attribute
instance-attribute
_layers_states = None
class-attribute
instance-attribute
_model_ref = None
class-attribute
instance-attribute
args = self.Args.validate(*args, **kwargs)
instance-attribute
device = None
class-attribute
instance-attribute
enabled = True
class-attribute
instance-attribute
hooks = {'pre': [], 'forward': [], 'backward': []}
instance-attribute
model = None
class-attribute
instance-attribute
registered = []
instance-attribute
tokenizer = None
class-attribute
instance-attribute
__enter__()
Context manager entry: register hooks to model.
Raises:
Type | Description |
---|---|
RuntimeError
|
If model reference not set by pipeline |
Source code in aisteer360/algorithms/state_control/base.py
119 120 121 122 123 124 125 126 127 128 129 |
|
__exit__(exc_type, exc, tb)
Context manager exit: clean up all hooks.
Source code in aisteer360/algorithms/state_control/base.py
131 132 133 |
|
__init__(*args, **kwargs)
Source code in aisteer360/algorithms/state_control/base.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
_apply_ooi_normalization(hidden_states, original_norm)
Apply out-of-distribution preventive normalization.
Prevents hidden states from drifting too far from original distribution by rescaling to maintain norm magnitudes after steering.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hidden_states
|
Modified hidden states to normalize. |
required | |
original_norm
|
Original norm before modifications. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: Normalized hidden states. |
Raises:
Type | Description |
---|---|
ValueError
|
If NaN or Inf detected in hidden states. |
Source code in aisteer360/algorithms/state_control/cast/control.py
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 |
|
_apply_single_behavior(hidden_states, layer_id)
Apply behavior steering vector when conditions are met.
Modifies hidden states by adding scaled steering vectors to shift model behavior toward desired outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hidden_states
|
Hidden states to modify [batch, seq_len, hidden_dim]. |
required | |
layer_id
|
int
|
Current layer index. |
required |
Source code in aisteer360/algorithms/state_control/cast/control.py
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 |
|
_cast_pre_hook(module, input_args, input_kwargs, layer_id)
Apply conditional activation steering as a pre-forward hook.
Detect conditions and applies behavior modifications during the model's forward pass. Processes each layer independently based on its configuration.
Process:
- Extract hidden states from arguments
- If condition layer: detect if input matches target pattern
- If behavior layer and conditions met: apply steering vector
- Optionally apply OOI normalization to prevent distribution shift
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module
|
The layer module being hooked. |
required | |
input_args
|
Tuple
|
Positional arguments to the forward pass. |
required |
input_kwargs
|
dict
|
Keyword arguments to the forward pass. |
required |
layer_id
|
int
|
Index of the current layer. |
required |
Returns:
Type | Description |
---|---|
Tuple of potentially modified (input_args, input_kwargs). |
Raises:
Type | Description |
---|---|
RuntimeError
|
If hidden states cannot be located. |
Source code in aisteer360/algorithms/state_control/cast/control.py
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
|
_compute_similarity(x, y)
Compute the cosine similarity between two tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
First tensor. |
required |
y
|
Tensor
|
Second tensor. |
required |
Returns:
Type | Description |
---|---|
float
|
The cosine similarity as a float. |
Source code in aisteer360/algorithms/state_control/cast/control.py
430 431 432 433 434 435 436 437 438 439 440 441 442 |
|
_process_single_condition(hidden_state, layer_id)
Detect if input matches target condition pattern.
Projects hidden states onto condition subspace and compares similarity against threshold to determine if steering should be activated.
Process:
- Aggregate hidden states (mean or last token based on config)
- Project onto condition subspace using precomputed projector
- Compute cosine similarity between original and projected
- Compare against threshold with specified comparator
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hidden_state
|
Hidden state tensor to analyze [seq_len, hidden_dim]. |
required | |
layer_id
|
int
|
Current layer index. |
required |
Source code in aisteer360/algorithms/state_control/cast/control.py
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 |
|
_setup(model)
Configure all CAST internals for the given model.
Pre-computes steering vectors, condition projectors, and layer configurations to minimize runtime overhead during generation.
Process:
- Identifies condition and behavior layers from configuration
- Computes condition projection matrices for detection layers
- Prepares scaled behavior vectors for modification layers
- Stores layer-specific parameters in _layer_states
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
PreTrainedModel
|
Model to configure CAST for. |
required |
Source code in aisteer360/algorithms/state_control/cast/control.py
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
|
_use_explained_variance_func(vector, layer_id)
Scale steering vector by its explained variance for adaptive control.
This method scales the steering vector based on its explained variance, potentially adjusting its impact on different layers of the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vector
|
SteeringVector
|
Steering vector containing directions and variances. |
required |
layer_id
|
int
|
Layer index to retrieve variance scaling for. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: Direction vector scaled by explained variance. |
Source code in aisteer360/algorithms/state_control/cast/control.py
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
|
get_hooks(input_ids, runtime_kwargs, **__)
Create pre-forward hooks for conditional activation steering.
Generates hook specifications for all model layers that will conditionally detect patterns and apply behavior modifications during the forward pass.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ids
|
Tensor
|
Input token IDs (unused but required by interface). |
required |
runtime_kwargs
|
dict | None
|
Runtime parameters (currently unused). |
required |
**__
|
Additional arguments (unused). |
{}
|
Returns:
Type | Description |
---|---|
dict[str, list]
|
dict[str, list]: Hook specifications with "pre", "forward", "backward" keys. Only "pre" hooks are populated with CAST steering logic. |
Source code in aisteer360/algorithms/state_control/cast/control.py
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
|
get_model_layer_list(model)
Extract the list of transformer layers from the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
PreTrainedModel
|
Model to extract layers from. |
required |
Returns:
List of layers for given model
List of layers module name prefix for given model
Source code in aisteer360/algorithms/state_control/cast/control.py
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
|
register_hooks(model)
Attach hooks to model.
Source code in aisteer360/algorithms/state_control/base.py
96 97 98 99 100 101 102 103 104 105 106 107 |
|
remove_hooks()
Remove all registered hooks from the model.
Source code in aisteer360/algorithms/state_control/base.py
109 110 111 112 113 |
|
reset()
Reset internal state tracking between generation calls.
Clears condition detection flags, forward call counters, and similarity scores.
Source code in aisteer360/algorithms/state_control/cast/control.py
91 92 93 94 95 96 97 98 |
|
set_hooks(hooks)
Update the hook specifications to be registered.
Source code in aisteer360/algorithms/state_control/base.py
115 116 117 |
|
steer(model, tokenizer=None, **__)
Initialization by configuring condition detection and behavior modification layers.
Sets up steering vectors, condition projectors, and layer-specific parameters for conditional activation steering. Pre-computes projection matrices and behavior vectors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
PreTrainedModel
|
The base language model to be steered. |
required |
tokenizer
|
PreTrainedTokenizer | None
|
Tokenizer (currently unused but maintained for API consistency). If None, attempts to retrieve from model attributes. |
None
|
**__
|
Additional arguments (unused). |
{}
|
Returns:
Name | Type | Description |
---|---|---|
PreTrainedModel |
PreTrainedModel
|
The input model, unchanged. |
Source code in aisteer360/algorithms/state_control/cast/control.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
|