File size: 2,911 Bytes
4d871c7
76cb536
0801ebc
182208d
 
3f48b5b
4d871c7
3f48b5b
6453441
7160766
8b4afb4
7160766
155b74f
4d871c7
3f48b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b4afb4
 
 
 
 
3f48b5b
 
7160766
8b4afb4
 
a7464e5
155b74f
 
 
7160766
8b4afb4
 
 
7160766
246dff9
3f48b5b
 
 
 
76da388
7160766
3f48b5b
 
 
 
 
 
 
155b74f
0801ebc
3f48b5b
0801ebc
 
7fae2e6
0801ebc
ec76eef
7fae2e6
7913ae5
3f48b5b
7fae2e6
0801ebc
75d3a03
3f48b5b
 
df8b6cb
3f48b5b
 
3486524
4d871c7
e5d3a7a
3f48b5b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from transformers import AutoTokenizer
from optimum.intel import OVModelForCausalLM
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")

# 模型與標記器載入
model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
print("Loading model...")
model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# 對話歷史記錄
history = []

# 回應函數
def respond(prompt):
    global history  # 使用全域變數存 history
    
    # 轉換 history 為 messages 格式
    messages = [{"role": "system", "content": "Answer the question in English only."}]
    
    # 加入歷史對話
    for user_text, assistant_text in history:
        messages.append({"role": "user", "content": user_text})
        messages.append({"role": "assistant", "content": assistant_text})

    # 加入當前輸入
    messages.append({"role": "user", "content": prompt})
    
    # 轉換為 tokenizer 需要的格式
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # 進行模型推理
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=4096,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()

    # 更新 history
    history.append((prompt, response))

    return response

# 清除歷史記錄
def clear_history():
    global history
    history = []
    return "History cleared!"

# Gradio 介面
with gr.Blocks() as demo:
    gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")

    with gr.Tabs():
        with gr.TabItem("聊天"):
            chat_if = gr.Interface(
                fn=respond,
                inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
                outputs=gr.Textbox(label="Response", interactive=False),
                api_name="hchat",
                title="DeepSeek-R1-Distill-Qwen-1.5B-openvino(with history)",
                description="回傳輸入內容的測試 API",
            )

    with gr.Row():
        clear_button = gr.Button("🧹 Clear History")
    
    # 點擊按鈕清除 history
    clear_button.click(fn=clear_history, inputs=[], outputs=[])

if __name__ == "__main__":
    print("Launching Gradio app...")
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)