Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Any, List, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from bokeh.models import ColumnDataSource, HoverTool | |
| from bokeh.palettes import Cividis256 as Pallete | |
| from bokeh.plotting import figure | |
| from bokeh.transform import factor_cmap | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.manifold import TSNE | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| SEED = 0 | |
| def load_model(): | |
| embedder = "distiluse-base-multilingual-cased-v1" | |
| return SentenceTransformer(embedder) | |
| def embed_text(text: List[str]) -> np.ndarray: | |
| embedder_model = load_model() | |
| return embedder_model.encode(text) | |
| def encode_labels(labels: pd.Series) -> pd.Series: | |
| if pd.api.types.is_numeric_dtype(labels): | |
| return labels | |
| return labels.astype("category").cat.codes | |
| def get_tsne_embeddings( | |
| embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED | |
| ) -> np.ndarray: | |
| tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) | |
| return tsne.fit_transform(embeddings) | |
| def draw_interactive_scatter_plot( | |
| texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str | |
| ) -> Any: | |
| # Normalize values to range between 0-255, to assign a color for each value | |
| max_value = values.max() | |
| min_value = values.min() | |
| values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) | |
| values_color_set = sorted(values_color) | |
| values_list = values.astype(str).tolist() | |
| values_set = sorted(values_list) | |
| labels_list = labels.astype(str).tolist() | |
| source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) | |
| hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) | |
| p = figure(plot_width=800, plot_height=800, tools=[hover], title="Embedding Lenses") | |
| p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) | |
| return p | |
| def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int]): | |
| logger.info("Loading dataset in memory") | |
| df = pd.read_csv(tsv, sep="\t") | |
| if label_column not in df.columns: | |
| df[label_column] = 0 | |
| df = df.dropna(subset=[text_column, label_column]) | |
| if sample: | |
| df = df.sample(min(sample, df.shape[0]), random_state=SEED) | |
| logger.info("Embedding sentences") | |
| embeddings = embed_text(df[text_column].values.tolist()) | |
| logger.info("Encoding labels") | |
| encoded_labels = encode_labels(df[label_column]) | |
| logger.info("Running t-SNE") | |
| tsne_embeddings = get_tsne_embeddings(embeddings) | |
| logger.info("Generating figure") | |
| plot = draw_interactive_scatter_plot( | |
| df[text_column].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column | |
| ) | |
| return plot | |
| st.title("Embedding Lenses") | |
| uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"]) | |
| text_column = st.text_input("Text column name", "text") | |
| label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label") | |
| sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000) | |
| if uploaded_file: | |
| plot = generate_plot(uploaded_file, text_column, label_column, sample) | |
| logger.info("Displaying plot") | |
| st.bokeh_chart(plot) | |
| logger.info("Done") | |