Kernels

Fix decoder output class

#3
by rom7 - opened
torch-ext/mamba_ssm/utils/generation.py CHANGED
@@ -11,7 +11,7 @@ import torch.nn.functional as F
11
  from einops import rearrange, repeat
12
  from torch import Tensor
13
  from torch.profiler import ProfilerActivity, profile, record_function
14
- from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
 
16
 
17
  @dataclass
@@ -146,7 +146,7 @@ def decode(
146
  max_length: int
147
  teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148
  logits, the next token is taken from the teacher_outputs. Useful for testing.
149
- Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
150
  sequences: (batch, max_length)
151
  scores: tuples of (batch, vocab_size)
152
  """
@@ -240,7 +240,7 @@ def decode(
240
  end.record()
241
  torch.cuda.synchronize()
242
  print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243
- output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
244
  return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245
 
246
 
 
11
  from einops import rearrange, repeat
12
  from torch import Tensor
13
  from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer
15
 
16
 
17
  @dataclass
 
146
  max_length: int
147
  teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148
  logits, the next token is taken from the teacher_outputs. Useful for testing.
149
+ Returns: GenerateDecoderOnlyOutput, with the following fields:
150
  sequences: (batch, max_length)
151
  scores: tuples of (batch, vocab_size)
152
  """
 
240
  end.record()
241
  torch.cuda.synchronize()
242
  print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243
+ output_cls = GenerateDecoderOnlyOutput
244
  return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245
 
246