Spaces:
Running
Running
| import torch | |
| from transformers import LogitsProcessor | |
| from transformers.generation.logits_process import _calc_banned_ngram_tokens | |
| from typing import List, Set | |
| class NoRepeatNGramLogitsProcessor(LogitsProcessor): | |
| def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None): | |
| if not isinstance(ngram_size, int) or ngram_size <= 0: | |
| raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") | |
| if not isinstance(window_size, int) or window_size <= 0: | |
| raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}") | |
| self.ngram_size = ngram_size | |
| self.window_size = window_size | |
| self.whitelist_token_ids = whitelist_token_ids or set() | |
| def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor: | |
| if len(input_ids) < self.ngram_size: | |
| return scores | |
| current_prefix = tuple(input_ids[-(self.ngram_size - 1):]) | |
| search_start = max(0, len(input_ids) - self.window_size) | |
| search_end = len(input_ids) - self.ngram_size + 1 | |
| banned_tokens = set() | |
| for i in range(search_start, search_end): | |
| ngram = tuple(input_ids[i:i + self.ngram_size]) | |
| if ngram[:-1] == current_prefix: | |
| banned_tokens.add(ngram[-1]) | |
| banned_tokens = banned_tokens - self.whitelist_token_ids | |
| if banned_tokens: | |
| scores = scores.clone() | |
| for token in banned_tokens: | |
| scores[token] = -float("inf") | |
| return scores |