Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pickle | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import pandas as pd | |
| df = pd.read_csv('data/cleaned_df.csv') | |
| labse_model = SentenceTransformer('sentence-transformers/LaBSE') | |
| distilbert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-multilingual-cased') | |
| distilbert_model = AutoModel.from_pretrained('distilbert-base-multilingual-cased') | |
| tiny2_tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2') | |
| tiny2_model = AutoModel.from_pretrained('cointegrated/rubert-tiny2') | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| distilbert_model.to(device) | |
| tiny2_model.to(device) | |
| with open('models/dashas/labse_index.pkl', 'rb') as f: | |
| labse_index = pickle.load(f) | |
| with open('models/dashas/distilbert_index.pkl', 'rb') as f: | |
| distilbert_index = pickle.load(f) | |
| with open('models/dashas/tiny2_index.pkl', 'rb') as f: | |
| tiny2_index = pickle.load(f) | |
| def search_series(query, model, tokenizer=None, index=None, top_k=5): | |
| if tokenizer: | |
| inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| else: | |
| query_embedding = model.encode([query]) | |
| distances, indices = index.search(query_embedding, top_k) | |
| results = df.iloc[indices[0]] | |
| return results | |
| st.title("Умный поиск сериалов") | |
| st.image("images/logo_1.jpeg", width=800) # Add your logo here | |
| query = st.text_input("Введите запрос:") | |
| model_choice = st.selectbox("Выберите модель:", ["LaBSE", "DistilBERT", "tiny2"]) | |
| top_k = st.slider("Количество результатов:", min_value=1, max_value=20, value=5) | |
| if st.button("Найти"): | |
| if query: | |
| if model_choice == "LaBSE": | |
| results = search_series(query, labse_model, index=labse_index, top_k=top_k) | |
| elif model_choice == "DistilBERT": | |
| results = search_series(query, distilbert_model, distilbert_tokenizer, distilbert_index, top_k=top_k) | |
| elif model_choice == "tiny2": | |
| results = search_series(query, tiny2_model, tiny2_tokenizer, tiny2_index, top_k=top_k) | |
| st.write("Результаты поиска:") | |
| for i, row in results.iterrows(): | |
| st.write(f"**{row['title']}**") | |
| st.write(row['description']) | |
| st.image(row['image_url'], width=600) | |
| else: | |
| st.write("Пожалуйста, введите запрос.") | |