Skip to content

FewShot

aisteer360.algorithms.input_control.few_shot

args

control

Few-shot learning control for prompt adaptation.

FewShot

Bases: InputControl

Implementation of few-shot learning control for prompt adaptation.

FewShot enables selective behavioral steering by prepending specific examples to user prompts, guiding model responses through demonstration.

The method operates in two modes:

  1. Pool-based sampling: Maintains pools of positive and negative examples from which k examples are dynamically selected using configurable sampling strategies (random, semantic similarity, etc.).

  2. Runtime injection: Accepts examples directly at inference time through runtime_kwargs, enabling context-specific demonstrations without predefined pools. Useful for dynamic or user-provided examples.

The selected examples are formatted into a system prompt with clear positive/negative labels and prepended to the user query using the model's chat template, allowing the model to learn the desired behavior pattern from the demonstrations.

Parameters:

Name Type Description Default
directive str

Instruction text that precedes the examples, explaining the task or desired behavior. Defaults to None.

required
positive_example_pool Sequence[dict]

Pool of positive examples demonstrating desired behavior. Each dict can contain multiple key-value pairs. Defaults to None.

required
negative_example_pool Sequence[dict]

Pool of negative examples showing undesired behavior to avoid. Each dict can contain multiple key-value pairs. Defaults to None.

required
k_positive int

Number of positive examples to sample from the pool per query. Defaults to None.

required
k_negative int

Number of negative examples to sample from the pool per query. Defaults to None.

required
selector_name str

Name of the selection strategy ('random', 'semantic', etc.). Determines how examples are chosen from pools. Defaults to 'random'.

required
template str

Custom template for formatting the system prompt. Should contain {directive} and {example_blocks} placeholders. Defaults to built-in template.

required

Runtime keyword arguments:

  • positive_examples (list[dict], optional): Positive examples to use for this specific query (overrides pool-based selection).
  • negative_examples (list[dict], optional): Negative examples to use for this specific query (overrides pool-based selection).

Notes:

  • Requires a tokenizer with chat_template support for optimal formatting
  • Examples are automatically labeled as "### Positive example" or "### Negative example"
  • When both pools and runtime examples are available, runtime examples take precedence
  • If no examples are provided, the original input is returned unchanged
Source code in aisteer360/algorithms/input_control/few_shot/control.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class FewShot(InputControl):
    """
    Implementation of few-shot learning control for prompt adaptation.

    FewShot enables selective behavioral steering by prepending specific examples to user prompts, guiding model
    responses through demonstration.

    The method operates in two modes:

    1. **Pool-based sampling**: Maintains pools of positive and negative examples from which k examples are dynamically
        selected using configurable sampling strategies (random, semantic similarity, etc.).

    2. **Runtime injection**: Accepts examples directly at inference time through runtime_kwargs, enabling
        context-specific demonstrations without predefined pools. Useful for dynamic or user-provided examples.

    The selected examples are formatted into a system prompt with clear positive/negative labels and prepended to the
    user query using the model's chat template, allowing the model to learn the desired behavior pattern from the
    demonstrations.

    Args:
        directive (str, optional): Instruction text that precedes the examples, explaining the task or desired behavior.
            Defaults to None.
        positive_example_pool (Sequence[dict], optional): Pool of positive examples demonstrating desired behavior.
            Each dict can contain multiple key-value pairs. Defaults to None.
        negative_example_pool (Sequence[dict], optional): Pool of negative examples showing undesired behavior to avoid.
            Each dict can contain multiple key-value pairs. Defaults to None.
        k_positive (int, optional): Number of positive examples to sample from the pool per query.
            Defaults to None.
        k_negative (int, optional): Number of negative examples to sample from the pool per query.
            Defaults to None.
        selector_name (str, optional): Name of the selection strategy ('random', 'semantic', etc.).
            Determines how examples are chosen from pools. Defaults to 'random'.
        template (str, optional): Custom template for formatting the system prompt. Should contain
            {directive} and {example_blocks} placeholders. Defaults to built-in template.

    Runtime keyword arguments:

    - `positive_examples` (`list[dict]`, `optional`): Positive examples to use for this specific query (overrides pool-based
    selection).
    - `negative_examples` (`list[dict]`, `optional`): Negative examples to use for this specific query (overrides pool-based
    selection).

    Notes:

    - Requires a tokenizer with chat_template support for optimal formatting
    - Examples are automatically labeled as "### Positive example" or "### Negative example"
    - When both pools and runtime examples are available, runtime examples take precedence
    - If no examples are provided, the original input is returned unchanged
    """

    Args = FewShotArgs

    # default templates
    _SYSTEM_PROMPT_TEMPLATE = "{directive}: \n{example_blocks}\n\n"
    _POSITIVE_EXAMPLE_TEMPLATE = "### Positive example (behavior to follow)\n{content}\n"
    _NEGATIVE_EXAMPLE_TEMPLATE = "### Negative example (behavior to avoid)\n{content}\n"

    # placeholders
    tokenizer: PreTrainedTokenizer | None = None
    selector_name: str | None = None
    directive: str | None = None
    positive_example_pool: Sequence[dict] | None = None
    negative_example_pool: Sequence[dict] | None = None
    k_positive: int | None = None
    k_negative: int | None = None
    selector: Selector | None = None

    def steer(
            self,
            model=None,
            tokenizer: PreTrainedTokenizer | None = None,
            **kwargs
    ) -> None:
        self.tokenizer = tokenizer

        # initialize selector if using pool mode
        if self.positive_example_pool is not None or self.negative_example_pool is not None:
            if self.selector_name:
                selector_cls = SELECTOR_REGISTRY.get(self.selector_name, RandomSelector)
                self.selector = selector_cls()
            else:
                self.selector = RandomSelector()

    def get_prompt_adapter(self) -> Callable[[list[int] | torch.Tensor, dict[str, Any]], list[int] | torch.Tensor]:
        """Return a prompt adapter function that adds few-shot examples to the model's system prompt. Creates and
        returns a closure that modifies input token sequences by prepending few-shot examples.

        The returned adapter function performs the following steps:

        1. Determines operational mode (runtime examples take precedence over pools)
        2. Decodes input tokens to retrieve the original user message
        3. Selects or retrieves appropriate examples based on mode
        4. Formats examples with positive/negative labels
        5. Constructs a system prompt containing the examples
        6. Applies the model's chat template (if available) to combine system prompt and user message
        7. Re-encodes the adapted text to tokens

        Returns:
            A prompt adapter function.

        Raises:
            RuntimeError: If tokenizer is not set (requires calling `steer()` first)

        Warnings:
            UserWarning: Issued when:

                - No examples available from either pools or runtime_kwargs
                - No examples remain after selection/sampling
                - Tokenizer lacks chat_template support (falls back to direct prepending)
        """

        if self.tokenizer is None:
            raise RuntimeError("FewShot needs a tokenizer; call .steer() first.")

        def adapter(input_ids: list[int] | torch.Tensor, runtime_kwargs: dict[str, Any]) -> list[int] | torch.Tensor:

            # infer mode from arguments
            using_runtime_examples = (runtime_kwargs and ("positive_examples" in runtime_kwargs or
                                                          "negative_examples" in runtime_kwargs))
            using_pool_mode = self.positive_example_pool is not None or self.negative_example_pool is not None

            if not using_runtime_examples and not using_pool_mode:
                warnings.warn(
                    "FewShot: No examples provided via runtime_kwargs or example pools. "
                    "Returning original input unchanged.",
                    UserWarning
                )
                return input_ids

            # decode to retrieve user message
            if isinstance(input_ids, torch.Tensor):
                input_ids_list = input_ids.tolist()[0]
            else:
                input_ids_list = input_ids

            original_text = self.tokenizer.decode(input_ids_list, skip_special_tokens=True)

            # get examples based on mode
            if using_runtime_examples:
                examples = self._gather_runtime_examples(runtime_kwargs)
            else:
                examples = self._sample_from_pools()

            if not examples:
                warnings.warn(
                    "FewShot: No examples available after selection. Returning original input unchanged.",
                    UserWarning
                )
                return input_ids

            examples_text = self._format_examples(examples)

            # apply chat template
            if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
                messages = [
                    {"role": "system", "content": examples_text},
                    {"role": "user", "content": original_text}
                ]
                adapted_text = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
            else:
                warnings.warn(
                    "No chat template found for tokenizer. Prepending few-shot examples directly to user query.",
                    UserWarning
                )
                adapted_text = examples_text + original_text

            # encode the adapted text
            adapted_tokens = self.tokenizer.encode(
                adapted_text,
                add_special_tokens=False,
                return_tensors="pt" if isinstance(input_ids, torch.Tensor) else None
            )

            if isinstance(input_ids, torch.Tensor):
                return adapted_tokens.squeeze(0) if adapted_tokens.dim() > 1 else adapted_tokens
            else:
                return adapted_tokens

        return adapter

    def _sample_from_pools(self) -> list[dict[str, Any]]:
        """Sample examples from the pools."""
        all_examples = []

        if self.positive_example_pool and self.k_positive and self.k_positive > 0:
            positive_samples = self.selector.sample(
                self.positive_example_pool,
                self.k_positive
            )
            for example in positive_samples:
                all_examples.append({**example, "_label": "positive"})

        if self.negative_example_pool and self.k_negative and self.k_negative > 0:
            negative_samples = self.selector.sample(
                self.negative_example_pool,
                self.k_negative
            )
            for example in negative_samples:
                all_examples.append({**example, "_label": "negative"})

        return all_examples

    def _format_examples(self, examples: list[dict[str, Any]]) -> str:
        """Format examples for system prompt."""
        if not examples:
            return ""

        example_blocks = []
        for example in examples:
            is_positive = example.get("_label", "positive") == "positive"
            content = self._format_example_content(example)

            if is_positive:
                example_blocks.append(self._POSITIVE_EXAMPLE_TEMPLATE.format(content=content))
            else:
                example_blocks.append(self._NEGATIVE_EXAMPLE_TEMPLATE.format(content=content))

        template = getattr(self, 'template', None) or self._SYSTEM_PROMPT_TEMPLATE
        formatted_blocks = "\n".join(example_blocks)

        return template.format(directive=self.directive or "", example_blocks=formatted_blocks)

    @staticmethod
    def _gather_runtime_examples(runtime_kwargs: dict[str, Any]) -> list[dict[str, Any]]:
        """Gather examples from runtime_kwargs."""
        examples = []
        if "positive_examples" in runtime_kwargs:
            for example in runtime_kwargs["positive_examples"]:
                examples.append({**example, "_label": "positive"})
        if "negative_examples" in runtime_kwargs:
            for example in runtime_kwargs["negative_examples"]:
                examples.append({**example, "_label": "negative"})
        return examples

    @staticmethod
    def _format_example_content(example: dict[str, Any]) -> str:
        segments = []
        for key, value in example.items():
            if key == "_label":
                continue
            formatted_key = key.replace("_", " ").title()
            segments.append(f"{formatted_key}: {value}")

        return "\n".join(segments)
args = self.Args.validate(*args, **kwargs) instance-attribute
directive = None class-attribute instance-attribute
enabled = True class-attribute instance-attribute
k_negative = None class-attribute instance-attribute
k_positive = None class-attribute instance-attribute
negative_example_pool = None class-attribute instance-attribute
positive_example_pool = None class-attribute instance-attribute
selector = None class-attribute instance-attribute
selector_name = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
get_prompt_adapter()

Return a prompt adapter function that adds few-shot examples to the model's system prompt. Creates and returns a closure that modifies input token sequences by prepending few-shot examples.

The returned adapter function performs the following steps:

  1. Determines operational mode (runtime examples take precedence over pools)
  2. Decodes input tokens to retrieve the original user message
  3. Selects or retrieves appropriate examples based on mode
  4. Formats examples with positive/negative labels
  5. Constructs a system prompt containing the examples
  6. Applies the model's chat template (if available) to combine system prompt and user message
  7. Re-encodes the adapted text to tokens

Returns:

Type Description
Callable[[list[int] | Tensor, dict[str, Any]], list[int] | Tensor]

A prompt adapter function.

Raises:

Type Description
RuntimeError

If tokenizer is not set (requires calling steer() first)

Warns:

Type Description
UserWarning

Issued when:

  • No examples available from either pools or runtime_kwargs
  • No examples remain after selection/sampling
  • Tokenizer lacks chat_template support (falls back to direct prepending)
Source code in aisteer360/algorithms/input_control/few_shot/control.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def get_prompt_adapter(self) -> Callable[[list[int] | torch.Tensor, dict[str, Any]], list[int] | torch.Tensor]:
    """Return a prompt adapter function that adds few-shot examples to the model's system prompt. Creates and
    returns a closure that modifies input token sequences by prepending few-shot examples.

    The returned adapter function performs the following steps:

    1. Determines operational mode (runtime examples take precedence over pools)
    2. Decodes input tokens to retrieve the original user message
    3. Selects or retrieves appropriate examples based on mode
    4. Formats examples with positive/negative labels
    5. Constructs a system prompt containing the examples
    6. Applies the model's chat template (if available) to combine system prompt and user message
    7. Re-encodes the adapted text to tokens

    Returns:
        A prompt adapter function.

    Raises:
        RuntimeError: If tokenizer is not set (requires calling `steer()` first)

    Warnings:
        UserWarning: Issued when:

            - No examples available from either pools or runtime_kwargs
            - No examples remain after selection/sampling
            - Tokenizer lacks chat_template support (falls back to direct prepending)
    """

    if self.tokenizer is None:
        raise RuntimeError("FewShot needs a tokenizer; call .steer() first.")

    def adapter(input_ids: list[int] | torch.Tensor, runtime_kwargs: dict[str, Any]) -> list[int] | torch.Tensor:

        # infer mode from arguments
        using_runtime_examples = (runtime_kwargs and ("positive_examples" in runtime_kwargs or
                                                      "negative_examples" in runtime_kwargs))
        using_pool_mode = self.positive_example_pool is not None or self.negative_example_pool is not None

        if not using_runtime_examples and not using_pool_mode:
            warnings.warn(
                "FewShot: No examples provided via runtime_kwargs or example pools. "
                "Returning original input unchanged.",
                UserWarning
            )
            return input_ids

        # decode to retrieve user message
        if isinstance(input_ids, torch.Tensor):
            input_ids_list = input_ids.tolist()[0]
        else:
            input_ids_list = input_ids

        original_text = self.tokenizer.decode(input_ids_list, skip_special_tokens=True)

        # get examples based on mode
        if using_runtime_examples:
            examples = self._gather_runtime_examples(runtime_kwargs)
        else:
            examples = self._sample_from_pools()

        if not examples:
            warnings.warn(
                "FewShot: No examples available after selection. Returning original input unchanged.",
                UserWarning
            )
            return input_ids

        examples_text = self._format_examples(examples)

        # apply chat template
        if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
            messages = [
                {"role": "system", "content": examples_text},
                {"role": "user", "content": original_text}
            ]
            adapted_text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        else:
            warnings.warn(
                "No chat template found for tokenizer. Prepending few-shot examples directly to user query.",
                UserWarning
            )
            adapted_text = examples_text + original_text

        # encode the adapted text
        adapted_tokens = self.tokenizer.encode(
            adapted_text,
            add_special_tokens=False,
            return_tensors="pt" if isinstance(input_ids, torch.Tensor) else None
        )

        if isinstance(input_ids, torch.Tensor):
            return adapted_tokens.squeeze(0) if adapted_tokens.dim() > 1 else adapted_tokens
        else:
            return adapted_tokens

    return adapter
steer(model=None, tokenizer=None, **kwargs)

Optional steering/preparation.

Source code in aisteer360/algorithms/input_control/few_shot/control.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def steer(
        self,
        model=None,
        tokenizer: PreTrainedTokenizer | None = None,
        **kwargs
) -> None:
    self.tokenizer = tokenizer

    # initialize selector if using pool mode
    if self.positive_example_pool is not None or self.negative_example_pool is not None:
        if self.selector_name:
            selector_cls = SELECTOR_REGISTRY.get(self.selector_name, RandomSelector)
            self.selector = selector_cls()
        else:
            self.selector = RandomSelector()

selectors

Example selectors for few-shot learning prompt adaptation.

This module provides different strategies for selecting examples from pools during few-shot prompting. Selectors determine which examples are passed as demonstrations to the model.

Available selectors:

  • RandomSelector: Randomly samples examples from the pool

base

Base interface for few-shot example selection strategies.

Selector

Bases: ABC

Base class for example selector.

Source code in aisteer360/algorithms/input_control/few_shot/selectors/base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Selector(ABC):
    """
    Base class for example selector.
    """

    @abstractmethod
    def sample(
        self,
        pool: Sequence[dict],
        k: int,
        **kwargs: Any
    ) -> list[dict]:
        """Return k items chosen from pool."""
        raise NotImplementedError
sample(pool, k, **kwargs) abstractmethod

Return k items chosen from pool.

Source code in aisteer360/algorithms/input_control/few_shot/selectors/base.py
13
14
15
16
17
18
19
20
21
@abstractmethod
def sample(
    self,
    pool: Sequence[dict],
    k: int,
    **kwargs: Any
) -> list[dict]:
    """Return k items chosen from pool."""
    raise NotImplementedError

random_selector

RandomSelector

Bases: Selector

Selects examples uniformly at random from a pool for few-shot prompting.

Source code in aisteer360/algorithms/input_control/few_shot/selectors/random_selector.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class RandomSelector(Selector):
    """Selects examples uniformly at random from a pool for few-shot prompting."""

    def sample(self, pool: Sequence[dict], k: int, **_) -> list[dict]:
        """Select k examples uniformly at random from the pool.

        Args:
            pool: Available examples to select from
            k: Number of examples to select
            **_: Ignored (for compatibility with other selectors)

        Returns:
            List of randomly selected examples (up to min(k, len(pool)))
        """
        return random.sample(pool, min(k, len(pool)))
sample(pool, k, **_)

Select k examples uniformly at random from the pool.

Parameters:

Name Type Description Default
pool Sequence[dict]

Available examples to select from

required
k int

Number of examples to select

required
**_

Ignored (for compatibility with other selectors)

{}

Returns:

Type Description
list[dict]

List of randomly selected examples (up to min(k, len(pool)))

Source code in aisteer360/algorithms/input_control/few_shot/selectors/random_selector.py
10
11
12
13
14
15
16
17
18
19
20
21
def sample(self, pool: Sequence[dict], k: int, **_) -> list[dict]:
    """Select k examples uniformly at random from the pool.

    Args:
        pool: Available examples to select from
        k: Number of examples to select
        **_: Ignored (for compatibility with other selectors)

    Returns:
        List of randomly selected examples (up to min(k, len(pool)))
    """
    return random.sample(pool, min(k, len(pool)))