|
- import torch
- from torch import nn
- from transformers import AutoModelForCausalLM, AutoTokenizer
-
-
- class GPTRewardModel(nn.Module):
- def __init__(self, config,tokenizer):
- super().__init__()
- self.tokenizer = tokenizer
- self.PAD_ID = self.tokenizer.pad_token_id
- model = AutoModelForCausalLM.from_config(config)
- ## for bert tokenizer
- model.resize_token_embeddings(len(tokenizer))
- # tokenizer.pad_token_id= tokenizer.eos_token_id
- model.config.end_token_id = tokenizer.eos_token_id
- model.config.pad_token_id = model.config.eos_token_id
- ##
-
- self.config = model.config
- # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd``
- self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
- self.transformer = model.transformer
- self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
-
- def forward(
- self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- mc_token_ids=None,
- labels=None,
- return_dict=False,
- output_attentions=False,
- output_hidden_states=False,
- ):
- loss = None
- transformer_outputs = self.transformer(
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- )
-
- hidden_states = transformer_outputs[0]
-
- rewards = self.v_head(hidden_states).squeeze(-1)
- chosen_end_scores = []
- rejected_end_scores = []
-
- # Split the inputs and rewards into two parts, chosen and rejected
- assert len(input_ids.shape) == 2
- bs = input_ids.shape[0] // 2
- chosen = input_ids[:bs]
- rejected = input_ids[bs:]
- chosen_rewards = rewards[:bs]
- rejected_rewards = rewards[bs:]
-
- # Compute pairwise loss. Only backprop on the last value before padding
- loss = 0
- inference = False
- for i in range(bs):
- if torch.all(torch.eq(chosen[i], rejected[i])).item():
- c_inds = (chosen[i] == self.PAD_ID).nonzero()
- c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
- chosen_end_scores.append(chosen_rewards[i, c_ind - 1])
- inference = True
- continue
-
- # Check if there is any padding otherwise take length of sequence
- c_inds = (chosen[i] == self.PAD_ID).nonzero()
- c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
- r_inds = (rejected[i] == self.PAD_ID).nonzero()
- r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1]
- end_ind = max(c_ind, r_ind)
-
- # Retrieve first index where trajectories diverge
- divergence_ind = (chosen[i] != rejected[i]).nonzero()[0]
- assert divergence_ind > 0
-
- # Index into the correct rewards
- c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind]
- r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind]
-
- # Append the last rewards to the list of end scores
- chosen_end_scores.append(c_truncated_reward[-1])
- rejected_end_scores.append(r_truncated_reward[-1])
-
- # Compute loss
- loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
- loss = loss / bs
-
- if not inference:
- chosen_end_scores = torch.stack(chosen_end_scores)
- rejected_end_scores = torch.stack(rejected_end_scores)
-
- if inference:
- chosen_end_scores = torch.stack(chosen_end_scores)
- return {"chosen_end_scores": chosen_end_scores}
-
- return {
- "loss": loss,
- "chosen_end_scores": chosen_end_scores,
- "rejected_end_scores": rejected_end_scores,
- }
|