Spaces:
Runtime error
Runtime error
| from typing import Dict, Sequence, Tuple | |
| import re | |
| import numpy as np | |
| import torch | |
| def postprocess_classification_generation(predictions) -> str: | |
| return re.split("Prompt|Completion", predictions, 1)[0] | |
| def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float: | |
| """Compute the accuracy of a sequence of predictions.""" | |
| def _preprocess_fn(s): | |
| """Function to preprocess both targets and predictions.""" | |
| return s.lower() | |
| is_correct = [ | |
| _preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"]) | |
| for x in predictions | |
| ] | |
| return np.mean(is_correct).item() | |
| def compute_shifted_logits_and_labels( | |
| logits: torch.Tensor, encodings, tokenizer, eoc_token_id | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Helper function to compute shifted logits and labels. | |
| This allows for straightforward computation of the loss on shift_logits | |
| and shift_labels such that the nth element of logits computes the n-1th | |
| element of the original labels (in the outputs, the nth element of logits | |
| corresponds to the nth element of the labels). | |
| Elements in shift_labels that correspond to inputs are masked with values | |
| of -100 (by default in hf, loss is only computed on token IDs >= 0). | |
| Returns: tuple containing two elements: | |
| shift_logits: a float Tensor of shape [batch_size, seq_len - 1]. | |
| shift_labels: an integer Tensor of shape [batch_size, seq_len - 1] | |
| """ | |
| labels = encodings["input_ids"].clone() | |
| # convert padding and EOC tokens to -100 so they are ignored in loss | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| labels[labels == eoc_token_id] = -100 | |
| # Convert all tokens in prefix until separator to -100 so they are | |
| # ignored in loss | |
| for idx in range(len(labels)): | |
| # Find the location of the last token of prefix *from right*, | |
| # since the first non-padding token of the sequence will also be | |
| # eos_token (because bos_token and eos_token are the same for | |
| # the tokenizer). | |
| end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1 | |
| labels[idx, : end_of_prefix + 1] = -100 | |
| # Shift so that tokens < n predict n. The shifted tensors both have | |
| # shape [batch_size, seq_len - 1]. | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| return shift_logits, shift_labels | |
| def compute_per_sample_probs( | |
| encodings, tokenizer, logits: torch.Tensor, eoc_token_id | |
| ) -> torch.Tensor: | |
| """Helper function to compute per-sample probability of the input sequence. | |
| Assumes <eos token> is used to separate inputs from targets in the | |
| prompt text | |
| """ | |
| shift_logits, shift_labels = compute_shifted_logits_and_labels( | |
| logits, encodings, tokenizer, eoc_token_id | |
| ) | |
| # Tuple of tensors for unmasked label tokens. The first element of the | |
| # tuple contains the batch indices; the second element contains the | |
| # sequence indices. | |
| unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True) | |
| # Tensor where the i^th element is the token_id corresponding to the i^th | |
| # element of unmasked_indices | |
| unmasked_token_ids = shift_labels[unmasked_indices] | |
| # 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens. | |
| target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids]) | |
| target_idxs = target_idxs.to(shift_logits.device) | |
| # Sanity check that every element in batch has at least one unmasked | |
| # target token | |
| assert torch.all( | |
| torch.bincount(target_idxs[:, 0]) != 0 | |
| ), "At least one element in batch has no unmasked target tokens." | |
| # Renormalize over tokens to make sure they are proper probabilities via | |
| # softmax over the token dimension. | |
| shift_probs = torch.nn.functional.softmax(shift_logits, 2) | |
| # Compute the probability of the target sequence (as the product of the | |
| # probability of the individual tokens in the sequence). | |
| target_probs = torch.ones(len(shift_labels), device=shift_logits.device) | |
| for i, j, k in target_idxs: | |
| target_probs[i] *= shift_probs[i, j, k] | |
| return target_probs | |
| def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor: | |
| """Helper function to compute per-sample classification loss. | |
| Assumes <eos token> is used to separate inputs from targets in the | |
| prompt text | |
| """ | |
| shift_logits, shift_labels = compute_shifted_logits_and_labels( | |
| logits, encodings, tokenizer, eoc_token_id | |
| ) | |
| device = shift_logits.device | |
| # Loss is computed token-wise, on Tensors of shape | |
| # [batch_size * (seq_len - 1), vocab_size] | |
| # and returns a loss tensor of shape | |
| # [batch_size * (seq_len - 1)]. Most of the tokens will be masked | |
| # in this computation. | |
| loss = torch.nn.functional.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1).to(device), | |
| reduction="none", | |
| ) | |
| # Reshape to [batch_size, seq_len - 1] | |
| loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu() | |
| # loss_mask is 1 for tokens we want included in the loss, and 0 for tokens | |
| # that should be ignored in the loss. | |
| loss_mask = (shift_labels != -100).int().cpu() | |
| loss *= loss_mask | |
| # Compute per-element loss : sum loss over all (unmasked) tokens and | |
| # divide by number of variable tokens to obtain tensor of | |
| # shape [batch_size,] | |
| loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float() | |
| return loss | |