File size: 2,911 Bytes
4d871c7 76cb536 0801ebc 182208d 3f48b5b 4d871c7 3f48b5b 6453441 7160766 8b4afb4 7160766 155b74f 4d871c7 3f48b5b 8b4afb4 3f48b5b 7160766 8b4afb4 a7464e5 155b74f 7160766 8b4afb4 7160766 246dff9 3f48b5b 76da388 7160766 3f48b5b 155b74f 0801ebc 3f48b5b 0801ebc 7fae2e6 0801ebc ec76eef 7fae2e6 7913ae5 3f48b5b 7fae2e6 0801ebc 75d3a03 3f48b5b df8b6cb 3f48b5b 3486524 4d871c7 e5d3a7a 3f48b5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from transformers import AutoTokenizer
from optimum.intel import OVModelForCausalLM
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")
# 模型與標記器載入
model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
print("Loading model...")
model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# 對話歷史記錄
history = []
# 回應函數
def respond(prompt):
global history # 使用全域變數存 history
# 轉換 history 為 messages 格式
messages = [{"role": "system", "content": "Answer the question in English only."}]
# 加入歷史對話
for user_text, assistant_text in history:
messages.append({"role": "user", "content": user_text})
messages.append({"role": "assistant", "content": assistant_text})
# 加入當前輸入
messages.append({"role": "user", "content": prompt})
# 轉換為 tokenizer 需要的格式
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 進行模型推理
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=4096,
temperature=0.7,
top_p=0.9,
do_sample=True
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
# 更新 history
history.append((prompt, response))
return response
# 清除歷史記錄
def clear_history():
global history
history = []
return "History cleared!"
# Gradio 介面
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")
with gr.Tabs():
with gr.TabItem("聊天"):
chat_if = gr.Interface(
fn=respond,
inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
outputs=gr.Textbox(label="Response", interactive=False),
api_name="hchat",
title="DeepSeek-R1-Distill-Qwen-1.5B-openvino(with history)",
description="回傳輸入內容的測試 API",
)
with gr.Row():
clear_button = gr.Button("🧹 Clear History")
# 點擊按鈕清除 history
clear_button.click(fn=clear_history, inputs=[], outputs=[])
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |