Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| import os | |
| import random | |
| import textwrap | |
| import warnings | |
| from collections import defaultdict | |
| from contextlib import contextmanager, nullcontext | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Callable, Literal, Optional, Union | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from accelerate import PartialState | |
| from accelerate.utils import tqdm | |
| from datasets import Dataset, IterableDataset | |
| from torch import autocast | |
| from torch.utils.data import DataLoader | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BaseImageProcessor, | |
| DataCollator, | |
| FeatureExtractionMixin, | |
| PreTrainedModel, | |
| PreTrainedTokenizerBase, | |
| ProcessorMixin, | |
| Trainer, | |
| is_comet_available, | |
| is_wandb_available, | |
| ) | |
| from transformers.data.data_collator import DataCollatorMixin | |
| from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES | |
| from transformers.trainer_callback import TrainerCallback | |
| from transformers.trainer_utils import EvalLoopOutput | |
| from transformers.utils import is_liger_kernel_available, is_peft_available | |
| from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt | |
| from ..models import create_reference_model, prepare_deepspeed | |
| from ..models.utils import prepare_fsdp | |
| from .callbacks import SyncRefModelCallback | |
| from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType | |
| from .utils import ( | |
| RunningMoments, | |
| cap_exp, | |
| disable_dropout_in_model, | |
| empty_cache, | |
| flush_left, | |
| flush_right, | |
| generate_model_card, | |
| get_comet_experiment_url, | |
| log_table_to_comet_experiment, | |
| pad, | |
| pad_to_length, | |
| peft_module_casting_to_bf16, | |
| selective_log_softmax, | |
| ) | |
| if is_peft_available(): | |
| from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training | |
| if is_liger_kernel_available(): | |
| from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss | |
| if is_wandb_available(): | |
| import wandb | |
| def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor: | |
| """Shift input ids one token to the right, and pad with pad_token_id""" | |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
| shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
| shifted_input_ids[:, 0] = decoder_start_token_id | |
| class DataCollatorForPreference(DataCollatorMixin): | |
| """ | |
| Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they | |
| are not all of the same length. | |
| Args: | |
| pad_token_id (`int`): | |
| Token ID to use for padding. | |
| return_tensors (`str`, *optional*, defaults to `"pt"`): | |
| Type of Tensor to return. Only `"pt"` is currently supported. | |
| Examples: | |
| ```python | |
| >>> from trl import DataCollatorForPreference | |
| >>> collator = DataCollatorForPreference(pad_token_id=0) | |
| >>> examples = [ | |
| ... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, | |
| ... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]} | |
| ... ] | |
| >>> collator(examples) | |
| {'prompt_input_ids': tensor([[1, 2, 3], | |
| [0, 7, 8]]), | |
| 'prompt_attention_mask': tensor([[1, 1, 1], | |
| [0, 1, 1]]), | |
| 'chosen_input_ids': tensor([[ 4, 5], | |
| [ 9, 10]]), | |
| 'chosen_attention_mask': tensor([[1, 1], | |
| [1, 1]]), | |
| 'rejected_input_ids': tensor([[ 6, 0, 0], | |
| [11, 12, 13]]), | |
| 'rejected_attention_mask': tensor([[1, 0, 0], | |
| [1, 1, 1]]) | |
| } | |
| ``` | |
| """ | |
| pad_token_id: int | |
| return_tensors: str = "pt" | |
| def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: | |
| # Convert to tensor | |
| prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] | |
| prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] | |
| chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] | |
| chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids] | |
| rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] | |
| rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids] | |
| if "pixel_values" in examples[0]: | |
| pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] | |
| if "pixel_attention_mask" in examples[0]: | |
| pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] | |
| if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: | |
| ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) | |
| ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) | |
| # Pad | |
| output = {} | |
| output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") | |
| output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") | |
| output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id) | |
| output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0) | |
| output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id) | |
| output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0) | |
| if "pixel_values" in examples[0]: | |
| output["pixel_values"] = pad(pixel_values, padding_value=0.0) | |
| if "pixel_attention_mask" in examples[0]: | |
| output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) | |
| if "image_sizes" in examples[0]: | |
| output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) | |
| if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: | |
| output["ref_chosen_logps"] = ref_chosen_logps | |
| output["ref_rejected_logps"] = ref_rejected_logps | |
| return output | |
| class DPOTrainer(Trainer): | |
| """ | |
| Trainer for Direct Preference Optimization (DPO) method. | |
| This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. | |
| Args: | |
| model (`Union[str, PreTrainedModel]`): | |
| Model to be trained. Can be either: | |
| - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or | |
| a path to a *directory* containing model weights saved using | |
| [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is | |
| loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments | |
| in `args.model_init_kwargs`. | |
| - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. | |
| ref_model (`PreTrainedModelWrapper`): | |
| Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no | |
| reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. | |
| args ([`DPOConfig`], *optional*, defaults to `None`): | |
| Configuration for this trainer. If `None`, a default configuration is used. | |
| data_collator (`DataCollator`, *optional*): | |
| Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. | |
| Will default to [`DataCollatorForPreference`]. | |
| train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): | |
| Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can | |
| be either: | |
| - [Standard](dataset_formats#standard): Each sample contains plain text. | |
| - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role | |
| and content). | |
| eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): | |
| Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. | |
| processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): | |
| Processing class used to process the data. If `None`, the processing class is loaded from the model's name | |
| with [`~transformers.AutoTokenizer.from_pretrained`]. | |
| compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): | |
| The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return | |
| a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to | |
| `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered | |
| after the last eval batch to signal that the function needs to calculate and return the global summary | |
| statistics rather than accumulating the batch-level statistics. | |
| callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): | |
| List of callbacks to customize the training loop. Will add those to the list of default callbacks | |
| detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). | |
| If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] | |
| method. | |
| optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): | |
| A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your | |
| model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. | |
| optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): | |
| A tuple containing the optimizer class and keyword arguments to use. | |
| Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. | |
| preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): | |
| A function that preprocess the logits right before caching them at each evaluation step. Must take two | |
| tensors, the logits and the labels, and return the logits once processed as desired. The modifications made | |
| by this function will be reflected in the predictions received by `compute_metrics`. | |
| Note that the labels (second parameter) will be `None` if the dataset does not have them. | |
| peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): | |
| PEFT configuration used to wrap the model. If `None`, the model is not wrapped. | |
| """ | |
| _tag_names = ["trl", "dpo"] | |
| def __init__( | |
| self, | |
| model: Union[str, nn.Module, PreTrainedModel], | |
| ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, | |
| args: Optional[DPOConfig] = None, | |
| data_collator: Optional[DataCollator] = None, # type: ignore | |
| train_dataset: Optional[Union[Dataset, IterableDataset]] = None, | |
| eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, | |
| processing_class: Optional[ | |
| Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] | |
| ] = None, | |
| compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, | |
| callbacks: Optional[list[TrainerCallback]] = None, | |
| optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), | |
| optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, | |
| preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | |
| peft_config: Optional["PeftConfig"] = None, | |
| ): | |
| # Args | |
| model_id = model if isinstance(model, str) else model.config._name_or_path | |
| if args is None: | |
| model_name = model_id.split("/")[-1] | |
| args = DPOConfig(f"{model_name}-DPO") | |
| # Handle the tokenizer | |
| if processing_class is None: | |
| processing_class = AutoTokenizer.from_pretrained(model_id) | |
| if args.padding_value is not None: | |
| self.padding_value = args.padding_value | |
| else: | |
| if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None: | |
| self.padding_value = processing_class.pad_token_id | |
| elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None: | |
| self.padding_value = processing_class.tokenizer.pad_token_id | |
| else: | |
| raise ValueError( | |
| "`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the " | |
| "`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " | |
| "`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating " | |
| "the trainer." | |
| ) | |
| # Model | |
| if not isinstance(model, str) and ref_model is model: | |
| raise ValueError( | |
| "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " | |
| "same as `model`, you must mass a copy of it, or `None` if you use peft." | |
| ) | |
| if args.model_init_kwargs is not None and not isinstance(model, str): | |
| warnings.warn( | |
| "You passed model_init_kwargs to the `DPOConfig`, but your model is already instantiated. " | |
| "The `model_init_kwargs` will be ignored." | |
| ) | |
| if isinstance(model, str): | |
| model = self._create_model_from_path(model, args) | |
| if args.ref_model_init_kwargs is not None and not isinstance(ref_model, str): | |
| warnings.warn( | |
| "You passed ref_model_init_kwargs to the `DPOConfig`, but your ref_model is already instantiated. " | |
| "The `ref_model_init_kwargs` will be ignored." | |
| ) | |
| if isinstance(ref_model, str): | |
| ref_model = self._create_model_from_path(ref_model, args, is_ref=True) | |
| # PEFT configuration and model wrapping | |
| model = self._prepare_peft_model(model, ref_model, peft_config, args) | |
| if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): | |
| raise ValueError( | |
| "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." | |
| " Please install `wandb` or `comet-ml` to resolve." | |
| ) | |
| self.is_encoder_decoder = model.config.is_encoder_decoder | |
| self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() | |
| self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) | |
| self.model_adapter_name = args.model_adapter_name | |
| self.ref_adapter_name = args.ref_adapter_name | |
| self.reference_free = args.reference_free | |
| if ref_model: | |
| self.ref_model = ref_model | |
| elif self.is_peft_model or args.precompute_ref_log_probs: | |
| # The `model` with adapters turned off will be used as the reference model | |
| self.ref_model = None | |
| else: | |
| self.ref_model = create_reference_model(model) | |
| # Disable dropout in the model and reference model | |
| if args.disable_dropout: | |
| disable_dropout_in_model(model) | |
| if self.ref_model is not None: | |
| disable_dropout_in_model(self.ref_model) | |
| # Liger kernel | |
| if args.use_liger_loss: | |
| if not is_liger_kernel_available(): | |
| raise ImportError( | |
| "You set `use_liger_loss=True` but the liger kernel is not available. " | |
| "Please install liger-kernel first: `pip install liger-kernel`" | |
| ) | |
| if args.loss_type != "sigmoid": | |
| raise ValueError( | |
| "You set `use_liger_loss=True` but the loss type is not `sigmoid`. " | |
| "Please set `loss_type='sigmoid'` to use the liger kernel." | |
| ) | |
| self.dpo_loss_fn = LigerFusedLinearDPOLoss( | |
| ignore_index=args.label_pad_token_id, | |
| beta=args.beta, | |
| use_ref_model=not args.reference_free, | |
| average_log_prob=False, | |
| ) | |
| # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the | |
| # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the | |
| # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and | |
| # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens | |
| # of the input, floating-point operations will not be computed." To suppress this warning, we set the | |
| # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate | |
| # that the warning has already been issued. | |
| model.warnings_issued["estimate_tokens"] = True | |
| # Data collator | |
| if data_collator is None: | |
| data_collator = DataCollatorForPreference(pad_token_id=self.padding_value) | |
| self.generate_during_eval = args.generate_during_eval | |
| self.label_pad_token_id = args.label_pad_token_id | |
| self.max_prompt_length = args.max_prompt_length | |
| self.max_completion_length = args.max_completion_length | |
| self.max_length = args.max_length | |
| self.truncation_mode = args.truncation_mode | |
| self.precompute_ref_log_probs = args.precompute_ref_log_probs | |
| self.use_logits_to_keep = args.use_logits_to_keep | |
| if args.padding_free: | |
| if model.config._attn_implementation != "flash_attention_2": | |
| warnings.warn( | |
| "Padding-free training is enabled, but the attention implementation is not set to " | |
| "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " | |
| "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " | |
| "other implementations may lead to unexpected behavior. To ensure compatibility, set " | |
| "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " | |
| "attention mechanism can handle flattened sequences." | |
| ) | |
| if args.per_device_train_batch_size == 1: | |
| warnings.warn( | |
| "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " | |
| "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " | |
| "to at least 2." | |
| ) | |
| self.padding_free = args.padding_free | |
| # Since ref_logs are precomputed on the first call to get_train/eval_dataloader | |
| # keep track of first called to avoid computation of future calls | |
| self._precomputed_train_ref_log_probs = False | |
| self._precomputed_eval_ref_log_probs = False | |
| if ( | |
| args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] | |
| and args.label_smoothing > 0 | |
| ): | |
| warnings.warn( | |
| f"You are using the {args.loss_type} loss type that does not support label smoothing. The " | |
| "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", | |
| UserWarning, | |
| ) | |
| if args.loss_type == "kto_pair": | |
| raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") | |
| self.beta = args.beta | |
| self.label_smoothing = args.label_smoothing | |
| self.loss_type = args.loss_type | |
| self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) | |
| self.use_weighting = args.use_weighting | |
| self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) | |
| if self.aux_loss_enabled and self.aux_loss_coef == 0.0: | |
| warnings.warn( | |
| "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " | |
| "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " | |
| "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " | |
| "loss.", | |
| UserWarning, | |
| ) | |
| self._stored_metrics = defaultdict(lambda: defaultdict(list)) | |
| self.f_divergence_type = args.f_divergence_type | |
| self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} | |
| self.dataset_num_proc = args.dataset_num_proc | |
| # Dataset preparation | |
| train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") | |
| if eval_dataset is not None: | |
| if isinstance(eval_dataset, dict): | |
| eval_dataset = { | |
| key: self._prepare_dataset(dataset, processing_class, args, key) | |
| for key, dataset in eval_dataset.items() | |
| } | |
| else: | |
| eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") | |
| super().__init__( | |
| model=model, | |
| args=args, | |
| data_collator=data_collator, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| processing_class=processing_class, | |
| compute_metrics=compute_metrics, | |
| callbacks=callbacks, | |
| optimizers=optimizers, | |
| optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, | |
| preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
| ) | |
| # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the | |
| # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set | |
| # self.model_accepts_loss_kwargs to False to enable scaling. | |
| self.model_accepts_loss_kwargs = False | |
| # Add tags for models that have been loaded with the correct transformers version | |
| if hasattr(self.model, "add_model_tags"): | |
| self.model.add_model_tags(self._tag_names) | |
| if not hasattr(self, "accelerator"): | |
| raise AttributeError( | |
| "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." | |
| ) | |
| # Deepspeed Zero-3 does not support precompute_ref_log_probs | |
| if self.is_deepspeed_enabled: | |
| if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: | |
| raise ValueError( | |
| "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." | |
| ) | |
| if self.ref_model is None: | |
| if not (self.is_peft_model or self.precompute_ref_log_probs): | |
| raise ValueError( | |
| "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" | |
| ) | |
| if args.sync_ref_model: | |
| raise ValueError( | |
| "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." | |
| ) | |
| else: | |
| if self.is_deepspeed_enabled: | |
| self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) | |
| elif self.is_fsdp_enabled: | |
| self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) | |
| else: | |
| self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) | |
| if args.sync_ref_model: | |
| if self.precompute_ref_log_probs: | |
| raise ValueError( | |
| "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." | |
| ) | |
| self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) | |
| if self.loss_type == "bco_pair": | |
| self.running = RunningMoments(self.accelerator) | |
| def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel: | |
| """Creates a model from a path or model identifier.""" | |
| if not is_ref: | |
| model_init_kwargs = args.model_init_kwargs or {} | |
| else: | |
| model_init_kwargs = args.ref_model_init_kwargs or {} | |
| # Handle torch dtype | |
| torch_dtype = model_init_kwargs.get("torch_dtype") | |
| if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: | |
| pass # torch_dtype is already a torch.dtype or "auto" or None | |
| elif isinstance(torch_dtype, str): # it's a str, but not "auto" | |
| torch_dtype = getattr(torch, torch_dtype) | |
| model_init_kwargs["torch_dtype"] = torch_dtype | |
| else: | |
| raise ValueError( | |
| "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " | |
| f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." | |
| ) | |
| # Disable caching if gradient checkpointing is enabled (not supported) | |
| # if args.gradient_checkpointing: | |
| # model_init_kwargs["use_cache"] = False | |
| # Create model | |
| model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) | |
| return model | |
| def _prepare_peft_model( | |
| self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig | |
| ) -> PreTrainedModel: | |
| """Prepares a model for PEFT training.""" | |
| # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` | |
| # has been called in order to properly call autocast if needed. | |
| self._peft_has_been_casted_to_bf16 = False | |
| if not is_peft_available() and peft_config is not None: | |
| raise ValueError( | |
| "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" | |
| ) | |
| elif is_peft_available() and peft_config is not None: | |
| # if model is a peft model and we have a peft_config, we merge and unload it first | |
| if isinstance(model, PeftModel): | |
| model = model.merge_and_unload() | |
| if ref_model is not None and not args.force_use_ref_model: | |
| raise ValueError( | |
| "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference" | |
| " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init." | |
| " if you want to use a different ref_model." | |
| ) | |
| if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): | |
| _support_gc_kwargs = hasattr( | |
| args, "gradient_checkpointing_kwargs" | |
| ) and "gradient_checkpointing_kwargs" in list( | |
| inspect.signature(prepare_model_for_kbit_training).parameters | |
| ) | |
| prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} | |
| if _support_gc_kwargs: | |
| prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs | |
| model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) | |
| else: | |
| model = self._prepare_gradient_checkpointing(model, args) | |
| # get peft model with the given config | |
| model = get_peft_model(model, peft_config) | |
| if args.bf16 and getattr(model, "is_loaded_in_4bit", False): | |
| peft_module_casting_to_bf16(model) | |
| # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager | |
| self._peft_has_been_casted_to_bf16 = True | |
| else: | |
| model = self._prepare_gradient_checkpointing(model, args) | |
| return model | |
| def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): | |
| """Prepare the gradienting checkpointing for the model.""" | |
| # For models that use gradient_checkpointing, we need to attach a hook that enables input | |
| # to explicitly have `requires_grad=True`, otherwise training will either silently | |
| # fail or completely fail. | |
| if args.gradient_checkpointing: | |
| # For backward compatibility with older versions of transformers | |
| if hasattr(model, "enable_input_require_grads"): | |
| model.enable_input_require_grads() | |
| else: | |
| def make_inputs_require_grad(module, input, output): | |
| output.requires_grad_(True) | |
| model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) | |
| return model | |
| def _prepare_dataset( | |
| self, | |
| dataset: Union[Dataset, IterableDataset], | |
| processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], | |
| args: DPOConfig, | |
| dataset_name: str, | |
| ) -> Union[Dataset, IterableDataset]: | |
| # Build the kwargs for the `map` function | |
| map_kwargs = {} | |
| if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size | |
| map_kwargs["num_proc"] = args.dataset_num_proc | |
| map_kwargs["writer_batch_size"] = 10 | |
| with PartialState().main_process_first(): | |
| # Extract prompt if needed | |
| if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | |
| map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" | |
| dataset = dataset.map(maybe_extract_prompt, **map_kwargs) | |
| # Apply the chat template if needed | |
| if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | |
| map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" | |
| dataset = dataset.map( | |
| maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs | |
| ) | |
| # Tokenize the dataset | |
| if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | |
| map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" | |
| dataset = dataset.map( | |
| self.tokenize_row if not self.is_vision_model else self.process_row, | |
| remove_columns=["chosen", "rejected"], | |
| fn_kwargs={ | |
| "processing_class": processing_class, | |
| "max_prompt_length": args.max_prompt_length, | |
| "max_completion_length": args.max_completion_length, | |
| # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) | |
| "add_special_tokens": False, | |
| }, | |
| **map_kwargs, | |
| ) | |
| return dataset | |
| def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): | |
| """ | |
| Tokenize a row of the dataset. | |
| Args: | |
| features (`dict[str, str]`): | |
| Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. | |
| processing_class (`PreTrainedTokenizerBase`): | |
| Processing class used to process the data. | |
| max_prompt_length (`int` or `None`): | |
| Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. | |
| max_completion_length (`int` or `None`): | |
| Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. | |
| add_special_tokens (`bool`): | |
| Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, | |
| the prompt sequence will have a bos token prepended and an eos token appended. In any case, the | |
| completion sequences will have an eos token appended. | |
| Returns: | |
| `dict[str, list[int]]`: | |
| Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and | |
| `"rejected_input_ids". | |
| Example: | |
| ```python | |
| >>> from transformers import GPT2Tokenizer | |
| >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} | |
| >>> DPOTrainer.tokenize_row( | |
| ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False | |
| ... ) | |
| {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} | |
| ``` | |
| """ | |
| tokenizer = processing_class # the processing class is a tokenizer | |
| prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] | |
| chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] | |
| rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] | |
| # Add special tokens (typically for encoder-decoder models) | |
| if add_special_tokens: | |
| if tokenizer.bos_token_id is not None: | |
| prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids | |
| if tokenizer.eos_token_id is not None: | |
| prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] | |
| chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] | |
| rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] | |
| # Truncate prompt and completion sequences | |
| if max_prompt_length is not None: | |
| prompt_input_ids = prompt_input_ids[-max_prompt_length:] | |
| if max_completion_length is not None: | |
| chosen_input_ids = chosen_input_ids[:max_completion_length] | |
| rejected_input_ids = rejected_input_ids[:max_completion_length] | |
| return { | |
| "prompt_input_ids": prompt_input_ids, | |
| "chosen_input_ids": chosen_input_ids, | |
| "rejected_input_ids": rejected_input_ids, | |
| } | |
| def process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): | |
| """ | |
| Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. | |
| """ | |
| processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor | |
| processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) | |
| prompt_input_ids = processed_features["input_ids"][0] | |
| pixel_values = processed_features["pixel_values"][0] | |
| chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] | |
| rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] | |
| # Add special tokens (typically for encoder-decoder models) | |
| if add_special_tokens: | |
| if tokenizer.bos_token_id is not None: | |
| prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids | |
| if tokenizer.eos_token_id is not None: | |
| prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] | |
| chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] | |
| rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] | |
| # Truncate prompt and completion sequences | |
| if max_prompt_length is not None: | |
| prompt_input_ids = prompt_input_ids[-max_prompt_length:] | |
| if max_completion_length is not None: | |
| chosen_input_ids = chosen_input_ids[:max_completion_length] | |
| rejected_input_ids = rejected_input_ids[:max_completion_length] | |
| output = { | |
| "prompt_input_ids": prompt_input_ids, | |
| "pixel_values": pixel_values, | |
| "chosen_input_ids": chosen_input_ids, | |
| "rejected_input_ids": rejected_input_ids, | |
| } | |
| if "pixel_attention_mask" in processed_features: | |
| output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] | |
| if "image_sizes" in processed_features: | |
| output["image_sizes"] = processed_features["image_sizes"][0] | |
| return output | |
| def _set_signature_columns_if_needed(self): | |
| # If `self.args.remove_unused_columns` is True, non-signature columns are removed. | |
| # By default, this method sets `self._signature_columns` to the model's expected inputs. | |
| # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. | |
| # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. | |
| if self._signature_columns is None: | |
| self._signature_columns = [ | |
| "prompt_input_ids", | |
| "chosen_input_ids", | |
| "rejected_input_ids", | |
| "image_sizes", | |
| "ref_chosen_logps", | |
| "ref_rejected_logps", | |
| ] | |
| def get_train_dataloader(self) -> DataLoader: | |
| """ | |
| Returns the training [`~torch.utils.data.DataLoader`]. | |
| Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. | |
| """ | |
| if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: | |
| batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size | |
| dataloader_params = { | |
| "batch_size": batch_size, | |
| "collate_fn": self.data_collator, | |
| "num_workers": self.args.dataloader_num_workers, | |
| "pin_memory": self.args.dataloader_pin_memory, | |
| "shuffle": False, | |
| } | |
| # prepare dataloader | |
| data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) | |
| ref_chosen_logps = [] | |
| ref_rejected_logps = [] | |
| for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): | |
| ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) | |
| ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( | |
| (ref_chosen_logp, ref_rejected_logp) | |
| ) | |
| ref_chosen_logps.append(ref_chosen_logp.cpu()) | |
| ref_rejected_logps.append(ref_rejected_logp.cpu()) | |
| # Unnecessary cache clearing to avoid OOM | |
| empty_cache() | |
| self.accelerator.free_memory() | |
| all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() | |
| all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() | |
| self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) | |
| self.train_dataset = self.train_dataset.add_column( | |
| name="ref_rejected_logps", column=all_ref_rejected_logps | |
| ) | |
| self._precomputed_train_ref_log_probs = True | |
| return super().get_train_dataloader() | |
| def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: | |
| """ | |
| Returns the evaluation [`~torch.utils.data.DataLoader`]. | |
| Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. | |
| Args: | |
| eval_dataset (`torch.utils.data.Dataset`, *optional*): | |
| If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted | |
| by the `model.forward()` method are automatically removed. It must implement `__len__`. | |
| """ | |
| if eval_dataset is None and self.eval_dataset is None: | |
| raise ValueError("Trainer: evaluation requires an eval_dataset.") | |
| eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | |
| if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: | |
| batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size | |
| dataloader_params = { | |
| "batch_size": batch_size, | |
| "collate_fn": self.data_collator, | |
| "num_workers": self.args.dataloader_num_workers, | |
| "pin_memory": self.args.dataloader_pin_memory, | |
| "shuffle": False, | |
| } | |
| # prepare dataloader | |
| data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) | |
| ref_chosen_logps = [] | |
| ref_rejected_logps = [] | |
| for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): | |
| ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) | |
| ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( | |
| (ref_chosen_logp, ref_rejected_logp) | |
| ) | |
| ref_chosen_logps.append(ref_chosen_logp.cpu()) | |
| ref_rejected_logps.append(ref_rejected_logp.cpu()) | |
| all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() | |
| all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() | |
| eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) | |
| eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) | |
| # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs | |
| if self.eval_dataset is not None: | |
| self.eval_dataset = eval_dataset | |
| self._precomputed_eval_ref_log_probs = True | |
| return super().get_eval_dataloader(eval_dataset=eval_dataset) | |
| def null_ref_context(self): | |
| """Context manager for handling null reference model (that is, peft adapter manipulation).""" | |
| with ( | |
| self.accelerator.unwrap_model(self.model).disable_adapter() | |
| if self.is_peft_model and not self.ref_adapter_name | |
| else nullcontext() | |
| ): | |
| if self.ref_adapter_name: | |
| self.model.set_adapter(self.ref_adapter_name) | |
| yield | |
| if self.ref_adapter_name: | |
| self.model.set_adapter(self.model_adapter_name or "default") | |
| def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: | |
| """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" | |
| compte_ref_context_manager = ( | |
| autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
| ) | |
| with torch.no_grad(), compte_ref_context_manager: | |
| if self.ref_model is None: | |
| with self.null_ref_context(): | |
| ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) | |
| else: | |
| ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) | |
| return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] | |
| def concatenated_inputs( | |
| batch: dict[str, Union[list, torch.LongTensor]], padding_value: int | |
| ) -> dict[str, torch.LongTensor]: | |
| """ | |
| Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt | |
| and completion sequences. | |
| Args: | |
| batch (`dict[str, Union[list, torch.LongTensor]]`): | |
| A batch of input data. The batch must contain the following keys: | |
| - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input IDs. | |
| - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen completion input IDs. | |
| - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected completion input IDs. | |
| - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. | |
| - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. | |
| padding_value (`int`): | |
| The padding value to use for the concatenated completion sequences (`chosen_input_ids` and | |
| `rejected_input_ids`). | |
| Returns: | |
| `dict[str, torch.LongTensor]`: A dictionary containing: | |
| - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. | |
| - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. | |
| - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, prompt_length)`. | |
| - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. | |
| - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. | |
| - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if `"prompt_pixel_attention_mask"` are present. | |
| Notes: | |
| The completion input IDs and attention masks are padded to the maximum completion length of the chosen | |
| or rejected sequences. | |
| """ | |
| output = {} | |
| # For the prompt, the input_ids are the same for both the chosen and rejected responses | |
| output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) | |
| output["prompt_attention_mask"] = torch.cat( | |
| [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 | |
| ) | |
| if "pixel_values" in batch: | |
| output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) | |
| if "pixel_attention_mask" in batch: | |
| output["pixel_attention_mask"] = torch.cat( | |
| [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 | |
| ) | |
| if "image_sizes" in batch: | |
| output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) | |
| # Concatenate the chosen and rejected completions | |
| max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) | |
| output["completion_input_ids"] = torch.cat( | |
| ( | |
| pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), | |
| pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), | |
| ), | |
| ) | |
| output["completion_attention_mask"] = torch.cat( | |
| ( | |
| pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), | |
| pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), | |
| ), | |
| ) | |
| return output | |
| def dpo_loss( | |
| self, | |
| chosen_logps: torch.FloatTensor, | |
| rejected_logps: torch.FloatTensor, | |
| ref_chosen_logps: torch.FloatTensor, | |
| ref_rejected_logps: torch.FloatTensor, | |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: | |
| """ | |
| Compute the DPO loss for a batch of policy and reference model log probabilities. | |
| Args: | |
| chosen_logps (`torch.FloatTensor`): | |
| Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. | |
| rejected_logps (`torch.FloatTensor`): | |
| Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. | |
| ref_chosen_logps (`torch.FloatTensor`): | |
| Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. | |
| ref_rejected_logps (`torch.FloatTensor`): | |
| Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. | |
| Returns: | |
| A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. | |
| The losses tensor contains the DPO loss for each example in the batch. | |
| The `chosen_rewards` and `rejected_rewards` tensors contain the rewards for the chosen and rejected | |
| responses, respectively. | |
| """ | |
| device = self.accelerator.device | |
| # Get the log ratios for the chosen and rejected responses | |
| chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) | |
| rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) | |
| if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: | |
| # The alpha-divergence formula: (1 - u^-alpha) / alpha | |
| # The divergence difference between the chosen and rejected sample is: | |
| # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha | |
| # = (u[l]^-alpha - u[w]^-alpha) / alpha | |
| # where u[w] and u[l] are the policy/reference probability ratios | |
| # for the chosen and rejected samples, respectively. | |
| alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT | |
| if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: | |
| alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) | |
| logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef | |
| else: | |
| logratios = chosen_logps - rejected_logps | |
| if self.reference_free: | |
| ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) | |
| else: | |
| ref_logratios = ref_chosen_logps - ref_rejected_logps | |
| logratios = logratios.to(self.accelerator.device) | |
| ref_logratios = ref_logratios.to(self.accelerator.device) | |
| logits = logratios - ref_logratios | |
| if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: | |
| # The js-divergence formula: log(2 * u / (1 + u)) | |
| # The divergence difference between the chosen and rejected sample is: | |
| # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) | |
| # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) | |
| # where u[w] and u[l] are the policy/reference probability ratios | |
| # for the chosen and rejected samples, respectively. | |
| logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) | |
| # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. | |
| # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the | |
| # labels and calculates a conservative DPO loss. | |
| if self.loss_type == "sigmoid": | |
| losses = ( | |
| -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | |
| - F.logsigmoid(-self.beta * logits) * self.label_smoothing | |
| ) | |
| elif self.loss_type == "robust": | |
| losses = ( | |
| -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | |
| + F.logsigmoid(-self.beta * logits) * self.label_smoothing | |
| ) / (1 - 2 * self.label_smoothing) | |
| elif self.loss_type == "exo_pair": | |
| # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 | |
| import math | |
| if self.label_smoothing == 0: | |
| self.label_smoothing = 1e-3 | |
| losses = (self.beta * logits).sigmoid() * ( | |
| F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) | |
| ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) | |
| elif self.loss_type == "hinge": | |
| losses = torch.relu(1 - self.beta * logits) | |
| elif self.loss_type == "ipo": | |
| # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. | |
| losses = (logits - 1 / (2 * self.beta)) ** 2 | |
| elif self.loss_type == "bco_pair": | |
| chosen_logratios = chosen_logps - ref_chosen_logps | |
| rejected_logratios = rejected_logps - ref_rejected_logps | |
| chosen_rewards = self.beta * chosen_logratios | |
| rejected_rewards = self.beta * rejected_logratios | |
| rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() | |
| self.running.update(rewards) | |
| delta = self.running.mean | |
| losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( | |
| -(self.beta * rejected_logratios - delta) | |
| ) | |
| elif self.loss_type == "sppo_hard": | |
| # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, | |
| # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. | |
| # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is | |
| # set to 1 for the winner and 0 for the loser. | |
| a = chosen_logps - ref_chosen_logps | |
| b = rejected_logps - ref_rejected_logps | |
| losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 | |
| elif self.loss_type == "nca_pair": | |
| chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta | |
| rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta | |
| losses = ( | |
| -F.logsigmoid(chosen_rewards) | |
| - 0.5 * F.logsigmoid(-chosen_rewards) | |
| - 0.5 * F.logsigmoid(-rejected_rewards) | |
| ) | |
| elif self.loss_type == "aot_pair": | |
| chosen_logratios = chosen_logps - ref_chosen_logps | |
| rejected_logratios = rejected_logps - ref_rejected_logps | |
| chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) | |
| rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) | |
| delta = chosen_logratios_sorted - rejected_logratios_sorted | |
| losses = ( | |
| -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) | |
| - F.logsigmoid(-self.beta * delta) * self.label_smoothing | |
| ) | |
| elif self.loss_type == "aot": | |
| logratios = chosen_logps - rejected_logps | |
| ref_logratios = ref_chosen_logps - ref_rejected_logps | |
| logratios_sorted, _ = torch.sort(logratios, dim=0) | |
| ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) | |
| delta = logratios_sorted - ref_logratios_sorted | |
| losses = ( | |
| -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) | |
| - F.logsigmoid(-self.beta * delta) * self.label_smoothing | |
| ) | |
| elif self.loss_type == "apo_zero": | |
| # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) | |
| # Use this loss when you believe the chosen outputs are better than your model's default output | |
| losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood | |
| losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood | |
| losses = losses_chosen + losses_rejected | |
| elif self.loss_type == "apo_down": | |
| # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) | |
| # Use this loss when you believe the chosen outputs are worse than your model's default output. | |
| # Decrease chosen likelihood and decrease rejected likelihood more | |
| losses_chosen = F.sigmoid(self.beta * chosen_logratios) | |
| losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) | |
| losses = losses_chosen + losses_rejected | |
| elif self.loss_type == "discopop": | |
| # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) | |
| # This loss was discovered with LLM discovery | |
| logratios = chosen_logps - rejected_logps | |
| ref_logratios = ref_chosen_logps - ref_rejected_logps | |
| logits = logratios - ref_logratios | |
| logits = logits * self.beta | |
| # Modulate the mixing coefficient based on the log ratio magnitudes | |
| log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) | |
| logistic_component = -F.logsigmoid(logits) | |
| exp_component = torch.exp(-logits) | |
| # Blend between logistic and exponential component based on log ratio modulation | |
| losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation | |
| else: | |
| raise ValueError( | |
| f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " | |
| "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', 'apo_down']" | |
| ) | |
| chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() | |
| rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() | |
| return losses, chosen_rewards, rejected_rewards | |
| def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): | |
| unwrapped_model = self.accelerator.unwrap_model(model) | |
| concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) | |
| model_kwargs = {} | |
| if self.aux_loss_enabled: | |
| model_kwargs["output_router_logits"] = True | |
| # Add the pixel values and attention masks for vision models | |
| if "pixel_values" in concatenated_batch: | |
| model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] | |
| if "pixel_attention_mask" in concatenated_batch: | |
| model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] | |
| if "image_sizes" in concatenated_batch: | |
| model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] | |
| prompt_attention_mask = concatenated_batch["prompt_attention_mask"] | |
| completion_attention_mask = concatenated_batch["completion_attention_mask"] | |
| if self.is_encoder_decoder: | |
| # 1. Get encoder outputs | |
| encoder_outputs = unwrapped_model.get_encoder()( | |
| concatenated_batch["prompt_input_ids"], | |
| attention_mask=concatenated_batch["prompt_attention_mask"], | |
| return_dict=True, | |
| ) | |
| # 2. Prepare decoder inputs | |
| decoder_input_ids = shift_tokens_right( | |
| concatenated_batch["completion_input_ids"], | |
| unwrapped_model.config.decoder_start_token_id, | |
| ) | |
| # 3. Get decoder outputs | |
| decoder_outputs = unwrapped_model.get_decoder()( | |
| input_ids=decoder_input_ids, | |
| attention_mask=concatenated_batch["completion_attention_mask"], | |
| encoder_hidden_states=encoder_outputs.last_hidden_state, | |
| encoder_attention_mask=concatenated_batch["prompt_attention_mask"], | |
| use_cache=False, | |
| ) | |
| hidden_states = decoder_outputs.last_hidden_state | |
| ref_hidden_states = None | |
| if not self.reference_free and self.ref_model is not None: | |
| unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) | |
| ref_encoder_outputs = unwrapped_ref_model.get_encoder()( | |
| concatenated_batch["prompt_input_ids"], | |
| attention_mask=concatenated_batch["prompt_attention_mask"], | |
| return_dict=True, | |
| ) | |
| ref_decoder_outputs = unwrapped_ref_model.get_decoder()( | |
| input_ids=decoder_input_ids, | |
| attention_mask=concatenated_batch["completion_attention_mask"], | |
| encoder_hidden_states=ref_encoder_outputs.last_hidden_state, | |
| encoder_attention_mask=concatenated_batch["prompt_attention_mask"], | |
| use_cache=False, | |
| ) | |
| ref_hidden_states = ref_decoder_outputs.last_hidden_state | |
| elif not self.reference_free: | |
| with self.null_ref_context(): | |
| ref_encoder_outputs = unwrapped_model.get_encoder()( | |
| concatenated_batch["prompt_input_ids"], | |
| attention_mask=concatenated_batch["prompt_attention_mask"], | |
| return_dict=True, | |
| ) | |
| ref_decoder_outputs = unwrapped_model.get_decoder()( | |
| input_ids=decoder_input_ids, | |
| attention_mask=concatenated_batch["completion_attention_mask"], | |
| encoder_hidden_states=ref_encoder_outputs.last_hidden_state, | |
| encoder_attention_mask=concatenated_batch["prompt_attention_mask"], | |
| use_cache=False, | |
| ) | |
| ref_hidden_states = ref_decoder_outputs.last_hidden_state | |
| labels = concatenated_batch["completion_input_ids"] | |
| loss_mask = completion_attention_mask.bool() | |
| else: | |
| # For decoder-only models | |
| input_ids = torch.cat( | |
| (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 | |
| ) | |
| attention_mask = torch.cat( | |
| (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), | |
| dim=1, | |
| ) | |
| # Mask the prompt but not the completion for the loss | |
| loss_mask = torch.cat( | |
| (torch.zeros_like(prompt_attention_mask), completion_attention_mask), | |
| dim=1, | |
| ) | |
| # Flush and truncate | |
| if self.max_length is not None and self.max_length < attention_mask.size(1): | |
| if self.truncation_mode == "keep_start": | |
| # Flush left to reduce the memory usage | |
| # [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
| # [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
| attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
| attention_mask = attention_mask[:, : self.max_length] | |
| input_ids = input_ids[:, : self.max_length] | |
| loss_mask = loss_mask[:, : self.max_length] | |
| elif self.truncation_mode == "keep_end": | |
| # Flush right before truncating left, then flush left | |
| # [[0, 0, x, x, x, x], -> [[0, 0, x, x], | |
| # [0, x, x, x, 0, 0]] [0, x, x, x]] | |
| attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) | |
| input_ids = input_ids[:, -self.max_length :] | |
| attention_mask = attention_mask[:, -self.max_length :] | |
| loss_mask = loss_mask[:, -self.max_length :] | |
| attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
| else: | |
| raise ValueError( | |
| f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " | |
| "'keep_start']." | |
| ) | |
| else: | |
| # Flush left to reduce the memory usage | |
| # [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
| # [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
| attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
| # Add logits_to_keep optimization | |
| if self.use_logits_to_keep: | |
| first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() | |
| logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 | |
| model_kwargs["logits_to_keep"] = logits_to_keep | |
| model_kwargs["output_hidden_states"] = True | |
| # Add padding-free training support | |
| if self.padding_free: | |
| input_ids = input_ids[attention_mask.bool()].unsqueeze(0) | |
| loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) | |
| position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 | |
| model_kwargs["position_ids"] = position_ids | |
| else: | |
| model_kwargs["attention_mask"] = attention_mask | |
| # Get the base model outputs (before LM head) | |
| if hasattr(unwrapped_model, "get_decoder"): | |
| base_model = unwrapped_model.get_decoder() | |
| else: | |
| base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) | |
| outputs = base_model( | |
| input_ids, | |
| use_cache=False, | |
| **model_kwargs, | |
| ) | |
| hidden_states = outputs.last_hidden_state[:, :-1] | |
| # Get reference hidden states if needed | |
| ref_hidden_states = None | |
| if not self.reference_free and self.ref_model is not None: | |
| unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) | |
| if hasattr(unwrapped_ref_model, "get_decoder"): | |
| ref_base_model = unwrapped_ref_model.get_decoder() | |
| else: | |
| ref_base_model = getattr( | |
| unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model | |
| ) | |
| ref_outputs = ref_base_model( | |
| input_ids, | |
| use_cache=False, | |
| **model_kwargs, | |
| ) | |
| ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] | |
| elif not self.reference_free: | |
| if hasattr(unwrapped_model, "get_decoder"): | |
| ref_base_model = unwrapped_model.get_decoder() | |
| else: | |
| ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) | |
| with self.null_ref_context(): | |
| ref_outputs = ref_base_model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| use_cache=False, | |
| **model_kwargs, | |
| ) | |
| ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] | |
| masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) | |
| labels = masked_input_ids[:, 1:] # Shift right for casual LM | |
| # Get the LM head | |
| lm_head = unwrapped_model.get_output_embeddings() | |
| # Get reference model weights if needed | |
| ref_weight = None | |
| ref_bias = None | |
| if not self.reference_free: | |
| if self.ref_model is not None: | |
| unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) | |
| ref_lm_head = unwrapped_ref_model.get_output_embeddings() | |
| else: | |
| with self.null_ref_context(): | |
| ref_lm_head = unwrapped_model.get_output_embeddings() | |
| ref_weight = ref_lm_head.weight | |
| ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None | |
| # Compute loss using Liger kernel | |
| loss_output = self.dpo_loss_fn( | |
| lm_head.weight, | |
| hidden_states, | |
| labels, | |
| bias=lm_head.bias if hasattr(lm_head, "bias") else None, | |
| ref_input=ref_hidden_states if not self.reference_free else None, | |
| ref_weight=ref_weight if not self.reference_free else None, | |
| ref_bias=ref_bias if not self.reference_free else None, | |
| ) | |
| ( | |
| loss, | |
| (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), | |
| ) = loss_output | |
| output = { | |
| "loss": loss, | |
| "chosen_logps": chosen_logps, | |
| "rejected_logps": rejected_logps, | |
| "mean_chosen_logits": chosen_logits_mean, | |
| "mean_rejected_logits": rejected_logits_mean, | |
| "nll_loss": nll_loss, | |
| "chosen_rewards": aux_outputs[0], | |
| "rejected_rewards": aux_outputs[1], | |
| } | |
| if self.aux_loss_enabled: | |
| output["aux_loss"] = outputs.aux_loss | |
| return output | |
| def concatenated_forward( | |
| self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False | |
| ): | |
| """ | |
| Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. | |
| We do this to avoid doing two forward passes, because it's faster for FSDP. | |
| Args: | |
| model: | |
| Model to run the forward pass on. | |
| batch: | |
| Batch of input data. | |
| is_ref_model: | |
| Whether this method is being called for the reference model. If `True`, length desensitization is not | |
| applied. | |
| """ | |
| num_examples = batch["prompt_input_ids"].shape[0] | |
| concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) | |
| model_kwargs = {"use_cache": False} | |
| if self.aux_loss_enabled: | |
| model_kwargs["output_router_logits"] = True | |
| # Add the pixel values and attention masks for vision models | |
| if "pixel_values" in concatenated_batch: | |
| model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] | |
| if "pixel_attention_mask" in concatenated_batch: | |
| model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] | |
| if "image_sizes" in concatenated_batch: | |
| model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] | |
| prompt_input_ids = concatenated_batch["prompt_input_ids"] | |
| prompt_attention_mask = concatenated_batch["prompt_attention_mask"] | |
| completion_input_ids = concatenated_batch["completion_input_ids"] | |
| completion_attention_mask = concatenated_batch["completion_attention_mask"] | |
| if self.is_encoder_decoder: | |
| labels = completion_input_ids | |
| labels[completion_attention_mask == 0] = self.label_pad_token_id | |
| outputs = model( | |
| input_ids=prompt_input_ids, | |
| attention_mask=prompt_attention_mask, | |
| labels=labels, # we need the labels for the logits to be returned | |
| **model_kwargs, | |
| ) | |
| logits = outputs.logits | |
| loss_mask = completion_attention_mask.bool() | |
| else: | |
| # Concatenate the prompt and completion inputs | |
| input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) | |
| attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) | |
| # Mask the prompt but not the completion for the loss | |
| loss_mask = torch.cat( | |
| (torch.zeros_like(prompt_attention_mask), completion_attention_mask), | |
| dim=1, | |
| ) | |
| # Flush and truncate | |
| if self.max_length is not None and self.max_length < attention_mask.size(1): | |
| if self.truncation_mode == "keep_start": | |
| # Flush left to reduce the memory usage | |
| # [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
| # [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
| attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
| attention_mask = attention_mask[:, : self.max_length] | |
| input_ids = input_ids[:, : self.max_length] | |
| loss_mask = loss_mask[:, : self.max_length] | |
| elif self.truncation_mode == "keep_end": | |
| # Flush right before truncating left, then flush left | |
| # [[0, 0, x, x, x, x], -> [[0, 0, x, x], | |
| # [0, x, x, x, 0, 0]] [0, x, x, x]] | |
| attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) | |
| input_ids = input_ids[:, -self.max_length :] | |
| attention_mask = attention_mask[:, -self.max_length :] | |
| loss_mask = loss_mask[:, -self.max_length :] | |
| attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
| else: | |
| raise ValueError( | |
| f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " | |
| "'keep_start']." | |
| ) | |
| else: | |
| # Flush left to reduce the memory usage | |
| # [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
| # [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
| attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
| if self.use_logits_to_keep: | |
| # Compute logits_to_keep based on loss_mask pattern: | |
| # [[0, 0, 0, x, x, x, x], | |
| # [0, 0, 0, x, x, x, 0]] | |
| # ^ start computing logits from here ([:, -(7-3+1):]) | |
| first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() | |
| logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label | |
| model_kwargs["logits_to_keep"] = logits_to_keep | |
| model_kwargs["output_hidden_states"] = True | |
| if self.padding_free: | |
| # Flatten the input_ids, position_ids, and loss_mask | |
| # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] | |
| # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] | |
| input_ids = input_ids[attention_mask.bool()].unsqueeze(0) | |
| loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) | |
| position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 | |
| model_kwargs["position_ids"] = position_ids | |
| else: | |
| model_kwargs["attention_mask"] = attention_mask | |
| outputs = model(input_ids, **model_kwargs) | |
| logits = outputs.logits | |
| # Offset the logits by one to align with the labels | |
| labels = torch.roll(input_ids, shifts=-1, dims=1) | |
| loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() | |
| if self.use_logits_to_keep: | |
| # Align labels with logits | |
| # logits: -, -, [x2, x3, x4, x5, x6] | |
| # ^ --------- ^ after logits[:, :-1, :] | |
| # labels: [y0, y1, y2, y3, y4, y5, y6] | |
| # ^ --------- ^ with logits_to_keep=4, [:, -4:] | |
| # loss_mask: [0, 0, 0, 1, 1, 1, 1] | |
| labels = labels[:, -logits_to_keep:] | |
| loss_mask = loss_mask[:, -logits_to_keep:] | |
| if logits.shape[:2] != labels.shape[:2]: | |
| # for llava, the returned logits include the image tokens (placed before the text tokens) | |
| seq_len = labels.shape[1] | |
| logits = logits[:, -seq_len:] | |
| # Compute the log probabilities of the labels | |
| labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later | |
| per_token_logps = selective_log_softmax(logits, labels) | |
| per_token_logps[~loss_mask] = 0 | |
| per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) | |
| if self.padding_free: | |
| # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) | |
| batch_size, seq_len = attention_mask.shape | |
| per_token_logps_ = torch.zeros( | |
| batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype | |
| ) | |
| per_token_logps_[attention_mask.bool()] = per_token_logps | |
| per_token_logps = per_token_logps_ | |
| all_logps = per_token_logps[:, 1:].sum(-1) | |
| output = {} | |
| if self.use_weighting: | |
| with torch.no_grad(): | |
| # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 | |
| logprobs = F.log_softmax(logits, dim=-1) | |
| weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space | |
| per_token_logps_adjusted = per_token_logps - weights_adjustment_factor | |
| all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) | |
| chosen_weights = all_weights[:num_examples] | |
| rejected_weights = all_weights[num_examples:] | |
| output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) | |
| if self.args.rpo_alpha is not None: | |
| # Only use the chosen logits for the RPO loss | |
| chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] | |
| chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] | |
| # Compute the log probabilities of the labels | |
| output["nll_loss"] = F.cross_entropy( | |
| torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 | |
| ) | |
| if self.loss_type == "ipo": | |
| all_logps = all_logps / loss_mask.sum(-1) | |
| if self.args.ld_alpha is not None and not is_ref_model: | |
| # Compute response lengths based on loss_mask | |
| completion_lengths = loss_mask.sum(dim=1) | |
| chosen_lengths = completion_lengths[:num_examples] | |
| rejected_lengths = completion_lengths[num_examples:] | |
| public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper | |
| public_lengths = torch.cat([public_lengths, public_lengths], dim=0) | |
| seq_len = per_token_logps.size(1) | |
| position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) | |
| ld_mask = position_ids < public_lengths.unsqueeze(1) | |
| mask = position_ids < completion_lengths.unsqueeze(1) | |
| front_mask = (ld_mask & mask).float() | |
| rear_mask = (~ld_mask & mask).float() | |
| front_logps = (per_token_logps * front_mask).sum(dim=1) | |
| rear_logps = (per_token_logps * rear_mask).sum(dim=1) | |
| all_logps = front_logps + self.args.ld_alpha * rear_logps | |
| output["chosen_logps"] = all_logps[:num_examples] | |
| output["rejected_logps"] = all_logps[num_examples:] | |
| # Compute the mean logits | |
| if self.padding_free: | |
| # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). | |
| # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, | |
| # and the second half to the rejected tokens. | |
| # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. | |
| split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] | |
| mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() | |
| mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() | |
| else: | |
| mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() | |
| mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() | |
| output["mean_chosen_logits"] = mean_chosen_logits | |
| output["mean_rejected_logits"] = mean_rejected_logits | |
| if self.aux_loss_enabled: | |
| output["aux_loss"] = outputs.aux_loss | |
| return output | |
| def get_batch_loss_metrics( | |
| self, | |
| model, | |
| batch: dict[str, Union[list, torch.LongTensor]], | |
| train_eval: Literal["train", "eval"] = "train", | |
| ): | |
| """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" | |
| metrics = {} | |
| if self.args.use_liger_loss: | |
| model_output = self._compute_loss_liger(model, batch) | |
| losses = model_output["loss"] | |
| chosen_rewards = model_output["chosen_rewards"] | |
| rejected_rewards = model_output["rejected_rewards"] | |
| else: | |
| model_output = self.concatenated_forward(model, batch) | |
| # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model | |
| if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: | |
| ref_chosen_logps = batch["ref_chosen_logps"] | |
| ref_rejected_logps = batch["ref_rejected_logps"] | |
| else: | |
| ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) | |
| losses, chosen_rewards, rejected_rewards = self.dpo_loss( | |
| model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps | |
| ) | |
| reward_accuracies = (chosen_rewards > rejected_rewards).float() | |
| if self.args.rpo_alpha is not None: | |
| losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper | |
| if self.use_weighting: | |
| losses = losses * model_output["policy_weights"] | |
| if self.aux_loss_enabled: | |
| losses = losses + self.aux_loss_coef * model_output["aux_loss"] | |
| prefix = "eval_" if train_eval == "eval" else "" | |
| metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() | |
| metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() | |
| metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() | |
| metrics[f"{prefix}rewards/margins"] = ( | |
| self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() | |
| ) | |
| metrics[f"{prefix}logps/chosen"] = ( | |
| self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() | |
| ) | |
| metrics[f"{prefix}logps/rejected"] = ( | |
| self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() | |
| ) | |
| metrics[f"{prefix}logits/chosen"] = ( | |
| self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() | |
| ) | |
| metrics[f"{prefix}logits/rejected"] = ( | |
| self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() | |
| ) | |
| if self.args.rpo_alpha is not None: | |
| metrics[f"{prefix}nll_loss"] = ( | |
| self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() | |
| ) | |
| if self.aux_loss_enabled: | |
| metrics[f"{prefix}aux_loss"] = ( | |
| self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() | |
| ) | |
| return losses.mean(), metrics | |
| def compute_loss( | |
| self, | |
| model: Union[PreTrainedModel, nn.Module], | |
| inputs: dict[str, Union[torch.Tensor, Any]], | |
| return_outputs=False, | |
| num_items_in_batch=None, | |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: | |
| compute_loss_context_manager = ( | |
| autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
| ) | |
| with compute_loss_context_manager: | |
| loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") | |
| # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: | |
| loss = loss.to(self.args.device) | |
| # force log the metrics | |
| self.store_metrics(metrics, train_eval="train") | |
| if return_outputs: | |
| return loss, metrics | |
| return loss | |
| def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: | |
| """Generate samples from the model and reference model for the given batch of inputs.""" | |
| # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with | |
| # the torch amp context manager as some hidden states are silently casted to full precision. | |
| generate_context_manager = ( | |
| autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
| ) | |
| with generate_context_manager: | |
| policy_output = model.generate( | |
| input_ids=batch["prompt_input_ids"], | |
| attention_mask=batch["prompt_attention_mask"], | |
| max_length=self.max_length, | |
| do_sample=True, | |
| pad_token_id=self.padding_value, | |
| ) | |
| # if ref_output in batch use that otherwise use the reference model | |
| if "ref_output" in batch: | |
| ref_output = batch["ref_output"] | |
| else: | |
| if self.ref_model is None: | |
| with self.null_ref_context(): | |
| ref_output = self.model.generate( | |
| input_ids=batch["prompt_input_ids"], | |
| attention_mask=batch["prompt_attention_mask"], | |
| max_length=self.max_length, | |
| do_sample=True, | |
| pad_token_id=self.padding_value, | |
| ) | |
| else: | |
| ref_output = self.ref_model.generate( | |
| input_ids=batch["prompt_input_ids"], | |
| attention_mask=batch["prompt_attention_mask"], | |
| max_length=self.max_length, | |
| do_sample=True, | |
| pad_token_id=self.padding_value, | |
| ) | |
| policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) | |
| policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) | |
| ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) | |
| ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) | |
| return policy_output_decoded, ref_output_decoded | |
| def prediction_step( | |
| self, | |
| model: Union[PreTrainedModel, nn.Module], | |
| inputs: dict[str, Union[torch.Tensor, Any]], | |
| prediction_loss_only: bool, | |
| ignore_keys: Optional[list[str]] = None, | |
| ): | |
| if ignore_keys is None: | |
| if hasattr(model, "config"): | |
| ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) | |
| else: | |
| ignore_keys = [] | |
| prediction_context_manager = ( | |
| autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
| ) | |
| with torch.no_grad(), prediction_context_manager: | |
| loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") | |
| # force log the metrics | |
| self.store_metrics(metrics, train_eval="eval") | |
| if prediction_loss_only: | |
| return loss.detach(), None, None | |
| # logits for the chosen and rejected samples from model | |
| logits_dict = { | |
| "eval_logits/chosen": metrics["eval_logits/chosen"], | |
| "eval_logits/rejected": metrics["eval_logits/rejected"], | |
| } | |
| logits = [v for k, v in logits_dict.items() if k not in ignore_keys] | |
| logits = torch.tensor(logits, device=self.accelerator.device) | |
| labels = torch.zeros(logits.shape[0], device=self.accelerator.device) | |
| return (loss.detach(), logits, labels) | |
| def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: | |
| for key, value in metrics.items(): | |
| self._stored_metrics[train_eval][key].append(value) | |
| def evaluation_loop( | |
| self, | |
| dataloader: DataLoader, | |
| description: str, | |
| prediction_loss_only: Optional[bool] = None, | |
| ignore_keys: Optional[list[str]] = None, | |
| metric_key_prefix: str = "eval", | |
| ) -> EvalLoopOutput: | |
| """ | |
| Overriding built-in evaluation loop to store metrics for each batch. | |
| Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. | |
| Works both with or without labels. | |
| """ | |
| # Sample and save to game log if requested (for one batch to save time) | |
| if self.generate_during_eval: | |
| # Generate random indices within the range of the total number of samples | |
| num_samples = len(dataloader.dataset) | |
| random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) | |
| # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader | |
| random_batch_dataset = dataloader.dataset.select(random_indices) | |
| random_batch = self.data_collator(random_batch_dataset) | |
| random_batch = self._prepare_inputs(random_batch) | |
| policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) | |
| table = pd.DataFrame( | |
| columns=["Prompt", "Policy", "Ref Model"], | |
| data=[ | |
| [prompt, pol[len(prompt) :], ref[len(prompt) :]] | |
| for prompt, pol, ref in zip( | |
| random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded | |
| ) | |
| ], | |
| ) | |
| if "wandb" in self.args.report_to and self.accelerator.is_main_process: | |
| wandb.log({"game_log": wandb.Table(data=table)}) | |
| if "comet_ml" in self.args.report_to: | |
| log_table_to_comet_experiment( | |
| name="game_log.csv", | |
| table=table, | |
| ) | |
| # Base evaluation | |
| initial_output = super().evaluation_loop( | |
| dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix | |
| ) | |
| return initial_output | |
| def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: | |
| """ | |
| Log `logs` on the various objects watching training, including stored metrics. | |
| Args: | |
| logs (`dict[str, float]`): | |
| The values to log. | |
| start_time (`float` or `None`, *optional*, defaults to `None`): | |
| Start time of the training. | |
| """ | |
| # logs either has 'loss' or 'eval_loss' | |
| train_eval = "train" if "loss" in logs else "eval" | |
| # Add averaged stored metrics to logs | |
| for key, metrics in self._stored_metrics[train_eval].items(): | |
| logs[key] = torch.tensor(metrics).mean().item() | |
| del self._stored_metrics[train_eval] | |
| return super().log(logs, start_time) | |
| # Ensure the model card is saved along with the checkpoint | |
| def _save_checkpoint(self, model, trial): | |
| if self.args.hub_model_id is None: | |
| model_name = Path(self.args.output_dir).name | |
| else: | |
| model_name = self.args.hub_model_id.split("/")[-1] | |
| self.create_model_card(model_name=model_name) | |
| super()._save_checkpoint(model, trial) | |
| def create_model_card( | |
| self, | |
| model_name: Optional[str] = None, | |
| dataset_name: Optional[str] = None, | |
| tags: Union[str, list[str], None] = None, | |
| ): | |
| """ | |
| Creates a draft of a model card using the information available to the `Trainer`. | |
| Args: | |
| model_name (`str` or `None`, *optional*, defaults to `None`): | |
| Name of the model. | |
| dataset_name (`str` or `None`, *optional*, defaults to `None`): | |
| Name of the dataset used for training. | |
| tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): | |
| Tags to be associated with the model card. | |
| """ | |
| if not self.is_world_process_zero(): | |
| return | |
| if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): | |
| base_model = self.model.config._name_or_path | |
| else: | |
| base_model = None | |
| tags = tags or set() | |
| if isinstance(tags, str): | |
| tags = {tags} | |
| if hasattr(self.model.config, "unsloth_version"): | |
| tags.add("unsloth") | |
| tags.update(self._tag_names) | |
| citation = textwrap.dedent( | |
| """\ | |
| @inproceedings{rafailov2023direct, | |
| title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, | |
| author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, | |
| year = 2023, | |
| booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, | |
| url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, | |
| editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, | |
| }""" | |
| ) | |
| model_card = generate_model_card( | |
| base_model=base_model, | |
| model_name=model_name, | |
| hub_model_id=self.hub_model_id, | |
| dataset_name=dataset_name, | |
| tags=tags, | |
| wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, | |
| comet_url=get_comet_experiment_url(), | |
| trainer_name="DPO", | |
| trainer_citation=citation, | |
| paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", | |
| paper_id="2305.18290", | |
| ) | |
| model_card.save(os.path.join(self.args.output_dir, "README.md")) | |