dream-s1k-demo / loss_probe.py
况兑
fix: cast attention_mask to bool to satisfy Dream forward/generate expectations
a769d64
import os, torch, gradio as gr
from transformers import AutoModel, AutoTokenizer
MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B")
REV = os.getenv("REV", None)
print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}")
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV).to(device).eval()
def check_loss():
msgs = [
{"role": "system", "content": "只输出一个数字"},
{"role": "user", "content": "Compute: 1+1"},
]
enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False)
# 保证 dtype / device 正确;attention_mask 用 bool 可兼容
input_ids = enc["input_ids"].to(device)
attn = enc.get("attention_mask", None)
if attn is not None:
attn = attn.to(device).to(torch.bool)
labels = input_ids.clone()
try:
out = model(input_ids=input_ids, attention_mask=attn, labels=labels)
has_loss = getattr(out, "loss", None) is not None
return f"[CHECK] supports labels->loss? {has_loss} | type={type(out)}"
except Exception as e:
return f"[CHECK] raised: {repr(e)}"
def quick_infer(q: str):
if not q.strip():
return ""
messages = [{"role": "user", "content": q}]
inputs = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device).to(torch.bool)
with torch.no_grad():
out = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=64,
steps=64,
temperature=0.0,
return_dict_in_generate=True,
)
text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
return text
with gr.Blocks() as demo:
gr.Markdown("## Dream Loss Probe\n- 点击 **Run self-check** 看是否有 `loss`\n- 右侧可用 `diffusion_generate` 试跑")
with gr.Row():
check_btn = gr.Button("Run self-check")
check_out = gr.Textbox(label="Result")
check_btn.click(fn=check_loss, inputs=None, outputs=check_out)
with gr.Row():
q = gr.Textbox(label="Quick inference prompt", value="Compute: 1+1")
a = gr.Textbox(label="Model output")
run = gr.Button("Generate (diffusion_generate)")
run.click(fn=quick_infer, inputs=q, outputs=a)
if __name__ == "__main__":
demo.launch()