Spaces:
Runtime error
Runtime error
Try to run on zero with custom componetns
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
import requests
|
| 3 |
import logging
|
| 4 |
import duckdb
|
|
@@ -73,23 +73,25 @@ prompt = system_prompt + example_prompt + main_prompt
|
|
| 73 |
|
| 74 |
llama2 = TextGeneration(generator, prompt=prompt)
|
| 75 |
representation_model = {
|
| 76 |
-
|
| 77 |
"Llama2": llama2,
|
| 78 |
# "MMR": mmr,
|
| 79 |
}
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
def get_parquet_urls(dataset, config, split):
|
|
@@ -111,19 +113,19 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
|
|
| 111 |
return df[column].tolist()
|
| 112 |
|
| 113 |
|
| 114 |
-
|
| 115 |
def calculate_embeddings(docs):
|
| 116 |
return sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
|
| 117 |
|
| 118 |
|
| 119 |
-
|
| 120 |
def fit_model(base_model, docs, embeddings):
|
| 121 |
new_model = BERTopic(
|
| 122 |
"english",
|
| 123 |
# Sub-models
|
| 124 |
embedding_model=sentence_model,
|
| 125 |
-
|
| 126 |
-
|
| 127 |
representation_model=representation_model,
|
| 128 |
# Hyperparameters
|
| 129 |
top_n_words=10,
|
|
@@ -140,10 +142,7 @@ def fit_model(base_model, docs, embeddings):
|
|
| 140 |
updated_model = BERTopic.merge_models([base_model, new_model])
|
| 141 |
nr_new_topics = len(set(updated_model.topics_)) - len(set(base_model.topics_))
|
| 142 |
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
|
| 143 |
-
logging.info("The following topics are newly found:")
|
| 144 |
-
logging.info(f"{new_topics}\n")
|
| 145 |
-
# updated_model.set_topic_labels(updated_model.topic_labels_)
|
| 146 |
-
|
| 147 |
return updated_model, new_model
|
| 148 |
|
| 149 |
|
|
@@ -176,9 +175,7 @@ def generate_topics(dataset, config, split, column, nested_column):
|
|
| 176 |
logging.info(f"Topics: {llama2_labels}")
|
| 177 |
base_model.set_topic_labels(llama2_labels)
|
| 178 |
|
| 179 |
-
reduced_embeddings =
|
| 180 |
-
n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
|
| 181 |
-
).fit_transform(embeddings)
|
| 182 |
|
| 183 |
all_docs.extend(docs)
|
| 184 |
all_reduced_embeddings = np.vstack((all_reduced_embeddings, reduced_embeddings))
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import requests
|
| 3 |
import logging
|
| 4 |
import duckdb
|
|
|
|
| 73 |
|
| 74 |
llama2 = TextGeneration(generator, prompt=prompt)
|
| 75 |
representation_model = {
|
| 76 |
+
"KeyBERT": keybert,
|
| 77 |
"Llama2": llama2,
|
| 78 |
# "MMR": mmr,
|
| 79 |
}
|
| 80 |
|
| 81 |
+
umap_model = UMAP(
|
| 82 |
+
n_neighbors=15, n_components=5, min_dist=0.0, metric="cosine", random_state=42
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
hdbscan_model = HDBSCAN(
|
| 86 |
+
min_cluster_size=15,
|
| 87 |
+
metric="euclidean",
|
| 88 |
+
cluster_selection_method="eom",
|
| 89 |
+
prediction_data=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
reduce_umap_model = UMAP(
|
| 93 |
+
n_neighbors=15, n_components=2, min_dist=0.0, metric="cosine", random_state=42
|
| 94 |
+
)
|
| 95 |
|
| 96 |
|
| 97 |
def get_parquet_urls(dataset, config, split):
|
|
|
|
| 113 |
return df[column].tolist()
|
| 114 |
|
| 115 |
|
| 116 |
+
@spaces.GPU
|
| 117 |
def calculate_embeddings(docs):
|
| 118 |
return sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
|
| 119 |
|
| 120 |
|
| 121 |
+
@spaces.GPU
|
| 122 |
def fit_model(base_model, docs, embeddings):
|
| 123 |
new_model = BERTopic(
|
| 124 |
"english",
|
| 125 |
# Sub-models
|
| 126 |
embedding_model=sentence_model,
|
| 127 |
+
umap_model=umap_model,
|
| 128 |
+
hdbscan_model=hdbscan_model,
|
| 129 |
representation_model=representation_model,
|
| 130 |
# Hyperparameters
|
| 131 |
top_n_words=10,
|
|
|
|
| 142 |
updated_model = BERTopic.merge_models([base_model, new_model])
|
| 143 |
nr_new_topics = len(set(updated_model.topics_)) - len(set(base_model.topics_))
|
| 144 |
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
|
| 145 |
+
logging.info(f"The following topics are newly found: {new_topics}")
|
|
|
|
|
|
|
|
|
|
| 146 |
return updated_model, new_model
|
| 147 |
|
| 148 |
|
|
|
|
| 175 |
logging.info(f"Topics: {llama2_labels}")
|
| 176 |
base_model.set_topic_labels(llama2_labels)
|
| 177 |
|
| 178 |
+
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
|
|
|
|
|
|
| 179 |
|
| 180 |
all_docs.extend(docs)
|
| 181 |
all_reduced_embeddings = np.vstack((all_reduced_embeddings, reduced_embeddings))
|