Spaces:
Runtime error
Runtime error
| import requests | |
| from collections import Counter | |
| from requests.adapters import HTTPAdapter, Retry | |
| import os | |
| import time | |
| import logging | |
| import gradio as gr | |
| import pandas as pd | |
| import polars as pl | |
| import matplotlib.pyplot as plt | |
| import spaces | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from huggingface_hub import PyTorchModelHubMixin | |
| import torch | |
| from torch import nn | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| from tqdm import tqdm | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504]) | |
| session.mount('http://', HTTPAdapter(max_retries=retries)) | |
| class QualityModel(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, config): | |
| super(QualityModel, self).__init__() | |
| self.model = AutoModel.from_pretrained(config["base_model"]) | |
| self.dropout = nn.Dropout(config["fc_dropout"]) | |
| self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) | |
| def forward(self, input_ids, attention_mask): | |
| features = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| ).last_hidden_state | |
| dropped = self.dropout(features) | |
| outputs = self.fc(dropped) | |
| return torch.softmax(outputs[:, 0, :], dim=1) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta") | |
| tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta") | |
| model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device) | |
| # model = torch.compile(model) | |
| model.eval() | |
| def predict(texts: list[str]): | |
| inputs = tokenizer( | |
| texts, return_tensors="pt", padding="longest", truncation=True | |
| ).to(device) | |
| outputs = model(inputs["input_ids"], inputs["attention_mask"]) | |
| predicted_classes = torch.argmax(outputs, dim=1) | |
| predicted_domains = [ | |
| config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy() | |
| ] | |
| return predicted_domains | |
| def plot_and_df(texts, preds): | |
| texts_df = pd.DataFrame({"quality": preds, "text": texts}) | |
| counts = Counter(preds) | |
| counts_df = pd.DataFrame( | |
| { | |
| "quality": ["Low", "Medium", "High"], | |
| "count": [counts.get("Low", 0), counts.get("Medium", 0), counts.get("High", 0)] | |
| } | |
| ) | |
| # counts.reset_index(inplace=True) | |
| return ( | |
| gr.BarPlot(counts_df, x="quality", y="count", sort=None), | |
| texts_df[texts_df["quality"] == "Low"][["text"]][:min(texts_df.shape[0], 20)], | |
| texts_df[texts_df["quality"] == "Medium"][["text"]][:min(texts_df.shape[0], 20)], | |
| texts_df[texts_df["quality"] == "High"][["text"]][:min(texts_df.shape[0], 20)], | |
| ) | |
| def get_first_parquet_filename(dataset, config, split): | |
| parquet_resp = session.get(f"https://datasets-server.huggingface.co/parquet?dataset={dataset}&config={config}", timeout=20).json() | |
| if "error" in parquet_resp: | |
| raise ValueError(parquet_resp["error"]) | |
| first_parquet_file_url = [file for file in parquet_resp["parquet_files"] if file["split"] == split][0]["url"] | |
| return "/".join(first_parquet_file_url.split("/")[-3:]) | |
| def run_quality_check(dataset, config, split, column, nested_column, batch_size, num_examples): | |
| logging.info(f"Fetching data for {dataset=} {config=} {split=} {column=}") | |
| try: | |
| filename = get_first_parquet_filename(dataset, config, split) | |
| except Exception as error: | |
| yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| try: | |
| logging.info(f"Loading hf://datasets/{dataset}@~parquet/{filename}") | |
| yield f"loading data...", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{filename}", columns=[column]) | |
| except Exception as error: | |
| yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| logging.info("Data fetched.") | |
| data_sample = data.sample(num_examples, seed=16) if data.shape[0] > num_examples else data | |
| texts = data_sample[column].to_list() | |
| if nested_column: | |
| texts = [text[nested_column] for text in texts] | |
| predictions, texts_processed = [], [] | |
| num_examples = min(len(texts), num_examples) | |
| for i in range(0, num_examples, batch_size): | |
| batch_texts = texts[i:i+batch_size] | |
| try: | |
| batch_predictions = predict(batch_texts) | |
| except Exception as error: | |
| yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| predictions.extend(batch_predictions) | |
| texts_processed.extend(batch_texts) | |
| yield {"quality check in progress...": i / num_examples}, *plot_and_df(texts_processed, predictions), pd.DataFrame() | |
| yield {"quality check finished": 1.}, *plot_and_df(texts_processed, predictions), data_sample | |
| PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY") | |
| PERSPECTIVE_URL = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={PERSPECTIVE_API_KEY}" | |
| REQUESTED_ATTRIBUTES = {"TOXICITY": {}, "SEVERE_TOXICITY": {}, | |
| "IDENTITY_ATTACK": {}, "INSULT": {}, "PROFANITY": {}, | |
| "THREAT": {}} | |
| ATT_SCORE = "attributeScores" | |
| SUM_SCORE = "summaryScore" | |
| def plot_toxicity(scores): | |
| fig, axs = plt.subplots(2, 3)#, figsize=(10, 6)) | |
| for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores): | |
| axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.)) | |
| axs[x,y].set_xlabel(score_name) | |
| fig.supylabel("Number of texts") | |
| fig.suptitle("Histogram of toxicity scores") | |
| fig.tight_layout() | |
| return fig | |
| def call_perspective_api(texts_df, column_name, nested_column_name, dataset, config, split):#, full_check=False): | |
| headers = { | |
| "content-type": "application/json", | |
| } | |
| req_att_scores = {**{attr: [] for attr in REQUESTED_ATTRIBUTES}} | |
| texts_processed = {column_name: []} | |
| # fetch data if it doesn't exist yet | |
| if texts_df.values.tolist() == [['', '', '']]: | |
| logging.info(f"Fetching data for {dataset=} {config=} {split=} {column_name=}") | |
| try: | |
| filename = get_first_parquet_filename(dataset, config, split) | |
| except Exception as error: | |
| yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| try: | |
| logging.info(f"Loading hf://datasets/{dataset}@~parquet/{filename}") | |
| yield f"loading data...", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| texts_df = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{filename}", columns=[column_name]) | |
| except Exception as error: | |
| yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| logging.info("Data fetched.") | |
| texts_df = texts_df.to_pandas() | |
| texts = texts_df.sample(100, random_state=16)[column_name].values if texts_df.shape[0] > 100 else texts_df[column_name].values | |
| if nested_column_name: | |
| texts = [text[nested_column_name] for text in texts] | |
| n_samples = len(texts) | |
| for i, text in tqdm(enumerate(texts), desc="scanning with perspective"): | |
| data = { | |
| "comment": {"text": text}, | |
| "languages": ["en"], | |
| "requestedAttributes": REQUESTED_ATTRIBUTES | |
| } | |
| time.sleep(1) | |
| try: | |
| req_response = session.post(PERSPECTIVE_URL, json=data, headers=headers) | |
| except Exception as e: | |
| logging.info(e) | |
| logging.info(data) | |
| # yield {"bad request, example skipped...": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({**texts_processed, **req_att_scores}) | |
| continue | |
| if req_response.ok: | |
| response = req_response.json() | |
| if ATT_SCORE in response: | |
| texts_processed[column_name].append(text) | |
| for req_att in REQUESTED_ATTRIBUTES: | |
| if req_att in response[ATT_SCORE]: | |
| att_score = response[ATT_SCORE][req_att][SUM_SCORE]["value"] | |
| req_att_scores[req_att].append(att_score) | |
| else: | |
| req_att_scores[req_att].append(0) | |
| else: | |
| raise ValueError(req_response) | |
| else: | |
| try: | |
| req_response.raise_for_status() | |
| except Exception as e: | |
| logging.info(e) | |
| logging.info(data) | |
| # yield {"bad request, example skipped": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({**texts_processed, **req_att_scores}) | |
| continue | |
| if i % 10 == 0: | |
| plot_toxicity(req_att_scores) | |
| yield {"toxicity check in progress...": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({**texts_processed, **req_att_scores}) | |
| plot_toxicity(req_att_scores) | |
| yield {"toxicity check finished.": 1.}, plt.gcf(), pd.DataFrame.from_dict({**texts_processed, **req_att_scores}) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 📈 Text Data Quality Checker 📉 | |
| This space gives some instruments to have a quick glance at the quality of an English text dataset. | |
| * It uses [NVIDIA's quality classifier model](https://huggingface.co/nvidia/quality-classifier-deberta) | |
| on a small subset of texts. | |
| * It uses [Perspective](https://perspectiveapi.com/how-it-works/) API to check toxicity of 100 random dataset texts | |
| ## Select dataset and text column | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| dataset_name = HuggingfaceHubSearch( | |
| label="Hub Dataset ID", | |
| placeholder="Search for dataset id on Huggingface", | |
| search_type="dataset", | |
| ) | |
| subset_dropdown = gr.Dropdown(label="Subset", visible=False) | |
| split_dropdown = gr.Dropdown(label="Split", visible=False) | |
| with gr.Accordion("Dataset preview", open=False): | |
| def embed(name, subset, split): | |
| html_code = f""" | |
| <iframe | |
| src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}" | |
| frameborder="0" | |
| width="100%" | |
| height="600px" | |
| ></iframe> | |
| """ | |
| return gr.HTML(value=html_code) | |
| with gr.Row(): | |
| text_column_dropdown = gr.Dropdown(label="Text column name") | |
| nested_text_column_dropdown = gr.Dropdown(label="Nested text column name", visible=False) | |
| def _resolve_dataset_selection(dataset: str, default_subset: str, default_split: str, text_feature): | |
| if "/" not in dataset.strip().strip("/"): | |
| return { | |
| subset_dropdown: gr.Dropdown(visible=False), | |
| split_dropdown: gr.Dropdown(visible=False), | |
| text_column_dropdown: gr.Dropdown(label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False) | |
| } | |
| info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=20).json() | |
| if "error" in info_resp: | |
| return { | |
| subset_dropdown: gr.Dropdown(visible=False), | |
| split_dropdown: gr.Dropdown(visible=False), | |
| text_column_dropdown: gr.Dropdown(label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False) | |
| } | |
| subsets: list[str] = list(info_resp["dataset_info"]) | |
| subset = default_subset if default_subset in subsets else subsets[0] | |
| splits: list[str] = list(info_resp["dataset_info"][subset]["splits"]) | |
| split = default_split if default_split in splits else splits[0] | |
| features = info_resp["dataset_info"][subset]["features"] | |
| def _is_string_feature(feature): | |
| return isinstance(feature, dict) and feature.get("dtype") == "string" | |
| text_features = [feature_name for feature_name, feature in features.items() if _is_string_feature(feature)] | |
| nested_features = [feature_name for feature_name, feature in features.items() if isinstance(feature, dict) and isinstance(next(iter(feature.values())), dict)] | |
| nested_text_features = [feature_name for feature_name in nested_features if any(_is_string_feature(nested_feature) for nested_feature in features[feature_name].values())] | |
| if not text_feature: | |
| return { | |
| subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1), | |
| split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1), | |
| text_column_dropdown: gr.Dropdown(choices=text_features + nested_text_features, label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| if text_feature in nested_text_features: | |
| nested_keys = [feature_name for feature_name, feature in features[text_feature].items() if _is_string_feature(feature)] | |
| return { | |
| subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1), | |
| split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1), | |
| text_column_dropdown: gr.Dropdown(choices=text_features + nested_text_features, | |
| label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(value=nested_keys[0], choices=nested_keys, | |
| label="Nested text column name", visible=True) | |
| } | |
| return { | |
| subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1), | |
| split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1), | |
| text_column_dropdown: gr.Dropdown(choices=text_features + nested_text_features, label="Text column name"), | |
| nested_text_column_dropdown: gr.Dropdown(visible=False), | |
| } | |
| def show_input_from_subset_dropdown(dataset: str) -> dict: | |
| return _resolve_dataset_selection(dataset, default_subset="default", default_split="train", text_feature=None) | |
| def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
| return _resolve_dataset_selection(dataset, default_subset=subset, default_split="train", text_feature=None) | |
| def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
| return _resolve_dataset_selection(dataset, default_subset=subset, default_split=split, text_feature=None) | |
| def show_input_from_text_column_dropdown(dataset: str, subset: str, split: str, text_column) -> dict: | |
| return _resolve_dataset_selection(dataset, default_subset=subset, default_split=split, text_feature=text_column) | |
| gr.Markdown("## Run nvidia quality classifier") | |
| batch_size = gr.Slider(0, 64, 32, step=4, label="Inference batch size", info="(set this to smaller value if this space crashes.)") | |
| num_examples = gr.Slider(0, 5000, 500, step=10, label="Number of examples", info="Number of random examples to run quality classifier on") | |
| gr_check_btn = gr.Button("Check Quality") | |
| progress_bar = gr.Label(show_label=False) | |
| plot = gr.BarPlot() | |
| with gr.Accordion("Explore some individual examples for each class", open=False): | |
| gr.Markdown("### Low") | |
| df_low = gr.DataFrame() | |
| gr.Markdown("### Medium") | |
| df_medium = gr.DataFrame() | |
| gr.Markdown("### High") | |
| df_high = gr.DataFrame() | |
| texts_df = gr.DataFrame(visible=False) | |
| gr.Examples( | |
| [ | |
| ["HuggingFaceFW/fineweb-edu", "default", "train", "text", None, 16, 500], | |
| # ["fka/awesome-chatgpt-prompts", "default", "train", "prompt", 64, 200], | |
| # ["proj-persona/PersonaHub", "instruction", "train", "synthesized text", 32, 1000], | |
| ["argilla/FinePersonas-v0.1", "default", "train", "persona", None, 64, 5000], | |
| ["allenai/real-toxicity-prompts", "default", "train", "continuation", "text", 64, 5000], | |
| ], | |
| [dataset_name, subset_dropdown, split_dropdown, text_column_dropdown, nested_text_column_dropdown, batch_size, num_examples], | |
| [progress_bar, plot, df_low, df_medium, df_high, texts_df], | |
| fn=run_quality_check, | |
| run_on_click=False, | |
| cache_examples=False, | |
| ) | |
| gr_check_btn.click( | |
| run_quality_check, | |
| inputs=[dataset_name, subset_dropdown, split_dropdown, text_column_dropdown, nested_text_column_dropdown, batch_size, num_examples], | |
| outputs=[progress_bar, plot, df_low, df_medium, df_high, texts_df] | |
| ) | |
| gr.Markdown("""## Explore toxicity | |
| Run [Perspective](https://perspectiveapi.com/how-it-works/) on 100 random samples to check toxicity | |
| """) | |
| gr_toxicity_btn = gr.Button("Check Toxicity") | |
| toxicity_progress_bar = gr.Label(show_label=False) | |
| toxicity_hist = gr.Plot() | |
| with gr.Accordion("Explore examples with toxicity scores:", open=False): | |
| toxicity_df = gr.DataFrame() | |
| gr_toxicity_btn.click( | |
| call_perspective_api, | |
| inputs=[texts_df, text_column_dropdown, nested_text_column_dropdown, dataset_name, subset_dropdown, split_dropdown],#, checkbox], | |
| outputs=[toxicity_progress_bar, toxicity_hist, toxicity_df] | |
| ) | |
| demo.launch() |