Spaces:
Runtime error
Runtime error
| # app.py | |
| import gradio as gr | |
| import spaces | |
| from threading import Thread | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| ) | |
| # ------------------------------ | |
| # 1. 加载模型与 Tokenizer | |
| # ------------------------------ | |
| model_name = "agentica-org/DeepScaleR-1.5B-Preview" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
| # 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # ------------------------------ | |
| # 2. 对话历史 -> Prompt 格式 | |
| # ------------------------------ | |
| def preprocess_messages(history): | |
| """ | |
| 将聊天记录拼成一个最简单的 Prompt。 | |
| 你可以自定义更适合该模型的提示格式或特殊 Token。 | |
| """ | |
| prompt = "" | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| prompt += f"User: {user_msg}\n" | |
| if assistant_msg: | |
| prompt += f"Assistant: {assistant_msg}\n" | |
| # 继续生成时,提示 "Assistant:" | |
| prompt += "Assistant: " | |
| return prompt | |
| # ------------------------------ | |
| # 3. 预测 / 推理函数 | |
| # ------------------------------ | |
| # 让 huggingface spaces 调用 GPU | |
| def predict(history, max_length, top_p, temperature): | |
| """ | |
| 基于当前的 history 做文本生成。 | |
| 使用 HF 提供的 TextIteratorStreamer 实现流式生成。 | |
| """ | |
| prompt = preprocess_messages(history) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, # 自动 padding | |
| truncation=True, # 超长截断 | |
| max_length=2048 # 你可根据显存大小或模型上限做调整 | |
| ) | |
| input_ids = inputs["input_ids"].to(model.device) | |
| attention_mask = inputs["attention_mask"].to(model.device) | |
| # 流式输出器 | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, | |
| timeout=60, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generate_kwargs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "max_new_tokens": max_length, # 新生成的 token 数 | |
| "do_sample": True, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "repetition_penalty": 1.2, | |
| "streamer": streamer, | |
| } | |
| # 在后台线程中执行 generate,主线程循环读取新 token | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # 将最新生成的 token 依次拼接到 history[-1][1] | |
| partial_output = "" | |
| for new_token in streamer: | |
| partial_output += new_token | |
| history[-1][1] = partial_output | |
| yield history | |
| # ------------------------------ | |
| # 4. Gradio UI | |
| # ------------------------------ | |
| def main(): | |
| with gr.Blocks() as demo: | |
| gr.HTML("<h1 align='center'>DeepScaleR-1.5B Chat Demo</h1>") | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| user_input = gr.Textbox( | |
| show_label=True, | |
| placeholder="请输入您的问题...", | |
| label="User Input" | |
| ) | |
| submitBtn = gr.Button("Submit") | |
| clearBtn = gr.Button("Clear History") | |
| with gr.Column(scale=1): | |
| max_length = gr.Slider( | |
| minimum=0, | |
| maximum=1024, # 可根据需要调大/调小 | |
| value=512, | |
| step=1, | |
| label="Max New Tokens", | |
| interactive=True | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.8, | |
| step=0.01, | |
| label="Top P", | |
| interactive=True | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.01, | |
| label="Temperature", | |
| interactive=True | |
| ) | |
| # 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成 | |
| def user(query, history): | |
| return "", history + [[query, ""]] | |
| submitBtn.click( | |
| fn=user, | |
| inputs=[user_input, chatbot], | |
| outputs=[user_input, chatbot], | |
| queue=False # 不排队 | |
| ).then( | |
| fn=predict, | |
| inputs=[chatbot, max_length, top_p, temperature], | |
| outputs=chatbot | |
| ) | |
| # 清空聊天记录 | |
| def clear_history(): | |
| return [], [] | |
| clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False) | |
| # 可选:启用队列防止并发冲突 | |
| demo.queue(concurrency_count=1) | |
| demo.launch() | |
| # ------------------------------ | |
| # 入口 | |
| # ------------------------------ | |
| if __name__ == "__main__": | |
| main() | |