Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # @Time : 2023/5/6 4:29 p.m. | |
| # @Author : JianingWang | |
| # @File : reward_model.py | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, AutoConfig | |
| from loss.rl_loss import LogSigLoss, LogExpLoss | |
| from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model | |
| """ | |
| RoERTa for Reward Model | |
| """ | |
| class RobertaForReward(RobertaPreTrainedModel): | |
| """ | |
| Reward model base class. | |
| Args: | |
| model (nn.Module): Reward model. | |
| value_head (nn.Module): Value head to get reward score. | |
| """ | |
| def __init__(self, config) -> None: | |
| super().__init__(config) | |
| self.config = config | |
| self.roberta = RobertaModel(config) | |
| self.value_head = nn.Linear(self.config.n_embd, 1) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| chosen_sequences: torch.LongTensor, | |
| chosen_attention_mask: Optional[torch.Tensor], | |
| rejected_sequences: Optional[torch.LongTensor] = None, | |
| rejected_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # obtain reward value of chosen sequence | |
| chosen_outputs = self.roberta(chosen_sequences, attention_mask=chosen_attention_mask) | |
| chosen_last_hidden_states = chosen_outputs['last_hidden_state'] | |
| chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1] | |
| chosen_values = chosen_values.mean(dim=1).squeeze(1) # ensure shape is (B) | |
| return_dict = { | |
| "chosen_values": chosen_values, | |
| } | |
| # if has rejected, obtain reward of rejected sequence, and calculate the loss | |
| if rejected_sequences is not None: | |
| rejected_outputs = self.roberta(rejected_sequences, attention_mask=rejected_attention_mask) | |
| rejected_last_hidden_states = rejected_outputs['last_hidden_state'] | |
| rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1] | |
| rejected_values = rejected_values.mean(dim=1).squeeze(1) # ensure shape is (B) | |
| return_dict["rejected_values"] = rejected_values | |
| loss_fn = LogSigLoss() | |
| loss = loss_fn(chosen_values, rejected_values) | |
| return_dict["loss"] = loss | |
| return return_dict | |
| """ | |
| GPT2 for Reward Model | |
| """ | |
| class GPT2ForReward(GPT2PreTrainedModel): | |
| _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] | |
| """ | |
| Reward model base class. | |
| Args: | |
| model (nn.Module): Reward model. | |
| value_head (nn.Module): Value head to get reward score. | |
| """ | |
| def __init__(self, config) -> None: | |
| super().__init__(config) | |
| self.config = config | |
| self.transformer = GPT2Model(config) | |
| self.value_head = nn.Linear(self.config.n_embd, 1) | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| self.post_init() | |
| def forward( | |
| self, | |
| chosen_sequences: torch.LongTensor, | |
| chosen_attention_mask: Optional[torch.Tensor], | |
| rejected_sequences: Optional[torch.LongTensor] = None, | |
| rejected_attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # obtain reward value of chosen sequence | |
| chosen_outputs = self.transformer(chosen_sequences, attention_mask=chosen_attention_mask) | |
| chosen_last_hidden_states = chosen_outputs['last_hidden_state'] | |
| chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1] | |
| chosen_values = chosen_values.mean(dim=1).squeeze(1) # ensure shape is (B) | |
| return_dict = { | |
| "chosen_values": chosen_values, | |
| } | |
| # if has rejected, obtain reward of rejected sequence, and calculate the loss | |
| if rejected_sequences is not None: | |
| rejected_outputs = self.transformer(rejected_sequences, attention_mask=rejected_attention_mask) | |
| rejected_last_hidden_states = rejected_outputs['last_hidden_state'] | |
| rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1] | |
| rejected_values = rejected_values.mean(dim=1).squeeze(1) # ensure shape is (B) | |
| return_dict["rejected_values"] = rejected_values | |
| loss_fn = LogSigLoss() | |
| loss = loss_fn(chosen_values, rejected_values) | |
| return_dict["loss"] = loss | |
| return return_dict | |
| def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: | |
| """ | |
| This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or | |
| [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct | |
| beam_idx at every generation step. | |
| """ | |
| return tuple( | |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) | |
| for layer_past in past | |
| ) |