Spaces:
Runtime error
Runtime error
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import ast | |
| import gradio as gr | |
| from os.path import isdir | |
| from data_measurements.dataset_statistics import DatasetStatisticsCacheClass as dmt_cls | |
| import utils | |
| from utils import dataset_utils | |
| from utils import gradio_utils as gr_utils | |
| import widgets | |
| logs = utils.prepare_logging(__file__) | |
| # Utility for sidebar description and selection of the dataset | |
| DATASET_NAME_TO_DICT = dataset_utils.get_dataset_info_dicts() | |
| def get_load_prepare_list(dstats): | |
| """ | |
| # Get load_or_prepare functions for the measurements we will display | |
| """ | |
| # Measurement calculation: | |
| # Add any additional modules and their load-prepare function here. | |
| load_prepare_list = [("general stats", dstats.load_or_prepare_general_stats), | |
| ("label distribution", dstats.load_or_prepare_labels), | |
| ("text_lengths", dstats.load_or_prepare_text_lengths), | |
| ("duplicates", dstats.load_or_prepare_text_duplicates), | |
| ("npmi", dstats.load_or_prepare_npmi), | |
| ("zipf", dstats.load_or_prepare_zipf)] | |
| return load_prepare_list | |
| def get_ui_widgets(): | |
| """Get the widgets that will be displayed in the UI.""" | |
| return [widgets.DatasetDescription(DATASET_NAME_TO_DICT), | |
| widgets.GeneralStats(), | |
| widgets.LabelDistribution(), | |
| widgets.TextLengths(), | |
| widgets.Duplicates(), | |
| widgets.Npmi(), | |
| widgets.Zipf()] | |
| def get_widgets(): | |
| """ | |
| # A measurement widget requires 2 things: | |
| # - A load or prepare function | |
| # - A display function | |
| # We define these in two separate functions get_load_prepare_list and get_ui_widgets; | |
| # any widget can be added by modifying both functions and the rest of the app logic will work. | |
| # get_load_prepare_list is a function since it requires a DatasetStatisticsCacheClass which will | |
| # not be created until dataset and config values are selected in the ui | |
| """ | |
| return get_load_prepare_list, get_ui_widgets() | |
| def get_title(dstats): | |
| title_str = f"### Showing: {dstats.dset_name} - {dstats.dset_config} - {dstats.split_name} - {'-'.join(dstats.text_field)}" | |
| logs.info("showing header") | |
| return title_str | |
| def display_initial_UI(): | |
| """Displays the header in the UI""" | |
| # Extract the selected arguments | |
| dataset_args = gr_utils.sidebar_selection(DATASET_NAME_TO_DICT) | |
| return dataset_args | |
| def load_or_prepare_widgets(dstats, load_prepare_list, show_perplexities, live=True, pull_cache_from_hub=False): | |
| """ | |
| Takes the dataset arguments from the GUI and uses them to load a dataset from the Hub or, if | |
| a cache for those arguments is available, to load it from the cache. | |
| Widget data is loaded only when the system is live (deployed for users). | |
| Otherwise, the data is prepared if it doesn't yet exist. | |
| Args: | |
| ds_args (dict): the dataset arguments defined via the streamlit app GUI | |
| load_prepare_list (list): List of (widget_name, widget_load_or_prepare_function) | |
| show_perplexities (Bool): whether perplexities should be loaded and displayed for this dataset | |
| live (Bool): Whether the system is deployed for live use by users. | |
| pull_cache_from_hub (Bool): Whether the cache should be pulled from the hub (vs locally) | |
| Returns: | |
| dstats: the computed dataset statistics (from the dataset_statistics class) | |
| """ | |
| # When we're "live" (tool is being used by users on our servers), | |
| # cache is used and the f'ns are instructed to only try to load cache, | |
| # not to prepare/compute anything anew. | |
| if live: | |
| # Only use what's cached; don't prepare anything | |
| load_only = True | |
| logs.info("Only using cache.") | |
| else: | |
| # Prepare things anew and cache them if we're not live. | |
| load_only = False | |
| logs.info("Making new calculations if cache is not there.") | |
| if pull_cache_from_hub: | |
| dataset_utils.pull_cache_from_hub(dstats.cache_path, dstats.dataset_cache_dir) | |
| # Data common across DMT: | |
| # Includes the dataset text/requested feature column, | |
| # the dataset tokenized, and the vocabulary | |
| dstats.load_or_prepare_text_dataset(load_only=load_only) | |
| # Just a snippet of the dataset | |
| dstats.load_or_prepare_dset_peek(load_only=load_only) | |
| # Tokenized dataset | |
| dstats.load_or_prepare_tokenized_df(load_only=load_only) | |
| # Vocabulary (uses tokenized dataset) | |
| dstats.load_or_prepare_vocab(load_only=load_only) | |
| # Custom widgets | |
| for widget_tuple in load_prepare_list: | |
| widget_name = widget_tuple[0] | |
| widget_fn = widget_tuple[1] | |
| try: | |
| widget_fn(load_only=load_only) | |
| except Exception as e: | |
| logs.warning("Issue with %s." % widget_name) | |
| logs.exception(e) | |
| # TODO: If these are cached, can't we just show them by default? | |
| # It won't take up computation time. | |
| if show_perplexities: | |
| try: | |
| dstats.load_or_prepare_text_perplexities(load_only=load_only) | |
| except Exception as e: | |
| logs.warning("Issue with %s." % "perplexities") | |
| logs.exception(e) | |
| return dstats | |
| def show_column(dstats, display_list, show_perplexities, column_id=""): | |
| """ | |
| Function for displaying the elements in the streamlit app. | |
| Args: | |
| dstats (class): The dataset_statistics.py DatasetStatisticsCacheClass | |
| display_list (list): List of tuples for (widget_name, widget_display_function) | |
| show_perplexities (Bool): Whether perplexities should be loaded and displayed for this dataset | |
| column_id (str): Which column of the dataset the analysis is done on [DEPRECATED for v1] | |
| """ | |
| # start showing stuff | |
| gr_utils.expander_header(dstats, DATASET_NAME_TO_DICT) | |
| for widget_tuple in display_list: | |
| widget_type = widget_tuple[0] | |
| widget_fn = widget_tuple[1] | |
| logs.info("showing %s." % widget_type) | |
| try: | |
| widget_fn(dstats, column_id) | |
| except Exception as e: | |
| logs.warning("Jk jk jk. There was an issue with %s:" % widget_type) | |
| logs.exception(e) | |
| # TODO: Fix how this is a weird outlier. | |
| if show_perplexities: | |
| gr_utils.expander_text_perplexities(dstats, column_id) | |
| logs.info("Have finished displaying the widgets.") | |
| def create_demo(live: bool, pull_cache_from_hub: bool): | |
| with gr.Blocks() as demo: | |
| state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dataset_args = display_initial_UI() | |
| get_load_prepare_list_fn, widget_list = get_widgets() | |
| # # TODO: Make this less of a weird outlier. | |
| # Doesn't do anything right now | |
| show_perplexities = gr.Checkbox(label="Show text perplexities") | |
| with gr.Column(scale=4): | |
| gr.Markdown("# Data Measurements Tool") | |
| title = gr.Markdown() | |
| for widget in widget_list: | |
| widget.render() | |
| def update_ui(dataset: str, config: str, split: str, feature: str): | |
| feature = ast.literal_eval(feature) | |
| label_field, label_names = gr_utils.get_label_names(dataset, config, DATASET_NAME_TO_DICT) | |
| dstats = dmt_cls(dset_name=dataset, dset_config=config, split_name=split, text_field=feature, | |
| label_field=label_field, label_names=label_names, use_cache=True) | |
| load_prepare_list = get_load_prepare_list_fn(dstats) | |
| dstats = load_or_prepare_widgets(dstats, load_prepare_list, show_perplexities=False, | |
| live=live, pull_cache_from_hub=pull_cache_from_hub) | |
| output = {title: get_title(dstats), state: dstats} | |
| for widget in widget_list: | |
| output.update(widget.update(dstats)) | |
| return output | |
| def update_dataset(dataset: str): | |
| new_values = gr_utils.update_dataset(dataset, DATASET_NAME_TO_DICT) | |
| config = new_values[0][1] | |
| feature = new_values[1][1] | |
| split = new_values[2][1] | |
| new_dropdown = { | |
| dataset_args["dset_config"]: gr.Dropdown.update(choices=new_values[0][0], value=config), | |
| dataset_args["text_field"]: gr.Dropdown.update(choices=new_values[1][0], value=feature), | |
| dataset_args["split_name"]: gr.Dropdown.update(choices=new_values[2][0], value=split), | |
| } | |
| return new_dropdown | |
| def update_config(dataset: str, config: str): | |
| new_values = gr_utils.update_config(dataset, config, DATASET_NAME_TO_DICT) | |
| feature = new_values[0][1] | |
| split = new_values[1][1] | |
| new_dropdown = { | |
| dataset_args["text_field"]: gr.Dropdown.update(choices=new_values[0][0], value=feature), | |
| dataset_args["split_name"]: gr.Dropdown.update(choices=new_values[1][0], value=split) | |
| } | |
| return new_dropdown | |
| measurements = [comp for output in widget_list for comp in output.output_components] | |
| demo.load(update_ui, | |
| inputs=[dataset_args["dset_name"], dataset_args["dset_config"], dataset_args["split_name"], dataset_args["text_field"]], | |
| outputs=[title, state] + measurements) | |
| for widget in widget_list: | |
| widget.add_events(state) | |
| #dataset_args["text_field"] --> the text that could be returned | |
| dataset_args["dset_name"].change(update_dataset, | |
| inputs=[dataset_args["dset_name"]], | |
| outputs=[dataset_args["dset_config"], | |
| dataset_args["split_name"], dataset_args["text_field"], | |
| title, state] + measurements) | |
| dataset_args["dset_config"].change(update_config, | |
| inputs=[dataset_args["dset_name"], dataset_args["dset_config"]], | |
| outputs=[dataset_args["split_name"], dataset_args["text_field"], | |
| title, state] + measurements) | |
| dataset_args["calculate_btn"].click(update_ui, | |
| inputs=[dataset_args["dset_name"], dataset_args["dset_config"], | |
| dataset_args["split_name"], dataset_args["text_field"]], | |
| outputs=[title, state] + measurements) | |
| return demo | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--live", default=False, required=False, action="store_true", help="Flag to specify that this is not running live.") | |
| parser.add_argument( | |
| "--pull_cache_from_hub", default=False, required=False, action="store_true", help="Flag to specify whether to look in the hub for measurements caches. If you are using this option, you must have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo.") | |
| arguments = parser.parse_args() | |
| live = arguments.live | |
| pull_cache_from_hub = arguments.pull_cache_from_hub | |
| # Create and initialize the demo | |
| demo = create_demo(live, pull_cache_from_hub) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |