Spaces:
Runtime error
Runtime error
Parameterize behavior
Browse files- app.py +54 -56
- requirements.txt +2 -2
- prompts.py → templates.py +11 -0
app.py
CHANGED
|
@@ -19,7 +19,7 @@ from bertopic.representation import TextGeneration
|
|
| 19 |
from huggingface_hub import HfApi, SpaceCard
|
| 20 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 21 |
from sentence_transformers import SentenceTransformer
|
| 22 |
-
from
|
| 23 |
from torch import cuda, bfloat16
|
| 24 |
from transformers import (
|
| 25 |
BitsAndBytesConfig,
|
|
@@ -27,11 +27,6 @@ from transformers import (
|
|
| 27 |
AutoModelForCausalLM,
|
| 28 |
pipeline,
|
| 29 |
)
|
| 30 |
-
# from cuml.manifold import UMAP
|
| 31 |
-
# from cuml.cluster import HDBSCAN
|
| 32 |
-
|
| 33 |
-
from umap import UMAP
|
| 34 |
-
from hdbscan import HDBSCAN
|
| 35 |
|
| 36 |
"""
|
| 37 |
TODOs:
|
|
@@ -51,52 +46,68 @@ assert (
|
|
| 51 |
EXPORTS_REPOSITORY is not None
|
| 52 |
), "You need to set EXPORTS_REPOSITORY in your environment variables"
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
logging.basicConfig(
|
| 55 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 56 |
)
|
| 57 |
|
| 58 |
-
MAX_ROWS = 50_000
|
| 59 |
-
CHUNK_SIZE = 10_000
|
| 60 |
-
|
| 61 |
api = HfApi(token=HF_TOKEN)
|
| 62 |
-
|
| 63 |
session = requests.Session()
|
| 64 |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 65 |
|
| 66 |
# Representation model
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
-
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
| 75 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 76 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
)
|
| 82 |
-
model.eval()
|
| 83 |
-
generator = pipeline(
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
)
|
| 91 |
-
representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
|
| 92 |
-
|
|
|
|
| 93 |
|
| 94 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 95 |
|
| 96 |
|
| 97 |
def get_split_rows(dataset, config, split):
|
| 98 |
config_size = session.get(
|
| 99 |
-
f"
|
| 100 |
timeout=20,
|
| 101 |
).json()
|
| 102 |
if "error" in config_size:
|
|
@@ -112,7 +123,7 @@ def get_split_rows(dataset, config, split):
|
|
| 112 |
|
| 113 |
def get_parquet_urls(dataset, config, split):
|
| 114 |
parquet_files = session.get(
|
| 115 |
-
f"
|
| 116 |
timeout=20,
|
| 117 |
).json()
|
| 118 |
if "error" in parquet_files:
|
|
@@ -125,7 +136,6 @@ def get_parquet_urls(dataset, config, split):
|
|
| 125 |
def get_docs_from_parquet(parquet_urls, column, offset, limit):
|
| 126 |
SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
|
| 127 |
df = duckdb.sql(SQL_QUERY).to_df()
|
| 128 |
-
logging.debug(f"Dataframe: {df.head(5)}")
|
| 129 |
return df[column].tolist()
|
| 130 |
|
| 131 |
|
|
@@ -200,8 +210,7 @@ def _push_to_hub(
|
|
| 200 |
|
| 201 |
|
| 202 |
def create_space_with_content(dataset_id, html_file_path):
|
| 203 |
-
|
| 204 |
-
repo_id = f"datasets-topics/{dataset_id.replace('/', '-')}"
|
| 205 |
logging.info(f"Creating space with content: {repo_id} on file {html_file_path}")
|
| 206 |
api.create_repo(
|
| 207 |
repo_id=repo_id,
|
|
@@ -211,16 +220,6 @@ def create_space_with_content(dataset_id, html_file_path):
|
|
| 211 |
token=HF_TOKEN,
|
| 212 |
space_sdk="static",
|
| 213 |
)
|
| 214 |
-
SPACE_REPO_CARD_CONTENT = """
|
| 215 |
-
---
|
| 216 |
-
title: {dataset_id} topic modeling
|
| 217 |
-
sdk: static
|
| 218 |
-
pinned: false
|
| 219 |
-
datasets:
|
| 220 |
-
- {dataset_id}
|
| 221 |
-
---
|
| 222 |
-
|
| 223 |
-
"""
|
| 224 |
|
| 225 |
SpaceCard(
|
| 226 |
content=SPACE_REPO_CARD_CONTENT.format(dataset_id=dataset_id)
|
|
@@ -233,14 +232,14 @@ datasets:
|
|
| 233 |
repo_id=repo_id,
|
| 234 |
token=HF_TOKEN,
|
| 235 |
)
|
| 236 |
-
logging.info(f"Space
|
| 237 |
return repo_id
|
| 238 |
|
| 239 |
|
| 240 |
@spaces.GPU(duration=120)
|
| 241 |
def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
| 242 |
logging.info(
|
| 243 |
-
f"Generating topics for {dataset}
|
| 244 |
)
|
| 245 |
|
| 246 |
parquet_urls = get_parquet_urls(dataset, config, split)
|
|
@@ -326,8 +325,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
|
| 326 |
"linewidth": 0,
|
| 327 |
"fc": "#33333377",
|
| 328 |
},
|
| 329 |
-
|
| 330 |
-
dynamic_label_size=False,
|
| 331 |
# label_wrap_width=12,
|
| 332 |
# label_over_points=True,
|
| 333 |
# dynamic_label_size=True,
|
|
@@ -395,7 +393,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
|
| 395 |
# TODO: Export data to .arrow and also serve it
|
| 396 |
inline_data=True,
|
| 397 |
# offline_data_prefix=dataset_clear_name,
|
| 398 |
-
initial_zoom_fraction=0.
|
| 399 |
)
|
| 400 |
html_content = str(interactive_plot)
|
| 401 |
html_file_path = f"{dataset_clear_name}.html"
|
|
@@ -503,7 +501,7 @@ with gr.Blocks() as demo:
|
|
| 503 |
nested_text_column_dropdown: gr.Dropdown(visible=False),
|
| 504 |
}
|
| 505 |
info_resp = session.get(
|
| 506 |
-
f"
|
| 507 |
).json()
|
| 508 |
if "error" in info_resp:
|
| 509 |
return {
|
|
|
|
| 19 |
from huggingface_hub import HfApi, SpaceCard
|
| 20 |
from sklearn.feature_extraction.text import CountVectorizer
|
| 21 |
from sentence_transformers import SentenceTransformer
|
| 22 |
+
from templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT
|
| 23 |
from torch import cuda, bfloat16
|
| 24 |
from transformers import (
|
| 25 |
BitsAndBytesConfig,
|
|
|
|
| 27 |
AutoModelForCausalLM,
|
| 28 |
pipeline,
|
| 29 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
"""
|
| 32 |
TODOs:
|
|
|
|
| 46 |
EXPORTS_REPOSITORY is not None
|
| 47 |
), "You need to set EXPORTS_REPOSITORY in your environment variables"
|
| 48 |
|
| 49 |
+
MAX_ROWS = int(os.getenv("MAX_ROWS", "10_000"))
|
| 50 |
+
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000"))
|
| 51 |
+
DATASET_VIEWE_API_URL = "https://datasets-server.huggingface.co/"
|
| 52 |
+
DATASETS_TOPICS_ORGANIZATION = os.getenv(
|
| 53 |
+
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
|
| 54 |
+
)
|
| 55 |
+
USE_ARROW_STYLE = int(os.getenv("USE_ARROW_STYLE", "0"))
|
| 56 |
+
USE_CUML = int(os.getenv("USE_CUML", "0"))
|
| 57 |
+
|
| 58 |
+
if USE_CUML:
|
| 59 |
+
from cuml.manifold import UMAP
|
| 60 |
+
from cuml.cluster import HDBSCAN
|
| 61 |
+
else:
|
| 62 |
+
from umap import UMAP
|
| 63 |
+
from hdbscan import HDBSCAN
|
| 64 |
+
|
| 65 |
+
USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
|
| 66 |
+
|
| 67 |
logging.basicConfig(
|
| 68 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 69 |
)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
| 71 |
api = HfApi(token=HF_TOKEN)
|
|
|
|
| 72 |
session = requests.Session()
|
| 73 |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 74 |
|
| 75 |
# Representation model
|
| 76 |
+
if USE_LLM_TEXT_GENERATION:
|
| 77 |
+
bnb_config = BitsAndBytesConfig(
|
| 78 |
+
load_in_4bit=True,
|
| 79 |
+
bnb_4bit_quant_type="nf4",
|
| 80 |
+
bnb_4bit_use_double_quant=True,
|
| 81 |
+
bnb_4bit_compute_dtype=bfloat16,
|
| 82 |
+
)
|
| 83 |
|
| 84 |
+
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 86 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
+
model_id,
|
| 88 |
+
trust_remote_code=True,
|
| 89 |
+
quantization_config=bnb_config,
|
| 90 |
+
device_map="auto",
|
| 91 |
+
)
|
| 92 |
+
model.eval()
|
| 93 |
+
generator = pipeline(
|
| 94 |
+
model=model,
|
| 95 |
+
tokenizer=tokenizer,
|
| 96 |
+
task="text-generation",
|
| 97 |
+
temperature=0.1,
|
| 98 |
+
max_new_tokens=500,
|
| 99 |
+
repetition_penalty=1.1,
|
| 100 |
+
)
|
| 101 |
+
representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
|
| 102 |
+
else:
|
| 103 |
+
representation_model = KeyBERTInspired()
|
| 104 |
|
| 105 |
vectorizer_model = CountVectorizer(stop_words="english")
|
| 106 |
|
| 107 |
|
| 108 |
def get_split_rows(dataset, config, split):
|
| 109 |
config_size = session.get(
|
| 110 |
+
f"{DATASET_VIEWE_API_URL}/size?dataset={dataset}&config={config}",
|
| 111 |
timeout=20,
|
| 112 |
).json()
|
| 113 |
if "error" in config_size:
|
|
|
|
| 123 |
|
| 124 |
def get_parquet_urls(dataset, config, split):
|
| 125 |
parquet_files = session.get(
|
| 126 |
+
f"{DATASET_VIEWE_API_URL}/parquet?dataset={dataset}&config={config}&split={split}",
|
| 127 |
timeout=20,
|
| 128 |
).json()
|
| 129 |
if "error" in parquet_files:
|
|
|
|
| 136 |
def get_docs_from_parquet(parquet_urls, column, offset, limit):
|
| 137 |
SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
|
| 138 |
df = duckdb.sql(SQL_QUERY).to_df()
|
|
|
|
| 139 |
return df[column].tolist()
|
| 140 |
|
| 141 |
|
|
|
|
| 210 |
|
| 211 |
|
| 212 |
def create_space_with_content(dataset_id, html_file_path):
|
| 213 |
+
repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_id.replace('/', '-')}"
|
|
|
|
| 214 |
logging.info(f"Creating space with content: {repo_id} on file {html_file_path}")
|
| 215 |
api.create_repo(
|
| 216 |
repo_id=repo_id,
|
|
|
|
| 220 |
token=HF_TOKEN,
|
| 221 |
space_sdk="static",
|
| 222 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
SpaceCard(
|
| 225 |
content=SPACE_REPO_CARD_CONTENT.format(dataset_id=dataset_id)
|
|
|
|
| 232 |
repo_id=repo_id,
|
| 233 |
token=HF_TOKEN,
|
| 234 |
)
|
| 235 |
+
logging.info(f"Space creation done")
|
| 236 |
return repo_id
|
| 237 |
|
| 238 |
|
| 239 |
@spaces.GPU(duration=120)
|
| 240 |
def generate_topics(dataset, config, split, column, nested_column, plot_type):
|
| 241 |
logging.info(
|
| 242 |
+
f"Generating topics for {dataset=} {config=} {split=} {column=} {nested_column=} {plot_type=}"
|
| 243 |
)
|
| 244 |
|
| 245 |
parquet_urls = get_parquet_urls(dataset, config, split)
|
|
|
|
| 325 |
"linewidth": 0,
|
| 326 |
"fc": "#33333377",
|
| 327 |
},
|
| 328 |
+
dynamic_label_size=USE_ARROW_STYLE,
|
|
|
|
| 329 |
# label_wrap_width=12,
|
| 330 |
# label_over_points=True,
|
| 331 |
# dynamic_label_size=True,
|
|
|
|
| 393 |
# TODO: Export data to .arrow and also serve it
|
| 394 |
inline_data=True,
|
| 395 |
# offline_data_prefix=dataset_clear_name,
|
| 396 |
+
initial_zoom_fraction=0.8,
|
| 397 |
)
|
| 398 |
html_content = str(interactive_plot)
|
| 399 |
html_file_path = f"{dataset_clear_name}.html"
|
|
|
|
| 501 |
nested_text_column_dropdown: gr.Dropdown(visible=False),
|
| 502 |
}
|
| 503 |
info_resp = session.get(
|
| 504 |
+
f"{DATASET_VIEWE_API_URL}/info?dataset={dataset}", timeout=20
|
| 505 |
).json()
|
| 506 |
if "error" in info_resp:
|
| 507 |
return {
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
spaces
|
| 4 |
gradio
|
| 5 |
torch
|
|
|
|
| 1 |
+
--extra-index-url https://pypi.nvidia.com
|
| 2 |
+
cuml-cu11
|
| 3 |
spaces
|
| 4 |
gradio
|
| 5 |
torch
|
prompts.py → templates.py
RENAMED
|
@@ -29,3 +29,14 @@ Based on the information about the topic above, please create a short label of t
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
|
| 32 |
+
|
| 33 |
+
SPACE_REPO_CARD_CONTENT = """
|
| 34 |
+
---
|
| 35 |
+
title: {dataset_id} topic modeling
|
| 36 |
+
sdk: static
|
| 37 |
+
pinned: false
|
| 38 |
+
datasets:
|
| 39 |
+
- {dataset_id}
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
"""
|