File size: 2,725 Bytes
48dfdf0
118bf1e
 
48dfdf0
 
 
 
 
 
 
 
 
118bf1e
48dfdf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import os
os.environ["UNSLOTH_DEVICE"] = "cuda"
from unsloth import FastLanguageModel
import torch
HF_TOKEN = os.environ["HF_TOKEN"]
# -------------------- Load Model --------------------
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="ak0601/gpt-oss-20b-persona-chat",  # your trained model
    max_seq_length=1024,
    dtype=None,
    load_in_4bit=True,
    device_map="auto",
    token=HF_TOKEN
)

# -------------------- Conversation Formatter --------------------
def format_conversation(conversation):
    text = ""
    for turn in conversation:
        if turn["role"] == "system":
            text += f"[SYSTEM] {turn['content']}\n"
        elif turn["role"] == "user":
            text += f"[USER] {turn['content']}\n"
        elif turn["role"] == "assistant":
            text += f"[ASSISTANT] {turn['content']}\n"
    text += "[ASSISTANT]"
    return text

def generate_reply(conversation):
    inputs = tokenizer(
        format_conversation(conversation),
        return_tensors="pt"
    ).to(model.device)

    output_ids = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        eos_token_id=tokenizer.eos_token_id,
    )

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    response = response.split("[ASSISTANT]")[-1].strip()
    return response

# -------------------- Gradio Functions --------------------
def start_chat(persona):
    conversation = [
    {"role": "system", "content": f"""You are a digital twin.
ONLY respond based on persona and user input.
\nPersona: {persona}"""},
]
    return conversation, [(None, "How can I help you?")]

def chat(user_message, history, conversation):
    conversation.append({"role": "user", "content": user_message})
    reply = generate_reply(conversation)
    conversation.append({"role": "assistant", "content": reply})
    history.append((user_message, reply))
    return history, conversation

# -------------------- Gradio UI --------------------
with gr.Blocks() as demo:
    gr.Markdown("## 🤖 Digital Twin Chat")

    persona_box = gr.Textbox(label="Enter your persona",
                             value="I am male. I am unsociable. I have a weakness for sweets. I am a jack of all, master of none.")
    start_btn = gr.Button("Start Chat")

    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Your message")

    state_conversation = gr.State([])
    state_history = gr.State([])

    start_btn.click(start_chat, inputs=persona_box, outputs=[state_conversation, chatbot])
    msg.submit(chat, inputs=[msg, chatbot, state_conversation], outputs=[chatbot, state_conversation])

demo.launch()