Spaces:
Runtime error
Runtime error
Adding logs
Browse files
app.py
CHANGED
|
@@ -15,6 +15,7 @@ from bertopic.representation import KeyBERTInspired
|
|
| 15 |
from huggingface_hub import HfApi, InferenceClient
|
| 16 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 17 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 18 |
|
| 19 |
from src.hub import create_space_with_content
|
| 20 |
from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT
|
|
@@ -167,14 +168,11 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
| 167 |
|
| 168 |
try:
|
| 169 |
while offset < limit:
|
|
|
|
| 170 |
docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
|
| 171 |
if not docs:
|
| 172 |
break
|
| 173 |
-
|
| 174 |
-
logging.info(
|
| 175 |
-
f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs"
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
embeddings = calculate_embeddings(docs)
|
| 179 |
new_model = fit_model(docs, embeddings, n_neighbors, n_components)
|
| 180 |
|
|
@@ -192,14 +190,18 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
| 192 |
logging.info(f"The following topics are newly found: {new_topics}")
|
| 193 |
base_model = updated_model
|
| 194 |
|
|
|
|
| 195 |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
| 196 |
reduced_embeddings_list.append(reduced_embeddings)
|
| 197 |
|
| 198 |
all_docs.extend(docs)
|
| 199 |
reduced_embeddings_array = np.vstack(reduced_embeddings_list)
|
|
|
|
| 200 |
|
| 201 |
topics_info = base_model.get_topic_info()
|
| 202 |
all_topics = base_model.topics_
|
|
|
|
|
|
|
| 203 |
topic_plot = (
|
| 204 |
base_model.visualize_document_datamap(
|
| 205 |
docs=all_docs,
|
|
@@ -224,11 +226,13 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
| 224 |
if plot_type == "DataMapPlot"
|
| 225 |
else base_model.visualize_documents(
|
| 226 |
docs=all_docs,
|
|
|
|
| 227 |
reduced_embeddings=reduced_embeddings_array,
|
| 228 |
custom_labels=True,
|
| 229 |
title="",
|
| 230 |
)
|
| 231 |
)
|
|
|
|
| 232 |
rows_processed += len(docs)
|
| 233 |
progress = min(rows_processed / limit, 1.0)
|
| 234 |
logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
|
|
@@ -403,7 +407,7 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
| 403 |
del (
|
| 404 |
base_model,
|
| 405 |
all_topics,
|
| 406 |
-
|
| 407 |
topic_names_array,
|
| 408 |
interactive_plot,
|
| 409 |
)
|
|
|
|
| 15 |
from huggingface_hub import HfApi, InferenceClient
|
| 16 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 17 |
from sentence_transformers import SentenceTransformer
|
| 18 |
+
from torch import cuda
|
| 19 |
|
| 20 |
from src.hub import create_space_with_content
|
| 21 |
from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT
|
|
|
|
| 168 |
|
| 169 |
try:
|
| 170 |
while offset < limit:
|
| 171 |
+
logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}")
|
| 172 |
docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
|
| 173 |
if not docs:
|
| 174 |
break
|
| 175 |
+
logging.info(f"Got {len(docs)} docs ✓")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
embeddings = calculate_embeddings(docs)
|
| 177 |
new_model = fit_model(docs, embeddings, n_neighbors, n_components)
|
| 178 |
|
|
|
|
| 190 |
logging.info(f"The following topics are newly found: {new_topics}")
|
| 191 |
base_model = updated_model
|
| 192 |
|
| 193 |
+
logging.info("Reducing embeddings to 2D")
|
| 194 |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
| 195 |
reduced_embeddings_list.append(reduced_embeddings)
|
| 196 |
|
| 197 |
all_docs.extend(docs)
|
| 198 |
reduced_embeddings_array = np.vstack(reduced_embeddings_list)
|
| 199 |
+
logging.info("Reducing embeddings to 2D ✓")
|
| 200 |
|
| 201 |
topics_info = base_model.get_topic_info()
|
| 202 |
all_topics = base_model.topics_
|
| 203 |
+
logging.info(f"Preparing topics {plot_type} plot")
|
| 204 |
+
|
| 205 |
topic_plot = (
|
| 206 |
base_model.visualize_document_datamap(
|
| 207 |
docs=all_docs,
|
|
|
|
| 226 |
if plot_type == "DataMapPlot"
|
| 227 |
else base_model.visualize_documents(
|
| 228 |
docs=all_docs,
|
| 229 |
+
topics=all_topics,
|
| 230 |
reduced_embeddings=reduced_embeddings_array,
|
| 231 |
custom_labels=True,
|
| 232 |
title="",
|
| 233 |
)
|
| 234 |
)
|
| 235 |
+
logging.info("Plot done ✓")
|
| 236 |
rows_processed += len(docs)
|
| 237 |
progress = min(rows_processed / limit, 1.0)
|
| 238 |
logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
|
|
|
|
| 407 |
del (
|
| 408 |
base_model,
|
| 409 |
all_topics,
|
| 410 |
+
topics_info,
|
| 411 |
topic_names_array,
|
| 412 |
interactive_plot,
|
| 413 |
)
|