Spaces:
Build error
Build error
Update eagle/model/modeling_minicpm_kv.py
Browse files
eagle/model/modeling_minicpm_kv.py
CHANGED
|
@@ -1575,19 +1575,11 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
|
|
| 1575 |
|
| 1576 |
return attn_output, None, past_key_value
|
| 1577 |
|
| 1578 |
-
# <mod> dev: support sdpa only
|
| 1579 |
-
# MINICPM_ATTENTION_CLASSES = {
|
| 1580 |
-
# 'eager': MiniCPMAttention,
|
| 1581 |
-
# 'flash_attention_2': MiniCPMFlashAttention2,
|
| 1582 |
-
# 'sdpa': MiniCPMSdpaAttention,
|
| 1583 |
-
# }
|
| 1584 |
-
# <before-after-mod> -------------------------------------------------
|
| 1585 |
MINICPM_ATTENTION_CLASSES = {
|
| 1586 |
'eager': MiniCPMAttention,
|
| 1587 |
-
'flash_attention_2':
|
| 1588 |
-
'sdpa':
|
| 1589 |
}
|
| 1590 |
-
# </mod>
|
| 1591 |
|
| 1592 |
class MiniCPMDecoderLayer(nn.Module):
|
| 1593 |
def __init__(self, config: MiniCPMConfig, layer_idx: int):
|
|
@@ -1596,7 +1588,11 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
| 1596 |
if config.sparse_config is not None and torch.cuda.is_available():
|
| 1597 |
self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx)
|
| 1598 |
else:
|
| 1599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1600 |
|
| 1601 |
self.mlp = MiniCPMMLP(config)
|
| 1602 |
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -1805,7 +1801,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1805 |
# self._use_sdpa = config._attn_implementation == 'sdpa'
|
| 1806 |
# self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
|
| 1807 |
# <before-after-mod> -------------------------------------------------
|
| 1808 |
-
self._use_sdpa, self._use_flash_attention_2 =
|
| 1809 |
# </mod>
|
| 1810 |
|
| 1811 |
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -1963,8 +1959,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1963 |
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 1964 |
# )
|
| 1965 |
# <before-after-mod> -------------------------------------------------
|
| 1966 |
-
|
| 1967 |
-
raise NotImplementedError("JQZ 250917 | Currently support sdpa **ONLY**, further impl for flash attention or infllm attention not finished yet.")
|
| 1968 |
# # below is copied from modeling_llama_kv.py, Line 1110
|
| 1969 |
if attention_mask is None:
|
| 1970 |
attention_mask = torch.ones(
|
|
|
|
| 1575 |
|
| 1576 |
return attn_output, None, past_key_value
|
| 1577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1578 |
MINICPM_ATTENTION_CLASSES = {
|
| 1579 |
'eager': MiniCPMAttention,
|
| 1580 |
+
'flash_attention_2': MiniCPMFlashAttention2,
|
| 1581 |
+
'sdpa': MiniCPMSdpaAttention,
|
| 1582 |
}
|
|
|
|
| 1583 |
|
| 1584 |
class MiniCPMDecoderLayer(nn.Module):
|
| 1585 |
def __init__(self, config: MiniCPMConfig, layer_idx: int):
|
|
|
|
| 1588 |
if config.sparse_config is not None and torch.cuda.is_available():
|
| 1589 |
self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx)
|
| 1590 |
else:
|
| 1591 |
+
# <mod>
|
| 1592 |
+
# self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 1593 |
+
# <before-after-mod> -------------------------------------------------
|
| 1594 |
+
self.self_attn = MINICPM_ATTENTION_CLASSES["eager"](config=config, layer_idx=layer_idx)
|
| 1595 |
+
# </mod>
|
| 1596 |
|
| 1597 |
self.mlp = MiniCPMMLP(config)
|
| 1598 |
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 1801 |
# self._use_sdpa = config._attn_implementation == 'sdpa'
|
| 1802 |
# self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
|
| 1803 |
# <before-after-mod> -------------------------------------------------
|
| 1804 |
+
self._use_sdpa, self._use_flash_attention_2 = False, False
|
| 1805 |
# </mod>
|
| 1806 |
|
| 1807 |
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 1959 |
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 1960 |
# )
|
| 1961 |
# <before-after-mod> -------------------------------------------------
|
| 1962 |
+
# For HF space demo, use MiniCPMAttention **ONLY**
|
|
|
|
| 1963 |
# # below is copied from modeling_llama_kv.py, Line 1110
|
| 1964 |
if attention_mask is None:
|
| 1965 |
attention_mask = torch.ones(
|