Spaces:
Runtime error
Runtime error
| import spaces | |
| import requests | |
| import logging | |
| import duckdb | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from bertopic import BERTopic | |
| import pandas as pd | |
| import gradio as gr | |
| from bertopic.representation import KeyBERTInspired | |
| from umap import UMAP | |
| # from cuml.cluster import HDBSCAN | |
| # from cuml.manifold import UMAP | |
| from sentence_transformers import SentenceTransformer | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| session = requests.Session() | |
| def get_parquet_urls(dataset, config, split): | |
| parquet_files = session.get( | |
| f"https://datasets-server.huggingface.co/parquet?dataset={dataset}&config={config}&split={split}", | |
| timeout=20, | |
| ).json() | |
| if "error" in parquet_files: | |
| raise Exception(f"Error fetching parquet files: {parquet_files['error']}") | |
| parquet_urls = [file["url"] for file in parquet_files["parquet_files"]] | |
| logging.debug(f"Parquet files: {parquet_urls}") | |
| return ",".join(f"'{url}'" for url in parquet_urls) | |
| def get_docs_from_parquet(parquet_urls, column, offset, limit): | |
| SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};" | |
| df = duckdb.sql(SQL_QUERY).to_df() | |
| logging.debug(f"Dataframe: {df.head(5)}") | |
| return df[column].tolist() | |
| def calculate_embeddings(sentence_model, docs): | |
| embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100) | |
| logging.info(f"Embeddings shape: {embeddings.shape}") | |
| return embeddings | |
| def fit_model(base_model, sentence_model, representation_model, docs, embeddings): | |
| new_model = BERTopic( | |
| "english", | |
| embedding_model=sentence_model, | |
| representation_model=representation_model, | |
| min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model | |
| ) | |
| logging.info("Fitting new model") | |
| new_model.fit(docs, embeddings) | |
| logging.info("End fitting new model") | |
| if base_model is None: | |
| return new_model, new_model | |
| updated_model = BERTopic.merge_models([base_model, new_model]) | |
| nr_new_topics = len(set(updated_model.topics_)) - len(set(base_model.topics_)) | |
| new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:] | |
| logging.info("The following topics are newly found:") | |
| logging.info(f"{new_topics}\n") | |
| return updated_model, new_model | |
| def generate_topics(dataset, config, split, column, nested_column): | |
| logging.info( | |
| f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}" | |
| ) | |
| parquet_urls = get_parquet_urls(dataset, config, split) | |
| limit = 1_000 | |
| chunk_size = 300 | |
| offset = 0 | |
| representation_model = KeyBERTInspired() | |
| base_model = None | |
| # docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size) | |
| # base_model = BERTopic( | |
| # "english", representation_model=representation_model, min_topic_size=15 | |
| # ) | |
| # base_model.fit_transform(docs) | |
| # yield base_model.get_topic_info(), base_model.visualize_topics() | |
| # Create instances of GPU-accelerated UMAP and HDBSCAN | |
| # umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0) | |
| # hdbscan_model = HDBSCAN(min_samples=10, gen_min_span_tree=True) | |
| sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda") | |
| while True: | |
| docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size) | |
| logging.info(f"------------> New chunk data {offset=} {chunk_size=}") | |
| embeddings = calculate_embeddings(sentence_model, docs) | |
| offset = offset + chunk_size | |
| if not docs or offset >= limit: | |
| break | |
| # new_model = BERTopic( | |
| # "english", | |
| # embedding_model=sentence_model, | |
| # representation_model=representation_model, | |
| # min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model | |
| # ) | |
| # logging.info("Fitting new model") | |
| # new_model.fit(docs, embeddings) | |
| # logging.info("End fitting new model") | |
| # if base_model is not None: | |
| # updated_model = BERTopic.merge_models([base_model, new_model]) | |
| # nr_new_topics = len(set(updated_model.topics_)) - len( | |
| # set(base_model.topics_) | |
| # ) | |
| # new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:] | |
| # logging.info("The following topics are newly found:") | |
| # logging.info(f"{new_topics}\n") | |
| # base_model = updated_model | |
| # else: | |
| # base_model = new_model | |
| # logging.info(base_model.get_topic_info()) | |
| base_model, new_model = fit_model( | |
| base_model, sentence_model, representation_model, docs, embeddings | |
| ) | |
| # reduced_embeddings = UMAP( | |
| # n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine" | |
| # ).fit_transform(embeddings) | |
| # logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}") | |
| yield ( | |
| base_model.get_topic_info(), | |
| new_model.visualize_documents( | |
| docs, embeddings=embeddings | |
| ), # TODO: Visualize the merged models | |
| ) | |
| logging.info("Finished processing all data") | |
| return base_model.get_topic_info(), base_model.visualize_topics() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 💠 Dataset Topic Discovery 🔭 | |
| ## Select dataset and text column | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| dataset_name = HuggingfaceHubSearch( | |
| label="Hub Dataset ID", | |
| placeholder="Search for dataset id on Huggingface", | |
| search_type="dataset", | |
| ) | |
| subset_dropdown = gr.Dropdown(label="Subset", visible=False) | |
| split_dropdown = gr.Dropdown(label="Split", visible=False) | |
| with gr.Accordion("Dataset preview", open=False): | |
| def embed(name, subset, split): | |
| html_code = f""" | |
| <iframe | |
| src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}" | |
| frameborder="0" | |
| width="100%" | |
| height="600px" | |
| ></iframe> | |
| """ | |
| return gr.HTML(value=html_code) | |
| with gr.Row(): | |
| text_column_dropdown = gr.Dropdown(label="Text column name") | |
| nested_text_column_dropdown = gr.Dropdown( | |
| label="Nested text column name", visible=False | |
| ) | |
| generate_button = gr.Button("Generate Notebook", variant="primary") | |
| gr.Markdown("## Topics info") | |
| topics_df = gr.DataFrame(interactive=False, visible=True) | |
| topics_plot = gr.Plot() | |
| generate_button.click( | |
| generate_topics, | |
| inputs=[ | |
| dataset_name, | |
| subset_dropdown, | |
| split_dropdown, | |
| text_column_dropdown, | |
| nested_text_column_dropdown, | |
| ], | |
| outputs=[topics_df, topics_plot], | |
| ) | |
| # TODO: choose num_rows, random, or offset -> By default limit max to 1176 rows | |
| # -> From the article, it could be in GPU 1176/sec | |
| def _resolve_dataset_selection( | |
| dataset: str, default_subset: str, default_split: str, text_feature | |
| ): | |
| if "/" not in dataset.strip().strip("/"): | |
| return { | |
| subset_dropdown: gr.Dropdown(visible=False), | |
| split_dropdown: gr.Dropdown(visible=False), | |
| text_column_dropdown: gr.Dropdown(label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| info_resp = session.get( | |
| f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=20 | |
| ).json() | |
| if "error" in info_resp: | |
| return { | |
| subset_dropdown: gr.Dropdown(visible=False), | |
| split_dropdown: gr.Dropdown(visible=False), | |
| text_column_dropdown: gr.Dropdown(label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| subsets: list[str] = list(info_resp["dataset_info"]) | |
| subset = default_subset if default_subset in subsets else subsets[0] | |
| splits: list[str] = list(info_resp["dataset_info"][subset]["splits"]) | |
| split = default_split if default_split in splits else splits[0] | |
| features = info_resp["dataset_info"][subset]["features"] | |
| def _is_string_feature(feature): | |
| return isinstance(feature, dict) and feature.get("dtype") == "string" | |
| text_features = [ | |
| feature_name | |
| for feature_name, feature in features.items() | |
| if _is_string_feature(feature) | |
| ] | |
| nested_features = [ | |
| feature_name | |
| for feature_name, feature in features.items() | |
| if isinstance(feature, dict) | |
| and isinstance(next(iter(feature.values())), dict) | |
| ] | |
| nested_text_features = [ | |
| feature_name | |
| for feature_name in nested_features | |
| if any( | |
| _is_string_feature(nested_feature) | |
| for nested_feature in features[feature_name].values() | |
| ) | |
| ] | |
| if not text_feature: | |
| return { | |
| subset_dropdown: gr.Dropdown( | |
| value=subset, choices=subsets, visible=len(subsets) > 1 | |
| ), | |
| split_dropdown: gr.Dropdown( | |
| value=split, choices=splits, visible=len(splits) > 1 | |
| ), | |
| text_column_dropdown: gr.Dropdown( | |
| choices=text_features + nested_text_features, | |
| label="Text column name", | |
| ), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| if text_feature in nested_text_features: | |
| nested_keys = [ | |
| feature_name | |
| for feature_name, feature in features[text_feature].items() | |
| if _is_string_feature(feature) | |
| ] | |
| return { | |
| subset_dropdown: gr.Dropdown( | |
| value=subset, choices=subsets, visible=len(subsets) > 1 | |
| ), | |
| split_dropdown: gr.Dropdown( | |
| value=split, choices=splits, visible=len(splits) > 1 | |
| ), | |
| text_column_dropdown: gr.Dropdown( | |
| choices=text_features + nested_text_features, | |
| label="Text column name", | |
| ), | |
| nested_text_column_dropdown: gr.Dropdown( | |
| value=nested_keys[0], | |
| choices=nested_keys, | |
| label="Nested text column name", | |
| visible=True, | |
| ), | |
| } | |
| return { | |
| subset_dropdown: gr.Dropdown( | |
| value=subset, choices=subsets, visible=len(subsets) > 1 | |
| ), | |
| split_dropdown: gr.Dropdown( | |
| value=split, choices=splits, visible=len(splits) > 1 | |
| ), | |
| text_column_dropdown: gr.Dropdown( | |
| choices=text_features + nested_text_features, label="Text column name" | |
| ), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| def show_input_from_subset_dropdown(dataset: str) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, default_subset="default", default_split="train", text_feature=None | |
| ) | |
| def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, default_subset=subset, default_split="train", text_feature=None | |
| ) | |
| def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, default_subset=subset, default_split=split, text_feature=None | |
| ) | |
| def show_input_from_text_column_dropdown( | |
| dataset: str, subset: str, split: str, text_column | |
| ) -> dict: | |
| return _resolve_dataset_selection( | |
| dataset, | |
| default_subset=subset, | |
| default_split=split, | |
| text_feature=text_column, | |
| ) | |
| demo.launch() | |