Fix decoder output class
#3
by
						
rom7
	
							
						- opened
							
					
    	
        torch-ext/mamba_ssm/utils/generation.py
    CHANGED
    
    | @@ -11,7 +11,7 @@ import torch.nn.functional as F | |
| 11 | 
             
            from einops import rearrange, repeat
         | 
| 12 | 
             
            from torch import Tensor
         | 
| 13 | 
             
            from torch.profiler import ProfilerActivity, profile, record_function
         | 
| 14 | 
            -
            from transformers.generation import  | 
| 15 |  | 
| 16 |  | 
| 17 | 
             
            @dataclass
         | 
| @@ -146,7 +146,7 @@ def decode( | |
| 146 | 
             
                    max_length: int
         | 
| 147 | 
             
                    teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
         | 
| 148 | 
             
                        logits, the next token is taken from the teacher_outputs. Useful for testing.
         | 
| 149 | 
            -
                Returns:  | 
| 150 | 
             
                    sequences: (batch, max_length)
         | 
| 151 | 
             
                    scores: tuples of (batch, vocab_size)
         | 
| 152 | 
             
                """
         | 
| @@ -240,7 +240,7 @@ def decode( | |
| 240 | 
             
                    end.record()
         | 
| 241 | 
             
                    torch.cuda.synchronize()
         | 
| 242 | 
             
                    print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
         | 
| 243 | 
            -
                output_cls =  | 
| 244 | 
             
                return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
         | 
| 245 |  | 
| 246 |  | 
|  | |
| 11 | 
             
            from einops import rearrange, repeat
         | 
| 12 | 
             
            from torch import Tensor
         | 
| 13 | 
             
            from torch.profiler import ProfilerActivity, profile, record_function
         | 
| 14 | 
            +
            from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer
         | 
| 15 |  | 
| 16 |  | 
| 17 | 
             
            @dataclass
         | 
|  | |
| 146 | 
             
                    max_length: int
         | 
| 147 | 
             
                    teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
         | 
| 148 | 
             
                        logits, the next token is taken from the teacher_outputs. Useful for testing.
         | 
| 149 | 
            +
                Returns: GenerateDecoderOnlyOutput, with the following fields:
         | 
| 150 | 
             
                    sequences: (batch, max_length)
         | 
| 151 | 
             
                    scores: tuples of (batch, vocab_size)
         | 
| 152 | 
             
                """
         | 
|  | |
| 240 | 
             
                    end.record()
         | 
| 241 | 
             
                    torch.cuda.synchronize()
         | 
| 242 | 
             
                    print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
         | 
| 243 | 
            +
                output_cls = GenerateDecoderOnlyOutput
         | 
| 244 | 
             
                return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
         | 
| 245 |  | 
| 246 |  | 
