Skip to content

TRL

aisteer360.algorithms.structural_control.wrappers.trl

The TRL wrapper implements a variety of methods from Hugging Face's TRL library.

The current functionality spans the following methods:

  • SFT (Supervised Fine-Tuning): Standard supervised learning to fine-tune language models on demonstration data
  • DPO (Direct Preference Optimization): Trains models directly on preference data without requiring a separate reward model
  • APO (Anchored Preference Optimization): A variant of DPO that uses an anchor model to improve training stability and performance
  • SPPO (Self-Play Preference Optimization): Iterative preference optimization using self-generated synthetic data to reduce dependency on external preference datasets

For documentation information, please refer to the TRL page and the SPPO repository.

args

base_mixin

TRLMixin

Small shared helpers for TRL-based structural controls.

Source code in aisteer360/algorithms/structural_control/wrappers/trl/base_mixin.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class TRLMixin:
    """
    Small shared helpers for TRL-based structural controls.
    """

    # populated from Args by subclasses
    base_model_name_or_path: str | None = None
    tokenizer_name_or_path: str | None = None
    hf_model_kwargs: dict[str, Any] = {}

    training_args: dict[str, Any] = {}
    output_dir: str | None = None
    resume_from_checkpoint: str | None = None

    use_peft: bool = False
    peft_type: Any = None
    lora_kwargs: dict[str, Any] = {}
    adapter_name: str | None = None

    merge_lora_after_train: bool = False
    merged_output_dir: str | None = None

    # resolved at runtime
    model: PreTrainedModel | None = None
    tokenizer: PreTrainedTokenizer | None = None
    device = None

    def _resolve_model_tokenizer(
        self,
        model: PreTrainedModel | None,
        tokenizer: PreTrainedTokenizer | None,
    ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
        if model is None:
            if not self.base_model_name_or_path:
                raise ValueError("TRLMixin: model is None and `base_model_name_or_path` was not provided.")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.base_model_name_or_path,
                trust_remote_code=True,
                **(self.hf_model_kwargs or {}),
            )
        else:
            self.model = model

        if tokenizer is None:
            path = (
                self.tokenizer_name_or_path
                or getattr(self.model, "name_or_path", None)
                or self.base_model_name_or_path
            )
            if not path:
                raise ValueError("TRLMixin: could not resolve tokenizer path.")
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        else:
            self.tokenizer = tokenizer

        self.device = next(self.model.parameters()).device
        return self.model, self.tokenizer

    @staticmethod
    def _filter_kwargs_for_class_or_callable(target: Any, kwargs: dict[str, Any]) -> dict[str, Any]:
        """Keep only kwargs accepted by a dataclass or callable."""
        if is_dataclass(target):
            allowed = {f.name for f in fields(target)}
        else:
            try:
                allowed = set(inspect.signature(target).parameters.keys())
            except (TypeError, ValueError):
                allowed = set(kwargs.keys())
        return {k: v for k, v in kwargs.items() if k in allowed and v is not None}

    def _post_train_freeze(self) -> PreTrainedModel:
        self.model.eval()
        for parameter in self.model.parameters():
            parameter.requires_grad_(False)
        return self.model

    def _maybe_save_trained_artifacts(self, trainer) -> None:
        output_dir = self.training_args.get("output_dir") or self.output_dir
        if output_dir:
            trainer.save_model(output_dir)
            try:
                self.tokenizer.save_pretrained(output_dir)
            except Exception:
                pass

    def _maybe_merge_lora_in_place(self) -> None:
        """Optionally merge LoRA into the base weights."""
        if not (self.use_peft and self.merge_lora_after_train):
            return

        # trainer often returns a PEFT-wrapped model; merge if possible
        if hasattr(self.model, "merge_and_unload"):
            merged_model = self.model.merge_and_unload()
            self.model = merged_model
            self.device = next(self.model.parameters()).device

            # save if requested
            if self.merged_output_dir:
                self.model.save_pretrained(self.merged_output_dir)
                try:
                    self.tokenizer.save_pretrained(self.merged_output_dir)
                except Exception:
                    pass
adapter_name = None class-attribute instance-attribute
base_model_name_or_path = None class-attribute instance-attribute
device = None class-attribute instance-attribute
hf_model_kwargs = {} class-attribute instance-attribute
lora_kwargs = {} class-attribute instance-attribute
merge_lora_after_train = False class-attribute instance-attribute
merged_output_dir = None class-attribute instance-attribute
model = None class-attribute instance-attribute
output_dir = None class-attribute instance-attribute
peft_type = None class-attribute instance-attribute
resume_from_checkpoint = None class-attribute instance-attribute
tokenizer = None class-attribute instance-attribute
tokenizer_name_or_path = None class-attribute instance-attribute
training_args = {} class-attribute instance-attribute
use_peft = False class-attribute instance-attribute