Qubitium commited on
Commit
3301395
·
verified ·
1 Parent(s): 10a673b

Add files using upload-large-folder tool

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ license: mit
4
  ---
5
 
6
  <div align="center">
7
- Upconverted to BFloat16 by <a href='https://x.com/qubitium'>@Qubitum</a> at ModelCloud
8
 
9
  <svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
10
  <path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
@@ -186,4 +186,4 @@ Please refer to our [Tool Calling Guide](https://huggingface.co/MiniMaxAI/MiniMa
186
 
187
  # Contact Us
188
 
189
- Contact us at [model@minimax.io](mailto:model@minimax.io).
 
4
  ---
5
 
6
  <div align="center">
7
+ Upconverted to BFloat16 by <a href='https://x.com/qubitium'>@Qubitum</a> at ModelCloud.
8
 
9
  <svg width="60%" height="auto" viewBox="0 0 144 48" fill="none" xmlns="http://www.w3.org/2000/svg">
10
  <path d="M26.6782 7.96523C26.6782 7.02436 25.913 6.26087 24.9739 6.26087C24.0348 6.26087 23.2695 7.0261 23.2695 7.96523V36.2139C23.2695 38.4 21.4904 40.1791 19.3043 40.1791C17.1183 40.1791 15.3391 38.4 15.3391 36.2139V18.0904C15.3391 17.1496 14.5739 16.3861 13.6348 16.3861C12.6956 16.3861 11.9304 17.1513 11.9304 18.0904V25.7722C11.9304 27.9583 10.1513 29.7374 7.96518 29.7374C5.7791 29.7374 4 27.9583 4 25.7722V22.9878C4 22.3635 4.50609 21.8574 5.13043 21.8574C5.75478 21.8574 6.26087 22.3635 6.26087 22.9878V25.7722C6.26087 26.713 7.02605 27.4765 7.96518 27.4765C8.90431 27.4765 9.66954 26.7113 9.66954 25.7722V18.0904C9.66954 15.9044 11.4487 14.1252 13.6348 14.1252C15.8209 14.1252 17.6 15.9044 17.6 18.0904V36.2139C17.6 37.1548 18.3652 37.9183 19.3043 37.9183C20.2435 37.9183 21.0087 37.153 21.0087 36.2139V25.1322V7.96523C21.0087 5.77914 22.7878 4 24.9739 4C27.16 4 28.9391 5.77914 28.9391 7.96523V31.3565C28.9391 31.9809 28.433 32.487 27.8087 32.487C27.1843 32.487 26.6782 31.9809 26.6782 31.3565V7.96523ZM47.6539 14.1252C45.4678 14.1252 43.6887 15.9044 43.6887 18.0904V33.2296C43.6887 34.1704 42.9235 34.9339 41.9843 34.9339C41.0452 34.9339 40.28 34.1687 40.28 33.2296V7.96523C40.28 5.77914 38.5008 4 36.3148 4C34.1287 4 32.3496 5.77914 32.3496 7.96523V40.0348C32.3496 40.9756 31.5843 41.7391 30.6452 41.7391C29.7061 41.7391 28.9409 40.9739 28.9409 40.0348V36.0643C28.9409 35.44 28.4348 34.9339 27.8104 34.9339C27.1861 34.9339 26.68 35.44 26.68 36.0643V40.0348C26.68 42.2209 28.4591 44 30.6452 44C32.8313 44 34.6104 42.2209 34.6104 40.0348V7.96523C34.6104 7.02436 35.3756 6.26087 36.3148 6.26087C37.2539 6.26087 38.0191 7.0261 38.0191 7.96523V33.2296C38.0191 35.4156 39.7982 37.1948 41.9843 37.1948C44.1704 37.1948 45.9496 35.4156 45.9496 33.2296V18.0904C45.9496 17.1496 46.7148 16.3861 47.6539 16.3861C48.593 16.3861 49.3582 17.1513 49.3582 18.0904V31.3565C49.3582 31.9809 49.8643 32.487 50.4887 32.487C51.113 32.487 51.6191 31.9809 51.6191 31.3565V18.0904C51.6191 15.9044 49.84 14.1252 47.6539 14.1252Z" fill="url(#paint0_linear_17_483)"/>
 
186
 
187
  # Contact Us
188
 
189
+ Contact us at [model@minimax.io](mailto:model@minimax.io).
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: qubitium@modelcloud.ai, x.com/qubitium
5
+ #
6
+ # """MiniMax M2 Hugging Face remote code support."""
7
+
8
+ from .configuration_minimax_m2 import MiniMaxM2Config
9
+ from .modeling_minimax_m2 import (
10
+ MiniMaxForCausalLM,
11
+ MiniMaxM2ForCausalLM,
12
+ MiniMaxM2Model,
13
+ MiniMaxM2PreTrainedModel,
14
+ MiniMaxModel,
15
+ MiniMaxPreTrainedModel,
16
+ )
17
+
18
+ __all__ = [
19
+ "MiniMaxM2Config",
20
+ "MiniMaxM2PreTrainedModel",
21
+ "MiniMaxM2Model",
22
+ "MiniMaxM2ForCausalLM",
23
+ "MiniMaxPreTrainedModel",
24
+ "MiniMaxModel",
25
+ "MiniMaxForCausalLM",
26
+ ]
__pycache__/configuration_minimax_m2.cpython-313.pyc ADDED
Binary file (6.02 kB). View file
 
config.json CHANGED
@@ -79,7 +79,7 @@
79
  "layernorm_mlp_beta": 1.0,
80
  "max_position_embeddings": 196608,
81
  "mlp_intermediate_size": 8192,
82
- "model_type": "mixtral",
83
  "mtp_transformer_layers": 1,
84
  "num_attention_heads": 48,
85
  "num_experts_per_tok": 8,
@@ -105,5 +105,9 @@
105
  "use_qk_norm": true,
106
  "use_routing_bias": true,
107
  "vocab_size": 200064,
108
- "torch_dtype": "bfloat16"
 
 
 
 
109
  }
 
79
  "layernorm_mlp_beta": 1.0,
80
  "max_position_embeddings": 196608,
81
  "mlp_intermediate_size": 8192,
82
+ "model_type": "minimax",
83
  "mtp_transformer_layers": 1,
84
  "num_attention_heads": 48,
85
  "num_experts_per_tok": 8,
 
105
  "use_qk_norm": true,
106
  "use_routing_bias": true,
107
  "vocab_size": 200064,
108
+ "torch_dtype": "bfloat16",
109
+ "auto_map": {
110
+ "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
111
+ "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
112
+ }
113
  }
configuration_minimax_m2.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: qubitium@modelcloud.ai, x.com/qubitium
5
+
6
+ """Configuration for the MiniMax M2 architecture."""
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import List, Optional, Union
11
+
12
+ from transformers.configuration_utils import PretrainedConfig
13
+
14
+
15
+ class _QuantizationConfigDict(dict):
16
+ """Ensure quantization config always exposes a `quant_method`."""
17
+
18
+ def __init__(self, data: Optional[dict] = None):
19
+ if data is None:
20
+ data = {}
21
+ super().__init__(data)
22
+ self.setdefault("quant_method", "none")
23
+
24
+ def to_dict(self):
25
+ return dict(self)
26
+
27
+
28
+ class MiniMaxM2Config(PretrainedConfig):
29
+ model_type = "minimax"
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size: int = 200_064,
34
+ hidden_size: int = 3_072,
35
+ intermediate_size: int = 1_536,
36
+ mlp_intermediate_size: int = 8_192,
37
+ num_hidden_layers: int = 62,
38
+ num_attention_heads: int = 48,
39
+ num_key_value_heads: int = 8,
40
+ head_dim: Optional[int] = 128,
41
+ num_local_experts: int = 256,
42
+ num_experts_per_tok: int = 8,
43
+ attn_type_list: Optional[List[int]] = None,
44
+ attention_dropout: float = 0.0,
45
+ hidden_act: str = "silu",
46
+ rms_norm_eps: float = 1e-6,
47
+ max_position_embeddings: int = 196_608,
48
+ rope_theta: float = 5_000_000.0,
49
+ rotary_dim: int = 64,
50
+ rope_scaling: Optional[dict] = None,
51
+ use_qk_norm: bool = True,
52
+ qk_norm_type: str = "per_layer",
53
+ use_routing_bias: bool = True,
54
+ scoring_func: str = "sigmoid",
55
+ router_aux_loss_coef: float = 0.001,
56
+ router_jitter_noise: float = 0.0,
57
+ output_router_logits: bool = False,
58
+ use_grouped_topk: bool = True,
59
+ num_expert_group: Optional[int] = None,
60
+ topk_group: Optional[int] = None,
61
+ routed_scaling_factor: float = 1.0,
62
+ layernorm_full_attention_beta: float = 1.0,
63
+ layernorm_linear_attention_beta: float = 1.0,
64
+ layernorm_mlp_beta: float = 1.0,
65
+ shared_intermediate_size: int = 0,
66
+ shared_moe_mode: str = "sigmoid",
67
+ use_mtp: bool = True,
68
+ num_mtp_modules: int = 3,
69
+ mtp_transformer_layers: int = 1,
70
+ attn_window_size: Optional[Union[int, List[int]]] = None,
71
+ swa_rope_theta: float = -1.0,
72
+ sliding_window: Optional[int] = None,
73
+ initializer_range: float = 0.02,
74
+ tie_word_embeddings: bool = False,
75
+ max_model_len: Optional[int] = None,
76
+ bos_token_id: Optional[int] = None,
77
+ eos_token_id: Optional[int] = None,
78
+ pad_token_id: Optional[int] = None,
79
+ use_cache: bool = True,
80
+ **kwargs,
81
+ ) -> None:
82
+ quantization_config = kwargs.pop("quantization_config", None)
83
+ if quantization_config is None:
84
+ quantization_config = _QuantizationConfigDict()
85
+ elif not isinstance(quantization_config, _QuantizationConfigDict):
86
+ quantization_config = _QuantizationConfigDict(quantization_config)
87
+ transformers_version = kwargs.pop("transformers_version", None)
88
+
89
+ super().__init__(
90
+ bos_token_id=bos_token_id,
91
+ eos_token_id=eos_token_id,
92
+ tie_word_embeddings=tie_word_embeddings,
93
+ pad_token_id=pad_token_id,
94
+ **kwargs,
95
+ )
96
+
97
+ self.vocab_size = vocab_size
98
+ self.hidden_size = hidden_size
99
+ self.intermediate_size = intermediate_size
100
+ self.mlp_intermediate_size = mlp_intermediate_size
101
+ self.num_hidden_layers = num_hidden_layers
102
+ self.num_attention_heads = num_attention_heads
103
+ self.num_key_value_heads = num_key_value_heads
104
+ self.head_dim = head_dim or hidden_size // num_attention_heads
105
+ self.num_local_experts = num_local_experts
106
+ self.num_experts_per_tok = num_experts_per_tok
107
+ self.attn_type_list = attn_type_list or [1] * num_hidden_layers
108
+ self.attention_dropout = attention_dropout
109
+ self.hidden_act = hidden_act
110
+ self.rms_norm_eps = rms_norm_eps
111
+ self.max_position_embeddings = max_position_embeddings
112
+ self.rope_theta = rope_theta
113
+ self.rotary_dim = rotary_dim
114
+ self.rope_scaling = rope_scaling
115
+ self.use_qk_norm = use_qk_norm
116
+ self.qk_norm_type = qk_norm_type
117
+ self.use_routing_bias = use_routing_bias
118
+ self.scoring_func = scoring_func
119
+ self.router_aux_loss_coef = router_aux_loss_coef
120
+ self.router_jitter_noise = router_jitter_noise
121
+ self.output_router_logits = output_router_logits
122
+ self.use_grouped_topk = use_grouped_topk
123
+ self.num_expert_group = num_expert_group
124
+ self.topk_group = topk_group
125
+ self.routed_scaling_factor = routed_scaling_factor
126
+ self.layernorm_full_attention_beta = layernorm_full_attention_beta
127
+ self.layernorm_linear_attention_beta = layernorm_linear_attention_beta
128
+ self.layernorm_mlp_beta = layernorm_mlp_beta
129
+ self.shared_intermediate_size = shared_intermediate_size
130
+ self.shared_moe_mode = shared_moe_mode
131
+ self.use_mtp = use_mtp
132
+ self.num_mtp_modules = num_mtp_modules
133
+ self.mtp_transformer_layers = mtp_transformer_layers
134
+ self.attn_window_size = attn_window_size
135
+ self.swa_rope_theta = swa_rope_theta
136
+ self.sliding_window = sliding_window
137
+ self.initializer_range = initializer_range
138
+ self.max_model_len = max_model_len
139
+ self.use_cache = use_cache
140
+
141
+ # Convenient accessor used by rotary embedding helper
142
+ self.partial_rotary_factor = float(self.rotary_dim) / float(self.head_dim)
143
+ self.quantization_config = quantization_config
144
+ self.transformers_version = transformers_version
145
+
146
+
147
+ __all__ = ["MiniMaxM2Config"]
modeling_minimax_m2.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: qubitium@modelcloud.ai, x.com/qubitium
5
+
6
+ """PyTorch implementation of the MiniMax M2 architecture for Hugging Face Transformers."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import copy
11
+ import time
12
+ from typing import Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+
18
+ from transformers.activations import ACT2FN
19
+ from transformers.cache_utils import Cache, DynamicCache
20
+ from transformers.generation import GenerationMixin
21
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
22
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging
25
+
26
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, repeat_kv, rotate_half
27
+
28
+ from .configuration_minimax_m2 import MiniMaxM2Config
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ _CONFIG_FOR_DOC = "MiniMaxM2Config"
33
+ _CHECKPOINT_FOR_DOC = "MiniMaxAI/MiniMax-M2"
34
+
35
+
36
+ def load_balancing_loss_func(
37
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
38
+ num_experts: int,
39
+ top_k: int,
40
+ attention_mask: Optional[torch.Tensor] = None,
41
+ ) -> torch.Tensor:
42
+ if gate_logits is None:
43
+ return torch.tensor(0.0)
44
+ if isinstance(gate_logits, torch.Tensor):
45
+ logits = gate_logits
46
+ else:
47
+ logits = torch.cat([layer_gate.to(gate_logits[0].device) for layer_gate in gate_logits], dim=0)
48
+
49
+ routing_weights = torch.softmax(logits, dim=-1, dtype=torch.float32)
50
+ _, selected = torch.topk(routing_weights, top_k, dim=-1)
51
+ expert_mask = torch.nn.functional.one_hot(selected, num_experts)
52
+
53
+ if attention_mask is None:
54
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
55
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
56
+ else:
57
+ batch_size, seq_len = attention_mask.shape
58
+ num_layers = logits.shape[0] // (batch_size * seq_len)
59
+
60
+ expanded_mask = (
61
+ attention_mask[None, :, :, None, None]
62
+ .expand(num_layers, batch_size, seq_len, top_k, num_experts)
63
+ .reshape(-1, top_k, num_experts)
64
+ .to(logits.device)
65
+ )
66
+ tokens_per_expert = torch.sum(expert_mask.float() * expanded_mask, dim=0) / torch.sum(expanded_mask, dim=0)
67
+
68
+ router_mask = (
69
+ attention_mask[None, :, :, None]
70
+ .expand(num_layers, batch_size, seq_len, num_experts)
71
+ .reshape(-1, num_experts)
72
+ .to(logits.device)
73
+ )
74
+ router_prob_per_expert = torch.sum(routing_weights * router_mask, dim=0) / torch.sum(router_mask, dim=0)
75
+
76
+ loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
77
+ return loss * num_experts
78
+
79
+
80
+ def apply_rotary_pos_emb_partial(
81
+ q: torch.Tensor,
82
+ k: torch.Tensor,
83
+ cos: torch.Tensor,
84
+ sin: torch.Tensor,
85
+ rotary_dim: int,
86
+ unsqueeze_dim: int = 2,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ cos = cos.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
89
+ sin = sin.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
90
+ q_rot = q[..., :rotary_dim]
91
+ k_rot = k[..., :rotary_dim]
92
+
93
+ q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin)
94
+ k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin)
95
+
96
+ q = torch.cat((q_rot, q[..., rotary_dim:]), dim=-1)
97
+ k = torch.cat((k_rot, k[..., rotary_dim:]), dim=-1)
98
+ return q, k
99
+
100
+
101
+ class MiniMaxM2RMSNorm(nn.Module):
102
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
103
+ super().__init__()
104
+ self.weight = nn.Parameter(torch.ones(hidden_size))
105
+ self.variance_epsilon = eps
106
+
107
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
+ input_dtype = hidden_states.dtype
109
+ hidden_states = hidden_states.to(torch.float32)
110
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
111
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
112
+ return (self.weight * hidden_states).to(input_dtype)
113
+
114
+
115
+ class MiniMaxM2MLP(nn.Module):
116
+ def __init__(self, config: MiniMaxM2Config) -> None:
117
+ super().__init__()
118
+ self.hidden_size = config.hidden_size
119
+ self.intermediate_size = config.intermediate_size
120
+
121
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
122
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
123
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
124
+ self.act_fn = ACT2FN[config.hidden_act]
125
+
126
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
127
+ gate = self.act_fn(self.w1(hidden_states))
128
+ up = self.w3(hidden_states)
129
+ hidden_states = gate * up
130
+ hidden_states = self.w2(hidden_states)
131
+ return hidden_states
132
+
133
+
134
+ class MiniMaxM2SparseMoeBlock(nn.Module):
135
+ def __init__(self, config: MiniMaxM2Config) -> None:
136
+ super().__init__()
137
+ self.hidden_dim = config.hidden_size
138
+ self.experts = nn.ModuleList([MiniMaxM2MLP(config) for _ in range(config.num_local_experts)])
139
+ self.num_experts = config.num_local_experts
140
+ self.top_k = config.num_experts_per_tok
141
+ self.jitter_noise = config.router_jitter_noise
142
+ self.use_routing_bias = config.use_routing_bias
143
+ self.scoring_func = getattr(config, "scoring_func", "softmax")
144
+ self.use_grouped_topk = getattr(config, "use_grouped_topk", False)
145
+ self.num_expert_group = getattr(config, "num_expert_group", None)
146
+ self.topk_group = getattr(config, "topk_group", None)
147
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
148
+
149
+ if self.use_grouped_topk:
150
+ if self.num_expert_group is None or self.num_expert_group <= 0:
151
+ self.num_expert_group = 1
152
+ if self.topk_group is None or self.topk_group <= 0:
153
+ self.topk_group = min(self.num_expert_group, self.top_k)
154
+ else:
155
+ self.num_expert_group = 1
156
+ self.topk_group = 1
157
+
158
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
159
+ if self.use_routing_bias:
160
+ self.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
161
+ else:
162
+ self.register_parameter("e_score_correction_bias", None)
163
+
164
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
165
+ batch_size, seq_len, hidden_dim = hidden_states.shape
166
+ if self.training and self.jitter_noise > 0:
167
+ noise = torch.empty_like(hidden_states).uniform_(
168
+ 1.0 - self.jitter_noise,
169
+ 1.0 + self.jitter_noise,
170
+ )
171
+ hidden_states = hidden_states * noise
172
+
173
+ hidden_states = hidden_states.view(-1, hidden_dim)
174
+ gate_dtype = self.gate.weight.dtype
175
+ router_logits = self.gate(hidden_states.to(gate_dtype)).to(torch.float32)
176
+ if self.e_score_correction_bias is not None:
177
+ # Bias is applied after scoring (see vLLM/SGLang implementations).
178
+ correction_bias = self.e_score_correction_bias.to(router_logits.device, router_logits.dtype)
179
+ else:
180
+ correction_bias = None
181
+
182
+ if self.scoring_func == "sigmoid":
183
+ scores = torch.sigmoid(router_logits)
184
+ elif self.scoring_func == "softmax":
185
+ scores = torch.softmax(router_logits, dim=-1)
186
+ else:
187
+ raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
188
+
189
+ if correction_bias is not None:
190
+ original_scores = scores
191
+ scores = scores + correction_bias
192
+ else:
193
+ original_scores = scores
194
+ topk_scores: torch.Tensor
195
+ if self.use_grouped_topk and self.num_expert_group > 1:
196
+ experts_per_group = scores.size(-1) // self.num_expert_group
197
+ scores_grouped = scores.view(scores.size(0), self.num_expert_group, experts_per_group)
198
+ if correction_bias is not None:
199
+ topk_in_group = min(2, experts_per_group)
200
+ if topk_in_group > 0:
201
+ group_scores = scores_grouped.topk(topk_in_group, dim=-1)[0].sum(dim=-1)
202
+ else:
203
+ group_scores = torch.zeros_like(scores_grouped[..., 0])
204
+ else:
205
+ group_scores = scores_grouped.max(dim=-1).values
206
+ group_mask = torch.zeros_like(group_scores)
207
+ selected_groups = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True).indices
208
+ group_mask.scatter_(1, selected_groups, 1.0)
209
+ mask = group_mask.unsqueeze(-1).expand(-1, -1, experts_per_group).reshape(scores.size())
210
+ masked_scores = scores.masked_fill(mask == 0, float("-inf"))
211
+ topk_scores, selected_experts = torch.topk(masked_scores, self.top_k, dim=-1, sorted=True)
212
+ else:
213
+ topk_scores, selected_experts = torch.topk(scores, self.top_k, dim=-1, sorted=True)
214
+
215
+ if correction_bias is not None:
216
+ routing_weights = original_scores.gather(1, selected_experts)
217
+ else:
218
+ routing_weights = topk_scores
219
+
220
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)
221
+ if self.routed_scaling_factor != 1.0:
222
+ routing_weights = routing_weights * self.routed_scaling_factor
223
+ routing_weights = routing_weights.to(hidden_states.dtype)
224
+ selected_experts = selected_experts.to(torch.long)
225
+
226
+ final_hidden_states = torch.zeros_like(hidden_states)
227
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
228
+ expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten()
229
+
230
+ for expert_idx in expert_hit.tolist():
231
+ expert_layer = self.experts[expert_idx]
232
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
233
+ token_states = hidden_states.index_select(0, top_x)
234
+ expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1)
235
+ final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
236
+
237
+ final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
238
+ return final_hidden_states, router_logits
239
+
240
+
241
+ class MiniMaxM2Attention(nn.Module):
242
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
243
+ super().__init__()
244
+ self.config = config
245
+ self.layer_idx = layer_idx
246
+
247
+ self.head_dim = config.head_dim
248
+ self.num_heads = config.num_attention_heads
249
+ self.num_key_value_heads = config.num_key_value_heads
250
+ self.num_key_value_groups = self.num_heads // max(1, self.num_key_value_heads)
251
+ self.rotary_dim = config.rotary_dim
252
+ self.scaling = self.head_dim**-0.5
253
+ self.attention_dropout = config.attention_dropout
254
+ self.is_causal = True
255
+
256
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
257
+ max_model_len = getattr(config, "max_model_len", None)
258
+ if max_model_len is not None:
259
+ max_position_embeddings = max(max_position_embeddings, max_model_len)
260
+
261
+ attn_window_size = getattr(config, "attn_window_size", None)
262
+ if isinstance(attn_window_size, list):
263
+ sliding_window = attn_window_size[layer_idx]
264
+ else:
265
+ sliding_window = attn_window_size
266
+ if sliding_window is not None and sliding_window <= 0:
267
+ sliding_window = None
268
+ self.sliding_window = sliding_window
269
+
270
+ swa_rope_theta = getattr(config, "swa_rope_theta", -1.0)
271
+ rope_theta = config.rope_theta
272
+ if self.sliding_window is not None and swa_rope_theta > 0:
273
+ rope_theta = swa_rope_theta
274
+
275
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
276
+ self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
277
+ self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
278
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
279
+
280
+ self.use_qk_norm = config.use_qk_norm
281
+ if self.use_qk_norm:
282
+ self.q_norm = MiniMaxM2RMSNorm(self.num_heads * self.head_dim, eps=config.rms_norm_eps)
283
+ self.k_norm = MiniMaxM2RMSNorm(self.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps)
284
+
285
+ rope_config = copy.deepcopy(config)
286
+ rope_config.hidden_size = config.hidden_size
287
+ rope_config.num_attention_heads = config.num_attention_heads
288
+ rope_config.partial_rotary_factor = float(config.rotary_dim) / float(self.head_dim)
289
+ rope_config.rope_theta = rope_theta
290
+ rope_config.max_position_embeddings = max_position_embeddings
291
+ self.rotary_emb = LlamaRotaryEmbedding(rope_config)
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ attention_mask: Optional[torch.Tensor] = None,
297
+ position_ids: Optional[torch.LongTensor] = None,
298
+ past_key_values: Optional[Cache] = None,
299
+ use_cache: Optional[bool] = False,
300
+ cache_position: Optional[torch.LongTensor] = None,
301
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
302
+ output_attentions: bool = False,
303
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
304
+ bsz, q_len, _ = hidden_states.size()
305
+
306
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
307
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
308
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
309
+
310
+ if self.use_qk_norm:
311
+ q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1)
312
+ k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1)
313
+ q_flat = self.q_norm(q_flat)
314
+ k_flat = self.k_norm(k_flat)
315
+ query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
316
+ key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
317
+
318
+ if position_embeddings is None:
319
+ cos, sin = self.rotary_emb(value_states, position_ids)
320
+ else:
321
+ cos, sin = position_embeddings
322
+
323
+ query_states, key_states = apply_rotary_pos_emb_partial(
324
+ query_states.transpose(1, 2), key_states.transpose(1, 2), cos, sin, self.rotary_dim
325
+ )
326
+ query_states = query_states.transpose(1, 2)
327
+ key_states = key_states.transpose(1, 2)
328
+
329
+ if past_key_values is not None:
330
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
331
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
332
+
333
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
334
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
335
+
336
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
337
+ if attention_mask is not None:
338
+ attn_weights = attn_weights + attention_mask
339
+
340
+ if self.sliding_window is not None and past_key_values is None:
341
+ query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1)
342
+ key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1)
343
+ window_mask = key_positions < (query_positions - self.sliding_window)
344
+ if window_mask.any():
345
+ attn_weights = attn_weights.masked_fill(window_mask, float("-inf"))
346
+
347
+ attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
348
+ if self.training and self.attention_dropout > 0:
349
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout)
350
+
351
+ attn_output = torch.matmul(attn_weights, value_states)
352
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
353
+ attn_output = self.o_proj(attn_output)
354
+
355
+ if not output_attentions:
356
+ attn_weights = None
357
+ return attn_output, attn_weights
358
+
359
+
360
+ class MiniMaxM2LogitsProcessor(nn.Module):
361
+ def __init__(self, config: MiniMaxM2Config) -> None:
362
+ super().__init__()
363
+ self.scale = getattr(config, "logits_scale", 1.0)
364
+
365
+ def forward(self, lm_head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
366
+ logits = lm_head(hidden_states)
367
+ if self.scale != 1.0:
368
+ logits = logits * self.scale
369
+ return logits
370
+
371
+
372
+ class MiniMaxM2DecoderLayer(nn.Module):
373
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
374
+ super().__init__()
375
+ self.hidden_size = config.hidden_size
376
+ self.self_attn = MiniMaxM2Attention(config, layer_idx)
377
+ self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
378
+ self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
379
+ self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ position_ids: Optional[torch.LongTensor] = None,
386
+ past_key_values: Optional[Cache] = None,
387
+ use_cache: Optional[bool] = False,
388
+ cache_position: Optional[torch.LongTensor] = None,
389
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
390
+ output_attentions: bool = False,
391
+ residual: Optional[torch.Tensor] = None,
392
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
393
+ residual_input = hidden_states if residual is None else residual
394
+ hidden_states = self.input_layernorm(hidden_states)
395
+
396
+ attn_output, attn_weights = self.self_attn(
397
+ hidden_states=hidden_states,
398
+ attention_mask=attention_mask,
399
+ position_ids=position_ids,
400
+ past_key_values=past_key_values,
401
+ use_cache=use_cache,
402
+ cache_position=cache_position,
403
+ position_embeddings=position_embeddings,
404
+ output_attentions=output_attentions,
405
+ )
406
+ hidden_states = residual_input + attn_output
407
+
408
+ residual_post_attn = hidden_states
409
+ hidden_states = self.post_attention_layernorm(hidden_states)
410
+ moe_output, router_logits = self.block_sparse_moe(hidden_states)
411
+ hidden_states = residual_post_attn + moe_output
412
+
413
+ return hidden_states, hidden_states, router_logits, attn_weights
414
+
415
+
416
+ class MiniMaxM2PreTrainedModel(PreTrainedModel):
417
+ config_class = MiniMaxM2Config
418
+ base_model_prefix = "model"
419
+ supports_gradient_checkpointing = True
420
+ _no_split_modules = ["MiniMaxM2DecoderLayer"]
421
+ _supports_flash_attn = False
422
+ _supports_sdpa = False
423
+ _supports_attention_backend = False
424
+
425
+ def _init_weights(self, module: nn.Module) -> None:
426
+ if isinstance(module, nn.Linear):
427
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
428
+ if module.bias is not None:
429
+ module.bias.data.zero_()
430
+ elif isinstance(module, nn.Embedding):
431
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
432
+ if module.padding_idx is not None:
433
+ module.weight.data[module.padding_idx].zero_()
434
+
435
+ def _remap_qkv_weights(self, state_dict):
436
+ num_q = self.config.num_attention_heads * self.config.head_dim
437
+ num_kv = self.config.num_key_value_heads * self.config.head_dim
438
+
439
+ for layer_idx in range(self.config.num_hidden_layers):
440
+ prefix = f"model.layers.{layer_idx}.self_attn"
441
+ weight_key = f"{prefix}.qkv_proj.weight"
442
+ if weight_key in state_dict:
443
+ qkv_weight = state_dict.pop(weight_key)
444
+ q_weight, k_weight, v_weight = qkv_weight.split([num_q, num_kv, num_kv], dim=0)
445
+ state_dict.setdefault(f"{prefix}.q_proj.weight", q_weight)
446
+ state_dict.setdefault(f"{prefix}.k_proj.weight", k_weight)
447
+ state_dict.setdefault(f"{prefix}.v_proj.weight", v_weight)
448
+
449
+ def load_state_dict(self, state_dict, strict: bool = True):
450
+ if not isinstance(state_dict, dict):
451
+ raise TypeError(f"Expected state_dict to be dict, got {type(state_dict)}")
452
+
453
+ filtered_state_dict = {}
454
+ drop_suffixes = ("weight_scale_inv", "weight_scale", "input_scale", "scales", "amax")
455
+ for key, value in state_dict.items():
456
+ if key.endswith(drop_suffixes) or "fp8" in key:
457
+ continue
458
+ filtered_state_dict[key] = value
459
+
460
+ self._remap_qkv_weights(filtered_state_dict)
461
+
462
+ if logger.isEnabledFor(logging.INFO):
463
+ logger.info(
464
+ "MiniMaxM2: loading %d tensors (filtered from %d original).",
465
+ len(filtered_state_dict),
466
+ len(state_dict),
467
+ )
468
+
469
+ load_start = time.perf_counter()
470
+ result = super().load_state_dict(filtered_state_dict, strict=strict)
471
+ load_elapsed = time.perf_counter() - load_start
472
+ if logger.isEnabledFor(logging.INFO):
473
+ logger.info("MiniMaxM2: state_dict load finished in %.2f seconds.", load_elapsed)
474
+
475
+ return result
476
+
477
+
478
+ class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
479
+ def __init__(self, config: MiniMaxM2Config) -> None:
480
+ super().__init__(config)
481
+ self.padding_idx = config.pad_token_id
482
+ self.vocab_size = config.vocab_size
483
+
484
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
485
+ self.layers = nn.ModuleList(
486
+ [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
487
+ )
488
+ self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
489
+ self.gradient_checkpointing = False
490
+
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self) -> nn.Module:
494
+ return self.embed_tokens
495
+
496
+ def set_input_embeddings(self, value: nn.Module) -> None:
497
+ self.embed_tokens = value
498
+
499
+ def forward(
500
+ self,
501
+ input_ids: Optional[torch.LongTensor] = None,
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ position_ids: Optional[torch.LongTensor] = None,
504
+ past_key_values: Optional[Cache] = None,
505
+ inputs_embeds: Optional[torch.Tensor] = None,
506
+ cache_position: Optional[torch.LongTensor] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: bool = False,
509
+ output_hidden_states: bool = False,
510
+ output_router_logits: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ ) -> Union[MoeModelOutputWithPast, Tuple]:
513
+ if (input_ids is None) == (inputs_embeds is None):
514
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
515
+
516
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
517
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
518
+ output_router_logits = (
519
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
520
+ )
521
+
522
+ if inputs_embeds is None:
523
+ inputs_embeds = self.embed_tokens(input_ids)
524
+
525
+ if use_cache and past_key_values is None:
526
+ past_key_values = DynamicCache(config=self.config)
527
+
528
+ if cache_position is None:
529
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
530
+ cache_position = torch.arange(
531
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
532
+ )
533
+
534
+ if position_ids is None:
535
+ position_ids = cache_position.unsqueeze(0)
536
+
537
+ if self.config.sliding_window is not None:
538
+ causal_mask = create_sliding_window_causal_mask(
539
+ config=self.config,
540
+ input_embeds=inputs_embeds,
541
+ attention_mask=attention_mask,
542
+ cache_position=cache_position,
543
+ past_key_values=past_key_values,
544
+ position_ids=position_ids,
545
+ )
546
+ else:
547
+ causal_mask = create_causal_mask(
548
+ config=self.config,
549
+ input_embeds=inputs_embeds,
550
+ attention_mask=attention_mask,
551
+ cache_position=cache_position,
552
+ past_key_values=past_key_values,
553
+ position_ids=position_ids,
554
+ )
555
+
556
+ hidden_states = inputs_embeds
557
+
558
+ all_hidden_states = () if output_hidden_states else None
559
+ all_attentions = () if output_attentions else None
560
+ all_router_logits = () if output_router_logits else None
561
+
562
+ residual = None
563
+ for decoder_layer in self.layers:
564
+ if output_hidden_states:
565
+ all_hidden_states = all_hidden_states + (hidden_states,)
566
+
567
+ layer_outputs = decoder_layer(
568
+ hidden_states,
569
+ attention_mask=causal_mask,
570
+ position_ids=position_ids,
571
+ past_key_values=past_key_values,
572
+ use_cache=use_cache,
573
+ cache_position=cache_position,
574
+ position_embeddings=None,
575
+ output_attentions=output_attentions,
576
+ residual=residual,
577
+ )
578
+
579
+ hidden_states, residual, router_logits, attn_weights = layer_outputs
580
+
581
+ if output_router_logits:
582
+ all_router_logits = all_router_logits + (router_logits,)
583
+ if output_attentions:
584
+ all_attentions = all_attentions + (attn_weights,)
585
+
586
+ hidden_states = self.norm(hidden_states)
587
+
588
+ if output_hidden_states:
589
+ all_hidden_states = all_hidden_states + (hidden_states,)
590
+
591
+ if not return_dict:
592
+ outputs = (hidden_states, past_key_values)
593
+ if output_hidden_states:
594
+ outputs += (all_hidden_states,)
595
+ if output_attentions:
596
+ outputs += (all_attentions,)
597
+ if output_router_logits:
598
+ outputs += (all_router_logits,)
599
+ return outputs
600
+
601
+ return MoeModelOutputWithPast(
602
+ last_hidden_state=hidden_states,
603
+ past_key_values=past_key_values,
604
+ hidden_states=all_hidden_states,
605
+ attentions=all_attentions,
606
+ router_logits=all_router_logits,
607
+ )
608
+
609
+
610
+ class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
611
+ def __init__(self, config: MiniMaxM2Config) -> None:
612
+ super().__init__(config)
613
+ self.model = MiniMaxM2Model(config)
614
+ self.vocab_size = config.vocab_size
615
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
616
+ self.router_aux_loss_coef = config.router_aux_loss_coef
617
+ self.num_experts = config.num_local_experts
618
+ self.num_experts_per_tok = config.num_experts_per_tok
619
+ self.logits_processor = MiniMaxM2LogitsProcessor(config)
620
+
621
+ self.post_init()
622
+
623
+ def get_input_embeddings(self) -> nn.Module:
624
+ return self.model.embed_tokens
625
+
626
+ def set_input_embeddings(self, value: nn.Module) -> None:
627
+ self.model.embed_tokens = value
628
+
629
+ def get_output_embeddings(self) -> nn.Module:
630
+ return self.lm_head
631
+
632
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
633
+ self.lm_head = new_embeddings
634
+
635
+ def prepare_inputs_for_generation(
636
+ self,
637
+ input_ids: torch.LongTensor,
638
+ past_key_values: Optional[Cache] = None,
639
+ attention_mask: Optional[torch.Tensor] = None,
640
+ inputs_embeds: Optional[torch.Tensor] = None,
641
+ **kwargs,
642
+ ):
643
+ if past_key_values is not None:
644
+ input_ids = input_ids[:, -1:]
645
+ if attention_mask is not None:
646
+ attention_mask = attention_mask[:, -past_key_values.get_seq_length() - 1 :]
647
+
648
+ return {
649
+ "input_ids": input_ids,
650
+ "attention_mask": attention_mask,
651
+ "past_key_values": past_key_values,
652
+ "inputs_embeds": inputs_embeds,
653
+ }
654
+
655
+ def forward(
656
+ self,
657
+ input_ids: Optional[torch.LongTensor] = None,
658
+ attention_mask: Optional[torch.Tensor] = None,
659
+ position_ids: Optional[torch.LongTensor] = None,
660
+ past_key_values: Optional[Cache] = None,
661
+ inputs_embeds: Optional[torch.Tensor] = None,
662
+ labels: Optional[torch.LongTensor] = None,
663
+ cache_position: Optional[torch.LongTensor] = None,
664
+ use_cache: Optional[bool] = None,
665
+ output_attentions: bool = False,
666
+ output_hidden_states: bool = False,
667
+ output_router_logits: Optional[bool] = None,
668
+ return_dict: Optional[bool] = None,
669
+ logits_to_keep: Union[int, torch.Tensor] = 0,
670
+ ) -> Union[MoeCausalLMOutputWithPast, Tuple]:
671
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
672
+ output_router_logits = (
673
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
674
+ )
675
+
676
+ model_outputs = self.model(
677
+ input_ids=input_ids,
678
+ attention_mask=attention_mask,
679
+ position_ids=position_ids,
680
+ past_key_values=past_key_values,
681
+ inputs_embeds=inputs_embeds,
682
+ cache_position=cache_position,
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ output_hidden_states=output_hidden_states,
686
+ output_router_logits=output_router_logits,
687
+ return_dict=True,
688
+ )
689
+
690
+ hidden_states = model_outputs.last_hidden_state
691
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
692
+ logits = self.logits_processor(self.lm_head, hidden_states[:, slice_indices, :])
693
+
694
+ loss = None
695
+ if labels is not None:
696
+ shift_logits = logits[..., :-1, :].contiguous()
697
+ shift_labels = labels[..., 1:].contiguous()
698
+ loss_fct = nn.CrossEntropyLoss()
699
+ loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
700
+
701
+ aux_loss = None
702
+ if output_router_logits and model_outputs.router_logits is not None:
703
+ aux_loss = load_balancing_loss_func(
704
+ model_outputs.router_logits,
705
+ num_experts=self.num_experts,
706
+ top_k=self.num_experts_per_tok,
707
+ attention_mask=attention_mask,
708
+ )
709
+ if loss is not None:
710
+ loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device)
711
+
712
+ if not return_dict:
713
+ output = (logits,) + (model_outputs.past_key_values,)
714
+ if output_hidden_states:
715
+ output += (model_outputs.hidden_states,)
716
+ if output_attentions:
717
+ output += (model_outputs.attentions,)
718
+ if output_router_logits:
719
+ output += (model_outputs.router_logits,)
720
+ return ((loss,) + output) if loss is not None else output
721
+
722
+ return MoeCausalLMOutputWithPast(
723
+ loss=loss,
724
+ aux_loss=aux_loss,
725
+ logits=logits,
726
+ past_key_values=model_outputs.past_key_values,
727
+ hidden_states=model_outputs.hidden_states,
728
+ attentions=model_outputs.attentions,
729
+ router_logits=model_outputs.router_logits,
730
+ )
731
+
732
+ # -----------------------------------------------------------------------------
733
+ # Backward compatibility aliases
734
+ # -----------------------------------------------------------------------------
735
+
736
+ MiniMaxRMSNorm = MiniMaxM2RMSNorm
737
+ MiniMaxSparseMoeBlock = MiniMaxM2SparseMoeBlock
738
+ MiniMaxAttention = MiniMaxM2Attention
739
+ MiniMaxDecoderLayer = MiniMaxM2DecoderLayer
740
+ MiniMaxMLP = MiniMaxM2MLP
741
+ MiniMaxPreTrainedModel = MiniMaxM2PreTrainedModel
742
+ MiniMaxModel = MiniMaxM2Model
743
+
744
+
745
+ class MiniMaxForCausalLM(MiniMaxM2ForCausalLM):
746
+ """Alias for compatibility with checkpoints exporting MiniMaxForCausalLM."""
747
+
748
+
749
+ __all__ = [
750
+ "MiniMaxM2RMSNorm",
751
+ "MiniMaxM2SparseMoeBlock",
752
+ "MiniMaxM2Attention",
753
+ "MiniMaxM2DecoderLayer",
754
+ "MiniMaxM2Model",
755
+ "MiniMaxM2ForCausalLM",
756
+ "MiniMaxM2PreTrainedModel",
757
+ "MiniMaxRMSNorm",
758
+ "MiniMaxSparseMoeBlock",
759
+ "MiniMaxAttention",
760
+ "MiniMaxDecoderLayer",
761
+ "MiniMaxPreTrainedModel",
762
+ "MiniMaxModel",
763
+ "MiniMaxMLP",
764
+ "MiniMaxForCausalLM",
765
+ ]
test_minimax_m2_hf.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: qubitium@modelcloud.ai, x.com/qubitium
5
+
6
+ """
7
+ MiniMax-M2 Hugging Face checkpoint sanity check with streaming output.
8
+
9
+ Usage:
10
+ python test_minimax_m2_hf.py \
11
+ --model-path /monster/data/model/MiniMax-M2-bf16 \
12
+ --question "How many letter A are there in the word Alphabet? Reply with the number only."
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import threading
19
+ from pathlib import Path
20
+
21
+ import torch.nn as nn
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
23
+
24
+ from gptqmodel.hf_minimax_m2.modeling_minimax_m2 import (
25
+ MiniMaxAttention,
26
+ MiniMaxDecoderLayer,
27
+ MiniMaxForCausalLM,
28
+ MiniMaxMLP,
29
+ MiniMaxM2Attention,
30
+ MiniMaxM2DecoderLayer,
31
+ MiniMaxM2ForCausalLM,
32
+ MiniMaxM2MLP,
33
+ MiniMaxM2RMSNorm,
34
+ MiniMaxM2SparseMoeBlock,
35
+ MiniMaxRMSNorm,
36
+ MiniMaxSparseMoeBlock,
37
+ )
38
+
39
+
40
+ def parse_args() -> argparse.Namespace:
41
+ parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.")
42
+ parser.add_argument(
43
+ "--model-path",
44
+ type=str,
45
+ default="/monster/data/model/MiniMax-M2-bf16",
46
+ help="Path to the MiniMax-M2 Hugging Face checkpoint directory.",
47
+ )
48
+ parser.add_argument(
49
+ "--question",
50
+ type=str,
51
+ default="How many letter A are there in the word Alphabet? Reply with the number only.",
52
+ help="User question to send through the chat template.",
53
+ )
54
+ parser.add_argument(
55
+ "--max-new-tokens",
56
+ type=int,
57
+ default=512,
58
+ help="Maximum number of new tokens to sample from the model.",
59
+ )
60
+ return parser.parse_args()
61
+
62
+
63
+ def build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
64
+ messages = [
65
+ {"role": "system", "content": "You are a helpful assistant."},
66
+ {"role": "user", "content": question},
67
+ ]
68
+ return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
69
+
70
+
71
+ def assert_module_types(model: MiniMaxM2ForCausalLM) -> None:
72
+ causal_lm_types = (MiniMaxM2ForCausalLM, MiniMaxForCausalLM)
73
+ decoder_layer_types = (MiniMaxM2DecoderLayer, MiniMaxDecoderLayer)
74
+ attention_types = (MiniMaxM2Attention, MiniMaxAttention)
75
+ moe_block_types = (MiniMaxM2SparseMoeBlock, MiniMaxSparseMoeBlock)
76
+ norm_types = (MiniMaxM2RMSNorm, MiniMaxRMSNorm)
77
+ mlp_types = (MiniMaxM2MLP, MiniMaxMLP)
78
+
79
+ assert isinstance(
80
+ model, causal_lm_types
81
+ ), f"Expected MiniMaxM2ForCausalLM/MiniMaxForCausalLM, received {type(model).__name__}"
82
+
83
+ decoder = getattr(model, "model", None)
84
+ assert decoder is not None, "Model is missing the `model` attribute with decoder layers."
85
+
86
+ for layer_idx, layer in enumerate(decoder.layers):
87
+ assert isinstance(
88
+ layer, decoder_layer_types
89
+ ), f"Layer {layer_idx}: expected MiniMax(M2)DecoderLayer, got {type(layer).__name__}"
90
+ assert isinstance(
91
+ layer.self_attn, attention_types
92
+ ), f"Layer {layer_idx}: unexpected self_attn type {type(layer.self_attn).__name__}"
93
+ assert isinstance(
94
+ layer.block_sparse_moe, moe_block_types
95
+ ), f"Layer {layer_idx}: unexpected MoE block type {type(layer.block_sparse_moe).__name__}"
96
+ assert isinstance(
97
+ layer.input_layernorm, norm_types
98
+ ), f"Layer {layer_idx}: unexpected input_layernorm type {type(layer.input_layernorm).__name__}"
99
+ assert isinstance(
100
+ layer.post_attention_layernorm, norm_types
101
+ ), f"Layer {layer_idx}: unexpected post_attention_layernorm type {type(layer.post_attention_layernorm).__name__}"
102
+
103
+ moe_block = layer.block_sparse_moe
104
+ assert isinstance(
105
+ moe_block.experts, nn.ModuleList
106
+ ), f"Layer {layer_idx}: expected experts to be a ModuleList, got {type(moe_block.experts).__name__}"
107
+ for expert_idx, expert in enumerate(moe_block.experts):
108
+ assert isinstance(
109
+ expert, mlp_types
110
+ ), f"Layer {layer_idx} expert {expert_idx}: expected MiniMax(M2)MLP, got {type(expert).__name__}"
111
+
112
+
113
+ def main() -> None:
114
+ args = parse_args()
115
+ model_path = Path(args.model_path).expanduser().resolve()
116
+
117
+ print(f"Loading tokenizer from {model_path}...")
118
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
119
+
120
+ print(f"Loading model from {model_path}...")
121
+ model = AutoModelForCausalLM.from_pretrained(
122
+ model_path,
123
+ dtype="bfloat16",
124
+ device_map="auto",
125
+ trust_remote_code=True,
126
+ )
127
+
128
+ # Uncomment to enforce module type checks.
129
+ # print("Validating module types...")
130
+ # assert_module_types(model)
131
+
132
+ prompt = build_prompt(tokenizer, args.question)
133
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
134
+
135
+ print("Running generation (streaming)...\n")
136
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False)
137
+ eos_ids = model.generation_config.eos_token_id
138
+ if eos_ids is None:
139
+ eos_ids = []
140
+ elif isinstance(eos_ids, int):
141
+ eos_ids = [eos_ids]
142
+ think_end_id = tokenizer.convert_tokens_to_ids("</think>")
143
+ if think_end_id is not None and think_end_id not in eos_ids:
144
+ eos_ids = eos_ids + [think_end_id]
145
+
146
+ generation_kwargs = dict(
147
+ **inputs,
148
+ max_new_tokens=args.max_new_tokens,
149
+ streamer=streamer,
150
+ eos_token_id=eos_ids if eos_ids else None,
151
+ )
152
+
153
+ generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
154
+ generation_thread.start()
155
+
156
+ completion = []
157
+ first_chunk = True
158
+ seen_end_reasoning = False
159
+ for text in streamer:
160
+ if first_chunk:
161
+ print("<think>", end="", flush=True)
162
+ completion.append("<think>")
163
+ first_chunk = False
164
+ print(text, end="", flush=True)
165
+ completion.append(text)
166
+ if "</think>" in text:
167
+ seen_end_reasoning = True
168
+
169
+ generation_thread.join()
170
+ print("\n\n=== Completed Response ===")
171
+ final_text = "".join(completion).strip()
172
+ print(final_text or "<empty response>")
173
+ if not seen_end_reasoning:
174
+ print("\n[warning] No </think> token detected in streamed output.", flush=True)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()