Spaces:
Paused
Paused
| import os, torch, gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B") | |
| REV = os.getenv("REV", None) | |
| print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}") | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV) | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV).to(device).eval() | |
| def check_loss(): | |
| msgs = [ | |
| {"role": "system", "content": "只输出一个数字"}, | |
| {"role": "user", "content": "Compute: 1+1"}, | |
| ] | |
| enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False) | |
| # 保证 dtype / device 正确;attention_mask 用 bool 可兼容 | |
| input_ids = enc["input_ids"].to(device) | |
| attn = enc.get("attention_mask", None) | |
| if attn is not None: | |
| attn = attn.to(device).to(torch.bool) | |
| labels = input_ids.clone() | |
| try: | |
| out = model(input_ids=input_ids, attention_mask=attn, labels=labels) | |
| has_loss = getattr(out, "loss", None) is not None | |
| return f"[CHECK] supports labels->loss? {has_loss} | type={type(out)}" | |
| except Exception as e: | |
| return f"[CHECK] raised: {repr(e)}" | |
| def quick_infer(q: str): | |
| if not q.strip(): | |
| return "" | |
| messages = [{"role": "user", "content": q}] | |
| inputs = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) | |
| input_ids = inputs.input_ids.to(device) | |
| attention_mask = inputs.attention_mask.to(device).to(torch.bool) | |
| with torch.no_grad(): | |
| out = model.diffusion_generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=64, | |
| steps=64, | |
| temperature=0.0, | |
| return_dict_in_generate=True, | |
| ) | |
| text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip() | |
| return text | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Dream Loss Probe\n- 点击 **Run self-check** 看是否有 `loss`\n- 右侧可用 `diffusion_generate` 试跑") | |
| with gr.Row(): | |
| check_btn = gr.Button("Run self-check") | |
| check_out = gr.Textbox(label="Result") | |
| check_btn.click(fn=check_loss, inputs=None, outputs=check_out) | |
| with gr.Row(): | |
| q = gr.Textbox(label="Quick inference prompt", value="Compute: 1+1") | |
| a = gr.Textbox(label="Model output") | |
| run = gr.Button("Generate (diffusion_generate)") | |
| run.click(fn=quick_infer, inputs=q, outputs=a) | |
| if __name__ == "__main__": | |
| demo.launch() | |