Spaces:
Runtime error
Runtime error
| from statistics import mean | |
| import random | |
| import torch | |
| from transformers import BertModel, BertTokenizerFast | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| threshold = 0.4 | |
| tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE") | |
| model = BertModel.from_pretrained("setu4993/LaBSE") | |
| model = model.eval() | |
| order_food_ex = [ | |
| "food", | |
| "I am hungry, I want to order food", | |
| "How do I order food", | |
| "What are the food options", | |
| "I need dinner", | |
| "I want lunch", | |
| "What are the menu options", | |
| "I want a hamburger" | |
| ] | |
| talk_to_human_ex = [ | |
| "I need to talk to someone", | |
| "Connect me with a human", | |
| "I need to speak with a person", | |
| "Put me on with a human", | |
| "Connect me with customer service", | |
| "human" | |
| ] | |
| def embed(text, tokenizer, model): | |
| inputs = tokenizer(text, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return outputs.pooler_output | |
| def similarity(embeddings_1, embeddings_2): | |
| normalized_embeddings_1 = F.normalize(embeddings_1, p=2) | |
| normalized_embeddings_2 = F.normalize(embeddings_2, p=2) | |
| return torch.matmul( | |
| normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1) | |
| ) | |
| order_food_embed = [embed(x, tokenizer, model) for x in order_food_ex] | |
| talk_to_human_embed = [embed(x, tokenizer, model) for x in talk_to_human_ex] | |
| def chat(message, history): | |
| history = history or [] | |
| message_embed = embed(message, tokenizer, model) | |
| order_sim = [] | |
| for em in order_food_embed: | |
| order_sim.append(float(similarity(em, message_embed))) | |
| human_sim = [] | |
| for em in talk_to_human_embed: | |
| human_sim.append(float(similarity(em, message_embed))) | |
| if mean(order_sim) > threshold: | |
| response = random.choice([ | |
| "We have hamburgers or pizza! Which one do you want?", | |
| "Do you want a hamburger or a pizza?"]) | |
| elif mean(human_sim) > threshold: | |
| response = random.choice([ | |
| "Sure, a customer service agent will jump into this convo shortly!", | |
| "No problem. Let me forward on this conversation to a person that can respond."]) | |
| else: | |
| response = "Sorry, I didn't catch that. Could your rephrase?" | |
| history.append((message, response)) | |
| return history, history | |
| iface = gr.Interface( | |
| chat, | |
| ["text", "state"], | |
| ["chatbot", "state"], | |
| allow_screenshot=False, | |
| allow_flagging="never", | |
| ) | |
| iface.launch() | |