Add _support_flash_attn_2 to Llama 2 32k
Browse files- modeling_flash_llama.py +1 -0
    	
        modeling_flash_llama.py
    CHANGED
    
    | @@ -499,6 +499,7 @@ class LlamaPreTrainedModel(PreTrainedModel): | |
| 499 | 
             
                supports_gradient_checkpointing = True
         | 
| 500 | 
             
                _no_split_modules = ["LlamaDecoderLayer"]
         | 
| 501 | 
             
                _skip_keys_device_placement = "past_key_values"
         | 
|  | |
| 502 |  | 
| 503 | 
             
                def _init_weights(self, module):
         | 
| 504 | 
             
                    std = self.config.initializer_range
         | 
|  | |
| 499 | 
             
                supports_gradient_checkpointing = True
         | 
| 500 | 
             
                _no_split_modules = ["LlamaDecoderLayer"]
         | 
| 501 | 
             
                _skip_keys_device_placement = "past_key_values"
         | 
| 502 | 
            +
                _supports_flash_attn_2 = True
         | 
| 503 |  | 
| 504 | 
             
                def _init_weights(self, module):
         | 
| 505 | 
             
                    std = self.config.initializer_range
         | 

