Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| from dotenv import load_dotenv | |
| import plotly | |
| import shutil | |
| import smtplib | |
| import ssl | |
| import sys | |
| import textwrap | |
| from data_measurements import dataset_statistics | |
| from data_measurements.zipf import zipf | |
| from huggingface_hub import create_repo, Repository, hf_api | |
| from os import getenv | |
| from os.path import exists, join as pjoin | |
| from pathlib import Path | |
| import utils | |
| from utils import dataset_utils | |
| logs = utils.prepare_logging(__file__) | |
| def load_or_prepare_widgets(ds_args, show_embeddings=False, | |
| show_perplexities=False, use_cache=False): | |
| """ | |
| Loader specifically for the widgets used in the app. | |
| Args: | |
| ds_args: | |
| show_embeddings: | |
| show_perplexities: | |
| use_cache: | |
| Returns: | |
| """ | |
| dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args, use_cache=use_cache) | |
| # Header widget | |
| dstats.load_or_prepare_dset_peek() | |
| # General stats widget | |
| dstats.load_or_prepare_general_stats() | |
| # Labels widget | |
| dstats.load_or_prepare_labels() | |
| # Text lengths widget | |
| dstats.load_or_prepare_text_lengths() | |
| if show_embeddings: | |
| # Embeddings widget | |
| dstats.load_or_prepare_embeddings() | |
| if show_perplexities: | |
| # Text perplexities widget | |
| dstats.load_or_prepare_text_perplexities() | |
| # Text duplicates widget | |
| dstats.load_or_prepare_text_duplicates() | |
| # nPMI widget | |
| dstats.load_or_prepare_npmi() | |
| # Zipf widget | |
| dstats.load_or_prepare_zipf() | |
| def load_or_prepare(dataset_args, calculation=False, use_cache=False): | |
| # TODO: Catch error exceptions for each measurement, so that an error | |
| # for one measurement doesn't break the calculation of all of them. | |
| do_all = False | |
| dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, | |
| use_cache=use_cache) | |
| logs.info("Tokenizing dataset.") | |
| dstats.load_or_prepare_tokenized_df() | |
| logs.info("Calculating vocab.") | |
| dstats.load_or_prepare_vocab() | |
| if not calculation: | |
| do_all = True | |
| if do_all or calculation == "general": | |
| logs.info("\n* Calculating general statistics.") | |
| dstats.load_or_prepare_general_stats() | |
| logs.info("Done!") | |
| logs.info( | |
| "Basic text statistics now available at %s." % dstats.general_stats_json_fid) | |
| if do_all or calculation == "duplicates": | |
| logs.info("\n* Calculating text duplicates.") | |
| dstats.load_or_prepare_text_duplicates() | |
| duplicates_fid_dict = dstats.duplicates_files | |
| logs.info("If all went well, then results are in the following files:") | |
| for key, value in duplicates_fid_dict.items(): | |
| logs.info("%s: %s" % (key, value)) | |
| if do_all or calculation == "lengths": | |
| logs.info("\n* Calculating text lengths.") | |
| dstats.load_or_prepare_text_lengths() | |
| length_fid_dict = dstats.length_obj.get_filenames() | |
| print("If all went well, then results are in the following files:") | |
| for key, value in length_fid_dict.items(): | |
| print("%s: %s" % (key, value)) | |
| print() | |
| if do_all or calculation == "labels": | |
| logs.info("\n* Calculating label statistics.") | |
| if dstats.label_field not in dstats.dset.features: | |
| logs.warning("No label field found.") | |
| logs.info("No label statistics to calculate.") | |
| else: | |
| dstats.load_or_prepare_labels() | |
| npmi_fid_dict = dstats.label_files | |
| print("If all went well, then results are in the following files:") | |
| for key, value in npmi_fid_dict.items(): | |
| print("%s: %s" % (key, value)) | |
| print() | |
| if do_all or calculation == "npmi": | |
| print("\n* Preparing nPMI.") | |
| dstats.load_or_prepare_npmi() | |
| npmi_fid_dict = dstats.npmi_files | |
| print("If all went well, then results are in the following files:") | |
| for key, value in npmi_fid_dict.items(): | |
| if isinstance(value, dict): | |
| print(key + ":") | |
| for key2, value2 in value.items(): | |
| print("\t%s: %s" % (key2, value2)) | |
| else: | |
| print("%s: %s" % (key, value)) | |
| print() | |
| if do_all or calculation == "zipf": | |
| logs.info("\n* Preparing Zipf.") | |
| dstats.load_or_prepare_zipf() | |
| logs.info("Done!") | |
| zipf_json_fid, zipf_fig_json_fid, zipf_fig_html_fid = zipf.get_zipf_fids( | |
| dstats.dataset_cache_dir) | |
| logs.info("Zipf results now available at %s." % zipf_json_fid) | |
| logs.info( | |
| "Figure saved to %s, with corresponding json at %s." | |
| % (zipf_fig_html_fid, zipf_fig_json_fid) | |
| ) | |
| # Don't do this one until someone specifically asks for it -- takes awhile. | |
| if calculation == "embeddings": | |
| logs.info("\n* Preparing text embeddings.") | |
| dstats.load_or_prepare_embeddings() | |
| # Don't do this one until someone specifically asks for it -- takes awhile. | |
| if calculation == "perplexities": | |
| logs.info("\n* Preparing text perplexities.") | |
| dstats.load_or_prepare_text_perplexities() | |
| def pass_args_to_DMT(dset_name, dset_config, split_name, text_field, label_field, label_names, calculation, dataset_cache_dir, prepare_gui=False, use_cache=True): | |
| if not use_cache: | |
| logs.info("Not using any cache; starting afresh") | |
| dataset_args = { | |
| "dset_name": dset_name, | |
| "dset_config": dset_config, | |
| "split_name": split_name, | |
| "text_field": text_field, | |
| "label_field": label_field, | |
| "label_names": label_names, | |
| "dataset_cache_dir": dataset_cache_dir | |
| } | |
| if prepare_gui: | |
| load_or_prepare_widgets(dataset_args, use_cache=use_cache) | |
| else: | |
| load_or_prepare(dataset_args, calculation=calculation, use_cache=use_cache) | |
| def set_defaults(args): | |
| if not args.config: | |
| args.config = "default" | |
| logs.info("Config name not specified. Assuming it's 'default'.") | |
| if not args.split: | |
| args.split = "train" | |
| logs.info("Split name not specified. Assuming it's 'train'.") | |
| if not args.feature: | |
| args.feature = "text" | |
| logs.info("Text column name not given. Assuming it's 'text'.") | |
| if not args.label_field: | |
| args.label_field = "label" | |
| logs.info("Label column name not given. Assuming it's 'label'.") | |
| return args | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| description=textwrap.dedent( | |
| """ | |
| Example for hate speech18 dataset: | |
| python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text" | |
| Example for IMDB dataset: | |
| python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text" | |
| """ | |
| ), | |
| ) | |
| parser.add_argument( | |
| "-d", "--dataset", required=True, help="Name of dataset to prepare" | |
| ) | |
| parser.add_argument( | |
| "-c", "--config", required=False, default="", help="Dataset configuration to prepare" | |
| ) | |
| parser.add_argument( | |
| "-s", "--split", required=False, default="", type=str, | |
| help="Dataset split to prepare" | |
| ) | |
| parser.add_argument( | |
| "-f", | |
| "--feature", | |
| "-t", | |
| "--text-field", | |
| required=False, | |
| nargs="+", | |
| type=str, | |
| default="", | |
| help="Column to prepare (handled as text)", | |
| ) | |
| parser.add_argument( | |
| "-w", | |
| "--calculation", | |
| help="""What to calculate (defaults to everything except embeddings and perplexities).\n | |
| Options are:\n | |
| - `general` (for duplicate counts, missing values, length statistics.)\n | |
| - `duplicates` for duplicate counts\n | |
| - `lengths` for text length distribution\n | |
| - `labels` for label distribution\n | |
| - `embeddings` (Warning: Slow.)\n | |
| - `perplexities` (Warning: Slow.)\n | |
| - `npmi` for word associations\n | |
| - `zipf` for zipfian statistics | |
| """, | |
| ) | |
| parser.add_argument( | |
| "-l", | |
| "--label_field", | |
| type=str, | |
| required=False, | |
| default="", | |
| help="Field name for label column in dataset (Required if there is a label field that you want information about)", | |
| ) | |
| parser.add_argument('-n', '--label_names', nargs='+', default=[]) | |
| parser.add_argument( | |
| "--use_cache", | |
| default=False, | |
| required=False, | |
| action="store_true", | |
| help="Whether to use cached files (Optional)", | |
| ) | |
| parser.add_argument("--out_dir", default="cache_dir", | |
| help="Where to write out to.") | |
| parser.add_argument( | |
| "--overwrite_previous", | |
| default=False, | |
| required=False, | |
| action="store_true", | |
| help="Whether to overwrite a previous local cache for these same arguments (Optional)", | |
| ) | |
| parser.add_argument( | |
| "--email", | |
| default=None, | |
| help="An email that recieves a message about whether the computation was successful. If email is not None, then you must have EMAIL_PASSWORD=<your email password> for the sender email (data.measurements.tool@gmail.com) in a file named .env at the root of this repo.") | |
| parser.add_argument( | |
| "--push_cache_to_hub", | |
| default=False, | |
| required=False, | |
| action="store_true", | |
| help="Whether to push the cache to an organization on the hub. If you are using this option, you must have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo.", | |
| ) | |
| parser.add_argument("--prepare_GUI_data", default=False, required=False, | |
| action="store_true", | |
| help="Use this to process all of the stats used in the GUI.") | |
| parser.add_argument("--keep_local", default=True, required=False, | |
| action="store_true", | |
| help="Whether to save the data locally.") | |
| orig_args = parser.parse_args() | |
| args = set_defaults(orig_args) | |
| logs.info("Proceeding with the following arguments:") | |
| logs.info(args) | |
| # run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi | |
| if args.email is not None: | |
| if Path(".env").is_file(): | |
| load_dotenv(".env") | |
| EMAIL_PASSWORD = getenv("EMAIL_PASSWORD") | |
| context = ssl.create_default_context() | |
| port = 465 | |
| server = smtplib.SMTP_SSL("smtp.gmail.com", port, context=context) | |
| server.login("data.measurements.tool@gmail.com", EMAIL_PASSWORD) | |
| dataset_cache_name, local_dataset_cache_dir = dataset_utils.get_cache_dir_naming(args.out_dir, args.dataset, args.config, args.split, args.feature) | |
| if not args.use_cache and exists(local_dataset_cache_dir): | |
| if args.overwrite_previous: | |
| shutil.rmtree(local_dataset_cache_dir) | |
| else: | |
| raise OSError("Cached results for this dataset already exist at %s. " | |
| "Delete it or use the --overwrite_previous argument." % local_dataset_cache_dir) | |
| # Initialize the local cache directory | |
| dataset_utils.make_path(local_dataset_cache_dir) | |
| # Initialize the repository | |
| # TODO: print out local or hub cache directory location. | |
| if args.push_cache_to_hub: | |
| repo = dataset_utils.initialize_cache_hub_repo(local_dataset_cache_dir, dataset_cache_name) | |
| # Run the measurements. | |
| try: | |
| pass_args_to_DMT( | |
| dset_name=args.dataset, | |
| dset_config=args.config, | |
| split_name=args.split, | |
| text_field=args.feature, | |
| label_field=args.label_field, | |
| label_names=args.label_names, | |
| calculation=args.calculation, | |
| dataset_cache_dir=local_dataset_cache_dir, | |
| prepare_gui=args.prepare_GUI_data, | |
| use_cache=args.use_cache, | |
| ) | |
| if args.push_cache_to_hub: | |
| repo.push_to_hub(commit_message="Added dataset cache.") | |
| computed_message = f"Data measurements have been computed for dataset" \ | |
| f" with these arguments: {args}." | |
| logs.info(computed_message) | |
| if args.email is not None: | |
| computed_message += "\nYou can return to the data measurements tool " \ | |
| "to view them." | |
| server.sendmail("data.measurements.tool@gmail.com", args.email, | |
| "Subject: Data Measurements Computed!\n\n" + computed_message) | |
| logs.info(computed_message) | |
| except Exception as e: | |
| logs.exception(e) | |
| error_message = f"An error occurred in computing data measurements " \ | |
| f"for dataset with arguments: {args}. " \ | |
| f"Feel free to make an issue here: " \ | |
| f"https://github.com/huggingface/data-measurements-tool/issues" | |
| if args.email is not None: | |
| server.sendmail("data.measurements.tool@gmail.com", args.email, | |
| "Subject: Data Measurements not Computed\n\n" + error_message) | |
| logs.warning("Data measurements not computed. ☹️") | |
| logs.warning(error_message) | |
| return | |
| if not args.keep_local: | |
| # Remove the dataset from local storage - we only want it stored on the hub. | |
| logs.warning("Deleting measurements data locally at %s" % local_dataset_cache_dir) | |
| shutil.rmtree(local_dataset_cache_dir) | |
| else: | |
| logs.info("Measurements made available locally at %s" % local_dataset_cache_dir) | |
| if __name__ == "__main__": | |
| main() | |