Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import pickle | |
| from sentence_transformers import SentenceTransformer, util | |
| import re | |
| mdl_name = 'sentence-transformers/all-distilroberta-v1' | |
| model = SentenceTransformer(mdl_name) | |
| embedding_cache_path = "scotch_embd_distilroberta.pkl" | |
| with open(embedding_cache_path, "rb") as fIn: | |
| cache_data = pickle.load(fIn) | |
| embedding_table = cache_data["embeddings"] | |
| reviews = cache_data["data"] | |
| reviews['price'] = reviews.price.apply(lambda x: re.findall("\d+", x.replace(",","").replace(".00","").replace("$",""))[0]).astype('int') | |
| def user_query_recommend(query, price_rng): | |
| # Embed user query | |
| embedding = model.encode(query) | |
| # Calculate similarity with all reviews | |
| sim_scores = util.cos_sim(embedding, embedding_table) | |
| #print(sim_scores.shape) | |
| # Recommend | |
| recommendations = reviews.copy() | |
| recommendations['sim'] = sim_scores.T | |
| if price_rng == "$0-$70": | |
| min_p, max_p = 0, 70 | |
| if price_rng == "$70-$150": | |
| min_p, max_p = 70, 150 | |
| if price_rng == "$150+": | |
| min_p, max_p = 150, 10000 | |
| op=recommendations\ | |
| .groupby("name")\ | |
| .sim.nlargest(2)\ | |
| .reset_index()\ | |
| [["name","sim"]] | |
| op = pd.merge(op, | |
| recommendations[['name', 'category', 'price', 'description','description_sent','sim']], | |
| how="left",on=["name",'sim']) | |
| op = op.loc[(op.price >= min_p) & (op.price <= max_p), | |
| ['name', 'category', 'price', 'description', 'description_sent','sim']].sort_values('sim',ascending=False)\ | |
| .groupby(['name', 'category', 'price', 'description'])\ | |
| .agg({"description_sent": lambda x: " ".join(x), | |
| "sim":['max']})\ | |
| .reset_index()\ | |
| .set_axis(['name', 'category', 'price', 'description', 'description_sent','sim'],axis="columns") | |
| #op = op.loc[(op.price >= min_p) & (op.price <= max_p), ['name', 'price', 'description_sent']] | |
| return op[['name', 'price', 'description_sent']].reset_index(drop=True).head(6) | |
| interface = gr.Interface( | |
| user_query_recommend, | |
| inputs=[gr.inputs.Textbox(lines=5, label = "enter flavour profile"), | |
| gr.inputs.Radio(choices = ["$0-$70", "$70-$150", "$150+"], default="$0-$70", type="value", label='Price range')], | |
| outputs=gr.outputs.Dataframe(max_rows=3, overflow_row_behaviour="paginate", type="pandas", label="Scotch recommendations"), | |
| title = "Scotch Recommendation", | |
| description = "Looking for scotch recommendations and have some flavours in mind? \nGet recommendations at a preferred price range using semantic search :) ", | |
| examples=[["very sweet with lemons and oranges and marmalades", "$0-$70"], | |
| ["smoky peaty and wood fire","$70-$150"], | |
| ["salty and spicy with exotic fruits", "$150+"], | |
| ["fragrant nose with chocolate, custard, toffee, pudding and caramel", "$70-$150"], | |
| ], | |
| theme="grass", | |
| ) | |
| interface.launch( | |
| enable_queue=True, | |
| #cache_examples=True, | |
| ) |