JqzAugUST commited on
Commit
075d36d
·
verified ·
1 Parent(s): 7149349

Update eagle/model/modeling_minicpm_kv.py

Browse files
Files changed (1) hide show
  1. eagle/model/modeling_minicpm_kv.py +9 -14
eagle/model/modeling_minicpm_kv.py CHANGED
@@ -1575,19 +1575,11 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
1575
 
1576
  return attn_output, None, past_key_value
1577
 
1578
- # <mod> dev: support sdpa only
1579
- # MINICPM_ATTENTION_CLASSES = {
1580
- # 'eager': MiniCPMAttention,
1581
- # 'flash_attention_2': MiniCPMFlashAttention2,
1582
- # 'sdpa': MiniCPMSdpaAttention,
1583
- # }
1584
- # <before-after-mod> -------------------------------------------------
1585
  MINICPM_ATTENTION_CLASSES = {
1586
  'eager': MiniCPMAttention,
1587
- 'flash_attention_2': MiniCPMAttention,
1588
- 'sdpa': MiniCPMAttention,
1589
  }
1590
- # </mod>
1591
 
1592
  class MiniCPMDecoderLayer(nn.Module):
1593
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
@@ -1596,7 +1588,11 @@ class MiniCPMDecoderLayer(nn.Module):
1596
  if config.sparse_config is not None and torch.cuda.is_available():
1597
  self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx)
1598
  else:
1599
- self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
 
 
 
 
1600
 
1601
  self.mlp = MiniCPMMLP(config)
1602
  self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1805,7 +1801,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1805
  # self._use_sdpa = config._attn_implementation == 'sdpa'
1806
  # self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
1807
  # <before-after-mod> -------------------------------------------------
1808
- self._use_sdpa, self._use_flash_attention_2 = True, False
1809
  # </mod>
1810
 
1811
  self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1963,8 +1959,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1963
  # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1964
  # )
1965
  # <before-after-mod> -------------------------------------------------
1966
- if not self._use_sdpa:
1967
- raise NotImplementedError("JQZ 250917 | Currently support sdpa **ONLY**, further impl for flash attention or infllm attention not finished yet.")
1968
  # # below is copied from modeling_llama_kv.py, Line 1110
1969
  if attention_mask is None:
1970
  attention_mask = torch.ones(
 
1575
 
1576
  return attn_output, None, past_key_value
1577
 
 
 
 
 
 
 
 
1578
  MINICPM_ATTENTION_CLASSES = {
1579
  'eager': MiniCPMAttention,
1580
+ 'flash_attention_2': MiniCPMFlashAttention2,
1581
+ 'sdpa': MiniCPMSdpaAttention,
1582
  }
 
1583
 
1584
  class MiniCPMDecoderLayer(nn.Module):
1585
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
 
1588
  if config.sparse_config is not None and torch.cuda.is_available():
1589
  self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx)
1590
  else:
1591
+ # <mod>
1592
+ # self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1593
+ # <before-after-mod> -------------------------------------------------
1594
+ self.self_attn = MINICPM_ATTENTION_CLASSES["eager"](config=config, layer_idx=layer_idx)
1595
+ # </mod>
1596
 
1597
  self.mlp = MiniCPMMLP(config)
1598
  self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1801
  # self._use_sdpa = config._attn_implementation == 'sdpa'
1802
  # self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
1803
  # <before-after-mod> -------------------------------------------------
1804
+ self._use_sdpa, self._use_flash_attention_2 = False, False
1805
  # </mod>
1806
 
1807
  self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1959
  # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1960
  # )
1961
  # <before-after-mod> -------------------------------------------------
1962
+ # For HF space demo, use MiniCPMAttention **ONLY**
 
1963
  # # below is copied from modeling_llama_kv.py, Line 1110
1964
  if attention_mask is None:
1965
  attention_mask = torch.ones(