Spaces:
Runtime error
Runtime error
Changing sentence transformer
Browse files- app.py +54 -29
- requirements.txt +3 -1
app.py
CHANGED
|
@@ -6,7 +6,11 @@ from bertopic import BERTopic
|
|
| 6 |
import pandas as pd
|
| 7 |
import gradio as gr
|
| 8 |
from bertopic.representation import KeyBERTInspired
|
| 9 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
logging.basicConfig(
|
| 12 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
@@ -24,7 +28,7 @@ def get_parquet_urls(dataset, config, split):
|
|
| 24 |
if "error" in parquet_files:
|
| 25 |
raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
|
| 26 |
parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
|
| 27 |
-
logging.
|
| 28 |
return ",".join(f"'{url}'" for url in parquet_urls)
|
| 29 |
|
| 30 |
|
|
@@ -34,7 +38,7 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
|
|
| 34 |
logging.debug(f"Dataframe: {df.head(5)}")
|
| 35 |
return df[column].tolist()
|
| 36 |
|
| 37 |
-
|
| 38 |
def generate_topics(dataset, config, split, column, nested_column):
|
| 39 |
logging.info(
|
| 40 |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
|
|
@@ -45,39 +49,60 @@ def generate_topics(dataset, config, split, column, nested_column):
|
|
| 45 |
chunk_size = 300
|
| 46 |
offset = 0
|
| 47 |
representation_model = KeyBERTInspired()
|
| 48 |
-
|
| 49 |
-
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
| 50 |
-
|
| 51 |
-
base_model = BERTopic(
|
| 52 |
-
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
while True:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
offset = offset + chunk_size
|
| 59 |
if not docs or offset >= limit:
|
| 60 |
break
|
| 61 |
|
| 62 |
-
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
| 63 |
-
logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
|
| 64 |
-
logging.info(docs[:5])
|
| 65 |
-
|
| 66 |
new_model = BERTopic(
|
| 67 |
-
"english",
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
logging.info("
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
logging.info(base_model.get_topic_info())
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
| 82 |
|
| 83 |
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import gradio as gr
|
| 8 |
from bertopic.representation import KeyBERTInspired
|
| 9 |
+
from umap import UMAP
|
| 10 |
+
|
| 11 |
+
# from cuml.cluster import HDBSCAN
|
| 12 |
+
# from cuml.manifold import UMAP
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
|
| 15 |
logging.basicConfig(
|
| 16 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
| 28 |
if "error" in parquet_files:
|
| 29 |
raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
|
| 30 |
parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
|
| 31 |
+
logging.debug(f"Parquet files: {parquet_urls}")
|
| 32 |
return ",".join(f"'{url}'" for url in parquet_urls)
|
| 33 |
|
| 34 |
|
|
|
|
| 38 |
logging.debug(f"Dataframe: {df.head(5)}")
|
| 39 |
return df[column].tolist()
|
| 40 |
|
| 41 |
+
|
| 42 |
def generate_topics(dataset, config, split, column, nested_column):
|
| 43 |
logging.info(
|
| 44 |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
|
|
|
|
| 49 |
chunk_size = 300
|
| 50 |
offset = 0
|
| 51 |
representation_model = KeyBERTInspired()
|
| 52 |
+
base_model = None
|
| 53 |
+
# docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
| 54 |
+
|
| 55 |
+
# base_model = BERTopic(
|
| 56 |
+
# "english", representation_model=representation_model, min_topic_size=15
|
| 57 |
+
# )
|
| 58 |
+
# base_model.fit_transform(docs)
|
| 59 |
+
|
| 60 |
+
# yield base_model.get_topic_info(), base_model.visualize_topics()
|
| 61 |
+
# Create instances of GPU-accelerated UMAP and HDBSCAN
|
| 62 |
+
# umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0)
|
| 63 |
+
# hdbscan_model = HDBSCAN(min_samples=10, gen_min_span_tree=True)
|
| 64 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
|
| 65 |
while True:
|
| 66 |
+
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
| 67 |
+
logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
|
| 68 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
|
| 69 |
+
logging.info(f"Embeddings shape: {embeddings.shape}")
|
| 70 |
offset = offset + chunk_size
|
| 71 |
if not docs or offset >= limit:
|
| 72 |
break
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
new_model = BERTopic(
|
| 75 |
+
"english",
|
| 76 |
+
embedding_model=sentence_model,
|
| 77 |
+
representation_model=representation_model,
|
| 78 |
+
min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
|
| 79 |
+
)
|
| 80 |
+
logging.info("Fitting new model")
|
| 81 |
+
new_model.fit(docs, embeddings)
|
| 82 |
+
logging.info("End fitting new model")
|
| 83 |
+
if base_model is not None:
|
| 84 |
+
updated_model = BERTopic.merge_models([base_model, new_model])
|
| 85 |
+
nr_new_topics = len(set(updated_model.topics_)) - len(
|
| 86 |
+
set(base_model.topics_)
|
| 87 |
+
)
|
| 88 |
+
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
|
| 89 |
+
logging.info("The following topics are newly found:")
|
| 90 |
+
logging.info(f"{new_topics}\n")
|
| 91 |
+
base_model = updated_model
|
| 92 |
+
else:
|
| 93 |
+
base_model = new_model
|
| 94 |
logging.info(base_model.get_topic_info())
|
| 95 |
+
reduced_embeddings = UMAP(
|
| 96 |
+
n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
|
| 97 |
+
).fit_transform(embeddings)
|
| 98 |
+
logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}")
|
| 99 |
+
yield (
|
| 100 |
+
base_model.get_topic_info(),
|
| 101 |
+
new_model.visualize_documents(
|
| 102 |
+
docs, embeddings=embeddings
|
| 103 |
+
), # TODO: Visualize the merged models
|
| 104 |
+
)
|
| 105 |
+
logging.info("Finished processing all data")
|
| 106 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
| 107 |
|
| 108 |
|
requirements.txt
CHANGED
|
@@ -4,4 +4,6 @@ umap-learn
|
|
| 4 |
sentence-transformers
|
| 5 |
datamapplot
|
| 6 |
bertopic
|
| 7 |
-
pandas
|
|
|
|
|
|
|
|
|
| 4 |
sentence-transformers
|
| 5 |
datamapplot
|
| 6 |
bertopic
|
| 7 |
+
pandas
|
| 8 |
+
torch
|
| 9 |
+
cuml-cu11
|