|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
MiniMax-M2 Hugging Face checkpoint sanity check with streaming output. |
|
|
|
|
|
Usage: |
|
|
python test_minimax_m2_hf.py \ |
|
|
--model-path /monster/data/model/MiniMax-M2-bf16 \ |
|
|
--question "How many letter A are there in the word Alphabet? Reply with the number only." |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import threading |
|
|
from pathlib import Path |
|
|
|
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.") |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
default="/monster/data/model/MiniMax-M2-bf16", |
|
|
help="Path to the MiniMax-M2 Hugging Face checkpoint directory.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--question", |
|
|
type=str, |
|
|
default="How many letter A are there in the word Alphabet? Reply with the number only.", |
|
|
help="User question to send through the chat template.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-new-tokens", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Maximum number of new tokens to sample from the model.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def build_prompt(tokenizer: AutoTokenizer, question: str) -> str: |
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": question}, |
|
|
] |
|
|
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
model_path = Path(args.model_path).expanduser().resolve() |
|
|
|
|
|
print(f"Loading tokenizer from {model_path}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
print(f"Loading model from {model_path}...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
dtype="bfloat16", |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = build_prompt(tokenizer, args.question) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
print("Running generation (streaming)...\n") |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False) |
|
|
eos_ids = model.generation_config.eos_token_id |
|
|
if eos_ids is None: |
|
|
eos_ids = [] |
|
|
elif isinstance(eos_ids, int): |
|
|
eos_ids = [eos_ids] |
|
|
think_end_id = tokenizer.convert_tokens_to_ids("</think>") |
|
|
if think_end_id is not None and think_end_id not in eos_ids: |
|
|
eos_ids = eos_ids + [think_end_id] |
|
|
|
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
max_new_tokens=args.max_new_tokens, |
|
|
streamer=streamer, |
|
|
eos_token_id=eos_ids if eos_ids else None, |
|
|
) |
|
|
|
|
|
generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
generation_thread.start() |
|
|
|
|
|
completion = [] |
|
|
first_chunk = True |
|
|
seen_end_reasoning = False |
|
|
for text in streamer: |
|
|
if first_chunk: |
|
|
print("<think>", end="", flush=True) |
|
|
completion.append("<think>") |
|
|
first_chunk = False |
|
|
print(text, end="", flush=True) |
|
|
completion.append(text) |
|
|
if "</think>" in text: |
|
|
seen_end_reasoning = True |
|
|
|
|
|
generation_thread.join() |
|
|
print("\n\n=== Completed Response ===") |
|
|
final_text = "".join(completion).strip() |
|
|
print(final_text or "<empty response>") |
|
|
if not seen_end_reasoning: |
|
|
print("\n[warning] No </think> token detected in streamed output.", flush=True) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|