Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import Tuple | |
| import torch | |
| from outlines.samplers import MultinomialSampler | |
| logger = logging.getLogger(__name__) | |
| class PenalizedMultinomialSampler(MultinomialSampler): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.penalized_tokens_group: list[torch.IntTensor] = [] | |
| self.max_repeats_per_token_group: list[int] = [] | |
| self.repeats_per_token_group: list[int] = [] | |
| self.token_id_to_tokens_groups: list[list[int]] = [] | |
| def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None: | |
| max_token_ids = max(token_ids) | |
| if max_token_ids >= len(self.token_id_to_tokens_groups): | |
| self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)] | |
| for token_id in token_ids: | |
| self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group)) | |
| self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32)) | |
| self.max_repeats_per_token_group.append(max_repeats) | |
| self.repeats_per_token_group.append(0) | |
| def __call__( | |
| self, | |
| next_token_logits: torch.DoubleTensor, | |
| sequence_weights: torch.DoubleTensor, | |
| rng: torch.Generator, | |
| ) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]: | |
| """Call the multinomial sampler. | |
| Parameters | |
| ---------- | |
| next_token_logits | |
| A tensor of shape ``(n_seqs, vocab_size,)`` that represents the | |
| probability distribution of the next token over the vocabulary. | |
| sequence_weights | |
| A tensor of shape ``(n_seqs,)`` that represents the cumulative | |
| weight of each sequence. | |
| rng | |
| A random number generator. | |
| Returns | |
| ------- | |
| A tuple with an array that contains the ids of the sampled tokens of | |
| shape ``(n_seqs, 1)``, an array that contains the ancestors of each | |
| sampled id of shape ``(n_seqs,)`` and an array that contains the updated | |
| cumulative weights of each sequence of shape ``(n_seqs,)``. | |
| """ | |
| if sequence_weights.min() == sequence_weights.max() == 0: | |
| self.repeats_per_token_group = [0] * len(self.repeats_per_token_group) | |
| else: | |
| for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group): | |
| if repeats_per_token_group >= max_repeats_per_token_group: | |
| penalty = torch.zeros_like(next_token_logits) | |
| penalty[:, penalized_tokens_group] = - torch.inf | |
| next_token_logits = next_token_logits + penalty | |
| next_token_ids, ancestors, weights = super().__call__( | |
| next_token_logits=next_token_logits, | |
| sequence_weights=sequence_weights, | |
| rng=rng | |
| ) | |
| for next_token_id in next_token_ids.cpu(): | |
| if next_token_id < len(self.token_id_to_tokens_groups): | |
| for token_group in self.token_id_to_tokens_groups[next_token_id]: | |
| self.repeats_per_token_group[token_group] += 1 | |
| return next_token_ids, ancestors, weights | |