Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| import numpy as np | |
| import plotly.express as px | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import umap | |
| import pandas as pd | |
| class EmbeddingVisualizer: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(self, model_name): | |
| if self.model is not None: | |
| # Clear CUDA cache if using GPU | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN")) | |
| if "gemma" in model_name: | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN"), torch_dtype=torch.float16) | |
| else: | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.model = self.model.to(self.device) | |
| return f"Loaded model: {model_name}" | |
| def get_embedding(self, text): | |
| if not text.strip(): | |
| return None | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[-1] | |
| mask = inputs["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float() | |
| masked_embeddings = hidden_states * mask | |
| sum_embeddings = torch.sum(masked_embeddings, dim=1) | |
| sum_mask = torch.clamp(torch.sum(mask, dim=1), min=1e-9) | |
| embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy() | |
| return embedding | |
| def calculate_similarity_matrix(self, embeddings): | |
| if not embeddings: | |
| return None | |
| embeddings_np = np.array(embeddings) | |
| return cosine_similarity(embeddings_np) | |
| def reduce_dimensionality(self, embeddings, n_components, method): | |
| # Ensure we have enough samples for the requested components | |
| n_samples = embeddings.shape[0] | |
| # If only one sample, return it repeated to create a visible point | |
| if n_samples == 1: | |
| return np.tile(np.zeros((1, n_components)), (1, 1)) | |
| n_components = min(n_components, n_samples - 1) # Ensure k < N | |
| if method == "pca": | |
| reducer = PCA(n_components=n_components) | |
| elif method == "umap": | |
| # For very small datasets, fall back to PCA | |
| if n_samples < 4: | |
| reducer = PCA(n_components=n_components) | |
| else: | |
| # Adjust parameters based on data size | |
| n_neighbors = min(15, n_samples - 1) # Ensure n_neighbors < n_samples | |
| min_dist = 0.1 if n_samples > 4 else 0.5 # Increase min_dist for small datasets | |
| reducer = umap.UMAP( | |
| n_components=n_components, | |
| n_neighbors=n_neighbors, | |
| min_dist=min_dist, | |
| metric='euclidean', | |
| random_state=42 | |
| ) | |
| else: | |
| raise ValueError("Invalid dimensionality reduction method") | |
| # Convert to dense array if sparse | |
| if hasattr(embeddings, 'toarray'): | |
| embeddings = embeddings.toarray() | |
| return reducer.fit_transform(embeddings) | |
| def visualize_embeddings(self, model_choice, is_3d, | |
| word1, word2, word3, word4, word5, word6, word7, word8, | |
| positive_word1, positive_word2, | |
| negative_word1, negative_word2, | |
| dim_reduction_method): | |
| words = [word1, word2, word3, word4, word5, word6, word7, word8] | |
| words = [w for w in words if w.strip()] | |
| positive_words = [w for w in [positive_word1, positive_word2] if w.strip()] | |
| negative_words = [w for w in [negative_word1, negative_word2] if w.strip()] | |
| embeddings = [] | |
| labels = [] | |
| for word in words: | |
| emb = self.get_embedding(word) | |
| if emb is not None: | |
| embeddings.append(emb) | |
| labels.append(word) | |
| if positive_words or negative_words: | |
| pos_embs = [self.get_embedding(w) for w in positive_words if self.get_embedding(w) is not None] | |
| neg_embs = [self.get_embedding(w) for w in negative_words if self.get_embedding(w) is not None] | |
| if pos_embs or neg_embs: | |
| pos_sum = sum(pos_embs) if pos_embs else 0 | |
| neg_sum = sum(neg_embs) if neg_embs else 0 | |
| arithmetic_emb = pos_sum - neg_sum | |
| embeddings.append(arithmetic_emb) | |
| labels.append("Arithmetic Result") | |
| if not embeddings: | |
| return None | |
| embeddings = np.array(embeddings) | |
| # Reduce dimensionality | |
| if is_3d: | |
| embeddings_reduced = self.reduce_dimensionality(embeddings, 3, dim_reduction_method) | |
| fig = px.scatter_3d(x=embeddings_reduced[:, 0], | |
| y=embeddings_reduced[:, 1], | |
| z=embeddings_reduced[:, 2], | |
| text=labels, | |
| title=f"3D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}") | |
| fig.update_traces(textposition='top center') | |
| return fig | |
| else: | |
| embeddings_reduced = self.reduce_dimensionality(embeddings, 2, dim_reduction_method) | |
| fig = px.scatter(x=embeddings_reduced[:, 0], | |
| y=embeddings_reduced[:, 1], | |
| text=labels, | |
| title=f"2D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}") | |
| fig.update_traces(textposition='top center') | |
| return fig | |
| def visualize_similarity_heatmap(self, model_choice, | |
| word1, word2, word3, word4, word5, word6, word7, word8): | |
| words = [word1, word2, word3, word4, word5, word6, word7, word8] | |
| words = [w for w in words if w.strip()] | |
| embeddings = [self.get_embedding(word) for word in words if self.get_embedding(word) is not None] | |
| if not embeddings: | |
| return None | |
| similarity_matrix = self.calculate_similarity_matrix(embeddings) | |
| if similarity_matrix is None: | |
| return None | |
| fig = plt.figure(figsize=(10, 8)) | |
| ax = fig.add_subplot(111) | |
| cax = ax.matshow(similarity_matrix, interpolation='nearest') | |
| fig.colorbar(cax) | |
| ax.set_xticks(np.arange(len(words))) | |
| ax.set_yticks(np.arange(len(words))) | |
| ax.set_xticklabels(words, rotation=45, ha='left') | |
| ax.set_yticklabels(words) | |
| plt.title(f"Cosine Similarity Heatmap ({model_choice})") | |
| return fig | |
| # Initialize the visualizer | |
| visualizer = EmbeddingVisualizer() | |
| # Create Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Word Embedding Visualization") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_choice = gr.Dropdown( | |
| choices=["google/gemma-2b", "bert-large-uncased"], | |
| value="google/gemma-2b", | |
| label="Select Model" | |
| ) | |
| load_status = gr.Textbox(label="Model Status", interactive=False) | |
| is_3d = gr.Checkbox(label="Use 3D Visualization", value=False) | |
| dim_reduction_method = gr.Radio( | |
| choices=["pca", "umap"], | |
| value="pca", | |
| label="Dimensionality Reduction Method" | |
| ) | |
| with gr.Column(): | |
| word1 = gr.Textbox(label="Word 1") | |
| word2 = gr.Textbox(label="Word 2") | |
| word3 = gr.Textbox(label="Word 3") | |
| word4 = gr.Textbox(label="Word 4") | |
| word5 = gr.Textbox(label="Word 5") | |
| word6 = gr.Textbox(label="Word 6") | |
| word7 = gr.Textbox(label="Word 7") | |
| word8 = gr.Textbox(label="Word 8") | |
| with gr.Column(): | |
| positive_word1 = gr.Textbox(label="Positive Word 1") | |
| positive_word2 = gr.Textbox(label="Positive Word 2") | |
| negative_word1 = gr.Textbox(label="Negative Word 1") | |
| negative_word2 = gr.Textbox(label="Negative Word 2") | |
| with gr.Tabs(): | |
| with gr.Tab("Scatter Plot"): | |
| plot_output = gr.Plot() | |
| with gr.Tab("Similarity Heatmap"): | |
| heatmap_output = gr.Plot() | |
| # Load model when selected | |
| model_choice.change( | |
| fn=visualizer.load_model, | |
| inputs=[model_choice], | |
| outputs=[load_status] | |
| ) | |
| # Update visualization when any input changes | |
| inputs = [ | |
| model_choice, is_3d, | |
| word1, word2, word3, word4, word5, word6, word7, word8, | |
| positive_word1, positive_word2, | |
| negative_word1, negative_word2, | |
| dim_reduction_method | |
| ] | |
| for input_component in inputs: | |
| input_component.change( | |
| fn=visualizer.visualize_embeddings, | |
| inputs=inputs, | |
| outputs=[plot_output] | |
| ) | |
| similarity_inputs = [model_choice, | |
| word1, word2, word3, word4, word5, word6, word7, word8] | |
| for input_component in similarity_inputs: | |
| input_component.change( | |
| fn=visualizer.visualize_similarity_heatmap, | |
| inputs=similarity_inputs, | |
| outputs=[heatmap_output] | |
| ) | |
| # Add Clear All button | |
| clear_button = gr.Button("Clear All") | |
| def clear_all(): | |
| return [""] * 12 # Returns empty strings for the 12 text input components | |
| clear_button.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[word1, word2, word3, word4, word5, word6, word7, word8, | |
| positive_word1, positive_word2, | |
| negative_word1, negative_word2] | |
| ) | |
| if __name__ == "__main__": | |
| # Load initial model | |
| visualizer.load_model("google/gemma-2b") | |
| iface.launch() | |