Spaces:
Runtime error
Runtime error
Apply text generation layer at the end only
Browse files
app.py
CHANGED
|
@@ -44,7 +44,6 @@ DATASETS_TOPICS_ORGANIZATION = os.getenv(
|
|
| 44 |
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
|
| 45 |
)
|
| 46 |
USE_CUML = int(os.getenv("USE_CUML", "1"))
|
| 47 |
-
USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
|
| 48 |
|
| 49 |
# Use cuml lib only if configured
|
| 50 |
if USE_CUML:
|
|
@@ -60,43 +59,39 @@ logging.basicConfig(
|
|
| 60 |
)
|
| 61 |
|
| 62 |
api = HfApi(token=HF_TOKEN)
|
| 63 |
-
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 64 |
-
|
| 65 |
-
# Representation model
|
| 66 |
-
if USE_LLM_TEXT_GENERATION:
|
| 67 |
-
bnb_config = BitsAndBytesConfig(
|
| 68 |
-
load_in_4bit=True,
|
| 69 |
-
bnb_4bit_quant_type="nf4",
|
| 70 |
-
bnb_4bit_use_double_quant=True,
|
| 71 |
-
bnb_4bit_compute_dtype=bfloat16,
|
| 72 |
-
)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
| 95 |
vectorizer_model = CountVectorizer(stop_words="english")
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
def calculate_embeddings(docs):
|
| 99 |
-
return
|
| 100 |
|
| 101 |
|
| 102 |
def calculate_n_neighbors_and_components(n_rows):
|
|
@@ -126,7 +121,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
|
|
| 126 |
new_model = BERTopic(
|
| 127 |
language="english",
|
| 128 |
# Sub-models
|
| 129 |
-
embedding_model=
|
| 130 |
umap_model=umap_model, # Step 2 - UMAP model
|
| 131 |
hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
|
| 132 |
vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
|
|
@@ -294,13 +289,55 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
| 294 |
all_topics = base_model.topics_
|
| 295 |
topic_info = base_model.get_topic_info()
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
)
|
|
|
|
|
|
|
|
|
|
| 304 |
interactive_plot = datamapplot.create_interactive_plot(
|
| 305 |
reduced_embeddings_array,
|
| 306 |
topic_names_array,
|
|
@@ -348,7 +385,6 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
| 348 |
base_model,
|
| 349 |
all_topics,
|
| 350 |
topic_info,
|
| 351 |
-
topic_names,
|
| 352 |
topic_names_array,
|
| 353 |
interactive_plot,
|
| 354 |
)
|
|
|
|
| 44 |
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
|
| 45 |
)
|
| 46 |
USE_CUML = int(os.getenv("USE_CUML", "1"))
|
|
|
|
| 47 |
|
| 48 |
# Use cuml lib only if configured
|
| 49 |
if USE_CUML:
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
api = HfApi(token=HF_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
bnb_config = BitsAndBytesConfig(
|
| 64 |
+
load_in_4bit=True,
|
| 65 |
+
bnb_4bit_quant_type="nf4",
|
| 66 |
+
bnb_4bit_use_double_quant=True,
|
| 67 |
+
bnb_4bit_compute_dtype=bfloat16,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
+
model_id,
|
| 74 |
+
trust_remote_code=True,
|
| 75 |
+
quantization_config=bnb_config,
|
| 76 |
+
device_map="auto",
|
| 77 |
+
)
|
| 78 |
+
model.eval()
|
| 79 |
+
generator = pipeline(
|
| 80 |
+
model=model,
|
| 81 |
+
tokenizer=tokenizer,
|
| 82 |
+
task="text-generation",
|
| 83 |
+
temperature=0.1,
|
| 84 |
+
max_new_tokens=500,
|
| 85 |
+
repetition_penalty=1.1,
|
| 86 |
+
)
|
| 87 |
|
| 88 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 89 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 90 |
+
representation_model = KeyBERTInspired()
|
| 91 |
|
| 92 |
|
| 93 |
def calculate_embeddings(docs):
|
| 94 |
+
return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
|
| 95 |
|
| 96 |
|
| 97 |
def calculate_n_neighbors_and_components(n_rows):
|
|
|
|
| 121 |
new_model = BERTopic(
|
| 122 |
language="english",
|
| 123 |
# Sub-models
|
| 124 |
+
embedding_model=embedding_model, # Step 1 - Extract embeddings
|
| 125 |
umap_model=umap_model, # Step 2 - UMAP model
|
| 126 |
hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
|
| 127 |
vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
|
|
|
|
| 289 |
all_topics = base_model.topics_
|
| 290 |
topic_info = base_model.get_topic_info()
|
| 291 |
|
| 292 |
+
new_topics_by_text_generation = {}
|
| 293 |
+
for _, row in topic_info.iterrows():
|
| 294 |
+
logging.info(
|
| 295 |
+
f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
|
| 296 |
+
)
|
| 297 |
+
prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
|
| 298 |
+
logging.info(prompt)
|
| 299 |
+
topic_description = generator(prompt)
|
| 300 |
+
logging.info(topic_description)
|
| 301 |
+
new_topics_by_text_generation[row["Topic"]] = topic_description[0][
|
| 302 |
+
"generated_text"
|
| 303 |
+
].replace(prompt, "")
|
| 304 |
+
base_model.set_topic_labels(new_topics_by_text_generation)
|
| 305 |
+
|
| 306 |
+
topics_info = base_model.get_topic_info()
|
| 307 |
+
|
| 308 |
+
topic_plot = (
|
| 309 |
+
base_model.visualize_document_datamap(
|
| 310 |
+
docs=all_docs,
|
| 311 |
+
topics=all_topics,
|
| 312 |
+
custom_labels=True,
|
| 313 |
+
reduced_embeddings=reduced_embeddings_array,
|
| 314 |
+
title="",
|
| 315 |
+
sub_title=sub_title,
|
| 316 |
+
width=800,
|
| 317 |
+
height=700,
|
| 318 |
+
arrowprops={
|
| 319 |
+
"arrowstyle": "wedge,tail_width=0.5",
|
| 320 |
+
"connectionstyle": "arc3,rad=0.05",
|
| 321 |
+
"linewidth": 0,
|
| 322 |
+
"fc": "#33333377",
|
| 323 |
+
},
|
| 324 |
+
dynamic_label_size=True,
|
| 325 |
+
# label_wrap_width=12,
|
| 326 |
+
label_over_points=True,
|
| 327 |
+
max_font_size=36,
|
| 328 |
+
min_font_size=4,
|
| 329 |
+
)
|
| 330 |
+
if plot_type == "DataMapPlot"
|
| 331 |
+
else base_model.visualize_documents(
|
| 332 |
+
docs=all_docs,
|
| 333 |
+
reduced_embeddings=reduced_embeddings_array,
|
| 334 |
+
custom_labels=True,
|
| 335 |
+
title="",
|
| 336 |
+
)
|
| 337 |
)
|
| 338 |
+
custom_labels = base_model.custom_labels_
|
| 339 |
+
topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
|
| 340 |
+
|
| 341 |
interactive_plot = datamapplot.create_interactive_plot(
|
| 342 |
reduced_embeddings_array,
|
| 343 |
topic_names_array,
|
|
|
|
| 385 |
base_model,
|
| 386 |
all_topics,
|
| 387 |
topic_info,
|
|
|
|
| 388 |
topic_names_array,
|
| 389 |
interactive_plot,
|
| 390 |
)
|