JqzAugUST commited on
Commit
0750883
·
verified ·
1 Parent(s): 075d36d

Update eagle/model/modeling_minicpm_kv.py

Browse files
Files changed (1) hide show
  1. eagle/model/modeling_minicpm_kv.py +3 -3
eagle/model/modeling_minicpm_kv.py CHANGED
@@ -1589,9 +1589,9 @@ class MiniCPMDecoderLayer(nn.Module):
1589
  self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx)
1590
  else:
1591
  # <mod>
1592
- # self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1593
  # <before-after-mod> -------------------------------------------------
1594
- self.self_attn = MINICPM_ATTENTION_CLASSES["eager"](config=config, layer_idx=layer_idx)
1595
  # </mod>
1596
 
1597
  self.mlp = MiniCPMMLP(config)
@@ -1797,7 +1797,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1797
  self.layers = nn.ModuleList(
1798
  [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1799
  )
1800
- # <mod> dev: support sdpa only
1801
  # self._use_sdpa = config._attn_implementation == 'sdpa'
1802
  # self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
1803
  # <before-after-mod> -------------------------------------------------
 
1589
  self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx)
1590
  else:
1591
  # <mod>
1592
+ self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1593
  # <before-after-mod> -------------------------------------------------
1594
+ # self.self_attn = MINICPM_ATTENTION_CLASSES["eager"](config=config, layer_idx=layer_idx)
1595
  # </mod>
1596
 
1597
  self.mlp = MiniCPMMLP(config)
 
1797
  self.layers = nn.ModuleList(
1798
  [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1799
  )
1800
+ # <mod>
1801
  # self._use_sdpa = config._attn_implementation == 'sdpa'
1802
  # self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
1803
  # <before-after-mod> -------------------------------------------------