Spaces:
Runtime error
Runtime error
| import logging | |
| import pandas as pd | |
| from datasets import load_metric | |
| from os.path import exists | |
| from os.path import join as pjoin | |
| import utils | |
| from utils import dataset_utils as ds_utils | |
| logs = utils.prepare_logging(__file__) | |
| TOK_MODEL = "gpt2" | |
| PERPLEXITY = load_metric("perplexity") | |
| PERPLEXITY_FIELD = "perplexity" | |
| class DMTHelper: | |
| def __init__(self, dstats, load_only=False): | |
| self.dstats = dstats | |
| self.load_only = load_only | |
| self.results_dict = {} | |
| # Where in the Dataset object to find the text for the calculation | |
| self.text_field = ds_utils.OUR_TEXT_FIELD | |
| # Results in dataframe form | |
| self.df = None | |
| # Cache file | |
| self.perplexities_df_fid = pjoin(self.dstats.dataset_cache_dir, | |
| "perplexities_df.json") | |
| def run_DMT_processing(self): | |
| if self.dstats.use_cache and exists(self.perplexities_df_fid): | |
| self.df = ds_utils.read_df(self.perplexities_df_fid) | |
| elif not self.load_only: | |
| self.prepare_text_perplexities() | |
| if self.dstats.save: | |
| ds_utils.write_df(self.df, self.perplexities_df_fid) | |
| def prepare_text_perplexities(self): | |
| texts = self.dstats.text_dset[self.text_field] | |
| eval_results = PERPLEXITY.compute(input_texts=texts, model_id=TOK_MODEL) | |
| # TODO: What other stuff might be useful to grab? | |
| self.results_dict = {PERPLEXITY_FIELD: eval_results["perplexities"], | |
| self.text_field: self.dstats.text_dset[self.text_field]} | |
| self.df = pd.DataFrame(self.results_dict).sort_values( | |
| by=PERPLEXITY_FIELD, ascending=False) | |
| def get_df(self): | |
| return self.df | |