File size: 6,473 Bytes
1e1c096 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
"""
MiniMax-M2 Hugging Face checkpoint sanity check with streaming output.
Usage:
python test_minimax_m2_hf.py \
--model-path /monster/data/model/MiniMax-M2-bf16 \
--question "How many letter A are there in the word Alphabet? Reply with the number only."
"""
from __future__ import annotations
import argparse
import threading
from pathlib import Path
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# from gptqmodel.hf_minimax_m2.modeling_minimax_m2 import (
# MiniMaxAttention,
# MiniMaxDecoderLayer,
# MiniMaxForCausalLM,
# MiniMaxMLP,
# MiniMaxM2Attention,
# MiniMaxM2DecoderLayer,
# MiniMaxM2ForCausalLM,
# MiniMaxM2MLP,
# MiniMaxM2RMSNorm,
# MiniMaxM2SparseMoeBlock,
# MiniMaxRMSNorm,
# MiniMaxSparseMoeBlock,
# )
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.")
parser.add_argument(
"--model-path",
type=str,
default="/monster/data/model/MiniMax-M2-bf16",
help="Path to the MiniMax-M2 Hugging Face checkpoint directory.",
)
parser.add_argument(
"--question",
type=str,
default="How many letter A are there in the word Alphabet? Reply with the number only.",
help="User question to send through the chat template.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=512,
help="Maximum number of new tokens to sample from the model.",
)
return parser.parse_args()
def build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": question},
]
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
# def assert_module_types(model: MiniMaxM2ForCausalLM) -> None:
# causal_lm_types = (MiniMaxM2ForCausalLM, MiniMaxForCausalLM)
# decoder_layer_types = (MiniMaxM2DecoderLayer, MiniMaxDecoderLayer)
# attention_types = (MiniMaxM2Attention, MiniMaxAttention)
# moe_block_types = (MiniMaxM2SparseMoeBlock, MiniMaxSparseMoeBlock)
# norm_types = (MiniMaxM2RMSNorm, MiniMaxRMSNorm)
# mlp_types = (MiniMaxM2MLP, MiniMaxMLP)
#
# assert isinstance(
# model, causal_lm_types
# ), f"Expected MiniMaxM2ForCausalLM/MiniMaxForCausalLM, received {type(model).__name__}"
#
# decoder = getattr(model, "model", None)
# assert decoder is not None, "Model is missing the `model` attribute with decoder layers."
#
# for layer_idx, layer in enumerate(decoder.layers):
# assert isinstance(
# layer, decoder_layer_types
# ), f"Layer {layer_idx}: expected MiniMax(M2)DecoderLayer, got {type(layer).__name__}"
# assert isinstance(
# layer.self_attn, attention_types
# ), f"Layer {layer_idx}: unexpected self_attn type {type(layer.self_attn).__name__}"
# assert isinstance(
# layer.block_sparse_moe, moe_block_types
# ), f"Layer {layer_idx}: unexpected MoE block type {type(layer.block_sparse_moe).__name__}"
# assert isinstance(
# layer.input_layernorm, norm_types
# ), f"Layer {layer_idx}: unexpected input_layernorm type {type(layer.input_layernorm).__name__}"
# assert isinstance(
# layer.post_attention_layernorm, norm_types
# ), f"Layer {layer_idx}: unexpected post_attention_layernorm type {type(layer.post_attention_layernorm).__name__}"
#
# moe_block = layer.block_sparse_moe
# assert isinstance(
# moe_block.experts, nn.ModuleList
# ), f"Layer {layer_idx}: expected experts to be a ModuleList, got {type(moe_block.experts).__name__}"
# for expert_idx, expert in enumerate(moe_block.experts):
# assert isinstance(
# expert, mlp_types
# ), f"Layer {layer_idx} expert {expert_idx}: expected MiniMax(M2)MLP, got {type(expert).__name__}"
#
def main() -> None:
args = parse_args()
model_path = Path(args.model_path).expanduser().resolve()
print(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print(f"Loading model from {model_path}...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype="bfloat16",
device_map="auto",
trust_remote_code=True,
)
# Uncomment to enforce module type checks.
# print("Validating module types...")
# assert_module_types(model)
prompt = build_prompt(tokenizer, args.question)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print("Running generation (streaming)...\n")
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False)
eos_ids = model.generation_config.eos_token_id
if eos_ids is None:
eos_ids = []
elif isinstance(eos_ids, int):
eos_ids = [eos_ids]
think_end_id = tokenizer.convert_tokens_to_ids("</think>")
if think_end_id is not None and think_end_id not in eos_ids:
eos_ids = eos_ids + [think_end_id]
generation_kwargs = dict(
**inputs,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
eos_token_id=eos_ids if eos_ids else None,
)
generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
generation_thread.start()
completion = []
first_chunk = True
seen_end_reasoning = False
for text in streamer:
if first_chunk:
print("<think>", end="", flush=True)
completion.append("<think>")
first_chunk = False
print(text, end="", flush=True)
completion.append(text)
if "</think>" in text:
seen_end_reasoning = True
generation_thread.join()
print("\n\n=== Completed Response ===")
final_text = "".join(completion).strip()
print(final_text or "<empty response>")
if not seen_end_reasoning:
print("\n[warning] No </think> token detected in streamed output.", flush=True)
if __name__ == "__main__":
main()
|