Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import Optional, Tuple, Union, List, Callable | |
| from transformers.generation.logits_process import LogitsProcessor | |
| from transformers.generation.beam_search import BeamSearchScorer | |
| from transformers.deepspeed import is_deepspeed_zero3_enabled | |
| from transformers.generation.utils import ( | |
| LogitsProcessorList, | |
| StoppingCriteriaList, | |
| GenerationConfig, | |
| GenerationMixin, | |
| ) | |
| from transformers import LlamaForCausalLM | |
| import warnings | |
| import torch.distributed as dist | |
| from torch import nn | |
| import copy | |
| class SteamGenerationMixin(LlamaForCausalLM): | |
| # support for streamly generation | |
| # TODO: group_beam_search | |
| def stream_generate( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| prefix_allowed_tokens_fn: Optional[ | |
| Callable[[int, torch.Tensor], List[int]] | |
| ] = None, | |
| **kwargs, | |
| ): | |
| self._reorder_cache = self.base_model._reorder_cache | |
| if is_deepspeed_zero3_enabled() and dist.world_size() > 1: | |
| synced_gpus = True | |
| else: | |
| synced_gpus = False | |
| if kwargs.get("attention_mask", None) is not None: | |
| # concat prompt attention mask | |
| prefix_attention_mask = torch.ones( | |
| kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens | |
| ).to(kwargs["input_ids"].device) | |
| kwargs["attention_mask"] = torch.cat( | |
| (prefix_attention_mask, kwargs["attention_mask"]), dim=1 | |
| ) | |
| if kwargs.get("position_ids", None) is not None: | |
| warnings.warn( | |
| "Position ids are not supported for parameter efficient tuning. Ignoring position ids." | |
| ) | |
| kwargs["position_ids"] = None | |
| if kwargs.get("token_type_ids", None) is not None: | |
| warnings.warn( | |
| "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" | |
| ) | |
| kwargs["token_type_ids"] = None | |
| batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
| if generation_config is None: | |
| generation_config = self.generation_config | |
| generation_config = copy.deepcopy(generation_config) | |
| model_kwargs = generation_config.update(**kwargs) | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| has_default_max_length = ( | |
| kwargs.get("max_length") is None | |
| and generation_config.max_length is not None | |
| ) | |
| if has_default_max_length and generation_config.max_new_tokens is None: | |
| warnings.warn( | |
| f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " | |
| "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" | |
| " recommend using `max_new_tokens` to control the maximum length of the generation.", | |
| UserWarning, | |
| ) | |
| elif generation_config.max_new_tokens is not None: | |
| generation_config.max_length = ( | |
| generation_config.max_new_tokens + input_ids_seq_length | |
| ) | |
| if generation_config.min_new_tokens is not None: | |
| generation_config.min_length = ( | |
| generation_config.min_new_tokens + input_ids_seq_length | |
| ) | |
| if input_ids_seq_length >= generation_config.max_length: | |
| input_ids_string = ( | |
| "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
| ) | |
| # 2. Set generation parameters if not already defined | |
| logits_processor = ( | |
| logits_processor if logits_processor is not None else LogitsProcessorList() | |
| ) | |
| stopping_criteria = ( | |
| stopping_criteria | |
| if stopping_criteria is not None | |
| else StoppingCriteriaList() | |
| ) | |
| # 7. determine generation mode | |
| is_constraint_gen_mode = ( | |
| generation_config.constraints is not None or generation_config.force_words_ids is not None | |
| ) | |
| is_contrastive_search_gen_mode = ( | |
| generation_config.top_k is not None | |
| and generation_config.top_k > 1 | |
| and generation_config.do_sample is False | |
| and generation_config.penalty_alpha is not None | |
| and generation_config.penalty_alpha > 0 | |
| ) | |
| is_greedy_gen_mode = ( | |
| (generation_config.num_beams == 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is False | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| # beam=1 and do_sample=True | |
| is_sample_gen_mode = ( | |
| (generation_config.num_beams == 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is True | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| is_beam_gen_mode = ( | |
| (generation_config.num_beams > 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is False | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| is_beam_sample_gen_mode = ( | |
| (generation_config.num_beams > 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is True | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| is_group_beam_gen_mode = ( | |
| (generation_config.num_beams > 1) | |
| and (generation_config.num_beam_groups > 1) | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| # 8. prepare distribution pre_processing samplers | |
| logits_processor = self._get_logits_processor( | |
| generation_config=generation_config, | |
| input_ids_seq_length=input_ids_seq_length, | |
| encoder_input_ids=input_ids, | |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
| logits_processor=logits_processor, | |
| ) | |
| # 9. prepare stopping criteria | |
| stopping_criteria = self._get_stopping_criteria( | |
| generation_config=generation_config, stopping_criteria=stopping_criteria | |
| ) | |
| logits_warper = self._get_logits_warper(generation_config) | |
| if is_greedy_gen_mode: | |
| # 11. run greedy search | |
| return self.greedy_search( | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif is_sample_gen_mode: | |
| # 12. expand input_ids with `num_return_sequences` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_return_sequences, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| return self.stream_sample( | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif is_beam_gen_mode: | |
| return self.beam_search( | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif is_beam_sample_gen_mode: | |
| # interleave input_ids with `num_beams` additional sequences per batch | |
| return self.beam_sample( | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| else: | |
| raise Exception('not implement') | |
| def stream_sample( | |
| self, | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| this_peer_finished = False # used by synced_gpus only | |
| scores=() | |
| # auto-regressive generation | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # sample | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| yield input_ids | |
| # torch.cuda.empty_cache() | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
| ) | |
| # stop when each sentence is finished, or if we exceed the maximum length | |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| return input_ids | |
| def empty_cache(self): | |
| torch.cuda.empty_cache() | |
| def beam_sample( | |
| self, | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| num_beams = generation_config.num_beams | |
| batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1] | |
| beam_scorer = BeamSearchScorer( | |
| batch_size=batch_size, | |
| num_beams=generation_config.num_beams, | |
| device=input_ids.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
| max_length=generation_config.max_length, | |
| ) | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams * generation_config.num_return_sequences, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| scores = () | |
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
| # cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
| next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * num_beams, vocab_size) | |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) | |
| # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers | |
| # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see | |
| # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) | |
| next_token_scores = torch.gather(next_token_scores, -1, next_tokens) | |
| next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) | |
| next_tokens = torch.gather(next_tokens, -1, _indices) | |
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=None, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| yield input_ids | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past_key_values"] is not None: | |
| model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| sequence_outputs = beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=None, | |
| ) | |
| yield sequence_outputs["sequences"] | |
| def greedy_search( | |
| self, | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| # init values | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| # init attention / hidden states / scores tuples | |
| scores = () | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
| # argmax | |
| next_tokens = torch.argmax(next_tokens_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| yield input_ids | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
| ) | |
| # stop when each sentence is finished, or if we exceed the maximum length | |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| yield input_ids | |
| def beam_search( | |
| self, | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| # 10. go into beam search generation modes | |
| # 11. prepare beam search scorer | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| num_beams = generation_config.num_beams | |
| batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
| beam_scorer = BeamSearchScorer( | |
| batch_size=batch_size, | |
| num_beams=generation_config.num_beams, | |
| device=input_ids.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
| max_length=generation_config.max_length, | |
| ) | |
| # 12. interleave input_ids with `num_beams` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| # beam_search logits | |
| batch_beam_size, cur_len = input_ids.shape | |
| if num_beams * batch_size != batch_beam_size: | |
| raise ValueError( | |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| ) | |
| beam_scores = torch.zeros( | |
| (batch_size, num_beams), dtype=torch.float, device=input_ids.device | |
| ) | |
| beam_scores[:, 1:] = -1e9 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor( | |
| 0.0 if this_peer_finished else 1.0 | |
| ).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian. | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * num_beams, vocab_size) | |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
| next_token_scores = next_token_scores_processed + beam_scores[ | |
| :, None | |
| ].expand_as(next_token_scores) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view( | |
| batch_size, num_beams * vocab_size | |
| ) | |
| # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
| ) | |
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=None, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat( | |
| [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 | |
| ) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past_key_values"] is not None: | |
| model_kwargs["past_key_values"] = self._reorder_cache( | |
| model_kwargs["past_key_values"], beam_idx | |
| ) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| yield input_ids | |
| if beam_scorer.is_done or stopping_criteria(input_ids, None): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| final_result = beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=None, | |
| ) | |
| yield final_result["sequences"] | |