Skip to content

Output control

aisteer360.algorithms.output_control.base

Output control base classes.

This module provides the abstract base classes for methods that intervene during text generation (e.g., via modifying logits, constraining the output space, or implementing alternative decoding strategies).

Two base classes are provided:

  • OutputControl: Base class for all output control methods.
  • NoOutputControl: Identity (null) control; used when no output control is defined in steering pipeline.

Output controls implement steering through decoding algorithms and constraints, modifying the sampling process to produce generations y ~ᵈ p_θ(x), where ~ᵈ indicates the modified generation process.

Examples of output controls:

  • Constrained beam search
  • Reward-augmented decoding
  • Grammar-constrained generation
  • Token filtering and masking
  • Classifier-guided generation
  • Best-of-N sampling

See Also:

  • aisteer360.algorithms.output_control: Implementations of output control methods
  • aisteer360.core.steering_pipeline: Integration with steering pipeline

NoOutputControl

Bases: OutputControl

Identity output control.

Used as the default when no output control is needed. Calls (unsteered) model's generate.

Source code in aisteer360/algorithms/output_control/base.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class NoOutputControl(OutputControl):
    """Identity output control.

    Used as the default when no output control is needed. Calls (unsteered) model's generate.
    """
    enabled: bool = False

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,  # only for API compliance as runtime_kwargs are not used in HF models.
        model: PreTrainedModel,
        **gen_kwargs,
    ) -> torch.Tensor:
        """Null generate operation; applies model's generate."""
        return model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)

args = self.Args.validate(*args, **kwargs) instance-attribute

enabled = False class-attribute instance-attribute

generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs)

Null generate operation; applies model's generate.

Source code in aisteer360/algorithms/output_control/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def generate(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    runtime_kwargs: dict | None,  # only for API compliance as runtime_kwargs are not used in HF models.
    model: PreTrainedModel,
    **gen_kwargs,
) -> torch.Tensor:
    """Null generate operation; applies model's generate."""
    return model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)

steer(model, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/output_control/base.py
76
77
78
79
80
81
def steer(self,
          model: PreTrainedModel,
          tokenizer=None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass

OutputControl

Bases: ABC

Abstract base class for output control steering methods.

Overrides the generation process with custom logic.

Methods:

Name Description
generate

Custom generation (required)

steer

One-time preparation (optional)

Source code in aisteer360/algorithms/output_control/base.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class OutputControl(ABC):
    """Abstract base class for output control steering methods.

    Overrides the generation process with custom logic.

    Methods:
        generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs) -> Tensor: Custom generation (required)
        steer(model, tokenizer, **kwargs) -> None: One-time preparation (optional)
    """

    Args: Type[BaseArgs] | None = None

    enabled: bool = True

    def __init__(self, *args, **kwargs) -> None:
        if self.Args is None:  # null control
            if args or kwargs:
                raise TypeError(f"{type(self).__name__} accepts no constructor arguments.")
            return

        self.args: BaseArgs = self.Args.validate(*args, **kwargs)

        # move fields to attributes
        for field in fields(self.args):
            setattr(self, field.name, getattr(self.args, field.name))

    @abstractmethod
    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        runtime_kwargs: dict | None,
        model: PreTrainedModel,
        **gen_kwargs,
    ) -> torch.Tensor:
        """Custom generation logic."""
        pass

    def steer(self,
              model: PreTrainedModel,
              tokenizer=None,
              **kwargs) -> None:
        """Optional steering/preparation."""
        pass

args = self.Args.validate(*args, **kwargs) instance-attribute

enabled = True class-attribute instance-attribute

generate(input_ids, attention_mask, runtime_kwargs, model, **gen_kwargs) abstractmethod

Custom generation logic.

Source code in aisteer360/algorithms/output_control/base.py
64
65
66
67
68
69
70
71
72
73
74
@abstractmethod
def generate(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    runtime_kwargs: dict | None,
    model: PreTrainedModel,
    **gen_kwargs,
) -> torch.Tensor:
    """Custom generation logic."""
    pass

steer(model, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/output_control/base.py
76
77
78
79
80
81
def steer(self,
          model: PreTrainedModel,
          tokenizer=None,
          **kwargs) -> None:
    """Optional steering/preparation."""
    pass