ysn-rfd's picture
Update app.py
705a10e verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# مدل مورد نظر: می‌تونی اینو با هر مدل دیگه مثل HooshvareLab/gpt2-fa عوض کنی
model_name = "microsoft/DialoGPT-medium"
# بارگذاری مدل و توکنایزر
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# تاریخچه چت را نگه می‌داریم
chat_history_ids = None
def chat_with_bot(user_input, history=[]):
global chat_history_ids
# توکنایز ورودی
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
# ترکیب با تاریخچه قبلی (context-aware)
bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
# تولید پاسخ
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
# گرفتن پاسخ آخر فقط
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# به تاریخچه اضافه کن
history.append((user_input, response))
return history, history
# رابط Gradio
with gr.Blocks() as demo:
gr.Markdown("## 🤖 Chat with DialoGPT")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Type your message")
clear = gr.Button("Clear")
state = gr.State([])
msg.submit(chat_with_bot, [msg, state], [chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
demo.launch()