Upload modeling_nemotron_h.py (#10)
Browse files- Upload modeling_nemotron_h.py (0c2b9f771b8918befc13593f3243d94446e159c4)
- modeling_nemotron_h.py +6 -1
    	
        modeling_nemotron_h.py
    CHANGED
    
    | @@ -42,7 +42,7 @@ from transformers.utils.import_utils import ( | |
| 42 | 
             
                is_causal_conv1d_available,
         | 
| 43 | 
             
                is_flash_attn_2_available,
         | 
| 44 | 
             
                is_flash_attn_greater_or_equal_2_10,
         | 
| 45 | 
            -
                is_mamba_2_ssm_available, | 
| 46 | 
             
            )
         | 
| 47 | 
             
            from .configuration_nemotron_h import NemotronHConfig
         | 
| 48 |  | 
| @@ -1542,6 +1542,11 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): | |
| 1542 |  | 
| 1543 | 
             
                    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
         | 
| 1544 | 
             
                    if inputs_embeds is not None and empty_past_kv:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1545 | 
             
                        model_inputs = {"inputs_embeds": inputs_embeds}
         | 
| 1546 | 
             
                    else:
         | 
| 1547 | 
             
                        model_inputs = {"input_ids": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases
         | 
|  | |
| 42 | 
             
                is_causal_conv1d_available,
         | 
| 43 | 
             
                is_flash_attn_2_available,
         | 
| 44 | 
             
                is_flash_attn_greater_or_equal_2_10,
         | 
| 45 | 
            +
                is_mamba_2_ssm_available,
         | 
| 46 | 
             
            )
         | 
| 47 | 
             
            from .configuration_nemotron_h import NemotronHConfig
         | 
| 48 |  | 
|  | |
| 1542 |  | 
| 1543 | 
             
                    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
         | 
| 1544 | 
             
                    if inputs_embeds is not None and empty_past_kv:
         | 
| 1545 | 
            +
                        # TODO(pjin): workaround fix for properly extending inputs_embeds;
         | 
| 1546 | 
            +
                        # longer term, may be better handled elsewhere in .generate().
         | 
| 1547 | 
            +
                        if input_ids is not None and inputs_embeds.shape[1] < input_ids.shape[1]:
         | 
| 1548 | 
            +
                            new_token_embeds = self.get_input_embeddings()(input_ids[:,inputs_embeds.shape[1]:])
         | 
| 1549 | 
            +
                            inputs_embeds = torch.cat([inputs_embeds, new_token_embeds], dim=1)
         | 
| 1550 | 
             
                        model_inputs = {"inputs_embeds": inputs_embeds}
         | 
| 1551 | 
             
                    else:
         | 
| 1552 | 
             
                        model_inputs = {"input_ids": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases
         | 
