Add files using upload-large-folder tool
Browse files- README.md +2 -2
- __init__.py +26 -0
- __pycache__/configuration_minimax_m2.cpython-313.pyc +0 -0
- config.json +6 -2
- configuration_minimax_m2.py +147 -0
- modeling_minimax_m2.py +765 -0
- test_minimax_m2_hf.py +178 -0
README.md
CHANGED
|
@@ -4,7 +4,7 @@ license: mit
|
|
| 4 |
---
|
| 5 |
|
| 6 |
<div align="center">
|
| 7 |
-
Upconverted to BFloat16 by <a href='https://x.com/qubitium'>@Qubitum</a> at ModelCloud
|
| 8 |
|
| 9 |
<svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
|
| 10 |
<path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
|
|
@@ -186,4 +186,4 @@ Please refer to our [Tool Calling Guide](https://huggingface.co/MiniMaxAI/MiniMa
|
|
| 186 |
|
| 187 |
# Contact Us
|
| 188 |
|
| 189 |
-
Contact us at [model@minimax.io](mailto:model@minimax.io).
|
|
|
|
| 4 |
---
|
| 5 |
|
| 6 |
<div align="center">
|
| 7 |
+
Upconverted to BFloat16 by <a href='https://x.com/qubitium'>@Qubitum</a> at ModelCloud.
|
| 8 |
|
| 9 |
<svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
|
| 10 |
<path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
|
|
|
|
| 186 |
|
| 187 |
# Contact Us
|
| 188 |
|
| 189 |
+
Contact us at [model@minimax.io](mailto:model@minimax.io).
|
__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
|
| 2 |
+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
|
| 5 |
+
#
|
| 6 |
+
# """MiniMax M2 Hugging Face remote code support."""
|
| 7 |
+
|
| 8 |
+
from .configuration_minimax_m2 import MiniMaxM2Config
|
| 9 |
+
from .modeling_minimax_m2 import (
|
| 10 |
+
MiniMaxForCausalLM,
|
| 11 |
+
MiniMaxM2ForCausalLM,
|
| 12 |
+
MiniMaxM2Model,
|
| 13 |
+
MiniMaxM2PreTrainedModel,
|
| 14 |
+
MiniMaxModel,
|
| 15 |
+
MiniMaxPreTrainedModel,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"MiniMaxM2Config",
|
| 20 |
+
"MiniMaxM2PreTrainedModel",
|
| 21 |
+
"MiniMaxM2Model",
|
| 22 |
+
"MiniMaxM2ForCausalLM",
|
| 23 |
+
"MiniMaxPreTrainedModel",
|
| 24 |
+
"MiniMaxModel",
|
| 25 |
+
"MiniMaxForCausalLM",
|
| 26 |
+
]
|
__pycache__/configuration_minimax_m2.cpython-313.pyc
ADDED
|
Binary file (6.02 kB). View file
|
|
|
config.json
CHANGED
|
@@ -79,7 +79,7 @@
|
|
| 79 |
"layernorm_mlp_beta": 1.0,
|
| 80 |
"max_position_embeddings": 196608,
|
| 81 |
"mlp_intermediate_size": 8192,
|
| 82 |
-
"model_type": "
|
| 83 |
"mtp_transformer_layers": 1,
|
| 84 |
"num_attention_heads": 48,
|
| 85 |
"num_experts_per_tok": 8,
|
|
@@ -105,5 +105,9 @@
|
|
| 105 |
"use_qk_norm": true,
|
| 106 |
"use_routing_bias": true,
|
| 107 |
"vocab_size": 200064,
|
| 108 |
-
"torch_dtype": "bfloat16"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
|
|
|
| 79 |
"layernorm_mlp_beta": 1.0,
|
| 80 |
"max_position_embeddings": 196608,
|
| 81 |
"mlp_intermediate_size": 8192,
|
| 82 |
+
"model_type": "minimax",
|
| 83 |
"mtp_transformer_layers": 1,
|
| 84 |
"num_attention_heads": 48,
|
| 85 |
"num_experts_per_tok": 8,
|
|
|
|
| 105 |
"use_qk_norm": true,
|
| 106 |
"use_routing_bias": true,
|
| 107 |
"vocab_size": 200064,
|
| 108 |
+
"torch_dtype": "bfloat16",
|
| 109 |
+
"auto_map": {
|
| 110 |
+
"AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
|
| 111 |
+
"AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
|
| 112 |
+
}
|
| 113 |
}
|
configuration_minimax_m2.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
|
| 2 |
+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
|
| 5 |
+
|
| 6 |
+
"""Configuration for the MiniMax M2 architecture."""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import List, Optional, Union
|
| 11 |
+
|
| 12 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class _QuantizationConfigDict(dict):
|
| 16 |
+
"""Ensure quantization config always exposes a `quant_method`."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, data: Optional[dict] = None):
|
| 19 |
+
if data is None:
|
| 20 |
+
data = {}
|
| 21 |
+
super().__init__(data)
|
| 22 |
+
self.setdefault("quant_method", "none")
|
| 23 |
+
|
| 24 |
+
def to_dict(self):
|
| 25 |
+
return dict(self)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MiniMaxM2Config(PretrainedConfig):
|
| 29 |
+
model_type = "minimax"
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
vocab_size: int = 200_064,
|
| 34 |
+
hidden_size: int = 3_072,
|
| 35 |
+
intermediate_size: int = 1_536,
|
| 36 |
+
mlp_intermediate_size: int = 8_192,
|
| 37 |
+
num_hidden_layers: int = 62,
|
| 38 |
+
num_attention_heads: int = 48,
|
| 39 |
+
num_key_value_heads: int = 8,
|
| 40 |
+
head_dim: Optional[int] = 128,
|
| 41 |
+
num_local_experts: int = 256,
|
| 42 |
+
num_experts_per_tok: int = 8,
|
| 43 |
+
attn_type_list: Optional[List[int]] = None,
|
| 44 |
+
attention_dropout: float = 0.0,
|
| 45 |
+
hidden_act: str = "silu",
|
| 46 |
+
rms_norm_eps: float = 1e-6,
|
| 47 |
+
max_position_embeddings: int = 196_608,
|
| 48 |
+
rope_theta: float = 5_000_000.0,
|
| 49 |
+
rotary_dim: int = 64,
|
| 50 |
+
rope_scaling: Optional[dict] = None,
|
| 51 |
+
use_qk_norm: bool = True,
|
| 52 |
+
qk_norm_type: str = "per_layer",
|
| 53 |
+
use_routing_bias: bool = True,
|
| 54 |
+
scoring_func: str = "sigmoid",
|
| 55 |
+
router_aux_loss_coef: float = 0.001,
|
| 56 |
+
router_jitter_noise: float = 0.0,
|
| 57 |
+
output_router_logits: bool = False,
|
| 58 |
+
use_grouped_topk: bool = True,
|
| 59 |
+
num_expert_group: Optional[int] = None,
|
| 60 |
+
topk_group: Optional[int] = None,
|
| 61 |
+
routed_scaling_factor: float = 1.0,
|
| 62 |
+
layernorm_full_attention_beta: float = 1.0,
|
| 63 |
+
layernorm_linear_attention_beta: float = 1.0,
|
| 64 |
+
layernorm_mlp_beta: float = 1.0,
|
| 65 |
+
shared_intermediate_size: int = 0,
|
| 66 |
+
shared_moe_mode: str = "sigmoid",
|
| 67 |
+
use_mtp: bool = True,
|
| 68 |
+
num_mtp_modules: int = 3,
|
| 69 |
+
mtp_transformer_layers: int = 1,
|
| 70 |
+
attn_window_size: Optional[Union[int, List[int]]] = None,
|
| 71 |
+
swa_rope_theta: float = -1.0,
|
| 72 |
+
sliding_window: Optional[int] = None,
|
| 73 |
+
initializer_range: float = 0.02,
|
| 74 |
+
tie_word_embeddings: bool = False,
|
| 75 |
+
max_model_len: Optional[int] = None,
|
| 76 |
+
bos_token_id: Optional[int] = None,
|
| 77 |
+
eos_token_id: Optional[int] = None,
|
| 78 |
+
pad_token_id: Optional[int] = None,
|
| 79 |
+
use_cache: bool = True,
|
| 80 |
+
**kwargs,
|
| 81 |
+
) -> None:
|
| 82 |
+
quantization_config = kwargs.pop("quantization_config", None)
|
| 83 |
+
if quantization_config is None:
|
| 84 |
+
quantization_config = _QuantizationConfigDict()
|
| 85 |
+
elif not isinstance(quantization_config, _QuantizationConfigDict):
|
| 86 |
+
quantization_config = _QuantizationConfigDict(quantization_config)
|
| 87 |
+
transformers_version = kwargs.pop("transformers_version", None)
|
| 88 |
+
|
| 89 |
+
super().__init__(
|
| 90 |
+
bos_token_id=bos_token_id,
|
| 91 |
+
eos_token_id=eos_token_id,
|
| 92 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 93 |
+
pad_token_id=pad_token_id,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.vocab_size = vocab_size
|
| 98 |
+
self.hidden_size = hidden_size
|
| 99 |
+
self.intermediate_size = intermediate_size
|
| 100 |
+
self.mlp_intermediate_size = mlp_intermediate_size
|
| 101 |
+
self.num_hidden_layers = num_hidden_layers
|
| 102 |
+
self.num_attention_heads = num_attention_heads
|
| 103 |
+
self.num_key_value_heads = num_key_value_heads
|
| 104 |
+
self.head_dim = head_dim or hidden_size // num_attention_heads
|
| 105 |
+
self.num_local_experts = num_local_experts
|
| 106 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 107 |
+
self.attn_type_list = attn_type_list or [1] * num_hidden_layers
|
| 108 |
+
self.attention_dropout = attention_dropout
|
| 109 |
+
self.hidden_act = hidden_act
|
| 110 |
+
self.rms_norm_eps = rms_norm_eps
|
| 111 |
+
self.max_position_embeddings = max_position_embeddings
|
| 112 |
+
self.rope_theta = rope_theta
|
| 113 |
+
self.rotary_dim = rotary_dim
|
| 114 |
+
self.rope_scaling = rope_scaling
|
| 115 |
+
self.use_qk_norm = use_qk_norm
|
| 116 |
+
self.qk_norm_type = qk_norm_type
|
| 117 |
+
self.use_routing_bias = use_routing_bias
|
| 118 |
+
self.scoring_func = scoring_func
|
| 119 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 120 |
+
self.router_jitter_noise = router_jitter_noise
|
| 121 |
+
self.output_router_logits = output_router_logits
|
| 122 |
+
self.use_grouped_topk = use_grouped_topk
|
| 123 |
+
self.num_expert_group = num_expert_group
|
| 124 |
+
self.topk_group = topk_group
|
| 125 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 126 |
+
self.layernorm_full_attention_beta = layernorm_full_attention_beta
|
| 127 |
+
self.layernorm_linear_attention_beta = layernorm_linear_attention_beta
|
| 128 |
+
self.layernorm_mlp_beta = layernorm_mlp_beta
|
| 129 |
+
self.shared_intermediate_size = shared_intermediate_size
|
| 130 |
+
self.shared_moe_mode = shared_moe_mode
|
| 131 |
+
self.use_mtp = use_mtp
|
| 132 |
+
self.num_mtp_modules = num_mtp_modules
|
| 133 |
+
self.mtp_transformer_layers = mtp_transformer_layers
|
| 134 |
+
self.attn_window_size = attn_window_size
|
| 135 |
+
self.swa_rope_theta = swa_rope_theta
|
| 136 |
+
self.sliding_window = sliding_window
|
| 137 |
+
self.initializer_range = initializer_range
|
| 138 |
+
self.max_model_len = max_model_len
|
| 139 |
+
self.use_cache = use_cache
|
| 140 |
+
|
| 141 |
+
# Convenient accessor used by rotary embedding helper
|
| 142 |
+
self.partial_rotary_factor = float(self.rotary_dim) / float(self.head_dim)
|
| 143 |
+
self.quantization_config = quantization_config
|
| 144 |
+
self.transformers_version = transformers_version
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
__all__ = ["MiniMaxM2Config"]
|
modeling_minimax_m2.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
|
| 2 |
+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
|
| 5 |
+
|
| 6 |
+
"""PyTorch implementation of the MiniMax M2 architecture for Hugging Face Transformers."""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import time
|
| 12 |
+
from typing import Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
from transformers.activations import ACT2FN
|
| 19 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 20 |
+
from transformers.generation import GenerationMixin
|
| 21 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 22 |
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
| 23 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
+
from transformers.utils import logging
|
| 25 |
+
|
| 26 |
+
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, repeat_kv, rotate_half
|
| 27 |
+
|
| 28 |
+
from .configuration_minimax_m2 import MiniMaxM2Config
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
_CONFIG_FOR_DOC = "MiniMaxM2Config"
|
| 33 |
+
_CHECKPOINT_FOR_DOC = "MiniMaxAI/MiniMax-M2"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_balancing_loss_func(
|
| 37 |
+
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
|
| 38 |
+
num_experts: int,
|
| 39 |
+
top_k: int,
|
| 40 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
if gate_logits is None:
|
| 43 |
+
return torch.tensor(0.0)
|
| 44 |
+
if isinstance(gate_logits, torch.Tensor):
|
| 45 |
+
logits = gate_logits
|
| 46 |
+
else:
|
| 47 |
+
logits = torch.cat([layer_gate.to(gate_logits[0].device) for layer_gate in gate_logits], dim=0)
|
| 48 |
+
|
| 49 |
+
routing_weights = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 50 |
+
_, selected = torch.topk(routing_weights, top_k, dim=-1)
|
| 51 |
+
expert_mask = torch.nn.functional.one_hot(selected, num_experts)
|
| 52 |
+
|
| 53 |
+
if attention_mask is None:
|
| 54 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 55 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 56 |
+
else:
|
| 57 |
+
batch_size, seq_len = attention_mask.shape
|
| 58 |
+
num_layers = logits.shape[0] // (batch_size * seq_len)
|
| 59 |
+
|
| 60 |
+
expanded_mask = (
|
| 61 |
+
attention_mask[None, :, :, None, None]
|
| 62 |
+
.expand(num_layers, batch_size, seq_len, top_k, num_experts)
|
| 63 |
+
.reshape(-1, top_k, num_experts)
|
| 64 |
+
.to(logits.device)
|
| 65 |
+
)
|
| 66 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expanded_mask, dim=0) / torch.sum(expanded_mask, dim=0)
|
| 67 |
+
|
| 68 |
+
router_mask = (
|
| 69 |
+
attention_mask[None, :, :, None]
|
| 70 |
+
.expand(num_layers, batch_size, seq_len, num_experts)
|
| 71 |
+
.reshape(-1, num_experts)
|
| 72 |
+
.to(logits.device)
|
| 73 |
+
)
|
| 74 |
+
router_prob_per_expert = torch.sum(routing_weights * router_mask, dim=0) / torch.sum(router_mask, dim=0)
|
| 75 |
+
|
| 76 |
+
loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 77 |
+
return loss * num_experts
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def apply_rotary_pos_emb_partial(
|
| 81 |
+
q: torch.Tensor,
|
| 82 |
+
k: torch.Tensor,
|
| 83 |
+
cos: torch.Tensor,
|
| 84 |
+
sin: torch.Tensor,
|
| 85 |
+
rotary_dim: int,
|
| 86 |
+
unsqueeze_dim: int = 2,
|
| 87 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 88 |
+
cos = cos.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
|
| 89 |
+
sin = sin.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
|
| 90 |
+
q_rot = q[..., :rotary_dim]
|
| 91 |
+
k_rot = k[..., :rotary_dim]
|
| 92 |
+
|
| 93 |
+
q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
| 94 |
+
k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
| 95 |
+
|
| 96 |
+
q = torch.cat((q_rot, q[..., rotary_dim:]), dim=-1)
|
| 97 |
+
k = torch.cat((k_rot, k[..., rotary_dim:]), dim=-1)
|
| 98 |
+
return q, k
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class MiniMaxM2RMSNorm(nn.Module):
|
| 102 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 105 |
+
self.variance_epsilon = eps
|
| 106 |
+
|
| 107 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
input_dtype = hidden_states.dtype
|
| 109 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 110 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 111 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 112 |
+
return (self.weight * hidden_states).to(input_dtype)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class MiniMaxM2MLP(nn.Module):
|
| 116 |
+
def __init__(self, config: MiniMaxM2Config) -> None:
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.hidden_size = config.hidden_size
|
| 119 |
+
self.intermediate_size = config.intermediate_size
|
| 120 |
+
|
| 121 |
+
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 122 |
+
self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 123 |
+
self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 124 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 125 |
+
|
| 126 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
gate = self.act_fn(self.w1(hidden_states))
|
| 128 |
+
up = self.w3(hidden_states)
|
| 129 |
+
hidden_states = gate * up
|
| 130 |
+
hidden_states = self.w2(hidden_states)
|
| 131 |
+
return hidden_states
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class MiniMaxM2SparseMoeBlock(nn.Module):
|
| 135 |
+
def __init__(self, config: MiniMaxM2Config) -> None:
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.hidden_dim = config.hidden_size
|
| 138 |
+
self.experts = nn.ModuleList([MiniMaxM2MLP(config) for _ in range(config.num_local_experts)])
|
| 139 |
+
self.num_experts = config.num_local_experts
|
| 140 |
+
self.top_k = config.num_experts_per_tok
|
| 141 |
+
self.jitter_noise = config.router_jitter_noise
|
| 142 |
+
self.use_routing_bias = config.use_routing_bias
|
| 143 |
+
self.scoring_func = getattr(config, "scoring_func", "softmax")
|
| 144 |
+
self.use_grouped_topk = getattr(config, "use_grouped_topk", False)
|
| 145 |
+
self.num_expert_group = getattr(config, "num_expert_group", None)
|
| 146 |
+
self.topk_group = getattr(config, "topk_group", None)
|
| 147 |
+
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
| 148 |
+
|
| 149 |
+
if self.use_grouped_topk:
|
| 150 |
+
if self.num_expert_group is None or self.num_expert_group <= 0:
|
| 151 |
+
self.num_expert_group = 1
|
| 152 |
+
if self.topk_group is None or self.topk_group <= 0:
|
| 153 |
+
self.topk_group = min(self.num_expert_group, self.top_k)
|
| 154 |
+
else:
|
| 155 |
+
self.num_expert_group = 1
|
| 156 |
+
self.topk_group = 1
|
| 157 |
+
|
| 158 |
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 159 |
+
if self.use_routing_bias:
|
| 160 |
+
self.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
|
| 161 |
+
else:
|
| 162 |
+
self.register_parameter("e_score_correction_bias", None)
|
| 163 |
+
|
| 164 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 165 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 166 |
+
if self.training and self.jitter_noise > 0:
|
| 167 |
+
noise = torch.empty_like(hidden_states).uniform_(
|
| 168 |
+
1.0 - self.jitter_noise,
|
| 169 |
+
1.0 + self.jitter_noise,
|
| 170 |
+
)
|
| 171 |
+
hidden_states = hidden_states * noise
|
| 172 |
+
|
| 173 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 174 |
+
gate_dtype = self.gate.weight.dtype
|
| 175 |
+
router_logits = self.gate(hidden_states.to(gate_dtype)).to(torch.float32)
|
| 176 |
+
if self.e_score_correction_bias is not None:
|
| 177 |
+
# Bias is applied after scoring (see vLLM/SGLang implementations).
|
| 178 |
+
correction_bias = self.e_score_correction_bias.to(router_logits.device, router_logits.dtype)
|
| 179 |
+
else:
|
| 180 |
+
correction_bias = None
|
| 181 |
+
|
| 182 |
+
if self.scoring_func == "sigmoid":
|
| 183 |
+
scores = torch.sigmoid(router_logits)
|
| 184 |
+
elif self.scoring_func == "softmax":
|
| 185 |
+
scores = torch.softmax(router_logits, dim=-1)
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
|
| 188 |
+
|
| 189 |
+
if correction_bias is not None:
|
| 190 |
+
original_scores = scores
|
| 191 |
+
scores = scores + correction_bias
|
| 192 |
+
else:
|
| 193 |
+
original_scores = scores
|
| 194 |
+
topk_scores: torch.Tensor
|
| 195 |
+
if self.use_grouped_topk and self.num_expert_group > 1:
|
| 196 |
+
experts_per_group = scores.size(-1) // self.num_expert_group
|
| 197 |
+
scores_grouped = scores.view(scores.size(0), self.num_expert_group, experts_per_group)
|
| 198 |
+
if correction_bias is not None:
|
| 199 |
+
topk_in_group = min(2, experts_per_group)
|
| 200 |
+
if topk_in_group > 0:
|
| 201 |
+
group_scores = scores_grouped.topk(topk_in_group, dim=-1)[0].sum(dim=-1)
|
| 202 |
+
else:
|
| 203 |
+
group_scores = torch.zeros_like(scores_grouped[..., 0])
|
| 204 |
+
else:
|
| 205 |
+
group_scores = scores_grouped.max(dim=-1).values
|
| 206 |
+
group_mask = torch.zeros_like(group_scores)
|
| 207 |
+
selected_groups = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True).indices
|
| 208 |
+
group_mask.scatter_(1, selected_groups, 1.0)
|
| 209 |
+
mask = group_mask.unsqueeze(-1).expand(-1, -1, experts_per_group).reshape(scores.size())
|
| 210 |
+
masked_scores = scores.masked_fill(mask == 0, float("-inf"))
|
| 211 |
+
topk_scores, selected_experts = torch.topk(masked_scores, self.top_k, dim=-1, sorted=True)
|
| 212 |
+
else:
|
| 213 |
+
topk_scores, selected_experts = torch.topk(scores, self.top_k, dim=-1, sorted=True)
|
| 214 |
+
|
| 215 |
+
if correction_bias is not None:
|
| 216 |
+
routing_weights = original_scores.gather(1, selected_experts)
|
| 217 |
+
else:
|
| 218 |
+
routing_weights = topk_scores
|
| 219 |
+
|
| 220 |
+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)
|
| 221 |
+
if self.routed_scaling_factor != 1.0:
|
| 222 |
+
routing_weights = routing_weights * self.routed_scaling_factor
|
| 223 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 224 |
+
selected_experts = selected_experts.to(torch.long)
|
| 225 |
+
|
| 226 |
+
final_hidden_states = torch.zeros_like(hidden_states)
|
| 227 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 228 |
+
expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten()
|
| 229 |
+
|
| 230 |
+
for expert_idx in expert_hit.tolist():
|
| 231 |
+
expert_layer = self.experts[expert_idx]
|
| 232 |
+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
| 233 |
+
token_states = hidden_states.index_select(0, top_x)
|
| 234 |
+
expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1)
|
| 235 |
+
final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
|
| 236 |
+
|
| 237 |
+
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
|
| 238 |
+
return final_hidden_states, router_logits
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class MiniMaxM2Attention(nn.Module):
|
| 242 |
+
def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.config = config
|
| 245 |
+
self.layer_idx = layer_idx
|
| 246 |
+
|
| 247 |
+
self.head_dim = config.head_dim
|
| 248 |
+
self.num_heads = config.num_attention_heads
|
| 249 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 250 |
+
self.num_key_value_groups = self.num_heads // max(1, self.num_key_value_heads)
|
| 251 |
+
self.rotary_dim = config.rotary_dim
|
| 252 |
+
self.scaling = self.head_dim**-0.5
|
| 253 |
+
self.attention_dropout = config.attention_dropout
|
| 254 |
+
self.is_causal = True
|
| 255 |
+
|
| 256 |
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
| 257 |
+
max_model_len = getattr(config, "max_model_len", None)
|
| 258 |
+
if max_model_len is not None:
|
| 259 |
+
max_position_embeddings = max(max_position_embeddings, max_model_len)
|
| 260 |
+
|
| 261 |
+
attn_window_size = getattr(config, "attn_window_size", None)
|
| 262 |
+
if isinstance(attn_window_size, list):
|
| 263 |
+
sliding_window = attn_window_size[layer_idx]
|
| 264 |
+
else:
|
| 265 |
+
sliding_window = attn_window_size
|
| 266 |
+
if sliding_window is not None and sliding_window <= 0:
|
| 267 |
+
sliding_window = None
|
| 268 |
+
self.sliding_window = sliding_window
|
| 269 |
+
|
| 270 |
+
swa_rope_theta = getattr(config, "swa_rope_theta", -1.0)
|
| 271 |
+
rope_theta = config.rope_theta
|
| 272 |
+
if self.sliding_window is not None and swa_rope_theta > 0:
|
| 273 |
+
rope_theta = swa_rope_theta
|
| 274 |
+
|
| 275 |
+
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 276 |
+
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 277 |
+
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 278 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 279 |
+
|
| 280 |
+
self.use_qk_norm = config.use_qk_norm
|
| 281 |
+
if self.use_qk_norm:
|
| 282 |
+
self.q_norm = MiniMaxM2RMSNorm(self.num_heads * self.head_dim, eps=config.rms_norm_eps)
|
| 283 |
+
self.k_norm = MiniMaxM2RMSNorm(self.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps)
|
| 284 |
+
|
| 285 |
+
rope_config = copy.deepcopy(config)
|
| 286 |
+
rope_config.hidden_size = config.hidden_size
|
| 287 |
+
rope_config.num_attention_heads = config.num_attention_heads
|
| 288 |
+
rope_config.partial_rotary_factor = float(config.rotary_dim) / float(self.head_dim)
|
| 289 |
+
rope_config.rope_theta = rope_theta
|
| 290 |
+
rope_config.max_position_embeddings = max_position_embeddings
|
| 291 |
+
self.rotary_emb = LlamaRotaryEmbedding(rope_config)
|
| 292 |
+
|
| 293 |
+
def forward(
|
| 294 |
+
self,
|
| 295 |
+
hidden_states: torch.Tensor,
|
| 296 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 297 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 298 |
+
past_key_values: Optional[Cache] = None,
|
| 299 |
+
use_cache: Optional[bool] = False,
|
| 300 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 301 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 302 |
+
output_attentions: bool = False,
|
| 303 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 304 |
+
bsz, q_len, _ = hidden_states.size()
|
| 305 |
+
|
| 306 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 307 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 308 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 309 |
+
|
| 310 |
+
if self.use_qk_norm:
|
| 311 |
+
q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1)
|
| 312 |
+
k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1)
|
| 313 |
+
q_flat = self.q_norm(q_flat)
|
| 314 |
+
k_flat = self.k_norm(k_flat)
|
| 315 |
+
query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 316 |
+
key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 317 |
+
|
| 318 |
+
if position_embeddings is None:
|
| 319 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 320 |
+
else:
|
| 321 |
+
cos, sin = position_embeddings
|
| 322 |
+
|
| 323 |
+
query_states, key_states = apply_rotary_pos_emb_partial(
|
| 324 |
+
query_states.transpose(1, 2), key_states.transpose(1, 2), cos, sin, self.rotary_dim
|
| 325 |
+
)
|
| 326 |
+
query_states = query_states.transpose(1, 2)
|
| 327 |
+
key_states = key_states.transpose(1, 2)
|
| 328 |
+
|
| 329 |
+
if past_key_values is not None:
|
| 330 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 331 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 332 |
+
|
| 333 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 334 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 335 |
+
|
| 336 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
|
| 337 |
+
if attention_mask is not None:
|
| 338 |
+
attn_weights = attn_weights + attention_mask
|
| 339 |
+
|
| 340 |
+
if self.sliding_window is not None and past_key_values is None:
|
| 341 |
+
query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1)
|
| 342 |
+
key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1)
|
| 343 |
+
window_mask = key_positions < (query_positions - self.sliding_window)
|
| 344 |
+
if window_mask.any():
|
| 345 |
+
attn_weights = attn_weights.masked_fill(window_mask, float("-inf"))
|
| 346 |
+
|
| 347 |
+
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 348 |
+
if self.training and self.attention_dropout > 0:
|
| 349 |
+
attn_weights = F.dropout(attn_weights, p=self.attention_dropout)
|
| 350 |
+
|
| 351 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 352 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
|
| 353 |
+
attn_output = self.o_proj(attn_output)
|
| 354 |
+
|
| 355 |
+
if not output_attentions:
|
| 356 |
+
attn_weights = None
|
| 357 |
+
return attn_output, attn_weights
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MiniMaxM2LogitsProcessor(nn.Module):
|
| 361 |
+
def __init__(self, config: MiniMaxM2Config) -> None:
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.scale = getattr(config, "logits_scale", 1.0)
|
| 364 |
+
|
| 365 |
+
def forward(self, lm_head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
logits = lm_head(hidden_states)
|
| 367 |
+
if self.scale != 1.0:
|
| 368 |
+
logits = logits * self.scale
|
| 369 |
+
return logits
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class MiniMaxM2DecoderLayer(nn.Module):
|
| 373 |
+
def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
|
| 374 |
+
super().__init__()
|
| 375 |
+
self.hidden_size = config.hidden_size
|
| 376 |
+
self.self_attn = MiniMaxM2Attention(config, layer_idx)
|
| 377 |
+
self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
|
| 378 |
+
self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 379 |
+
self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 380 |
+
|
| 381 |
+
def forward(
|
| 382 |
+
self,
|
| 383 |
+
hidden_states: torch.Tensor,
|
| 384 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 385 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 386 |
+
past_key_values: Optional[Cache] = None,
|
| 387 |
+
use_cache: Optional[bool] = False,
|
| 388 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 389 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 390 |
+
output_attentions: bool = False,
|
| 391 |
+
residual: Optional[torch.Tensor] = None,
|
| 392 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
|
| 393 |
+
residual_input = hidden_states if residual is None else residual
|
| 394 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 395 |
+
|
| 396 |
+
attn_output, attn_weights = self.self_attn(
|
| 397 |
+
hidden_states=hidden_states,
|
| 398 |
+
attention_mask=attention_mask,
|
| 399 |
+
position_ids=position_ids,
|
| 400 |
+
past_key_values=past_key_values,
|
| 401 |
+
use_cache=use_cache,
|
| 402 |
+
cache_position=cache_position,
|
| 403 |
+
position_embeddings=position_embeddings,
|
| 404 |
+
output_attentions=output_attentions,
|
| 405 |
+
)
|
| 406 |
+
hidden_states = residual_input + attn_output
|
| 407 |
+
|
| 408 |
+
residual_post_attn = hidden_states
|
| 409 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 410 |
+
moe_output, router_logits = self.block_sparse_moe(hidden_states)
|
| 411 |
+
hidden_states = residual_post_attn + moe_output
|
| 412 |
+
|
| 413 |
+
return hidden_states, hidden_states, router_logits, attn_weights
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class MiniMaxM2PreTrainedModel(PreTrainedModel):
|
| 417 |
+
config_class = MiniMaxM2Config
|
| 418 |
+
base_model_prefix = "model"
|
| 419 |
+
supports_gradient_checkpointing = True
|
| 420 |
+
_no_split_modules = ["MiniMaxM2DecoderLayer"]
|
| 421 |
+
_supports_flash_attn = False
|
| 422 |
+
_supports_sdpa = False
|
| 423 |
+
_supports_attention_backend = False
|
| 424 |
+
|
| 425 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 426 |
+
if isinstance(module, nn.Linear):
|
| 427 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 428 |
+
if module.bias is not None:
|
| 429 |
+
module.bias.data.zero_()
|
| 430 |
+
elif isinstance(module, nn.Embedding):
|
| 431 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 432 |
+
if module.padding_idx is not None:
|
| 433 |
+
module.weight.data[module.padding_idx].zero_()
|
| 434 |
+
|
| 435 |
+
def _remap_qkv_weights(self, state_dict):
|
| 436 |
+
num_q = self.config.num_attention_heads * self.config.head_dim
|
| 437 |
+
num_kv = self.config.num_key_value_heads * self.config.head_dim
|
| 438 |
+
|
| 439 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 440 |
+
prefix = f"model.layers.{layer_idx}.self_attn"
|
| 441 |
+
weight_key = f"{prefix}.qkv_proj.weight"
|
| 442 |
+
if weight_key in state_dict:
|
| 443 |
+
qkv_weight = state_dict.pop(weight_key)
|
| 444 |
+
q_weight, k_weight, v_weight = qkv_weight.split([num_q, num_kv, num_kv], dim=0)
|
| 445 |
+
state_dict.setdefault(f"{prefix}.q_proj.weight", q_weight)
|
| 446 |
+
state_dict.setdefault(f"{prefix}.k_proj.weight", k_weight)
|
| 447 |
+
state_dict.setdefault(f"{prefix}.v_proj.weight", v_weight)
|
| 448 |
+
|
| 449 |
+
def load_state_dict(self, state_dict, strict: bool = True):
|
| 450 |
+
if not isinstance(state_dict, dict):
|
| 451 |
+
raise TypeError(f"Expected state_dict to be dict, got {type(state_dict)}")
|
| 452 |
+
|
| 453 |
+
filtered_state_dict = {}
|
| 454 |
+
drop_suffixes = ("weight_scale_inv", "weight_scale", "input_scale", "scales", "amax")
|
| 455 |
+
for key, value in state_dict.items():
|
| 456 |
+
if key.endswith(drop_suffixes) or "fp8" in key:
|
| 457 |
+
continue
|
| 458 |
+
filtered_state_dict[key] = value
|
| 459 |
+
|
| 460 |
+
self._remap_qkv_weights(filtered_state_dict)
|
| 461 |
+
|
| 462 |
+
if logger.isEnabledFor(logging.INFO):
|
| 463 |
+
logger.info(
|
| 464 |
+
"MiniMaxM2: loading %d tensors (filtered from %d original).",
|
| 465 |
+
len(filtered_state_dict),
|
| 466 |
+
len(state_dict),
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
load_start = time.perf_counter()
|
| 470 |
+
result = super().load_state_dict(filtered_state_dict, strict=strict)
|
| 471 |
+
load_elapsed = time.perf_counter() - load_start
|
| 472 |
+
if logger.isEnabledFor(logging.INFO):
|
| 473 |
+
logger.info("MiniMaxM2: state_dict load finished in %.2f seconds.", load_elapsed)
|
| 474 |
+
|
| 475 |
+
return result
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
|
| 479 |
+
def __init__(self, config: MiniMaxM2Config) -> None:
|
| 480 |
+
super().__init__(config)
|
| 481 |
+
self.padding_idx = config.pad_token_id
|
| 482 |
+
self.vocab_size = config.vocab_size
|
| 483 |
+
|
| 484 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 485 |
+
self.layers = nn.ModuleList(
|
| 486 |
+
[MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 487 |
+
)
|
| 488 |
+
self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 489 |
+
self.gradient_checkpointing = False
|
| 490 |
+
|
| 491 |
+
self.post_init()
|
| 492 |
+
|
| 493 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 494 |
+
return self.embed_tokens
|
| 495 |
+
|
| 496 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 497 |
+
self.embed_tokens = value
|
| 498 |
+
|
| 499 |
+
def forward(
|
| 500 |
+
self,
|
| 501 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 502 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 503 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 504 |
+
past_key_values: Optional[Cache] = None,
|
| 505 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 506 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 507 |
+
use_cache: Optional[bool] = None,
|
| 508 |
+
output_attentions: bool = False,
|
| 509 |
+
output_hidden_states: bool = False,
|
| 510 |
+
output_router_logits: Optional[bool] = None,
|
| 511 |
+
return_dict: Optional[bool] = None,
|
| 512 |
+
) -> Union[MoeModelOutputWithPast, Tuple]:
|
| 513 |
+
if (input_ids is None) == (inputs_embeds is None):
|
| 514 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
|
| 515 |
+
|
| 516 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 517 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 518 |
+
output_router_logits = (
|
| 519 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if inputs_embeds is None:
|
| 523 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 524 |
+
|
| 525 |
+
if use_cache and past_key_values is None:
|
| 526 |
+
past_key_values = DynamicCache(config=self.config)
|
| 527 |
+
|
| 528 |
+
if cache_position is None:
|
| 529 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 530 |
+
cache_position = torch.arange(
|
| 531 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if position_ids is None:
|
| 535 |
+
position_ids = cache_position.unsqueeze(0)
|
| 536 |
+
|
| 537 |
+
if self.config.sliding_window is not None:
|
| 538 |
+
causal_mask = create_sliding_window_causal_mask(
|
| 539 |
+
config=self.config,
|
| 540 |
+
input_embeds=inputs_embeds,
|
| 541 |
+
attention_mask=attention_mask,
|
| 542 |
+
cache_position=cache_position,
|
| 543 |
+
past_key_values=past_key_values,
|
| 544 |
+
position_ids=position_ids,
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
causal_mask = create_causal_mask(
|
| 548 |
+
config=self.config,
|
| 549 |
+
input_embeds=inputs_embeds,
|
| 550 |
+
attention_mask=attention_mask,
|
| 551 |
+
cache_position=cache_position,
|
| 552 |
+
past_key_values=past_key_values,
|
| 553 |
+
position_ids=position_ids,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
hidden_states = inputs_embeds
|
| 557 |
+
|
| 558 |
+
all_hidden_states = () if output_hidden_states else None
|
| 559 |
+
all_attentions = () if output_attentions else None
|
| 560 |
+
all_router_logits = () if output_router_logits else None
|
| 561 |
+
|
| 562 |
+
residual = None
|
| 563 |
+
for decoder_layer in self.layers:
|
| 564 |
+
if output_hidden_states:
|
| 565 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 566 |
+
|
| 567 |
+
layer_outputs = decoder_layer(
|
| 568 |
+
hidden_states,
|
| 569 |
+
attention_mask=causal_mask,
|
| 570 |
+
position_ids=position_ids,
|
| 571 |
+
past_key_values=past_key_values,
|
| 572 |
+
use_cache=use_cache,
|
| 573 |
+
cache_position=cache_position,
|
| 574 |
+
position_embeddings=None,
|
| 575 |
+
output_attentions=output_attentions,
|
| 576 |
+
residual=residual,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
hidden_states, residual, router_logits, attn_weights = layer_outputs
|
| 580 |
+
|
| 581 |
+
if output_router_logits:
|
| 582 |
+
all_router_logits = all_router_logits + (router_logits,)
|
| 583 |
+
if output_attentions:
|
| 584 |
+
all_attentions = all_attentions + (attn_weights,)
|
| 585 |
+
|
| 586 |
+
hidden_states = self.norm(hidden_states)
|
| 587 |
+
|
| 588 |
+
if output_hidden_states:
|
| 589 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 590 |
+
|
| 591 |
+
if not return_dict:
|
| 592 |
+
outputs = (hidden_states, past_key_values)
|
| 593 |
+
if output_hidden_states:
|
| 594 |
+
outputs += (all_hidden_states,)
|
| 595 |
+
if output_attentions:
|
| 596 |
+
outputs += (all_attentions,)
|
| 597 |
+
if output_router_logits:
|
| 598 |
+
outputs += (all_router_logits,)
|
| 599 |
+
return outputs
|
| 600 |
+
|
| 601 |
+
return MoeModelOutputWithPast(
|
| 602 |
+
last_hidden_state=hidden_states,
|
| 603 |
+
past_key_values=past_key_values,
|
| 604 |
+
hidden_states=all_hidden_states,
|
| 605 |
+
attentions=all_attentions,
|
| 606 |
+
router_logits=all_router_logits,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
|
| 611 |
+
def __init__(self, config: MiniMaxM2Config) -> None:
|
| 612 |
+
super().__init__(config)
|
| 613 |
+
self.model = MiniMaxM2Model(config)
|
| 614 |
+
self.vocab_size = config.vocab_size
|
| 615 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 616 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 617 |
+
self.num_experts = config.num_local_experts
|
| 618 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 619 |
+
self.logits_processor = MiniMaxM2LogitsProcessor(config)
|
| 620 |
+
|
| 621 |
+
self.post_init()
|
| 622 |
+
|
| 623 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 624 |
+
return self.model.embed_tokens
|
| 625 |
+
|
| 626 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 627 |
+
self.model.embed_tokens = value
|
| 628 |
+
|
| 629 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 630 |
+
return self.lm_head
|
| 631 |
+
|
| 632 |
+
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
| 633 |
+
self.lm_head = new_embeddings
|
| 634 |
+
|
| 635 |
+
def prepare_inputs_for_generation(
|
| 636 |
+
self,
|
| 637 |
+
input_ids: torch.LongTensor,
|
| 638 |
+
past_key_values: Optional[Cache] = None,
|
| 639 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 640 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 641 |
+
**kwargs,
|
| 642 |
+
):
|
| 643 |
+
if past_key_values is not None:
|
| 644 |
+
input_ids = input_ids[:, -1:]
|
| 645 |
+
if attention_mask is not None:
|
| 646 |
+
attention_mask = attention_mask[:, -past_key_values.get_seq_length() - 1 :]
|
| 647 |
+
|
| 648 |
+
return {
|
| 649 |
+
"input_ids": input_ids,
|
| 650 |
+
"attention_mask": attention_mask,
|
| 651 |
+
"past_key_values": past_key_values,
|
| 652 |
+
"inputs_embeds": inputs_embeds,
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
def forward(
|
| 656 |
+
self,
|
| 657 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 658 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 659 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 660 |
+
past_key_values: Optional[Cache] = None,
|
| 661 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 662 |
+
labels: Optional[torch.LongTensor] = None,
|
| 663 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 664 |
+
use_cache: Optional[bool] = None,
|
| 665 |
+
output_attentions: bool = False,
|
| 666 |
+
output_hidden_states: bool = False,
|
| 667 |
+
output_router_logits: Optional[bool] = None,
|
| 668 |
+
return_dict: Optional[bool] = None,
|
| 669 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 670 |
+
) -> Union[MoeCausalLMOutputWithPast, Tuple]:
|
| 671 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 672 |
+
output_router_logits = (
|
| 673 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
model_outputs = self.model(
|
| 677 |
+
input_ids=input_ids,
|
| 678 |
+
attention_mask=attention_mask,
|
| 679 |
+
position_ids=position_ids,
|
| 680 |
+
past_key_values=past_key_values,
|
| 681 |
+
inputs_embeds=inputs_embeds,
|
| 682 |
+
cache_position=cache_position,
|
| 683 |
+
use_cache=use_cache,
|
| 684 |
+
output_attentions=output_attentions,
|
| 685 |
+
output_hidden_states=output_hidden_states,
|
| 686 |
+
output_router_logits=output_router_logits,
|
| 687 |
+
return_dict=True,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
hidden_states = model_outputs.last_hidden_state
|
| 691 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
|
| 692 |
+
logits = self.logits_processor(self.lm_head, hidden_states[:, slice_indices, :])
|
| 693 |
+
|
| 694 |
+
loss = None
|
| 695 |
+
if labels is not None:
|
| 696 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 697 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 698 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 699 |
+
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
|
| 700 |
+
|
| 701 |
+
aux_loss = None
|
| 702 |
+
if output_router_logits and model_outputs.router_logits is not None:
|
| 703 |
+
aux_loss = load_balancing_loss_func(
|
| 704 |
+
model_outputs.router_logits,
|
| 705 |
+
num_experts=self.num_experts,
|
| 706 |
+
top_k=self.num_experts_per_tok,
|
| 707 |
+
attention_mask=attention_mask,
|
| 708 |
+
)
|
| 709 |
+
if loss is not None:
|
| 710 |
+
loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device)
|
| 711 |
+
|
| 712 |
+
if not return_dict:
|
| 713 |
+
output = (logits,) + (model_outputs.past_key_values,)
|
| 714 |
+
if output_hidden_states:
|
| 715 |
+
output += (model_outputs.hidden_states,)
|
| 716 |
+
if output_attentions:
|
| 717 |
+
output += (model_outputs.attentions,)
|
| 718 |
+
if output_router_logits:
|
| 719 |
+
output += (model_outputs.router_logits,)
|
| 720 |
+
return ((loss,) + output) if loss is not None else output
|
| 721 |
+
|
| 722 |
+
return MoeCausalLMOutputWithPast(
|
| 723 |
+
loss=loss,
|
| 724 |
+
aux_loss=aux_loss,
|
| 725 |
+
logits=logits,
|
| 726 |
+
past_key_values=model_outputs.past_key_values,
|
| 727 |
+
hidden_states=model_outputs.hidden_states,
|
| 728 |
+
attentions=model_outputs.attentions,
|
| 729 |
+
router_logits=model_outputs.router_logits,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# -----------------------------------------------------------------------------
|
| 733 |
+
# Backward compatibility aliases
|
| 734 |
+
# -----------------------------------------------------------------------------
|
| 735 |
+
|
| 736 |
+
MiniMaxRMSNorm = MiniMaxM2RMSNorm
|
| 737 |
+
MiniMaxSparseMoeBlock = MiniMaxM2SparseMoeBlock
|
| 738 |
+
MiniMaxAttention = MiniMaxM2Attention
|
| 739 |
+
MiniMaxDecoderLayer = MiniMaxM2DecoderLayer
|
| 740 |
+
MiniMaxMLP = MiniMaxM2MLP
|
| 741 |
+
MiniMaxPreTrainedModel = MiniMaxM2PreTrainedModel
|
| 742 |
+
MiniMaxModel = MiniMaxM2Model
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class MiniMaxForCausalLM(MiniMaxM2ForCausalLM):
|
| 746 |
+
"""Alias for compatibility with checkpoints exporting MiniMaxForCausalLM."""
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
__all__ = [
|
| 750 |
+
"MiniMaxM2RMSNorm",
|
| 751 |
+
"MiniMaxM2SparseMoeBlock",
|
| 752 |
+
"MiniMaxM2Attention",
|
| 753 |
+
"MiniMaxM2DecoderLayer",
|
| 754 |
+
"MiniMaxM2Model",
|
| 755 |
+
"MiniMaxM2ForCausalLM",
|
| 756 |
+
"MiniMaxM2PreTrainedModel",
|
| 757 |
+
"MiniMaxRMSNorm",
|
| 758 |
+
"MiniMaxSparseMoeBlock",
|
| 759 |
+
"MiniMaxAttention",
|
| 760 |
+
"MiniMaxDecoderLayer",
|
| 761 |
+
"MiniMaxPreTrainedModel",
|
| 762 |
+
"MiniMaxModel",
|
| 763 |
+
"MiniMaxMLP",
|
| 764 |
+
"MiniMaxForCausalLM",
|
| 765 |
+
]
|
test_minimax_m2_hf.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
|
| 2 |
+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
MiniMax-M2 Hugging Face checkpoint sanity check with streaming output.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python test_minimax_m2_hf.py \
|
| 11 |
+
--model-path /monster/data/model/MiniMax-M2-bf16 \
|
| 12 |
+
--question "How many letter A are there in the word Alphabet? Reply with the number only."
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import threading
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 23 |
+
|
| 24 |
+
from gptqmodel.hf_minimax_m2.modeling_minimax_m2 import (
|
| 25 |
+
MiniMaxAttention,
|
| 26 |
+
MiniMaxDecoderLayer,
|
| 27 |
+
MiniMaxForCausalLM,
|
| 28 |
+
MiniMaxMLP,
|
| 29 |
+
MiniMaxM2Attention,
|
| 30 |
+
MiniMaxM2DecoderLayer,
|
| 31 |
+
MiniMaxM2ForCausalLM,
|
| 32 |
+
MiniMaxM2MLP,
|
| 33 |
+
MiniMaxM2RMSNorm,
|
| 34 |
+
MiniMaxM2SparseMoeBlock,
|
| 35 |
+
MiniMaxRMSNorm,
|
| 36 |
+
MiniMaxSparseMoeBlock,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def parse_args() -> argparse.Namespace:
|
| 41 |
+
parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.")
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--model-path",
|
| 44 |
+
type=str,
|
| 45 |
+
default="/monster/data/model/MiniMax-M2-bf16",
|
| 46 |
+
help="Path to the MiniMax-M2 Hugging Face checkpoint directory.",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--question",
|
| 50 |
+
type=str,
|
| 51 |
+
default="How many letter A are there in the word Alphabet? Reply with the number only.",
|
| 52 |
+
help="User question to send through the chat template.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--max-new-tokens",
|
| 56 |
+
type=int,
|
| 57 |
+
default=512,
|
| 58 |
+
help="Maximum number of new tokens to sample from the model.",
|
| 59 |
+
)
|
| 60 |
+
return parser.parse_args()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
|
| 64 |
+
messages = [
|
| 65 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 66 |
+
{"role": "user", "content": question},
|
| 67 |
+
]
|
| 68 |
+
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def assert_module_types(model: MiniMaxM2ForCausalLM) -> None:
|
| 72 |
+
causal_lm_types = (MiniMaxM2ForCausalLM, MiniMaxForCausalLM)
|
| 73 |
+
decoder_layer_types = (MiniMaxM2DecoderLayer, MiniMaxDecoderLayer)
|
| 74 |
+
attention_types = (MiniMaxM2Attention, MiniMaxAttention)
|
| 75 |
+
moe_block_types = (MiniMaxM2SparseMoeBlock, MiniMaxSparseMoeBlock)
|
| 76 |
+
norm_types = (MiniMaxM2RMSNorm, MiniMaxRMSNorm)
|
| 77 |
+
mlp_types = (MiniMaxM2MLP, MiniMaxMLP)
|
| 78 |
+
|
| 79 |
+
assert isinstance(
|
| 80 |
+
model, causal_lm_types
|
| 81 |
+
), f"Expected MiniMaxM2ForCausalLM/MiniMaxForCausalLM, received {type(model).__name__}"
|
| 82 |
+
|
| 83 |
+
decoder = getattr(model, "model", None)
|
| 84 |
+
assert decoder is not None, "Model is missing the `model` attribute with decoder layers."
|
| 85 |
+
|
| 86 |
+
for layer_idx, layer in enumerate(decoder.layers):
|
| 87 |
+
assert isinstance(
|
| 88 |
+
layer, decoder_layer_types
|
| 89 |
+
), f"Layer {layer_idx}: expected MiniMax(M2)DecoderLayer, got {type(layer).__name__}"
|
| 90 |
+
assert isinstance(
|
| 91 |
+
layer.self_attn, attention_types
|
| 92 |
+
), f"Layer {layer_idx}: unexpected self_attn type {type(layer.self_attn).__name__}"
|
| 93 |
+
assert isinstance(
|
| 94 |
+
layer.block_sparse_moe, moe_block_types
|
| 95 |
+
), f"Layer {layer_idx}: unexpected MoE block type {type(layer.block_sparse_moe).__name__}"
|
| 96 |
+
assert isinstance(
|
| 97 |
+
layer.input_layernorm, norm_types
|
| 98 |
+
), f"Layer {layer_idx}: unexpected input_layernorm type {type(layer.input_layernorm).__name__}"
|
| 99 |
+
assert isinstance(
|
| 100 |
+
layer.post_attention_layernorm, norm_types
|
| 101 |
+
), f"Layer {layer_idx}: unexpected post_attention_layernorm type {type(layer.post_attention_layernorm).__name__}"
|
| 102 |
+
|
| 103 |
+
moe_block = layer.block_sparse_moe
|
| 104 |
+
assert isinstance(
|
| 105 |
+
moe_block.experts, nn.ModuleList
|
| 106 |
+
), f"Layer {layer_idx}: expected experts to be a ModuleList, got {type(moe_block.experts).__name__}"
|
| 107 |
+
for expert_idx, expert in enumerate(moe_block.experts):
|
| 108 |
+
assert isinstance(
|
| 109 |
+
expert, mlp_types
|
| 110 |
+
), f"Layer {layer_idx} expert {expert_idx}: expected MiniMax(M2)MLP, got {type(expert).__name__}"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main() -> None:
|
| 114 |
+
args = parse_args()
|
| 115 |
+
model_path = Path(args.model_path).expanduser().resolve()
|
| 116 |
+
|
| 117 |
+
print(f"Loading tokenizer from {model_path}...")
|
| 118 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 119 |
+
|
| 120 |
+
print(f"Loading model from {model_path}...")
|
| 121 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 122 |
+
model_path,
|
| 123 |
+
dtype="bfloat16",
|
| 124 |
+
device_map="auto",
|
| 125 |
+
trust_remote_code=True,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Uncomment to enforce module type checks.
|
| 129 |
+
# print("Validating module types...")
|
| 130 |
+
# assert_module_types(model)
|
| 131 |
+
|
| 132 |
+
prompt = build_prompt(tokenizer, args.question)
|
| 133 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 134 |
+
|
| 135 |
+
print("Running generation (streaming)...\n")
|
| 136 |
+
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False)
|
| 137 |
+
eos_ids = model.generation_config.eos_token_id
|
| 138 |
+
if eos_ids is None:
|
| 139 |
+
eos_ids = []
|
| 140 |
+
elif isinstance(eos_ids, int):
|
| 141 |
+
eos_ids = [eos_ids]
|
| 142 |
+
think_end_id = tokenizer.convert_tokens_to_ids("</think>")
|
| 143 |
+
if think_end_id is not None and think_end_id not in eos_ids:
|
| 144 |
+
eos_ids = eos_ids + [think_end_id]
|
| 145 |
+
|
| 146 |
+
generation_kwargs = dict(
|
| 147 |
+
**inputs,
|
| 148 |
+
max_new_tokens=args.max_new_tokens,
|
| 149 |
+
streamer=streamer,
|
| 150 |
+
eos_token_id=eos_ids if eos_ids else None,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
| 154 |
+
generation_thread.start()
|
| 155 |
+
|
| 156 |
+
completion = []
|
| 157 |
+
first_chunk = True
|
| 158 |
+
seen_end_reasoning = False
|
| 159 |
+
for text in streamer:
|
| 160 |
+
if first_chunk:
|
| 161 |
+
print("<think>", end="", flush=True)
|
| 162 |
+
completion.append("<think>")
|
| 163 |
+
first_chunk = False
|
| 164 |
+
print(text, end="", flush=True)
|
| 165 |
+
completion.append(text)
|
| 166 |
+
if "</think>" in text:
|
| 167 |
+
seen_end_reasoning = True
|
| 168 |
+
|
| 169 |
+
generation_thread.join()
|
| 170 |
+
print("\n\n=== Completed Response ===")
|
| 171 |
+
final_text = "".join(completion).strip()
|
| 172 |
+
print(final_text or "<empty response>")
|
| 173 |
+
if not seen_end_reasoning:
|
| 174 |
+
print("\n[warning] No </think> token detected in streamed output.", flush=True)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|