Text Generation
Transformers
Safetensors
minimax_m2
conversational
custom_code
fp8
rogeryoungh commited on
Commit
78ba357
·
verified ·
1 Parent(s): c2b7e11

Prepare support transformers (#24)

Browse files

- update: prepare for transformers (6d30eb0302d6f70383e70ab9f6dc60466c542eeb)
- update: add modeling_minimax_m2.py (35c5c796737083a6db5becdc0d4a4ffbcd1192db)
- update: transformers docs (30a4a95ba4b15478b7ad7834cbe99db88330b79e)
- update: fix generation_config in docs (8f96eee3cf95d2565e1274f85771639a2e3272b1)
- fix: import Unpack from processing_utils (3df78b4d6ccbef8925c7a9c5e345298b72bf1ae8)

README.md CHANGED
@@ -179,6 +179,9 @@ We recommend using [vLLM](https://docs.vllm.ai/en/stable/) to serve MiniMax-M2.
179
 
180
  We recommend using [MLX-LM](https://github.com/ml-explore/mlx-lm) to serve MiniMax-M2. Please refer to our [MLX Deployment Guide](https://huggingface.co/MiniMaxAI/MiniMax-M2/blob/main/docs/mlx_deploy_guide.md) for more details.
181
 
 
 
 
182
 
183
  ### Inference Parameters
184
  We recommend using the following parameters for best performance: `temperature=1.0`, `top_p = 0.95`, `top_k = 40`.
 
179
 
180
  We recommend using [MLX-LM](https://github.com/ml-explore/mlx-lm) to serve MiniMax-M2. Please refer to our [MLX Deployment Guide](https://huggingface.co/MiniMaxAI/MiniMax-M2/blob/main/docs/mlx_deploy_guide.md) for more details.
181
 
182
+ ### Transformers
183
+
184
+ We recommend using [Transformers](https://github.com/huggingface/transformers) to serve MiniMax-M2. Please refer to our [Transformers Deployment Guide](https://huggingface.co/MiniMaxAI/MiniMax-M2/blob/main/docs/transformers_deploy_guide.md) for more details.
185
 
186
  ### Inference Parameters
187
  We recommend using the following parameters for best performance: `temperature=1.0`, `top_p = 0.95`, `top_k = 40`.
config.json CHANGED
@@ -67,6 +67,10 @@
67
  1,
68
  1
69
  ],
 
 
 
 
70
  "bos_token_id": null,
71
  "eos_token_id": null,
72
  "head_dim": 128,
@@ -79,7 +83,7 @@
79
  "layernorm_mlp_beta": 1.0,
80
  "max_position_embeddings": 196608,
81
  "mlp_intermediate_size": 8192,
82
- "model_type": "minimax",
83
  "mtp_transformer_layers": 1,
84
  "num_attention_heads": 48,
85
  "num_experts_per_tok": 8,
@@ -96,6 +100,11 @@
96
  "weight_block_size": [
97
  128,
98
  128
 
 
 
 
 
99
  ]
100
  },
101
  "rms_norm_eps": 1e-06,
@@ -108,10 +117,10 @@
108
  "shared_moe_mode": "sigmoid",
109
  "sliding_window": null,
110
  "tie_word_embeddings": false,
111
- "transformers_version": "4.46.1",
112
  "use_cache": true,
113
  "use_mtp": true,
114
  "use_qk_norm": true,
115
  "use_routing_bias": true,
116
  "vocab_size": 200064
117
- }
 
67
  1,
68
  1
69
  ],
70
+ "auto_map": {
71
+ "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
72
+ "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
73
+ },
74
  "bos_token_id": null,
75
  "eos_token_id": null,
76
  "head_dim": 128,
 
83
  "layernorm_mlp_beta": 1.0,
84
  "max_position_embeddings": 196608,
85
  "mlp_intermediate_size": 8192,
86
+ "model_type": "minimax_m2",
87
  "mtp_transformer_layers": 1,
88
  "num_attention_heads": 48,
89
  "num_experts_per_tok": 8,
 
100
  "weight_block_size": [
101
  128,
102
  128
103
+ ],
104
+ "modules_to_not_convert": [
105
+ "gate",
106
+ "e_score_correction_bias",
107
+ "lm_head"
108
  ]
109
  },
110
  "rms_norm_eps": 1e-06,
 
117
  "shared_moe_mode": "sigmoid",
118
  "sliding_window": null,
119
  "tie_word_embeddings": false,
120
+ "transformers_version": "4.57.1",
121
  "use_cache": true,
122
  "use_mtp": true,
123
  "use_qk_norm": true,
124
  "use_routing_bias": true,
125
  "vocab_size": 200064
126
+ }
configuration_minimax_m2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_minimax_m2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from transformers.configuration_utils import PretrainedConfig
24
+
25
+
26
+ class MiniMaxM2Config(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an
29
+ MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
30
+ with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1.
31
+
32
+ [minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B)
33
+ [minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1)
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32000):
41
+ Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`MiniMaxM2Model`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 14336):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ num_key_value_heads (`int`, *optional*, defaults to 8):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details, check out [this
57
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
58
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
59
+ The attention head dimension.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
61
+ The non-linear activation function (function or string) in the decoder.
62
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
63
+ The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention
64
+ allows sequence of up to 4096*32 tokens.
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
68
+ The epsilon used by the rms normalization layers.
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
71
+ relevant if `config.is_decoder=True`.
72
+ pad_token_id (`int`, *optional*):
73
+ The id of the padding token.
74
+ bos_token_id (`int`, *optional*, defaults to 1):
75
+ The id of the "beginning-of-sequence" token.
76
+ eos_token_id (`int`, *optional*, defaults to 2):
77
+ The id of the "end-of-sequence" token.
78
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
79
+ Whether the model's input and output word embeddings should be tied.
80
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
81
+ The base period of the RoPE embeddings.
82
+ sliding_window (`int`, *optional*):
83
+ Sliding window attention window size. If not specified, will default to `4096`.
84
+ attention_dropout (`float`, *optional*, defaults to 0.0):
85
+ The dropout ratio for the attention probabilities.
86
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
87
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
88
+ parameter
89
+ num_local_experts (`int`, *optional*, defaults to 8):
90
+ Number of experts per Sparse MLP layer.
91
+ output_router_logits (`bool`, *optional*, defaults to `False`):
92
+ Whether or not the router logits should be returned by the model. Enabling this will also
93
+ allow the model to output the auxiliary loss. See [here]() for more details
94
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
95
+ The aux loss factor for the total loss.
96
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
97
+ Amount of noise to add to the router.
98
+
99
+ ```python
100
+ >>> from transformers import MiniMaxM2Model, MiniMaxM2Config
101
+
102
+ >>> # Initializing a MiniMaxM2 7B style configuration
103
+ >>> configuration = MiniMaxM2Config()
104
+
105
+ >>> # Initializing a model from the MiniMaxM2 7B style configuration
106
+ >>> model = MiniMaxM2Model(configuration)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> configuration = model.config
110
+ ```"""
111
+
112
+ model_type = "minimax_m2"
113
+ keys_to_ignore_at_inference = ["past_key_values"]
114
+ base_model_tp_plan = {
115
+ "layers.*.self_attn.q_proj": "colwise",
116
+ "layers.*.self_attn.k_proj": "colwise",
117
+ "layers.*.self_attn.v_proj": "colwise",
118
+ "layers.*.self_attn.o_proj": "rowwise",
119
+ "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
120
+ "layers.*.block_sparse_moe.experts.*.w1": "colwise",
121
+ "layers.*.block_sparse_moe.experts.*.w2": "rowwise",
122
+ "layers.*.block_sparse_moe.experts.*.w3": "colwise",
123
+ }
124
+ base_model_pp_plan = {
125
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
126
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
127
+ "norm": (["hidden_states"], ["hidden_states"]),
128
+ }
129
+
130
+ def __init__(
131
+ self,
132
+ vocab_size=32000,
133
+ hidden_size=4096,
134
+ intermediate_size=14336,
135
+ num_hidden_layers=32,
136
+ num_attention_heads=32,
137
+ num_key_value_heads=8,
138
+ head_dim=None,
139
+ hidden_act="silu",
140
+ max_position_embeddings=4096 * 32,
141
+ initializer_range=0.02,
142
+ rms_norm_eps=1e-5,
143
+ use_cache=True,
144
+ pad_token_id=None,
145
+ bos_token_id=1,
146
+ eos_token_id=2,
147
+ tie_word_embeddings=False,
148
+ rope_theta=1e6,
149
+ sliding_window=None,
150
+ attention_dropout=0.0,
151
+ num_experts_per_tok=2,
152
+ num_local_experts=8,
153
+ output_router_logits=False,
154
+ router_aux_loss_coef=0.001,
155
+ router_jitter_noise=0.0,
156
+ **kwargs,
157
+ ):
158
+ self.vocab_size = vocab_size
159
+ self.max_position_embeddings = max_position_embeddings
160
+ self.hidden_size = hidden_size
161
+ self.intermediate_size = intermediate_size
162
+ self.num_hidden_layers = num_hidden_layers
163
+ self.num_attention_heads = num_attention_heads
164
+ self.sliding_window = sliding_window
165
+
166
+ # for backward compatibility
167
+ if num_key_value_heads is None:
168
+ num_key_value_heads = num_attention_heads
169
+
170
+ self.num_key_value_heads = num_key_value_heads
171
+ self.hidden_act = hidden_act
172
+ self.initializer_range = initializer_range
173
+ self.rms_norm_eps = rms_norm_eps
174
+ self.use_cache = use_cache
175
+ self.rope_theta = rope_theta
176
+ self.attention_dropout = attention_dropout
177
+ self.head_dim = head_dim
178
+
179
+ self.num_experts_per_tok = num_experts_per_tok
180
+ self.num_local_experts = num_local_experts
181
+ self.output_router_logits = output_router_logits
182
+ self.router_aux_loss_coef = router_aux_loss_coef
183
+ self.router_jitter_noise = router_jitter_noise
184
+
185
+ self.use_qk_norm = kwargs.pop("use_qk_norm", False)
186
+ self.rotary_dim = kwargs.pop("rotary_dim", self.head_dim)
187
+ self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1)
188
+ if self.head_dim is not None:
189
+ self.partial_rotary_factor = self.rotary_dim / self.head_dim
190
+
191
+ super().__init__(
192
+ pad_token_id=pad_token_id,
193
+ bos_token_id=bos_token_id,
194
+ eos_token_id=eos_token_id,
195
+ tie_word_embeddings=tie_word_embeddings,
196
+ **kwargs,
197
+ )
198
+
199
+
200
+ __all__ = ["MiniMaxM2Config"]
docs/transformers_deploy_guide.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniMax M2 Model Transformers Deployment Guide
2
+
3
+ [English Version](./tramsformers_deploy_guide.md) | [Chinese Version](./tramsformers_deploy_guide_cn.md)
4
+
5
+ ## Applicable Models
6
+
7
+ This document applies to the following models. You only need to change the model name during deployment.
8
+
9
+ - [MiniMaxAI/MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2)
10
+
11
+ The deployment process is illustrated below using MiniMax-M2 as an example.
12
+
13
+ ## System Requirements
14
+
15
+ - OS: Linux
16
+
17
+ - Python: 3.9 - 3.12
18
+
19
+ - Transformers: 4.57.1
20
+
21
+ - GPU:
22
+
23
+ - compute capability 7.0 or higher
24
+
25
+ - Memory requirements: 220 GB for weights.
26
+
27
+ ## Deployment with Python
28
+
29
+ It is recommended to use a virtual environment (such as **venv**, **conda**, or **uv**) to avoid dependency conflicts.
30
+
31
+ We recommend installing Transformers in a fresh Python environment:
32
+
33
+ ```bash
34
+ uv pip install transformers torch accelerate --torch-backend=auto
35
+ ```
36
+
37
+ Run the following Python script to run the model. Transformers will automatically download and cache the MiniMax-M2 model from Hugging Face.
38
+
39
+ ```python
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
41
+ import torch
42
+
43
+ MODEL_PATH = "MiniMaxAI/MiniMax-M2"
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ MODEL_PATH,
47
+ device_map="auto",
48
+ trust_remote_code=True,
49
+ )
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
51
+
52
+ messages = [
53
+ {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]},
54
+ {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]},
55
+ {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}
56
+ ]
57
+
58
+ model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
59
+
60
+ generated_ids = model.generate(model_inputs, max_new_tokens=100, generation_config=model.generation_config)
61
+
62
+ response = tokenizer.batch_decode(generated_ids)[0]
63
+
64
+ print(response)
65
+ ```
66
+
67
+ ## Common Issues
68
+
69
+ ### Hugging Face Network Issues
70
+
71
+ If you encounter network issues, you can set up a proxy before pulling the model.
72
+
73
+ ```bash
74
+ export HF_ENDPOINT=https://hf-mirror.com
75
+ ```
76
+
77
+ ### MiniMax-M2 model is not currently supported
78
+
79
+ Please check that trust_remote_code=True.
80
+
81
+ ## Getting Support
82
+
83
+ If you encounter any issues while deploying the MiniMax model:
84
+
85
+ - Contact our technical support team through official channels such as email at [model@minimax.io](mailto:model@minimax.io)
86
+
87
+ - Submit an issue on our [GitHub](https://github.com/MiniMax-AI) repository
88
+
89
+ We continuously optimize the deployment experience for our models. Feedback is welcome!
90
+
docs/transformers_deploy_guide_cn.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniMax M2 模型 Transformers 部署指南
2
+
3
+ [英文版](./transformers_deploy_guide.md) | [中文版](./transformers_deploy_guide_cn.md)
4
+
5
+ ## 本文档适用模型
6
+
7
+ 本文档适用以下模型,只需在部署时修改模型名称即可。
8
+
9
+ - [MiniMaxAI/MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2)
10
+
11
+ 以下以 MiniMax-M2 为例说明部署流程。
12
+
13
+ ## 环境要求
14
+
15
+ - OS:Linux
16
+
17
+ - Python:3.9 - 3.12
18
+
19
+ - Transformers: 4.57.1
20
+
21
+ - GPU:
22
+
23
+ - compute capability 7.0 or higher
24
+
25
+ - 显存需求:权重需要 220 GB
26
+
27
+ ## 使用 Python 部署
28
+
29
+ 建议使用虚拟环境(如 **venv**、**conda**、**uv**)以避免依赖冲突。
30
+
31
+ 建议在全新的 Python 环境中安装 Transformers:
32
+
33
+ ```bash
34
+ uv pip install transformers torch accelerate --torch-backend=auto
35
+ ```
36
+
37
+ 运行如下 Python 命令运行模型,Transformers 会自动从 Huggingface 下载并缓存 MiniMax-M2 模型。
38
+
39
+ ```python
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
41
+ import torch
42
+
43
+ MODEL_PATH = "MiniMaxAI/MiniMax-M2"
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ MODEL_PATH,
47
+ device_map="auto",
48
+ trust_remote_code=True,
49
+ )
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
51
+
52
+ messages = [
53
+ {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]},
54
+ {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]},
55
+ {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}
56
+ ]
57
+
58
+ model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
59
+
60
+ generated_ids = model.generate(model_inputs, max_new_tokens=100, generation_config=model.generation_config)
61
+
62
+ response = tokenizer.batch_decode(generated_ids)[0]
63
+
64
+ print(response)
65
+ ```
66
+
67
+ ## 常见问题
68
+
69
+ ### Huggingface 网络问题
70
+
71
+ 如果遇到网络问题,可以设置代理后再进行拉取。
72
+
73
+ ```bash
74
+ export HF_ENDPOINT=https://hf-mirror.com
75
+ ```
76
+
77
+ ### MiniMax-M2 model is not currently supported
78
+
79
+ 请确认开启 trust_remote_code=True。
80
+
81
+ ## 获取支持
82
+
83
+ 如果在部署 MiniMax 模型过程中遇到任何问题:
84
+
85
+ - 通过邮箱 [model@minimax.io](mailto:model@minimax.io) 等官方渠道联系我们的技术支持团队
86
+
87
+ - 在我们的 [GitHub](https://github.com/MiniMax-AI) 仓库提交 Issue
88
+
89
+ - 通过我们的 [官方企业微信交流群](https://github.com/MiniMax-AI/MiniMax-AI.github.io/blob/main/images/wechat-qrcode.jpeg) 反馈
90
+
91
+ 我们会持续优化模型的部署体验,欢迎反馈!
generation_config.json CHANGED
@@ -1,5 +1,7 @@
1
  {
 
2
  "do_sample": true,
 
3
  "temperature": 1.0,
4
  "top_p": 0.95,
5
  "top_k": 40,
 
1
  {
2
+ "bos_token_id": 200019,
3
  "do_sample": true,
4
+ "eos_token_id": 200020,
5
  "temperature": 1.0,
6
  "top_p": 0.95,
7
  "top_k": 40,
modeling_minimax_m2.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_minimax_m2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from collections.abc import Callable
24
+ from typing import Optional, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import (
36
+ GenericForQuestionAnswering,
37
+ GenericForSequenceClassification,
38
+ GenericForTokenClassification,
39
+ GradientCheckpointingLayer,
40
+ )
41
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
42
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from ...processing_utils import Unpack
45
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
46
+ from transformers.utils.deprecation import deprecate_kwarg
47
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
48
+ from .configuration_minimax_m2 import MiniMaxM2Config
49
+
50
+
51
+ class MiniMaxM2MLP(nn.Module):
52
+ def __init__(self, config: MiniMaxM2Config):
53
+ super().__init__()
54
+ self.ffn_dim = config.intermediate_size
55
+ self.hidden_dim = config.hidden_size
56
+
57
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
58
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
59
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
60
+
61
+ self.act_fn = ACT2FN[config.hidden_act]
62
+
63
+ def forward(self, hidden_states):
64
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
65
+ current_hidden_states = self.w2(current_hidden_states)
66
+ return current_hidden_states
67
+
68
+
69
+ class MiniMaxM2Experts(nn.ModuleList):
70
+ """
71
+ ModuleList of experts.
72
+ """
73
+
74
+ def __init__(self, config: MiniMaxM2Config):
75
+ super().__init__()
76
+ self.top_k = config.num_experts_per_tok
77
+ self.num_experts = config.num_local_experts
78
+ for _ in range(self.num_experts):
79
+ self.append(MiniMaxM2MLP(config))
80
+
81
+ def forward(
82
+ self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
83
+ ) -> torch.Tensor:
84
+ """
85
+ Args:
86
+ hidden_states: (batch_size * sequence_length, hidden_dim)
87
+ selected_experts: (batch_size * sequence_length, top_k)
88
+ routing_weights: (batch_size * sequence_length, top_k)
89
+ Returns:
90
+ (batch_size * sequence_length, hidden_dim)
91
+ """
92
+ final_hidden_states = torch.zeros_like(hidden_states)
93
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
94
+
95
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
96
+ for expert_idx in expert_hit:
97
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
98
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
99
+ current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
100
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
101
+ return final_hidden_states
102
+
103
+
104
+ class MiniMaxM2SparseMoeBlock(nn.Module):
105
+ def __init__(self, config):
106
+ super().__init__()
107
+ self.top_k = config.num_experts_per_tok
108
+ self.jitter_noise = config.router_jitter_noise
109
+ self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
110
+ self.experts = MiniMaxM2Experts(config)
111
+ self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))
112
+
113
+ def route_tokens_to_experts(self, router_logits):
114
+ routing_weights = torch.nn.functional.sigmoid(router_logits.float())
115
+ scores_for_choice = routing_weights + self.e_score_correction_bias
116
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
117
+ top_k_weights = routing_weights.gather(1, top_k_index)
118
+ top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
119
+ return top_k_index, top_k_weights.to(router_logits.dtype)
120
+
121
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
122
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
123
+ if self.training and self.jitter_noise > 0:
124
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
125
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
126
+ router_logits = self.gate(hidden_states)
127
+ top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
128
+ hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
129
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
130
+ return hidden_states, router_logits
131
+
132
+
133
+ @use_kernel_forward_from_hub("RMSNorm")
134
+ class MiniMaxM2RMSNorm(nn.Module):
135
+ def __init__(self, hidden_size, eps=1e-6):
136
+ """
137
+ MiniMaxM2RMSNorm is equivalent to T5LayerNorm
138
+ """
139
+ super().__init__()
140
+ self.weight = nn.Parameter(torch.ones(hidden_size))
141
+ self.variance_epsilon = eps
142
+
143
+ def forward(self, hidden_states):
144
+ input_dtype = hidden_states.dtype
145
+ hidden_states = hidden_states.to(torch.float32)
146
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
147
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
148
+ return self.weight * hidden_states.to(input_dtype)
149
+
150
+ def extra_repr(self):
151
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
152
+
153
+
154
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
155
+ """
156
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
157
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
158
+ """
159
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
160
+ if n_rep == 1:
161
+ return hidden_states
162
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
163
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
164
+
165
+
166
+ def eager_attention_forward(
167
+ module: nn.Module,
168
+ query: torch.Tensor,
169
+ key: torch.Tensor,
170
+ value: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor],
172
+ scaling: float,
173
+ dropout: float = 0.0,
174
+ **kwargs: Unpack[TransformersKwargs],
175
+ ):
176
+ key_states = repeat_kv(key, module.num_key_value_groups)
177
+ value_states = repeat_kv(value, module.num_key_value_groups)
178
+
179
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
180
+ if attention_mask is not None:
181
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
182
+ attn_weights = attn_weights + causal_mask
183
+
184
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
185
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
186
+ attn_output = torch.matmul(attn_weights, value_states)
187
+ attn_output = attn_output.transpose(1, 2).contiguous()
188
+
189
+ return attn_output, attn_weights
190
+
191
+
192
+ def rotate_half(x):
193
+ """Rotates half the hidden dims of the input."""
194
+ x1 = x[..., : x.shape[-1] // 2]
195
+ x2 = x[..., x.shape[-1] // 2 :]
196
+ return torch.cat((-x2, x1), dim=-1)
197
+
198
+
199
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
200
+ """Applies Rotary Position Embedding to the query and key tensors.
201
+
202
+ Args:
203
+ q (`torch.Tensor`): The query tensor.
204
+ k (`torch.Tensor`): The key tensor.
205
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
206
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
207
+ position_ids (`torch.Tensor`, *optional*):
208
+ Deprecated and unused.
209
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
210
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
211
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
212
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
213
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
214
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
215
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
216
+ Returns:
217
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
218
+ """
219
+ cos = cos.unsqueeze(unsqueeze_dim)
220
+ sin = sin.unsqueeze(unsqueeze_dim)
221
+
222
+ # Keep half or full tensor for later concatenation
223
+ rotary_dim = cos.shape[-1]
224
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
225
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
226
+
227
+ # Apply rotary embeddings on the first half or full tensor
228
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
229
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
230
+
231
+ # Concatenate back to full shape
232
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
233
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
234
+ return q_embed, k_embed
235
+
236
+
237
+ class MiniMaxM2Attention(nn.Module):
238
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
239
+
240
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int):
241
+ super().__init__()
242
+ self.config = config
243
+ self.layer_idx = layer_idx
244
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
245
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
246
+ self.scaling = self.head_dim**-0.5
247
+ self.attention_dropout = config.attention_dropout
248
+ self.is_causal = True
249
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
253
+
254
+ self.use_qk_norm = config.use_qk_norm
255
+ if self.use_qk_norm:
256
+ self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps)
257
+ self.k_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_key_value_heads, eps=config.rms_norm_eps)
258
+
259
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
264
+ attention_mask: Optional[torch.Tensor],
265
+ past_key_values: Optional[Cache] = None,
266
+ cache_position: Optional[torch.LongTensor] = None,
267
+ **kwargs: Unpack[FlashAttentionKwargs],
268
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
269
+ input_shape = hidden_states.shape[:-1]
270
+ hidden_shape = (*input_shape, -1, self.head_dim)
271
+
272
+ query_states = self.q_proj(hidden_states)
273
+ key_states = self.k_proj(hidden_states)
274
+ value_states = self.v_proj(hidden_states)
275
+
276
+ if self.use_qk_norm: # main diff from Llama
277
+ query_states = self.q_norm(query_states)
278
+ key_states = self.k_norm(key_states)
279
+
280
+ key_states = key_states.view(hidden_shape)
281
+ query_states = query_states.view(hidden_shape)
282
+ value_states = value_states.view(hidden_shape)
283
+
284
+ query_states = query_states.transpose(1, 2)
285
+ key_states = key_states.transpose(1, 2)
286
+ value_states = value_states.transpose(1, 2)
287
+
288
+ cos, sin = position_embeddings
289
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
290
+
291
+ if past_key_values is not None:
292
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
293
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
294
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
295
+
296
+ attention_interface: Callable = eager_attention_forward
297
+ if self.config._attn_implementation != "eager":
298
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
299
+
300
+ attn_output, attn_weights = attention_interface(
301
+ self,
302
+ query_states,
303
+ key_states,
304
+ value_states,
305
+ attention_mask,
306
+ dropout=0.0 if not self.training else self.attention_dropout,
307
+ scaling=self.scaling,
308
+ **kwargs,
309
+ )
310
+
311
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
312
+ attn_output = self.o_proj(attn_output)
313
+ return attn_output, attn_weights
314
+
315
+
316
+ class MiniMaxM2DecoderLayer(GradientCheckpointingLayer):
317
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int):
318
+ super().__init__()
319
+ self.hidden_size = config.hidden_size
320
+
321
+ self.self_attn = MiniMaxM2Attention(config, layer_idx)
322
+
323
+ self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
324
+ self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
325
+ self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
326
+
327
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ position_ids: Optional[torch.LongTensor] = None,
334
+ past_key_values: Optional[Cache] = None,
335
+ cache_position: Optional[torch.LongTensor] = None,
336
+ **kwargs: Unpack[TransformersKwargs],
337
+ ) -> torch.FloatTensor:
338
+ residual = hidden_states
339
+
340
+ hidden_states = self.input_layernorm(hidden_states)
341
+
342
+ # Self Attention
343
+ hidden_states, _ = self.self_attn(
344
+ hidden_states=hidden_states,
345
+ position_embeddings=position_embeddings,
346
+ attention_mask=attention_mask,
347
+ position_ids=position_ids,
348
+ past_key_values=past_key_values,
349
+ cache_position=cache_position,
350
+ **kwargs,
351
+ )
352
+ hidden_states = residual + hidden_states
353
+
354
+ # Fully Connected
355
+ residual = hidden_states
356
+ hidden_states = self.post_attention_layernorm(hidden_states)
357
+ hidden_states, _ = self.block_sparse_moe(hidden_states)
358
+ hidden_states = residual + hidden_states
359
+
360
+ return hidden_states
361
+
362
+
363
+ class MiniMaxM2RotaryEmbedding(nn.Module):
364
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
365
+
366
+ def __init__(self, config: MiniMaxM2Config, device=None):
367
+ super().__init__()
368
+ # BC: "rope_type" was originally "type"
369
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
370
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
371
+ else:
372
+ self.rope_type = "default"
373
+ self.max_seq_len_cached = config.max_position_embeddings
374
+ self.original_max_seq_len = config.max_position_embeddings
375
+
376
+ self.config = config
377
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
378
+
379
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
380
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
381
+ self.original_inv_freq = self.inv_freq
382
+
383
+ @torch.no_grad()
384
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
385
+ def forward(self, x, position_ids):
386
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
387
+ position_ids_expanded = position_ids[:, None, :].float()
388
+
389
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
390
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
391
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
392
+ emb = torch.cat((freqs, freqs), dim=-1)
393
+ cos = emb.cos() * self.attention_scaling
394
+ sin = emb.sin() * self.attention_scaling
395
+
396
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
397
+
398
+
399
+ @auto_docstring
400
+ class MiniMaxM2PreTrainedModel(PreTrainedModel):
401
+ config: MiniMaxM2Config
402
+ base_model_prefix = "model"
403
+ supports_gradient_checkpointing = True
404
+ _no_split_modules = ["MiniMaxM2DecoderLayer"]
405
+ _skip_keys_device_placement = ["past_key_values"]
406
+ _supports_flash_attn = True
407
+ _supports_sdpa = True
408
+ _supports_flex_attn = True
409
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
410
+ _supports_attention_backend = True
411
+ _can_record_outputs = {
412
+ "router_logits": OutputRecorder(MiniMaxM2SparseMoeBlock, index=1),
413
+ "hidden_states": MiniMaxM2DecoderLayer,
414
+ "attentions": MiniMaxM2Attention,
415
+ }
416
+
417
+
418
+ @auto_docstring
419
+ class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
420
+ def __init__(self, config: MiniMaxM2Config):
421
+ super().__init__(config)
422
+ self.padding_idx = config.pad_token_id
423
+ self.vocab_size = config.vocab_size
424
+
425
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
426
+ self.layers = nn.ModuleList(
427
+ [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
428
+ )
429
+ self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
430
+ self.rotary_emb = MiniMaxM2RotaryEmbedding(config=config)
431
+ self.gradient_checkpointing = False
432
+
433
+ # Initialize weights and apply final processing
434
+ self.post_init()
435
+
436
+ @check_model_inputs
437
+ @auto_docstring
438
+ def forward(
439
+ self,
440
+ input_ids: Optional[torch.LongTensor] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.LongTensor] = None,
443
+ past_key_values: Optional[Cache] = None,
444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
445
+ use_cache: Optional[bool] = None,
446
+ cache_position: Optional[torch.LongTensor] = None,
447
+ **kwargs: Unpack[TransformersKwargs],
448
+ ) -> MoeModelOutputWithPast:
449
+ if (input_ids is None) ^ (inputs_embeds is not None):
450
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
451
+
452
+ if use_cache and past_key_values is None:
453
+ past_key_values = DynamicCache(config=self.config)
454
+
455
+ if inputs_embeds is None:
456
+ inputs_embeds = self.embed_tokens(input_ids)
457
+
458
+ if cache_position is None:
459
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
460
+ cache_position = torch.arange(
461
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
462
+ )
463
+ if position_ids is None:
464
+ position_ids = cache_position.unsqueeze(0)
465
+
466
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
467
+ causal_mask = mask_function(
468
+ config=self.config,
469
+ input_embeds=inputs_embeds,
470
+ attention_mask=attention_mask,
471
+ cache_position=cache_position,
472
+ past_key_values=past_key_values,
473
+ position_ids=position_ids,
474
+ )
475
+
476
+ hidden_states = inputs_embeds
477
+
478
+ # create position embeddings to be shared across the decoder layers
479
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
480
+
481
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
482
+ hidden_states = decoder_layer(
483
+ hidden_states,
484
+ position_embeddings=position_embeddings,
485
+ attention_mask=causal_mask,
486
+ position_ids=position_ids,
487
+ past_key_values=past_key_values,
488
+ use_cache=use_cache,
489
+ cache_position=cache_position,
490
+ **kwargs,
491
+ )
492
+
493
+ hidden_states = self.norm(hidden_states)
494
+
495
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
496
+ last_hidden_state=hidden_states,
497
+ past_key_values=past_key_values,
498
+ )
499
+
500
+
501
+ def load_balancing_loss_func(
502
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
503
+ num_experts: Optional[int] = None,
504
+ top_k=2,
505
+ attention_mask: Optional[torch.Tensor] = None,
506
+ ) -> Union[torch.Tensor, int]:
507
+ r"""
508
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
509
+
510
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
511
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
512
+ experts is too unbalanced.
513
+
514
+ Args:
515
+ gate_logits:
516
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
517
+ shape [batch_size X sequence_length, num_experts].
518
+ num_experts:
519
+ Number of experts
520
+ top_k:
521
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
522
+ parameter.
523
+ attention_mask (`torch.Tensor`, *optional*):
524
+ The attention_mask used in forward function
525
+ shape [batch_size X sequence_length] if not None.
526
+
527
+ Returns:
528
+ The auxiliary loss.
529
+ """
530
+ if gate_logits is None or not isinstance(gate_logits, tuple):
531
+ return 0
532
+
533
+ if isinstance(gate_logits, tuple):
534
+ compute_device = gate_logits[0].device
535
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
536
+
537
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
538
+
539
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
540
+
541
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
542
+
543
+ if attention_mask is None:
544
+ # Compute the percentage of tokens routed to each experts
545
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
546
+
547
+ # Compute the average probability of routing to these experts
548
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
549
+ else:
550
+ batch_size, sequence_length = attention_mask.shape
551
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
552
+
553
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
554
+ expert_attention_mask = (
555
+ attention_mask[None, :, :, None, None]
556
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
557
+ .reshape(-1, top_k, num_experts)
558
+ .to(compute_device)
559
+ )
560
+
561
+ # Compute the percentage of tokens routed to each experts
562
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
563
+ expert_attention_mask, dim=0
564
+ )
565
+
566
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
567
+ router_per_expert_attention_mask = (
568
+ attention_mask[None, :, :, None]
569
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
570
+ .reshape(-1, num_experts)
571
+ .to(compute_device)
572
+ )
573
+
574
+ # Compute the average probability of routing to these experts
575
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
576
+ router_per_expert_attention_mask, dim=0
577
+ )
578
+
579
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
580
+ return overall_loss * num_experts
581
+
582
+
583
+ @auto_docstring
584
+ class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
585
+ _tied_weights_keys = ["lm_head.weight"]
586
+ _tp_plan = {"lm_head": "colwise_rep"}
587
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
588
+
589
+ def __init__(self, config):
590
+ super().__init__(config)
591
+ self.model = MiniMaxM2Model(config)
592
+ self.vocab_size = config.vocab_size
593
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
594
+ self.router_aux_loss_coef = config.router_aux_loss_coef
595
+ self.num_experts = config.num_local_experts
596
+ self.num_experts_per_tok = config.num_experts_per_tok
597
+
598
+ # Initialize weights and apply final processing
599
+ self.post_init()
600
+
601
+ @can_return_tuple
602
+ @auto_docstring
603
+ def forward(
604
+ self,
605
+ input_ids: Optional[torch.LongTensor] = None,
606
+ attention_mask: Optional[torch.Tensor] = None,
607
+ position_ids: Optional[torch.LongTensor] = None,
608
+ past_key_values: Optional[Cache] = None,
609
+ inputs_embeds: Optional[torch.FloatTensor] = None,
610
+ labels: Optional[torch.LongTensor] = None,
611
+ use_cache: Optional[bool] = None,
612
+ output_router_logits: Optional[bool] = None,
613
+ cache_position: Optional[torch.LongTensor] = None,
614
+ logits_to_keep: Union[int, torch.Tensor] = 0,
615
+ **kwargs: Unpack[TransformersKwargs],
616
+ ) -> MoeCausalLMOutputWithPast:
617
+ r"""
618
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
619
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
620
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
621
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
622
+
623
+ Example:
624
+
625
+ ```python
626
+ >>> from transformers import AutoTokenizer, MiniMaxM2ForCausalLM
627
+
628
+ >>> model = MiniMaxM2ForCausalLM.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1")
629
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1")
630
+
631
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
632
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
633
+
634
+ >>> # Generate
635
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
636
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
637
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
638
+ ```"""
639
+
640
+ output_router_logits = (
641
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
642
+ )
643
+
644
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
645
+ outputs: MoeModelOutputWithPast = self.model(
646
+ input_ids=input_ids,
647
+ attention_mask=attention_mask,
648
+ position_ids=position_ids,
649
+ past_key_values=past_key_values,
650
+ inputs_embeds=inputs_embeds,
651
+ use_cache=use_cache,
652
+ output_router_logits=output_router_logits,
653
+ cache_position=cache_position,
654
+ **kwargs,
655
+ )
656
+
657
+ hidden_states = outputs.last_hidden_state
658
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
659
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
660
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
661
+
662
+ loss = None
663
+ if labels is not None:
664
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
665
+
666
+ aux_loss = None
667
+ if output_router_logits:
668
+ aux_loss = load_balancing_loss_func(
669
+ outputs.router_logits,
670
+ self.num_experts,
671
+ self.num_experts_per_tok,
672
+ attention_mask,
673
+ )
674
+ if labels is not None:
675
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
676
+
677
+ return MoeCausalLMOutputWithPast(
678
+ loss=loss,
679
+ aux_loss=aux_loss,
680
+ logits=logits,
681
+ past_key_values=outputs.past_key_values,
682
+ hidden_states=outputs.hidden_states,
683
+ attentions=outputs.attentions,
684
+ router_logits=outputs.router_logits,
685
+ )
686
+
687
+
688
+ class MiniMaxM2ForSequenceClassification(GenericForSequenceClassification, MiniMaxM2PreTrainedModel):
689
+ pass
690
+
691
+
692
+ class MiniMaxM2ForTokenClassification(GenericForTokenClassification, MiniMaxM2PreTrainedModel):
693
+ pass
694
+
695
+
696
+ class MiniMaxM2ForQuestionAnswering(GenericForQuestionAnswering, MiniMaxM2PreTrainedModel):
697
+ pass
698
+
699
+
700
+ __all__ = [
701
+ "MiniMaxM2ForCausalLM",
702
+ "MiniMaxM2ForQuestionAnswering",
703
+ "MiniMaxM2Model",
704
+ "MiniMaxM2PreTrainedModel",
705
+ "MiniMaxM2ForSequenceClassification",
706
+ "MiniMaxM2ForTokenClassification",
707
+ ]