Skip to content

State control

aisteer360.algorithms.state_control.base

State control base classes.

This module provides the abstract base class for methods that register hooks into the model (e.g., to modify intermediate representations during inference); does not change model weights.

Two base classes are provided:

  • StateControl: Base class for all state control methods.
  • NoStateControl: Identity (null) control; used when no state control is defined in steering pipeline.

State controls implement steering through runtime intervention in the model's forward pass, modifying internal states (activations, attention patterns) to produce generations following y ~ p_θᵃ(x), where "p_θᵃ" is the model with state controls.

Examples of state controls:

  • Activation steering (e.g., adding direction vectors)
  • Attention head manipulation and pruning
  • Layer-wise activation editing
  • Dynamic routing between components
  • Representation engineering techniques

The base class provides automatic hook management through context managers (ensures cleanup and avoids memory leaks).

See Also:

  • aisteer360.algorithms.state_control: Implementations of state control methods
  • aisteer360.core.steering_pipeline: Integration with steering pipeline

BackwardHook = Callable[[nn.Module, tuple, tuple], tuple] module-attribute

ForwardHook = Callable[[nn.Module, tuple, torch.Tensor], torch.Tensor] module-attribute

HookSpec = dict[str, str | PreHook | ForwardHook | BackwardHook] module-attribute

PreHook = Callable[[nn.Module, tuple], tuple | torch.Tensor] module-attribute

NoStateControl

Bases: StateControl

Identity state control.

Used as the default when no state control is needed. Returns empty hook dictionaries and skips registration.

Source code in aisteer360/algorithms/state_control/base.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class NoStateControl(StateControl):
    """Identity state control.

    Used as the default when no state control is needed. Returns empty hook dictionaries and skips registration.
    """
    enabled: bool = False

    def get_hooks(self, *_, **__) -> dict[str, list[HookSpec]]:
        """Return empty hooks."""
        return {"pre": [], "forward": [], "backward": []}

    def steer(self,
              model: PreTrainedModel,
              tokenizer=None,
              **kwargs) -> None:
        """Null steering operation."""
        pass

    def register_hooks(self, *_):
        """Null registration operation."""
        pass

    def remove_hooks(self, *_):
        """Null removal operation."""
        pass

    def set_hooks(self, hooks: dict[str, list[HookSpec]]):
        """Null set operation."""
        pass

    def reset(self):
        """Null reset operation."""
        pass

args = self.Args.validate(*args, **kwargs) instance-attribute

enabled = False class-attribute instance-attribute

hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute

registered = [] instance-attribute

get_hooks(*_, **__)

Return empty hooks.

Source code in aisteer360/algorithms/state_control/base.py
152
153
154
def get_hooks(self, *_, **__) -> dict[str, list[HookSpec]]:
    """Return empty hooks."""
    return {"pre": [], "forward": [], "backward": []}

register_hooks(*_)

Null registration operation.

Source code in aisteer360/algorithms/state_control/base.py
163
164
165
def register_hooks(self, *_):
    """Null registration operation."""
    pass

remove_hooks(*_)

Null removal operation.

Source code in aisteer360/algorithms/state_control/base.py
167
168
169
def remove_hooks(self, *_):
    """Null removal operation."""
    pass

reset()

Null reset operation.

Source code in aisteer360/algorithms/state_control/base.py
175
176
177
def reset(self):
    """Null reset operation."""
    pass

set_hooks(hooks)

Null set operation.

Source code in aisteer360/algorithms/state_control/base.py
171
172
173
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Null set operation."""
    pass

steer(model, tokenizer=None, **kwargs)

Null steering operation.

Source code in aisteer360/algorithms/state_control/base.py
156
157
158
159
160
161
def steer(self,
          model: PreTrainedModel,
          tokenizer=None,
          **kwargs) -> None:
    """Null steering operation."""
    pass

StateControl

Bases: ABC

Abstract base class for state control steering methods.

Modifies internal model states during forward passes via hooks.

Methods:

Name Description
get_hooks

Create hook specs (required)

steer

One-time preparation (optional)

reset

Reset logic (optional)

register_hooks

Attach hooks to model (provided)

remove_hooks

Remove all registered hooks (provided)

Source code in aisteer360/algorithms/state_control/base.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class StateControl(ABC):
    """Abstract base class for state control steering methods.

    Modifies internal model states during forward passes via hooks.

    Methods:
        get_hooks(input_ids, runtime_kwargs, **kwargs) -> dict: Create hook specs (required)
        steer(model, tokenizer, **kwargs) -> None: One-time preparation (optional)
        reset() -> None: Reset logic (optional)
        register_hooks(model) -> None: Attach hooks to model (provided)
        remove_hooks() -> None: Remove all registered hooks (provided)
    """

    Args: Type[BaseArgs] | None = None

    enabled: bool = True
    _model_ref: PreTrainedModel | None = None

    def __init__(self, *args, **kwargs) -> None:
        if self.Args is None:  # null control
            if args or kwargs:
                raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
            return

        self.args: BaseArgs = self.Args.validate(*args, **kwargs)

        # move fields to attributes
        for field in fields(self.args):
            setattr(self, field.name, getattr(self.args, field.name))

        self.hooks: dict[str, list[HookSpec]] = {"pre": [], "forward": [], "backward": []}
        self.registered: list[torch.utils.hooks.RemovableHandle] = []

    @abstractmethod
    def get_hooks(
        self,
        input_ids: torch.Tensor,
        runtime_kwargs: dict | None,
        **kwargs,
    ) -> dict[str, list[HookSpec]]:
        """Create hook specifications for the current generation."""
        pass

    def steer(self,
              model: PreTrainedModel,
              tokenizer: PreTrainedTokenizerBase = None,
              **kwargs) -> None:
        """Optional steering/preparation."""
        pass

    def register_hooks(self, model: PreTrainedModel) -> None:
        """Attach hooks to model."""
        for phase in ("pre", "forward", "backward"):
            for spec in self.hooks[phase]:
                module = model.get_submodule(spec["module"])
                if phase == "pre":
                    handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
                elif phase == "forward":
                    handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
                else:
                    handle = module.register_full_backward_hook(spec["hook_func"])
                self.registered.append(handle)

    def remove_hooks(self) -> None:
        """Remove all registered hooks from the model."""
        for handle in self.registered:
            handle.remove()
        self.registered.clear()

    def set_hooks(self, hooks: dict[str, list[HookSpec]]):
        """Update the hook specifications to be registered."""
        self.hooks = hooks

    def __enter__(self):
        """Context manager entry: register hooks to model.

        Raises:
            RuntimeError: If model reference not set by pipeline
        """
        if self._model_ref is None:
            raise RuntimeError("Model reference not set before entering context.")
        self.register_hooks(self._model_ref)

        return self

    def __exit__(self, exc_type, exc, tb):
        """Context manager exit: clean up all hooks."""
        self.remove_hooks()

    def reset(self):
        """Optional reset call for state control."""
        pass


    def reset(self):
        """Optional reset call for state control"""
        pass

args = self.Args.validate(*args, **kwargs) instance-attribute

enabled = True class-attribute instance-attribute

hooks = {'pre': [], 'forward': [], 'backward': []} instance-attribute

registered = [] instance-attribute

get_hooks(input_ids, runtime_kwargs, **kwargs) abstractmethod

Create hook specifications for the current generation.

Source code in aisteer360/algorithms/state_control/base.py
79
80
81
82
83
84
85
86
87
@abstractmethod
def get_hooks(
    self,
    input_ids: torch.Tensor,
    runtime_kwargs: dict | None,
    **kwargs,
) -> dict[str, list[HookSpec]]:
    """Create hook specifications for the current generation."""
    pass

register_hooks(model)

Attach hooks to model.

Source code in aisteer360/algorithms/state_control/base.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def register_hooks(self, model: PreTrainedModel) -> None:
    """Attach hooks to model."""
    for phase in ("pre", "forward", "backward"):
        for spec in self.hooks[phase]:
            module = model.get_submodule(spec["module"])
            if phase == "pre":
                handle = module.register_forward_pre_hook(spec["hook_func"], with_kwargs=True)
            elif phase == "forward":
                handle = module.register_forward_hook(spec["hook_func"], with_kwargs=True)
            else:
                handle = module.register_full_backward_hook(spec["hook_func"])
            self.registered.append(handle)

remove_hooks()

Remove all registered hooks from the model.

Source code in aisteer360/algorithms/state_control/base.py
109
110
111
112
113
def remove_hooks(self) -> None:
    """Remove all registered hooks from the model."""
    for handle in self.registered:
        handle.remove()
    self.registered.clear()

reset()

Optional reset call for state control

Source code in aisteer360/algorithms/state_control/base.py
140
141
142
def reset(self):
    """Optional reset call for state control"""
    pass

set_hooks(hooks)

Update the hook specifications to be registered.

Source code in aisteer360/algorithms/state_control/base.py
115
116
117
def set_hooks(self, hooks: dict[str, list[HookSpec]]):
    """Update the hook specifications to be registered."""
    self.hooks = hooks

steer(model, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/state_control/base.py
89
90
91
92
93
94
def steer(self,
          model: PreTrainedModel,
          tokenizer: PreTrainedTokenizerBase = None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass