Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import logging | |
| import os | |
| import datamapplot | |
| import duckdb | |
| import numpy as np | |
| import requests | |
| from dotenv import load_dotenv | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from bertopic import BERTopic | |
| from bertopic.representation import KeyBERTInspired | |
| from bertopic.representation import TextGeneration | |
| from huggingface_hub import HfApi, SpaceCard | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sentence_transformers import SentenceTransformer | |
| from templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT | |
| from torch import cuda, bfloat16 | |
| from transformers import ( | |
| BitsAndBytesConfig, | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| pipeline, | |
| ) | |
| """ | |
| TODOs: | |
| - Improve representation layer (Try with llamacpp or TextGeneration) | |
| - Make it run on Zero GPU | |
| - Try with more rows (Current: 50_000/10_000 -> Minimal Targett: 1_000_000/20_000) | |
| - Export interactive plots and serve their HTML content (It doesn't work with gr.HTML) | |
| """ | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" | |
| EXPORTS_REPOSITORY = os.getenv("EXPORTS_REPOSITORY") | |
| assert ( | |
| EXPORTS_REPOSITORY is not None | |
| ), "You need to set EXPORTS_REPOSITORY in your environment variables" | |
| MAX_ROWS = int(os.getenv("MAX_ROWS", "8_000")) | |
| CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000")) | |
| DATASET_VIEWE_API_URL = "https://datasets-server.huggingface.co/" | |
| DATASETS_TOPICS_ORGANIZATION = os.getenv( | |
| "DATASETS_TOPICS_ORGANIZATION", "datasets-topics" | |
| ) | |
| USE_ARROW_STYLE = int(os.getenv("USE_ARROW_STYLE", "0")) | |
| USE_CUML = int(os.getenv("USE_CUML", "0")) | |
| if USE_CUML: | |
| from cuml.manifold import UMAP | |
| from cuml.cluster import HDBSCAN | |
| else: | |
| from umap import UMAP | |
| from hdbscan import HDBSCAN | |
| USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1")) | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| api = HfApi(token=HF_TOKEN) | |
| session = requests.Session() | |
| sentence_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # Representation model | |
| if USE_LLM_TEXT_GENERATION: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=bfloat16, | |
| ) | |
| model_id = "meta-llama/Llama-2-7b-chat-hf" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| generator = pipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| task="text-generation", | |
| temperature=0.1, | |
| max_new_tokens=500, | |
| repetition_penalty=1.1, | |
| ) | |
| representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT) | |
| else: | |
| representation_model = KeyBERTInspired() | |
| vectorizer_model = CountVectorizer(stop_words="english") | |
| def get_split_rows(dataset, config, split): | |
| config_size = session.get( | |
| f"{DATASET_VIEWE_API_URL}/size?dataset={dataset}&config={config}", | |
| timeout=20, | |
| ).json() | |
| if "error" in config_size: | |
| raise Exception(f"Error fetching config size: {config_size['error']}") | |
| split_size = next( | |
| (s for s in config_size["size"]["splits"] if s["split"] == split), | |
| None, | |
| ) | |
| if split_size is None: | |
| raise Exception(f"Error fetching split {split} in config {config}") | |
| return split_size["num_rows"] | |
| def get_parquet_urls(dataset, config, split): | |
| parquet_files = session.get( | |
| f"{DATASET_VIEWE_API_URL}/parquet?dataset={dataset}&config={config}&split={split}", | |
| timeout=20, | |
| ).json() | |
| if "error" in parquet_files: | |
| raise Exception(f"Error fetching parquet files: {parquet_files['error']}") | |
| parquet_urls = [file["url"] for file in parquet_files["parquet_files"]] | |
| logging.debug(f"Parquet files: {parquet_urls}") | |
| return ",".join(f"'{url}'" for url in parquet_urls) | |
| def get_docs_from_parquet(parquet_urls, column, offset, limit): | |
| SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};" | |
| df = duckdb.sql(SQL_QUERY).to_df() | |
| return df[column].tolist() | |
| # @spaces.GPU | |
| def calculate_embeddings(docs): | |
| return sentence_model.encode(docs, show_progress_bar=True, batch_size=32) | |
| def calculate_n_neighbors_and_components(n_rows): | |
| n_neighbors = min(max(n_rows // 20, 15), 100) | |
| n_components = 10 if n_rows > 1000 else 5 # Higher components for larger datasets | |
| return n_neighbors, n_components | |
| # @spaces.GPU | |
| def fit_model(docs, embeddings, n_neighbors, n_components): | |
| umap_model = UMAP( | |
| n_neighbors=n_neighbors, | |
| n_components=n_components, | |
| min_dist=0.0, | |
| metric="cosine", | |
| random_state=42, | |
| ) | |
| hdbscan_model = HDBSCAN( | |
| min_cluster_size=max( | |
| 5, n_neighbors // 2 | |
| ), # Reducing min_cluster_size for fewer outliers | |
| metric="euclidean", | |
| cluster_selection_method="eom", | |
| prediction_data=True, | |
| ) | |
| new_model = BERTopic( | |
| language="english", | |
| # Sub-models | |
| embedding_model=sentence_model, # Step 1 - Extract embeddings | |
| umap_model=umap_model, # Step 2 - UMAP model | |
| hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings | |
| vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics | |
| representation_model=representation_model, # Step 5 - Label topics | |
| # Hyperparameters | |
| top_n_words=10, | |
| verbose=True, | |
| min_topic_size=n_neighbors, # Coherent with n_neighbors? | |
| ) | |
| logging.info("Fitting new model") | |
| new_model.fit(docs, embeddings) | |
| logging.info("End fitting new model") | |
| return new_model | |
| def _push_to_hub( | |
| dataset_id, | |
| file_path, | |
| ): | |
| logging.info(f"Pushing file to hub: {dataset_id} on file {file_path}") | |
| file_name = file_path.split("/")[-1] | |
| try: | |
| logging.info(f"About to push {file_path} - {dataset_id}") | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=file_name, | |
| repo_id=EXPORTS_REPOSITORY, | |
| repo_type="dataset", | |
| ) | |
| except Exception as e: | |
| logging.info("Failed to push file", e) | |
| raise | |
| def create_space_with_content(dataset_id, html_file_path): | |
| repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_id.replace('/', '-')}" | |
| logging.info(f"Creating space with content: {repo_id} on file {html_file_path}") | |
| api.create_repo( | |
| repo_id=repo_id, | |
| repo_type="space", | |
| private=False, | |
| exist_ok=True, | |
| token=HF_TOKEN, | |
| space_sdk="static", | |
| ) | |
| SpaceCard( | |
| content=SPACE_REPO_CARD_CONTENT.format(dataset_id=dataset_id) | |
| ).push_to_hub(repo_id=repo_id, repo_type="space", token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=html_file_path, | |
| path_in_repo="index.html", | |
| repo_type="space", | |
| repo_id=repo_id, | |
| token=HF_TOKEN, | |
| ) | |
| logging.info(f"Space creation done") | |
| return repo_id | |
| def generate_topics(dataset, config, split, column, nested_column, plot_type): | |
| logging.info( | |
| f"Generating topics for {dataset=} {config=} {split=} {column=} {nested_column=} {plot_type=}" | |
| ) | |
| parquet_urls = get_parquet_urls(dataset, config, split) | |
| split_rows = get_split_rows(dataset, config, split) | |
| logging.info(f"Split rows: {split_rows}") | |
| limit = min(split_rows, MAX_ROWS) | |
| n_neighbors, n_components = calculate_n_neighbors_and_components(limit) | |
| reduce_umap_model = UMAP( | |
| n_neighbors=n_neighbors, | |
| n_components=2, # For visualization, keeping it for 2D | |
| min_dist=0.0, | |
| metric="cosine", | |
| random_state=42, | |
| ) | |
| offset = 0 | |
| rows_processed = 0 | |
| base_model = None | |
| all_docs = [] | |
| reduced_embeddings_list = [] | |
| topics_info, topic_plot = None, None | |
| full_processing = split_rows <= MAX_ROWS | |
| message = ( | |
| f"⚙️ Processing full dataset: 0 of ({split_rows} rows)" | |
| if full_processing | |
| else f"⚙️ Processing partial dataset 0 of ({limit} rows)" | |
| ) | |
| yield ( | |
| gr.Accordion(open=False), | |
| gr.DataFrame(value=[], interactive=False, visible=True), | |
| gr.Plot(value=None, visible=True), | |
| gr.Label({message: rows_processed / limit}, visible=True), | |
| "", | |
| "", | |
| ) | |
| while offset < limit: | |
| docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE) | |
| if not docs: | |
| break | |
| logging.info( | |
| f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs" | |
| ) | |
| embeddings = calculate_embeddings(docs) | |
| new_model = fit_model(docs, embeddings, n_neighbors, n_components) | |
| if base_model is None: | |
| base_model = new_model | |
| else: | |
| updated_model = BERTopic.merge_models([base_model, new_model]) | |
| nr_new_topics = len(set(updated_model.topics_)) - len( | |
| set(base_model.topics_) | |
| ) | |
| new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:] | |
| logging.info(f"The following topics are newly found: {new_topics}") | |
| base_model = updated_model | |
| reduced_embeddings = reduce_umap_model.fit_transform(embeddings) | |
| reduced_embeddings_list.append(reduced_embeddings) | |
| all_docs.extend(docs) | |
| reduced_embeddings_array = np.vstack(reduced_embeddings_list) | |
| topics_info = base_model.get_topic_info() | |
| all_topics, _ = base_model.transform(all_docs) | |
| all_topics = np.array(all_topics) | |
| sub_title = ( | |
| f"Data map for the entire dataset ({limit} rows) using the column '{column}'" | |
| if full_processing | |
| else f"Data map for a sample of the dataset (first {limit} rows) using the column '{column}'" | |
| ) | |
| topic_plot = ( | |
| base_model.visualize_document_datamap( | |
| docs=all_docs, | |
| reduced_embeddings=reduced_embeddings_array, | |
| title=dataset, | |
| sub_title=sub_title, | |
| width=800, | |
| height=700, | |
| arrowprops={ | |
| "arrowstyle": "wedge,tail_width=0.5", | |
| "connectionstyle": "arc3,rad=0.05", | |
| "linewidth": 0, | |
| "fc": "#33333377", | |
| }, | |
| dynamic_label_size=True, | |
| label_wrap_width=12, | |
| label_over_points=True, | |
| max_font_size=36, | |
| min_font_size=4, | |
| ) | |
| if plot_type == "DataMapPlot" | |
| else base_model.visualize_documents( | |
| docs=all_docs, | |
| reduced_embeddings=reduced_embeddings_array, | |
| custom_labels=True, | |
| title=dataset, | |
| ) | |
| ) | |
| rows_processed += len(docs) | |
| progress = min(rows_processed / limit, 1.0) | |
| logging.info(f"Progress: {progress} % - {rows_processed} of {limit}") | |
| message = ( | |
| f"⚙️ Processing full dataset: {rows_processed} of {limit}" | |
| if full_processing | |
| else f"⚙️ Processing partial dataset: {rows_processed} of {limit} rows" | |
| ) | |
| yield ( | |
| gr.Accordion(open=False), | |
| topics_info, | |
| topic_plot, | |
| gr.Label({message: progress}, visible=True), | |
| "", | |
| "", | |
| ) | |
| offset += CHUNK_SIZE | |
| logging.info("Finished processing all data") | |
| plot_png = f"{dataset.replace('/', '-')}-{plot_type.lower()}.png" | |
| if plot_type == "DataMapPlot": | |
| topic_plot.savefig(plot_png, format="png", dpi=300) | |
| else: | |
| topic_plot.write_image(plot_png) | |
| _push_to_hub(dataset, plot_png) | |
| all_topics, _ = base_model.transform(all_docs) | |
| topic_info = base_model.get_topic_info() | |
| topic_names = {row["Topic"]: row["Name"] for index, row in topic_info.iterrows()} | |
| topic_names_array = np.array( | |
| [ | |
| topic_names.get(topic, "No Topic").split("_")[1].strip("-") | |
| for topic in all_topics | |
| ] | |
| ) | |
| dataset_clear_name = dataset.replace("/", "-") | |
| interactive_plot = datamapplot.create_interactive_plot( | |
| reduced_embeddings_array, | |
| topic_names_array, | |
| hover_text=all_docs, | |
| title=dataset, | |
| sub_title=sub_title.replace( | |
| "dataset", | |
| f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>", | |
| ), | |
| enable_search=True, | |
| # TODO: Export data to .arrow and also serve it | |
| inline_data=True, | |
| # offline_data_prefix=dataset_clear_name, | |
| initial_zoom_fraction=0.8, | |
| ) | |
| html_content = str(interactive_plot) | |
| html_file_path = f"{dataset_clear_name}.html" | |
| with open(html_file_path, "w", encoding="utf-8") as html_file: | |
| html_file.write(html_content) | |
| space_id = create_space_with_content(dataset, html_file_path) | |
| plot_png_link = ( | |
| f"https://huggingface.co/datasets/{EXPORTS_REPOSITORY}/blob/main/{plot_png}" | |
| ) | |
| space_link = f"https://huggingface.co/spaces/{space_id}" | |
| yield ( | |
| gr.Accordion(open=False), | |
| topics_info, | |
| topic_plot, | |
| gr.Label( | |
| {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True | |
| ), | |
| f"[]({plot_png_link})", | |
| f"[]({space_link})", | |
| ) | |
| cuda.empty_cache() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 💠 Dataset Topic Discovery 🔭") | |
| gr.Markdown("## Select dataset and text column") | |
| data_details_accordion = gr.Accordion("Data details", open=True) | |
| with data_details_accordion: | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| dataset_name = HuggingfaceHubSearch( | |
| label="Hub Dataset ID", | |
| placeholder="Search for dataset id on Huggingface", | |
| search_type="dataset", | |
| ) | |
| subset_dropdown = gr.Dropdown(label="Subset", visible=False) | |
| split_dropdown = gr.Dropdown(label="Split", visible=False) | |
| with gr.Accordion("Dataset preview", open=False): | |
| def embed(name, subset, split): | |
| html_code = f""" | |
| <iframe | |
| src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}" | |
| frameborder="0" | |
| width="100%" | |
| height="600px" | |
| ></iframe> | |
| """ | |
| return gr.HTML(value=html_code) | |
| with gr.Row(): | |
| text_column_dropdown = gr.Dropdown(label="Text column name") | |
| nested_text_column_dropdown = gr.Dropdown( | |
| label="Nested text column name", visible=False | |
| ) | |
| plot_type_radio = gr.Radio( | |
| ["DataMapPlot", "Plotly"], | |
| value="DataMapPlot", | |
| label="Choose the plot type", | |
| interactive=True, | |
| ) | |
| generate_button = gr.Button("Generate Topics", variant="primary") | |
| gr.Markdown("## Data map") | |
| full_topics_generation_label = gr.Label(visible=False, show_label=False) | |
| with gr.Row(): | |
| open_png_label = gr.Markdown() | |
| open_space_label = gr.Markdown() | |
| topics_plot = gr.Plot() | |
| with gr.Accordion("Topics Info", open=False): | |
| topics_df = gr.DataFrame(interactive=False, visible=True) | |
| generate_button.click( | |
| generate_topics, | |
| inputs=[ | |
| dataset_name, | |
| subset_dropdown, | |
| split_dropdown, | |
| text_column_dropdown, | |
| nested_text_column_dropdown, | |
| plot_type_radio, | |
| ], | |
| outputs=[ | |
| data_details_accordion, | |
| topics_df, | |
| topics_plot, | |
| full_topics_generation_label, | |
| open_png_label, | |
| open_space_label, | |
| ], | |
| ) | |
| def _resolve_dataset_selection( | |
| dataset: str, default_subset: str, default_split: str, text_feature | |
| ): | |
| if "/" not in dataset.strip().strip("/"): | |
| return { | |
| subset_dropdown: gr.Dropdown(visible=False), | |
| split_dropdown: gr.Dropdown(visible=False), | |
| text_column_dropdown: gr.Dropdown(label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| info_resp = session.get( | |
| f"{DATASET_VIEWE_API_URL}/info?dataset={dataset}", timeout=20 | |
| ).json() | |
| if "error" in info_resp: | |
| return { | |
| subset_dropdown: gr.Dropdown(visible=False), | |
| split_dropdown: gr.Dropdown(visible=False), | |
| text_column_dropdown: gr.Dropdown(label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| subsets: list[str] = list(info_resp["dataset_info"]) | |
| subset = default_subset if default_subset in subsets else subsets[0] | |
| splits: list[str] = list(info_resp["dataset_info"][subset]["splits"]) | |
| split = default_split if default_split in splits else splits[0] | |
| features = info_resp["dataset_info"][subset]["features"] | |
| def _is_string_feature(feature): | |
| return isinstance(feature, dict) and feature.get("dtype") == "string" | |
| text_features = [ | |
| feature_name | |
| for feature_name, feature in features.items() | |
| if _is_string_feature(feature) | |
| ] | |
| nested_features = [ | |
| feature_name | |
| for feature_name, feature in features.items() | |
| if isinstance(feature, dict) | |
| and isinstance(next(iter(feature.values())), dict) | |
| ] | |
| nested_text_features = [ | |
| feature_name | |
| for feature_name in nested_features | |
| if any( | |
| _is_string_feature(nested_feature) | |
| for nested_feature in features[feature_name].values() | |
| ) | |
| ] | |
| if not text_feature: | |
| return { | |
| subset_dropdown: gr.Dropdown( | |
| value=subset, choices=subsets, visible=len(subsets) > 1 | |
| ), | |
| split_dropdown: gr.Dropdown( | |
| value=split, choices=splits, visible=len(splits) > 1 | |
| ), | |
| text_column_dropdown: gr.Dropdown( | |
| choices=text_features + nested_text_features, | |
| label="Text column name", | |
| ), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| if text_feature in nested_text_features: | |
| nested_keys = [ | |
| feature_name | |
| for feature_name, feature in features[text_feature].items() | |
| if _is_string_feature(feature) | |
| ] | |
| return { | |
| subset_dropdown: gr.Dropdown( | |
| value=subset, choices=subsets, visible=len(subsets) > 1 | |
| ), | |
| split_dropdown: gr.Dropdown( | |
| value=split, choices=splits, visible=len(splits) > 1 | |
| ), | |
| text_column_dropdown: gr.Dropdown( | |
| choices=text_features + nested_text_features, | |
| label="Text column name", | |
| ), | |
| nested_text_column_dropdown: gr.Dropdown( | |
| value=nested_keys[0], | |
| choices=nested_keys, | |
| label="Nested text column name", | |
| visible=True, | |
| ), | |
| } | |
| return { | |
| subset_dropdown: gr.Dropdown( | |
| value=subset, choices=subsets, visible=len(subsets) > 1 | |
| ), | |
| split_dropdown: gr.Dropdown( | |
| value=split, choices=splits, visible=len(splits) > 1 | |
| ), | |
| text_column_dropdown: gr.Dropdown( | |
| choices=text_features + nested_text_features, label="Text column name" | |
| ), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| def show_input_from_subset_dropdown(dataset: str) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, default_subset="default", default_split="train", text_feature=None | |
| ) | |
| def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, default_subset=subset, default_split="train", text_feature=None | |
| ) | |
| def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, default_subset=subset, default_split=split, text_feature=None | |
| ) | |
| def show_input_from_text_column_dropdown( | |
| dataset: str, subset: str, split: str, text_column | |
| ) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, | |
| default_subset=subset, | |
| default_split=split, | |
| text_feature=text_column, | |
| ) | |
| demo.launch() | |