Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| os.environ["UNSLOTH_DEVICE"] = "cuda" | |
| from unsloth import FastLanguageModel | |
| import torch | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| # -------------------- Load Model -------------------- | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name="ak0601/gpt-oss-20b-persona-chat", # your trained model | |
| max_seq_length=1024, | |
| dtype=None, | |
| load_in_4bit=True, | |
| device_map="auto", | |
| token=HF_TOKEN | |
| ) | |
| # -------------------- Conversation Formatter -------------------- | |
| def format_conversation(conversation): | |
| text = "" | |
| for turn in conversation: | |
| if turn["role"] == "system": | |
| text += f"[SYSTEM] {turn['content']}\n" | |
| elif turn["role"] == "user": | |
| text += f"[USER] {turn['content']}\n" | |
| elif turn["role"] == "assistant": | |
| text += f"[ASSISTANT] {turn['content']}\n" | |
| text += "[ASSISTANT]" | |
| return text | |
| def generate_reply(conversation): | |
| inputs = tokenizer( | |
| format_conversation(conversation), | |
| return_tensors="pt" | |
| ).to(model.device) | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| response = response.split("[ASSISTANT]")[-1].strip() | |
| return response | |
| # -------------------- Gradio Functions -------------------- | |
| def start_chat(persona): | |
| conversation = [ | |
| {"role": "system", "content": f"""You are a digital twin. | |
| ONLY respond based on persona and user input. | |
| \nPersona: {persona}"""}, | |
| ] | |
| return conversation, [(None, "How can I help you?")] | |
| def chat(user_message, history, conversation): | |
| conversation.append({"role": "user", "content": user_message}) | |
| reply = generate_reply(conversation) | |
| conversation.append({"role": "assistant", "content": reply}) | |
| history.append((user_message, reply)) | |
| return history, conversation | |
| # -------------------- Gradio UI -------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π€ Digital Twin Chat") | |
| persona_box = gr.Textbox(label="Enter your persona", | |
| value="I am male. I am unsociable. I have a weakness for sweets. I am a jack of all, master of none.") | |
| start_btn = gr.Button("Start Chat") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(label="Your message") | |
| state_conversation = gr.State([]) | |
| state_history = gr.State([]) | |
| start_btn.click(start_chat, inputs=persona_box, outputs=[state_conversation, chatbot]) | |
| msg.submit(chat, inputs=[msg, chatbot, state_conversation], outputs=[chatbot, state_conversation]) | |
| demo.launch() | |