hsuwill000's picture
Update app.py
3f48b5b verified
raw
history blame
2.91 kB
import gradio as gr
from transformers import AutoTokenizer
from optimum.intel import OVModelForCausalLM
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")
# 模型與標記器載入
model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
print("Loading model...")
model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# 對話歷史記錄
history = []
# 回應函數
def respond(prompt):
global history # 使用全域變數存 history
# 轉換 history 為 messages 格式
messages = [{"role": "system", "content": "Answer the question in English only."}]
# 加入歷史對話
for user_text, assistant_text in history:
messages.append({"role": "user", "content": user_text})
messages.append({"role": "assistant", "content": assistant_text})
# 加入當前輸入
messages.append({"role": "user", "content": prompt})
# 轉換為 tokenizer 需要的格式
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 進行模型推理
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=4096,
temperature=0.7,
top_p=0.9,
do_sample=True
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
# 更新 history
history.append((prompt, response))
return response
# 清除歷史記錄
def clear_history():
global history
history = []
return "History cleared!"
# Gradio 介面
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")
with gr.Tabs():
with gr.TabItem("聊天"):
chat_if = gr.Interface(
fn=respond,
inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
outputs=gr.Textbox(label="Response", interactive=False),
api_name="hchat",
title="DeepSeek-R1-Distill-Qwen-1.5B-openvino(with history)",
description="回傳輸入內容的測試 API",
)
with gr.Row():
clear_button = gr.Button("🧹 Clear History")
# 點擊按鈕清除 history
clear_button.click(fn=clear_history, inputs=[], outputs=[])
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)