Spaces:
Runtime error
Runtime error
Disable ZeroGPU
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from transformers import (
|
|
| 7 |
)
|
| 8 |
|
| 9 |
# These imports at the end because of torch/datamapplot issue in Zero GPU
|
| 10 |
-
import spaces
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
import logging
|
|
@@ -93,8 +93,6 @@ representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
|
|
| 93 |
|
| 94 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 95 |
|
| 96 |
-
global_topic_model = None
|
| 97 |
-
|
| 98 |
|
| 99 |
def get_split_rows(dataset, config, split):
|
| 100 |
config_size = session.get(
|
|
@@ -131,7 +129,7 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
|
|
| 131 |
return df[column].tolist()
|
| 132 |
|
| 133 |
|
| 134 |
-
@spaces.GPU
|
| 135 |
def calculate_embeddings(docs):
|
| 136 |
return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
|
| 137 |
|
|
@@ -142,10 +140,8 @@ def calculate_n_neighbors_and_components(n_rows):
|
|
| 142 |
return n_neighbors, n_components
|
| 143 |
|
| 144 |
|
| 145 |
-
@spaces.GPU
|
| 146 |
def fit_model(docs, embeddings, n_neighbors, n_components):
|
| 147 |
-
global global_topic_model
|
| 148 |
-
|
| 149 |
umap_model = UMAP(
|
| 150 |
n_neighbors=n_neighbors,
|
| 151 |
n_components=n_components,
|
|
@@ -180,9 +176,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
|
|
| 180 |
new_model.fit(docs, embeddings)
|
| 181 |
logging.info("End fitting new model")
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
logging.info("Global model updated")
|
| 186 |
|
| 187 |
|
| 188 |
def _push_to_hub(
|
|
@@ -207,7 +201,6 @@ def _push_to_hub(
|
|
| 207 |
|
| 208 |
|
| 209 |
def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
| 210 |
-
global global_topic_model
|
| 211 |
logging.info(
|
| 212 |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
|
| 213 |
)
|
|
@@ -257,12 +250,12 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
|
| 257 |
)
|
| 258 |
|
| 259 |
embeddings = calculate_embeddings(docs)
|
| 260 |
-
fit_model(docs, embeddings, n_neighbors, n_components)
|
| 261 |
|
| 262 |
if base_model is None:
|
| 263 |
-
base_model =
|
| 264 |
else:
|
| 265 |
-
updated_model = BERTopic.merge_models([base_model,
|
| 266 |
nr_new_topics = len(set(updated_model.topics_)) - len(
|
| 267 |
set(base_model.topics_)
|
| 268 |
)
|
|
|
|
| 7 |
)
|
| 8 |
|
| 9 |
# These imports at the end because of torch/datamapplot issue in Zero GPU
|
| 10 |
+
# import spaces
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
import logging
|
|
|
|
| 93 |
|
| 94 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 95 |
|
|
|
|
|
|
|
| 96 |
|
| 97 |
def get_split_rows(dataset, config, split):
|
| 98 |
config_size = session.get(
|
|
|
|
| 129 |
return df[column].tolist()
|
| 130 |
|
| 131 |
|
| 132 |
+
# @spaces.GPU
|
| 133 |
def calculate_embeddings(docs):
|
| 134 |
return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
|
| 135 |
|
|
|
|
| 140 |
return n_neighbors, n_components
|
| 141 |
|
| 142 |
|
| 143 |
+
# @spaces.GPU
|
| 144 |
def fit_model(docs, embeddings, n_neighbors, n_components):
|
|
|
|
|
|
|
| 145 |
umap_model = UMAP(
|
| 146 |
n_neighbors=n_neighbors,
|
| 147 |
n_components=n_components,
|
|
|
|
| 176 |
new_model.fit(docs, embeddings)
|
| 177 |
logging.info("End fitting new model")
|
| 178 |
|
| 179 |
+
return new_model
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def _push_to_hub(
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
|
|
|
| 204 |
logging.info(
|
| 205 |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
|
| 206 |
)
|
|
|
|
| 250 |
)
|
| 251 |
|
| 252 |
embeddings = calculate_embeddings(docs)
|
| 253 |
+
new_model = fit_model(docs, embeddings, n_neighbors, n_components)
|
| 254 |
|
| 255 |
if base_model is None:
|
| 256 |
+
base_model = new_model
|
| 257 |
else:
|
| 258 |
+
updated_model = BERTopic.merge_models([base_model, new_model])
|
| 259 |
nr_new_topics = len(set(updated_model.topics_)) - len(
|
| 260 |
set(base_model.topics_)
|
| 261 |
)
|