Skip to content

Core

aisteer360.algorithms.core

Core functionality for steering pipelines, steering utilities, and argument parsing.

base_args

Base argument validation for steering method configuration.

T = TypeVar('T', bound='BaseArgs') module-attribute

BaseArgs dataclass

Base class for all method's args classes.

Source code in aisteer360/algorithms/core/base_args.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@dataclass
class BaseArgs:
    """Base class for all method's args classes."""

    @classmethod
    def validate(cls: Type[T], data: Any | None = None, **kwargs) -> T:
        """Create and validate an Args instance from dict, kwargs, or existing instance.

        Args:
            data: Existing instance, dict of args, or None
            **kwargs: Additional args (override values in data if both provided)

        Returns:
            Validated instance of the Args class
        """

        if isinstance(data, cls):
            return data

        if isinstance(data, Mapping):
            kwargs = {**data, **kwargs}

        return cls(**kwargs)
validate(data=None, **kwargs) classmethod

Create and validate an Args instance from dict, kwargs, or existing instance.

Parameters:

Name Type Description Default
data Any | None

Existing instance, dict of args, or None

None
**kwargs

Additional args (override values in data if both provided)

{}

Returns:

Type Description
T

Validated instance of the Args class

Source code in aisteer360/algorithms/core/base_args.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@classmethod
def validate(cls: Type[T], data: Any | None = None, **kwargs) -> T:
    """Create and validate an Args instance from dict, kwargs, or existing instance.

    Args:
        data: Existing instance, dict of args, or None
        **kwargs: Additional args (override values in data if both provided)

    Returns:
        Validated instance of the Args class
    """

    if isinstance(data, cls):
        return data

    if isinstance(data, Mapping):
        kwargs = {**data, **kwargs}

    return cls(**kwargs)

specs

Specification utilities for steering controls.

Provides:

  • ControlSpec: a description of a steering control plus a hyperparameter search space.

Space = Mapping[str, Sequence[Any]] | Sequence[Mapping[str, Any]] | Callable[[dict], Iterable[Mapping[str, Any]]] module-attribute

ControlSpec dataclass

Specification for a parameterized steering control.

A ControlSpec describes a control class plus a search space over its constructor arguments. It is used by a benchmark object to instantiate control instances for different hyperparameter settings.

Attributes:

Name Type Description
control_cls Type

The steering control class to instantiate.

params Mapping[str, Any]

Fixed constructor arguments for the control.

vars Space | None

Optional search space over additional constructor arguments. May be:

  • mapping (cartesian grid)
  • list of parameter dicts
  • callable that yields parameter dicts given a context
name str | None

Optional short name for this spec; defaults to control_cls.__name__ if omitted.

search_strategy Literal['grid', 'random']

Strategy for traversing vars when it is a mapping or a sequence. Either "grid" (use all points) or "random" (sample a subset).

num_samples int | None

Number of points to sample when search_strategy="random" and vars is a mapping or sequence; ignored when vars is callable.

seed int | None

Optional random seed used when search_strategy="random".

Source code in aisteer360/algorithms/core/specs.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@dataclass(slots=True)
class ControlSpec:
    """Specification for a parameterized steering control.

    A `ControlSpec` describes a control class plus a search space over its constructor arguments. It is used by a
    benchmark object to instantiate control instances for different hyperparameter settings.

    Attributes:
        control_cls: The steering control class to instantiate.
        params: Fixed constructor arguments for the control.
        vars: Optional search space over additional constructor arguments. May be:

            - mapping (cartesian grid)
            - list of parameter dicts
            - callable that yields parameter dicts given a context
        name: Optional short name for this spec; defaults to `control_cls.__name__` if omitted.
        search_strategy: Strategy for traversing `vars` when it is a mapping or a sequence. Either `"grid"` (use all
            points) or `"random"` (sample a subset).
        num_samples: Number of points to sample when `search_strategy="random"` and `vars` is a mapping or sequence;
            ignored when `vars` is callable.
        seed: Optional random seed used when `search_strategy="random"`.
    """

    control_cls: Type
    params: Mapping[str, Any] = field(default_factory=dict)
    vars: Space | None = None
    name: str | None = None
    search_strategy: Literal["grid", "random"] = "grid"
    num_samples: int | None = None
    seed: int | None = None

    def iter_points(self, context: dict) -> Iterable[dict[str, Any]]:
        """Iterate over local search points for this spec.

        Args:
            context: Context dictionary; passed through to functional `vars` if `vars` is callable.

        Yields:
            Parameter dictionaries (possibly empty) that will be merged into `params` when constructing a concrete
            control instance.
        """
        search_space = self.vars

        # no search space
        if search_space is None:
            yield {}
            return

        # forward the context
        if callable(search_space):
            yield from search_space(context)  # callable controls sampling
            return

        # Mapping[str, Sequence[Any]]: potentially large cartesian product
        if isinstance(search_space, Mapping):
            param_names = list(search_space.keys())
            param_values = [list(search_space[name]) for name in param_names]

            if any(len(vals) == 0 for vals in param_values):
                return

            sizes = [len(vals) for vals in param_values]
            n_points = math.prod(sizes)

            # GRID SEARCH: iterate over the cartesian product
            if (
                self.search_strategy == "grid"
                or self.num_samples is None
                or self.num_samples >= n_points
            ):
                for combo in itertools.product(*param_values):
                    yield dict(zip(param_names, combo))
                return

            # RANDOM SEARCH: sample indices
            rng = random.Random(self.seed)
            k = min(self.num_samples, n_points)
            index_samples = rng.sample(range(n_points), k)

            # decode (flat) index into a combination of parameter choices
            for flat_index in index_samples:
                idx = flat_index
                indices_per_dim: list[int] = []
                for size in reversed(sizes):
                    indices_per_dim.append(idx % size)
                    idx //= size
                indices_per_dim.reverse()

                values = [param_values[dim][indices_per_dim[dim]] for dim in range(len(param_names))]
                yield dict(zip(param_names, values))
            return

        # Sequence[Mapping[str, Any]]: explicit list of parameter dicts
        combinations = [dict(param_dict) for param_dict in search_space]

        if (
            self.search_strategy == "random"
            and self.num_samples is not None
            and self.num_samples < len(combinations)
        ):
            rng = random.Random(self.seed)
            combinations = rng.sample(combinations, self.num_samples)

        for combination in combinations:
            yield combination

    def resolve_params(self, chosen: dict[str, Any], context: dict) -> dict[str, Any]:
        """Compute the full kwargs for this control at a given search point.
        """
        local_context = dict(context)
        local_context["search_params"] = chosen

        resolved_params = {
            key: (value(local_context) if callable(value) else value)
            for key, value in self.params.items()
        }

        resolved_params.update(chosen)
        return resolved_params
control_cls instance-attribute
name = None class-attribute instance-attribute
num_samples = None class-attribute instance-attribute
params = field(default_factory=dict) class-attribute instance-attribute
search_strategy = 'grid' class-attribute instance-attribute
seed = None class-attribute instance-attribute
vars = None class-attribute instance-attribute
iter_points(context)

Iterate over local search points for this spec.

Parameters:

Name Type Description Default
context dict

Context dictionary; passed through to functional vars if vars is callable.

required

Yields:

Type Description
Iterable[dict[str, Any]]

Parameter dictionaries (possibly empty) that will be merged into params when constructing a concrete

Iterable[dict[str, Any]]

control instance.

Source code in aisteer360/algorithms/core/specs.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def iter_points(self, context: dict) -> Iterable[dict[str, Any]]:
    """Iterate over local search points for this spec.

    Args:
        context: Context dictionary; passed through to functional `vars` if `vars` is callable.

    Yields:
        Parameter dictionaries (possibly empty) that will be merged into `params` when constructing a concrete
        control instance.
    """
    search_space = self.vars

    # no search space
    if search_space is None:
        yield {}
        return

    # forward the context
    if callable(search_space):
        yield from search_space(context)  # callable controls sampling
        return

    # Mapping[str, Sequence[Any]]: potentially large cartesian product
    if isinstance(search_space, Mapping):
        param_names = list(search_space.keys())
        param_values = [list(search_space[name]) for name in param_names]

        if any(len(vals) == 0 for vals in param_values):
            return

        sizes = [len(vals) for vals in param_values]
        n_points = math.prod(sizes)

        # GRID SEARCH: iterate over the cartesian product
        if (
            self.search_strategy == "grid"
            or self.num_samples is None
            or self.num_samples >= n_points
        ):
            for combo in itertools.product(*param_values):
                yield dict(zip(param_names, combo))
            return

        # RANDOM SEARCH: sample indices
        rng = random.Random(self.seed)
        k = min(self.num_samples, n_points)
        index_samples = rng.sample(range(n_points), k)

        # decode (flat) index into a combination of parameter choices
        for flat_index in index_samples:
            idx = flat_index
            indices_per_dim: list[int] = []
            for size in reversed(sizes):
                indices_per_dim.append(idx % size)
                idx //= size
            indices_per_dim.reverse()

            values = [param_values[dim][indices_per_dim[dim]] for dim in range(len(param_names))]
            yield dict(zip(param_names, values))
        return

    # Sequence[Mapping[str, Any]]: explicit list of parameter dicts
    combinations = [dict(param_dict) for param_dict in search_space]

    if (
        self.search_strategy == "random"
        and self.num_samples is not None
        and self.num_samples < len(combinations)
    ):
        rng = random.Random(self.seed)
        combinations = rng.sample(combinations, self.num_samples)

    for combination in combinations:
        yield combination
resolve_params(chosen, context)

Compute the full kwargs for this control at a given search point.

Source code in aisteer360/algorithms/core/specs.py
127
128
129
130
131
132
133
134
135
136
137
138
139
def resolve_params(self, chosen: dict[str, Any], context: dict) -> dict[str, Any]:
    """Compute the full kwargs for this control at a given search point.
    """
    local_context = dict(context)
    local_context["search_params"] = chosen

    resolved_params = {
        key: (value(local_context) if callable(value) else value)
        for key, value in self.params.items()
    }

    resolved_params.update(chosen)
    return resolved_params

steering_pipeline

Core steering pipeline for composing and applying multiple LLM control methods.

SteeringPipeline dataclass

Main steering pipeline for applying various control methods to Hugging Face causal language models.

Enables application of structural, state, input, and output controls in a coordinated manner. Controls are applied in a fixed bottom-up order during steering, then used together during generation.

Workflow:

  1. Instantiate with a base model checkpoint and/or control objects
  2. Call steer() once to apply all controls in order (structural → input → state → output)
  3. Use generate() or generate_text() for inference with steering applied

Parameters:

Name Type Description Default
model_name_or_path str or Path

HuggingFace model hub name or local directory. Required when lazy_init=False. Ignored when lazy_init=True and the structural control returns a model.

None
controls Sequence[StructuralControl | StateControl | InputControl | OutputControl]

Controls for the steering pipeline, max one control per category. Omitted categories fall back to no-op controls (see control base classes).

()
tokenizer_name_or_path str

Tokenizer location. Defaults to model_name_or_path.

None
device_map str or dict[str, int]

Device map (passed to transformers.AutoModelForCausalLM.from_pretrained). Defaults to "auto". Cannot be used together with device parameter.

'auto'
device (device, str)

Device (passed to model's .to() method). When specified, device_map must remain at its default value of "auto".

None
hf_model_kwargs dict

Extra keyword arguments passed to transformers.AutoModelForCausalLM.from_pretrained.

dict()
lazy_init bool

If True, defers loading the base model until steer() time. Useful when a StructuralControl will itself load or create the final weights (e.g., MergeKit). When False, the model is loaded during SteeringPipeline construction. Defaults to False.

False

Raises:

Type Description
RuntimeError

If generate() is called before steer()

ValueError

If multiple controls provided for same category or required arguments missing

Note:

  • Maximum one control per category; omitted categories use no-op defaults
  • Controls with a tokenizer attribute will have it auto-injected if not already set
Source code in aisteer360/algorithms/core/steering_pipeline.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
@dataclass(slots=True)
class SteeringPipeline:
    """Main steering pipeline for applying various control methods to Hugging Face causal language models.

    Enables application of structural, state, input, and output controls in a coordinated manner.
    Controls are applied in a fixed bottom-up order during steering, then used together during generation.

    Workflow:

    1. Instantiate with a base model checkpoint and/or control objects
    2. Call `steer()` once to apply all controls in order (structural → input → state → output)
    3. Use `generate()` or `generate_text()` for inference with steering applied

    Args:
        model_name_or_path (str or pathlib.Path, optional): HuggingFace model hub name or local directory.
            Required when `lazy_init=False`. Ignored when `lazy_init=True` and the structural
            control returns a model.
        controls (Sequence[StructuralControl | StateControl | InputControl | OutputControl], optional):
            Controls for the steering pipeline, max one control per category. Omitted categories
            fall back to no-op controls (see control base classes).
        tokenizer_name_or_path (str, optional): Tokenizer location. Defaults to `model_name_or_path`.
        device_map (str or dict[str, int], optional): Device map (passed to
            `transformers.AutoModelForCausalLM.from_pretrained`). Defaults to `"auto"`.
            Cannot be used together with `device` parameter.
        device (torch.device, str, optional): Device (passed to model's `.to()` method).
            When specified, `device_map` must remain at its default value of `"auto"`.
        hf_model_kwargs (dict, optional): Extra keyword arguments passed to
            `transformers.AutoModelForCausalLM.from_pretrained`.
        lazy_init (bool, optional): If `True`, defers loading the base model until `steer()` time.
            Useful when a `StructuralControl` will itself load or create the final weights
            (e.g., MergeKit). When `False`, the model is loaded during `SteeringPipeline`
            construction. Defaults to `False`.

    Raises:
        RuntimeError: If `generate()` is called before `steer()`
        ValueError: If multiple controls provided for same category or required arguments missing

    Note:

    - Maximum one control per category; omitted categories use no-op defaults
    - Controls with a `tokenizer` attribute will have it auto-injected if not already set
    """

    # construction args
    model_name_or_path: str | Path | None = None
    controls: Sequence[StructuralControl | StateControl | InputControl | OutputControl] = ()
    tokenizer_name_or_path: str | None = None
    device_map: str | dict[str, int] | int | torch.device | None = "auto"
    device: torch.device | str | None = None
    hf_model_kwargs: dict = field(default_factory=dict)
    lazy_init: bool = False

    # lazy‑filled fields
    model: PreTrainedModel | None = field(init=False, default=None)
    tokenizer: AutoTokenizer | None = field(init=False, default=None)

    structural_control: StructuralControl = field(init=False)
    input_control: InputControl = field(init=False)
    state_control: StateControl = field(init=False)
    output_control: OutputControl = field(init=False)

    _is_steered: bool = field(default=False, init=False, repr=False)

    def __post_init__(self) -> None:

        # sort/validate the supplied steering methods
        controls_merged = merge_controls(self.controls)
        self.structural_control = controls_merged["structural_control"]
        self.input_control = controls_merged["input_control"]
        self.state_control = controls_merged["state_control"]
        self.output_control = controls_merged["output_control"]

        # load HF artifacts
        if not self.lazy_init:
            if self.model_name_or_path is None:
                raise ValueError("`model_name_or_path` must be provided when lazy_init=False")

            if self.device is not None and self.device_map != "auto":
                raise ValueError("Cannot specify both `device` and `device_map`.")

            if self.device is not None:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name_or_path,
                    **self.hf_model_kwargs,
                )
                self.model = self.model.to(self.device)
                self.device = self.model.device
            else:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name_or_path,
                    device_map=self.device_map,
                    **self.hf_model_kwargs,
                )
                self.device = self.model.device

            self.tokenizer = AutoTokenizer.from_pretrained(
                self.tokenizer_name_or_path or self.model_name_or_path,
                trust_remote_code=True,
            )
            self.tokenizer = ensure_pad_token(self.tokenizer)
        else:
            if isinstance(self.tokenizer_name_or_path, (str, Path)):
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.tokenizer_name_or_path,
                    trust_remote_code=True
                )
                self.tokenizer = ensure_pad_token(self.tokenizer)

        # late‑inject tokenizer into controls that accept it
        controls_iter = (self.structural_control, self.input_control, self.state_control, self.output_control)
        for control in controls_iter:
            if hasattr(control, "tokenizer") and getattr(control, "tokenizer") is None:
                setattr(control, "tokenizer", self.tokenizer)

    @property
    def supports_batching(self) -> bool:
        """Return True if all enabled controls in this pipeline are batch-safe.
        """
        controls = (
            self.structural_control,
            self.input_control,
            self.state_control,
            self.output_control,
        )
        return all(
            getattr(control, "supports_batching", False)
            for control in controls
            if getattr(control, "enabled", True)
        )

    def steer(self, **steer_kwargs) -> None:
        """Apply all steering controls to the model in place.

        Executes each control's steer() method in a fixed bottom-up order: structural -> input -> state -> output.
        This ensures that higher-level controls always see the final configured model from lower levels.

        If any control's steer() method returns a PreTrainedModel instance, it replaces the current model for subsequent
        controls.

        Args:
            **steer_kwargs: Keyword arguments passed to all control steer() methods

        Raises:
            RuntimeError: If called more than once or no model available after steering
        """
        if self._is_steered:
            return

        # steer each control (bottom-up order: structural -> input -> state -> output)
        for control in (self.structural_control, self.input_control, self.state_control, self.output_control):
            steer_fn = getattr(control, "steer", None)
            if callable(steer_fn):
                maybe_new_model = steer_fn(self.model, tokenizer=self.tokenizer, **steer_kwargs)
                if isinstance(maybe_new_model, nn.Module):
                    self.model = maybe_new_model

        # safety checks
        if self.model is None:
            raise RuntimeError(
                "No model is available after steering. Either provide a base model (lazy_init=False) or ensure a "
                "`StructuralControl` returns one."
            )

        if self.tokenizer is None:
            repo = getattr(self.model, "name_or_path", None)
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    repo or Path(getattr(self.structural_control.args, "out_path", "")),
                    trust_remote_code=True,
                )
                self.tokenizer = ensure_pad_token(self.tokenizer)

            except Exception as exception:
                raise RuntimeError("Failed to resolve tokenizer post‑steer.") from exception

        for control in (self.structural_control, self.input_control, self.state_control, self.output_control):
            if hasattr(control, "tokenizer") and getattr(control, "tokenizer", None) is None:
                setattr(control, "tokenizer", self.tokenizer)

        # return steered pipeline
        self._is_steered = True

    def _prepare_inputs(
            self,
            input_ids: list[int] | torch.LongTensor,
            attention_mask: torch.Tensor | None,
            runtime_kwargs: dict | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply input control and normalize input tensors.

        Transforms the prompt via the input control's adapter and ensures both input_ids and attention_mask are
        properly shaped tensors on the correct device.

        Args:
            input_ids: Input token IDs as list or tensor [seq_len] or [batch, seq_len]
            attention_mask: Optional attention mask matching input_ids shape
            runtime_kwargs: Per-call parameters for input control

        Returns:
            tuple[torch.Tensor, torch.Tensor]: (steered_input_ids, attention_mask), both as 2D tensors on model device
        """
        runtime_kwargs = runtime_kwargs or {}
        device = self.model.device

        # apply input control adapter
        adapter = self.input_control.get_prompt_adapter(runtime_kwargs)
        steered_input_ids = adapter(input_ids, runtime_kwargs)

        # normalize input_ids to 2D tensor
        if isinstance(steered_input_ids, list):
            steered_input_ids = torch.tensor(steered_input_ids, dtype=torch.long)
        if steered_input_ids.ndim == 1:
            steered_input_ids = steered_input_ids.unsqueeze(0)
        steered_input_ids = steered_input_ids.to(device)

        # normalize attention_mask
        if attention_mask is not None:
            if isinstance(attention_mask, list):
                attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
            if attention_mask.ndim == 1:
                attention_mask = attention_mask.unsqueeze(0)
            # rebuild if length mismatch after input control transformation
            if attention_mask.shape[-1] != steered_input_ids.shape[-1]:
                attention_mask = None

        if attention_mask is None:
            if self.tokenizer is not None and self.tokenizer.pad_token_id is not None:
                attention_mask = (steered_input_ids != self.tokenizer.pad_token_id).long()
            else:
                attention_mask = torch.ones_like(steered_input_ids, dtype=torch.long)

        attention_mask = attention_mask.to(dtype=steered_input_ids.dtype, device=device)

        return steered_input_ids, attention_mask

    def _setup_state_control(
            self,
            steered_input_ids: torch.Tensor,
            runtime_kwargs: dict | None,
            **kwargs,
    ) -> None:
        """Configure state control hooks for the current forward/generate call.

        Prepares the state control by computing hooks based on the (already transformed) input and setting up the model
        reference for the context manager.

        Args:
            steered_input_ids: Input token IDs after input control transformation
            runtime_kwargs: Per-call parameters for state control
            **kwargs: Additional arguments passed to get_hooks()
        """
        hooks = self.state_control.get_hooks(steered_input_ids, runtime_kwargs, **kwargs)
        self.state_control.set_hooks(hooks)
        self.state_control._model_ref = self.model
        self.state_control.reset()

    def generate(
            self,
            input_ids: list[int] | torch.LongTensor,
            attention_mask: torch.Tensor | None = None,
            runtime_kwargs: dict | None = None,
            **gen_kwargs
    ) -> torch.Tensor:
        """Generate text with all steering controls applied.

        Applies controls in sequence during generation:

        1. Input control adapts the prompt
        2. State control registers hooks for state control (e.g., activation steering)
        3. Output control handles the actual generation

        Args:
            input_ids: Token IDs as list or tensor (shape: [seq_len] or [batch, seq_len])
            attention_mask: Optional attention mask matching input_ids shape
            runtime_kwargs: Per-generation parameters for controls (e.g., {"substrings": [...]})
            **gen_kwargs: Generation parameters passed to `model.generate()`

        Returns:
            Generated token IDs (shape: [batch, generated_len])

        Raises:
            RuntimeError: If steer() has not yet been called
        """
        if not self._is_steered:
            raise RuntimeError("Must call `.steer()` before `.generate()`.")

        runtime_kwargs = runtime_kwargs or {}
        return_full_sequence = bool(gen_kwargs.pop("return_full_sequence", False))

        # input control
        steered_input_ids, attention_mask = self._prepare_inputs(
            input_ids=input_ids,
            attention_mask=attention_mask,
            runtime_kwargs=runtime_kwargs,
        )

        # state control
        self._setup_state_control(steered_input_ids, runtime_kwargs, **gen_kwargs)

        # output control
        with self.state_control:  # hooks live only for duration of decoding
            output_ids = self.output_control.generate(
                input_ids=steered_input_ids,
                attention_mask=attention_mask,
                runtime_kwargs=runtime_kwargs,
                model=self.model,
                **gen_kwargs
            )

        if not return_full_sequence:
            output_ids = output_ids[:, steered_input_ids.size(1):]

        return output_ids

    def generate_text(self, *args, **kwargs) -> str | list[str]:
        """Generate text and decode to string(s).

        Convenience wrapper that calls generate() and decodes the output tokens.

        Args:
            *args: Arguments passed to generate()
            **kwargs: Keyword arguments passed to generate()

        Returns:
            Decoded text string (single prompt) or list of strings (batch)
        """
        ids = self.generate(*args, **kwargs)
        if ids.ndim == 1:
            return self.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    def compute_logprobs(
            self,
            input_ids: list[int] | torch.LongTensor,
            attention_mask: torch.Tensor | None = None,
            ref_output_ids: list[int] | torch.LongTensor = None,
            runtime_kwargs: dict | None = None,
            **forward_kwargs: Any,
    ) -> torch.Tensor:
        """Compute per-token log-probabilities of ref_output_ids with structural, input, and state steering controls
        applied. Note that output controls are *not* applied since they concern scoring, not generation.

        The strategy below uses teacher forcing, computes log P(ref_t | steered_input, ref_1, ..., ref_{t-1}) for each
        token in the reference sequence.

        Args:
            input_ids: Input token IDs as list or tensor [seq_len] or [batch, seq_len]
            attention_mask: Optional attention mask matching input_ids shape
            ref_output_ids: Reference tokens to score [ref_len] or [batch, ref_len]
            runtime_kwargs: Per-call parameters for controls (e.g., {"substrings": [...]})
            **forward_kwargs: Additional arguments passed to model forward pass

        Returns:
            torch.Tensor: Log probabilities of shape [batch, ref_len] for decoder-only models,
                or [batch, ref_len - 1] for encoder-decoder models (excludes first decoder token)

        Raises:
            RuntimeError: If steer() has not been called
            ValueError: If ref_output_ids is None
        """
        if not self._is_steered:
            raise RuntimeError("Must call `.steer()` before `.compute_logprobs()`.")
        if ref_output_ids is None:
            raise ValueError("`ref_output_ids` is required for `compute_logprobs()`.")

        runtime_kwargs = runtime_kwargs or {}
        device = self.model.device

        # input control
        steered_input_ids, attention_mask = self._prepare_inputs(
            input_ids=input_ids,
            attention_mask=attention_mask,
            runtime_kwargs=runtime_kwargs,
        )

        # normalize ref_output_ids
        if isinstance(ref_output_ids, list):
            ref_output_ids = torch.tensor(ref_output_ids, dtype=torch.long)
        if ref_output_ids.ndim == 1:
            ref_output_ids = ref_output_ids.unsqueeze(0)
        ref_output_ids = ref_output_ids.to(device)

        batch_size = steered_input_ids.size(0)
        ref_len = ref_output_ids.size(1)

        # broadcast single ref sequence across batch
        if ref_output_ids.size(0) == 1 and batch_size > 1:
            ref_output_ids = ref_output_ids.expand(batch_size, -1)

        if ref_len == 0:
            return torch.zeros((batch_size, 0), device=device, dtype=torch.float32)

        # state control
        self._setup_state_control(steered_input_ids, runtime_kwargs, **forward_kwargs)

        # forward pass under state control context
        is_encoder_decoder = getattr(self.model.config, "is_encoder_decoder", False)

        with self.state_control:
            with torch.no_grad():
                if is_encoder_decoder:
                    outputs = self.model(
                        input_ids=steered_input_ids,
                        attention_mask=attention_mask,
                        decoder_input_ids=ref_output_ids,
                        **forward_kwargs,
                    )
                    # predicts ref[t+1] from ref[0:t]; logits[:, t, :] -> ref[t+1]
                    # logits[:, :-1, :] aligns with targets ref[:, 1:]
                    logits = outputs.logits[:, :-1, :]
                    target_ids = ref_output_ids[:, 1:]
                else:
                    # concatenate input + ref for causal teacher forcing
                    combined_ids = torch.cat([steered_input_ids, ref_output_ids], dim=1)
                    combined_mask = torch.cat([
                        attention_mask,
                        torch.ones(batch_size, ref_len, device=device, dtype=attention_mask.dtype),
                    ], dim=1)

                    outputs = self.model(
                        input_ids=combined_ids,
                        attention_mask=combined_mask,
                        **forward_kwargs,
                    )

                    # logits at [input_len - 1] predicts ref[0]
                    # logits at [input_len + ref_len - 2] predicts ref[ref_len - 1]
                    input_len = steered_input_ids.size(1)
                    logits = outputs.logits[:, input_len - 1: input_len + ref_len - 1, :]
                    target_ids = ref_output_ids

        # compute logprobs
        logprobs = torch.log_softmax(logits, dim=-1)
        token_logprobs = logprobs.gather(
            dim=-1,
            index=target_ids.unsqueeze(-1),
        ).squeeze(-1)

        return token_logprobs
controls = () class-attribute instance-attribute
device = None class-attribute instance-attribute
device_map = 'auto' class-attribute instance-attribute
hf_model_kwargs = field(default_factory=dict) class-attribute instance-attribute
input_control = field(init=False) class-attribute instance-attribute
lazy_init = False class-attribute instance-attribute
model = field(init=False, default=None) class-attribute instance-attribute
model_name_or_path = None class-attribute instance-attribute
output_control = field(init=False) class-attribute instance-attribute
state_control = field(init=False) class-attribute instance-attribute
structural_control = field(init=False) class-attribute instance-attribute
supports_batching property

Return True if all enabled controls in this pipeline are batch-safe.

tokenizer = field(init=False, default=None) class-attribute instance-attribute
tokenizer_name_or_path = None class-attribute instance-attribute
compute_logprobs(input_ids, attention_mask=None, ref_output_ids=None, runtime_kwargs=None, **forward_kwargs)

Compute per-token log-probabilities of ref_output_ids with structural, input, and state steering controls applied. Note that output controls are not applied since they concern scoring, not generation.

The strategy below uses teacher forcing, computes log P(ref_t | steered_input, ref_1, ..., ref_{t-1}) for each token in the reference sequence.

Parameters:

Name Type Description Default
input_ids list[int] | LongTensor

Input token IDs as list or tensor [seq_len] or [batch, seq_len]

required
attention_mask Tensor | None

Optional attention mask matching input_ids shape

None
ref_output_ids list[int] | LongTensor

Reference tokens to score [ref_len] or [batch, ref_len]

None
runtime_kwargs dict | None

Per-call parameters for controls (e.g., {"substrings": [...]})

None
**forward_kwargs Any

Additional arguments passed to model forward pass

{}

Returns:

Type Description
Tensor

torch.Tensor: Log probabilities of shape [batch, ref_len] for decoder-only models, or [batch, ref_len - 1] for encoder-decoder models (excludes first decoder token)

Raises:

Type Description
RuntimeError

If steer() has not been called

ValueError

If ref_output_ids is None

Source code in aisteer360/algorithms/core/steering_pipeline.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def compute_logprobs(
        self,
        input_ids: list[int] | torch.LongTensor,
        attention_mask: torch.Tensor | None = None,
        ref_output_ids: list[int] | torch.LongTensor = None,
        runtime_kwargs: dict | None = None,
        **forward_kwargs: Any,
) -> torch.Tensor:
    """Compute per-token log-probabilities of ref_output_ids with structural, input, and state steering controls
    applied. Note that output controls are *not* applied since they concern scoring, not generation.

    The strategy below uses teacher forcing, computes log P(ref_t | steered_input, ref_1, ..., ref_{t-1}) for each
    token in the reference sequence.

    Args:
        input_ids: Input token IDs as list or tensor [seq_len] or [batch, seq_len]
        attention_mask: Optional attention mask matching input_ids shape
        ref_output_ids: Reference tokens to score [ref_len] or [batch, ref_len]
        runtime_kwargs: Per-call parameters for controls (e.g., {"substrings": [...]})
        **forward_kwargs: Additional arguments passed to model forward pass

    Returns:
        torch.Tensor: Log probabilities of shape [batch, ref_len] for decoder-only models,
            or [batch, ref_len - 1] for encoder-decoder models (excludes first decoder token)

    Raises:
        RuntimeError: If steer() has not been called
        ValueError: If ref_output_ids is None
    """
    if not self._is_steered:
        raise RuntimeError("Must call `.steer()` before `.compute_logprobs()`.")
    if ref_output_ids is None:
        raise ValueError("`ref_output_ids` is required for `compute_logprobs()`.")

    runtime_kwargs = runtime_kwargs or {}
    device = self.model.device

    # input control
    steered_input_ids, attention_mask = self._prepare_inputs(
        input_ids=input_ids,
        attention_mask=attention_mask,
        runtime_kwargs=runtime_kwargs,
    )

    # normalize ref_output_ids
    if isinstance(ref_output_ids, list):
        ref_output_ids = torch.tensor(ref_output_ids, dtype=torch.long)
    if ref_output_ids.ndim == 1:
        ref_output_ids = ref_output_ids.unsqueeze(0)
    ref_output_ids = ref_output_ids.to(device)

    batch_size = steered_input_ids.size(0)
    ref_len = ref_output_ids.size(1)

    # broadcast single ref sequence across batch
    if ref_output_ids.size(0) == 1 and batch_size > 1:
        ref_output_ids = ref_output_ids.expand(batch_size, -1)

    if ref_len == 0:
        return torch.zeros((batch_size, 0), device=device, dtype=torch.float32)

    # state control
    self._setup_state_control(steered_input_ids, runtime_kwargs, **forward_kwargs)

    # forward pass under state control context
    is_encoder_decoder = getattr(self.model.config, "is_encoder_decoder", False)

    with self.state_control:
        with torch.no_grad():
            if is_encoder_decoder:
                outputs = self.model(
                    input_ids=steered_input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=ref_output_ids,
                    **forward_kwargs,
                )
                # predicts ref[t+1] from ref[0:t]; logits[:, t, :] -> ref[t+1]
                # logits[:, :-1, :] aligns with targets ref[:, 1:]
                logits = outputs.logits[:, :-1, :]
                target_ids = ref_output_ids[:, 1:]
            else:
                # concatenate input + ref for causal teacher forcing
                combined_ids = torch.cat([steered_input_ids, ref_output_ids], dim=1)
                combined_mask = torch.cat([
                    attention_mask,
                    torch.ones(batch_size, ref_len, device=device, dtype=attention_mask.dtype),
                ], dim=1)

                outputs = self.model(
                    input_ids=combined_ids,
                    attention_mask=combined_mask,
                    **forward_kwargs,
                )

                # logits at [input_len - 1] predicts ref[0]
                # logits at [input_len + ref_len - 2] predicts ref[ref_len - 1]
                input_len = steered_input_ids.size(1)
                logits = outputs.logits[:, input_len - 1: input_len + ref_len - 1, :]
                target_ids = ref_output_ids

    # compute logprobs
    logprobs = torch.log_softmax(logits, dim=-1)
    token_logprobs = logprobs.gather(
        dim=-1,
        index=target_ids.unsqueeze(-1),
    ).squeeze(-1)

    return token_logprobs
generate(input_ids, attention_mask=None, runtime_kwargs=None, **gen_kwargs)

Generate text with all steering controls applied.

Applies controls in sequence during generation:

  1. Input control adapts the prompt
  2. State control registers hooks for state control (e.g., activation steering)
  3. Output control handles the actual generation

Parameters:

Name Type Description Default
input_ids list[int] | LongTensor

Token IDs as list or tensor (shape: [seq_len] or [batch, seq_len])

required
attention_mask Tensor | None

Optional attention mask matching input_ids shape

None
runtime_kwargs dict | None

Per-generation parameters for controls (e.g., {"substrings": [...]})

None
**gen_kwargs

Generation parameters passed to model.generate()

{}

Returns:

Type Description
Tensor

Generated token IDs (shape: [batch, generated_len])

Raises:

Type Description
RuntimeError

If steer() has not yet been called

Source code in aisteer360/algorithms/core/steering_pipeline.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def generate(
        self,
        input_ids: list[int] | torch.LongTensor,
        attention_mask: torch.Tensor | None = None,
        runtime_kwargs: dict | None = None,
        **gen_kwargs
) -> torch.Tensor:
    """Generate text with all steering controls applied.

    Applies controls in sequence during generation:

    1. Input control adapts the prompt
    2. State control registers hooks for state control (e.g., activation steering)
    3. Output control handles the actual generation

    Args:
        input_ids: Token IDs as list or tensor (shape: [seq_len] or [batch, seq_len])
        attention_mask: Optional attention mask matching input_ids shape
        runtime_kwargs: Per-generation parameters for controls (e.g., {"substrings": [...]})
        **gen_kwargs: Generation parameters passed to `model.generate()`

    Returns:
        Generated token IDs (shape: [batch, generated_len])

    Raises:
        RuntimeError: If steer() has not yet been called
    """
    if not self._is_steered:
        raise RuntimeError("Must call `.steer()` before `.generate()`.")

    runtime_kwargs = runtime_kwargs or {}
    return_full_sequence = bool(gen_kwargs.pop("return_full_sequence", False))

    # input control
    steered_input_ids, attention_mask = self._prepare_inputs(
        input_ids=input_ids,
        attention_mask=attention_mask,
        runtime_kwargs=runtime_kwargs,
    )

    # state control
    self._setup_state_control(steered_input_ids, runtime_kwargs, **gen_kwargs)

    # output control
    with self.state_control:  # hooks live only for duration of decoding
        output_ids = self.output_control.generate(
            input_ids=steered_input_ids,
            attention_mask=attention_mask,
            runtime_kwargs=runtime_kwargs,
            model=self.model,
            **gen_kwargs
        )

    if not return_full_sequence:
        output_ids = output_ids[:, steered_input_ids.size(1):]

    return output_ids
generate_text(*args, **kwargs)

Generate text and decode to string(s).

Convenience wrapper that calls generate() and decodes the output tokens.

Parameters:

Name Type Description Default
*args

Arguments passed to generate()

()
**kwargs

Keyword arguments passed to generate()

{}

Returns:

Type Description
str | list[str]

Decoded text string (single prompt) or list of strings (batch)

Source code in aisteer360/algorithms/core/steering_pipeline.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def generate_text(self, *args, **kwargs) -> str | list[str]:
    """Generate text and decode to string(s).

    Convenience wrapper that calls generate() and decodes the output tokens.

    Args:
        *args: Arguments passed to generate()
        **kwargs: Keyword arguments passed to generate()

    Returns:
        Decoded text string (single prompt) or list of strings (batch)
    """
    ids = self.generate(*args, **kwargs)
    if ids.ndim == 1:
        return self.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
steer(**steer_kwargs)

Apply all steering controls to the model in place.

Executes each control's steer() method in a fixed bottom-up order: structural -> input -> state -> output. This ensures that higher-level controls always see the final configured model from lower levels.

If any control's steer() method returns a PreTrainedModel instance, it replaces the current model for subsequent controls.

Parameters:

Name Type Description Default
**steer_kwargs

Keyword arguments passed to all control steer() methods

{}

Raises:

Type Description
RuntimeError

If called more than once or no model available after steering

Source code in aisteer360/algorithms/core/steering_pipeline.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def steer(self, **steer_kwargs) -> None:
    """Apply all steering controls to the model in place.

    Executes each control's steer() method in a fixed bottom-up order: structural -> input -> state -> output.
    This ensures that higher-level controls always see the final configured model from lower levels.

    If any control's steer() method returns a PreTrainedModel instance, it replaces the current model for subsequent
    controls.

    Args:
        **steer_kwargs: Keyword arguments passed to all control steer() methods

    Raises:
        RuntimeError: If called more than once or no model available after steering
    """
    if self._is_steered:
        return

    # steer each control (bottom-up order: structural -> input -> state -> output)
    for control in (self.structural_control, self.input_control, self.state_control, self.output_control):
        steer_fn = getattr(control, "steer", None)
        if callable(steer_fn):
            maybe_new_model = steer_fn(self.model, tokenizer=self.tokenizer, **steer_kwargs)
            if isinstance(maybe_new_model, nn.Module):
                self.model = maybe_new_model

    # safety checks
    if self.model is None:
        raise RuntimeError(
            "No model is available after steering. Either provide a base model (lazy_init=False) or ensure a "
            "`StructuralControl` returns one."
        )

    if self.tokenizer is None:
        repo = getattr(self.model, "name_or_path", None)
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                repo or Path(getattr(self.structural_control.args, "out_path", "")),
                trust_remote_code=True,
            )
            self.tokenizer = ensure_pad_token(self.tokenizer)

        except Exception as exception:
            raise RuntimeError("Failed to resolve tokenizer post‑steer.") from exception

    for control in (self.structural_control, self.input_control, self.state_control, self.output_control):
        if hasattr(control, "tokenizer") and getattr(control, "tokenizer", None) is None:
            setattr(control, "tokenizer", self.tokenizer)

    # return steered pipeline
    self._is_steered = True

steering_utils

Helper functions for steering.

ensure_pad_token(tokenizer)

Set pad token to eos token if not already defined.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizerBase

HuggingFace tokenizer instance

required

Returns:

Type Description
PreTrainedTokenizerBase

The same tokenizer with pad_token configured

Source code in aisteer360/algorithms/core/steering_utils.py
67
68
69
70
71
72
73
74
75
76
77
78
79
def ensure_pad_token(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
    """Set pad token to eos token if not already defined.

    Args:
       tokenizer: HuggingFace tokenizer instance

    Returns:
       The same tokenizer with pad_token configured
    """
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer

merge_controls(supplied)

Sort supplied controls by category and ensure at most one per category.

Parameters:

Name Type Description Default
supplied Iterable[StructuralControl | StateControl | InputControl | OutputControl]

List of control instances to organize

required

Returns:

Type Description
dict[str, object]

Dict mapping field names to control instances (with default no-ops for unspecified categories)

Raises:

Type Description
ValueError

If multiple controls of the same category are supplied

TypeError

If an unrecognized control type is supplied

Source code in aisteer360/algorithms/core/steering_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def merge_controls(
        supplied: Iterable[StructuralControl | StateControl | InputControl | OutputControl]
) -> dict[str, object]:
    """Sort supplied controls by category and ensure at most one per category.

    Args:
       supplied: List of control instances to organize

    Returns:
       Dict mapping field names to control instances (with default no-ops for unspecified categories)

    Raises:
       ValueError: If multiple controls of the same category are supplied
       TypeError: If an unrecognized control type is supplied
    """
    bucket: dict[type, list] = defaultdict(list)
    for control in supplied:
        for category in _DEFAULT_FACTORIES:
            if isinstance(control, category):
                bucket[category].append(control)
                break
        else:
            raise TypeError(f"Unknown control type: {type(control)}")

    for category, controls in bucket.items():
        if len(controls) > 1:
            names = [type(control).__name__ for control in controls]
            raise ValueError(f"Multiple {category.__name__}s supplied: {names}")

    out: dict[str, object] = {}
    for category, factory in _DEFAULT_FACTORIES.items():
        instance = bucket.get(category, [factory()])[0]  # fresh instance every time
        out_key = (
            "input_control" if category is InputControl else
            "structural_control" if category is StructuralControl else
            "state_control" if category is StateControl else
            "output_control"
        )
        out[out_key] = instance
    return out