Spaces:
Build error
Build error
meg-huggingface
commited on
Commit
·
9f6cc2b
1
Parent(s):
79d1ddf
Addressing lengths bug; changing example; changing default way of running to be the feature-specified version
Browse files- run_data_measurements.py +6 -11
run_data_measurements.py
CHANGED
|
@@ -30,13 +30,15 @@ def load_or_prepare_widgets(ds_args, show_embeddings=False, use_cache=False):
|
|
| 30 |
|
| 31 |
dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args,
|
| 32 |
use_cache=use_cache)
|
|
|
|
|
|
|
| 33 |
# Header widget
|
| 34 |
dstats.load_or_prepare_dset_peek()
|
| 35 |
# General stats widget
|
| 36 |
dstats.load_or_prepare_general_stats()
|
| 37 |
# Labels widget
|
| 38 |
try:
|
| 39 |
-
dstats.set_label_field(
|
| 40 |
dstats.load_or_prepare_labels()
|
| 41 |
except:
|
| 42 |
pass
|
|
@@ -79,15 +81,7 @@ def load_or_prepare(dataset_args, do_html=False, use_cache=False):
|
|
| 79 |
|
| 80 |
if all or dataset_args["calculation"] == "lengths":
|
| 81 |
print("\n* Calculating text lengths.")
|
| 82 |
-
fig_tok_length_fid = pjoin(dstats.cache_path, "lengths_fig.html")
|
| 83 |
-
tok_length_json_fid = pjoin(dstats.cache_path, "lengths.json")
|
| 84 |
dstats.load_or_prepare_text_lengths()
|
| 85 |
-
with open(tok_length_json_fid, "w+") as f:
|
| 86 |
-
json.dump(dstats.fig_tok_length.to_json(), f)
|
| 87 |
-
print("Token lengths now available at %s." % tok_length_json_fid)
|
| 88 |
-
if do_html:
|
| 89 |
-
dstats.fig_tok_length.write_html(fig_tok_length_fid)
|
| 90 |
-
print("Figure saved to %s." % fig_tok_length_fid)
|
| 91 |
print("Done!")
|
| 92 |
|
| 93 |
if all or dataset_args["calculation"] == "labels":
|
|
@@ -95,6 +89,7 @@ def load_or_prepare(dataset_args, do_html=False, use_cache=False):
|
|
| 95 |
print("Warning: You asked for label calculation, but didn't provide "
|
| 96 |
"the labels field name. Assuming it is 'label'...")
|
| 97 |
dstats.set_label_field("label")
|
|
|
|
| 98 |
print("\n* Calculating label distribution.")
|
| 99 |
dstats.load_or_prepare_labels()
|
| 100 |
fig_label_html = pjoin(dstats.cache_path, "labels_fig.html")
|
|
@@ -190,7 +185,7 @@ def get_text_label_df(
|
|
| 190 |
"calculation": calculation,
|
| 191 |
"cache_dir": out_dir,
|
| 192 |
}
|
| 193 |
-
|
| 194 |
|
| 195 |
|
| 196 |
def main():
|
|
@@ -272,7 +267,7 @@ def main():
|
|
| 272 |
args = parser.parse_args()
|
| 273 |
print("Proceeding with the following arguments:")
|
| 274 |
print(args)
|
| 275 |
-
# run_data_measurements.py -
|
| 276 |
get_text_label_df(
|
| 277 |
args.dataset,
|
| 278 |
args.config,
|
|
|
|
| 30 |
|
| 31 |
dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args,
|
| 32 |
use_cache=use_cache)
|
| 33 |
+
# Embeddings widget
|
| 34 |
+
dstats.load_or_prepare_dataset()
|
| 35 |
# Header widget
|
| 36 |
dstats.load_or_prepare_dset_peek()
|
| 37 |
# General stats widget
|
| 38 |
dstats.load_or_prepare_general_stats()
|
| 39 |
# Labels widget
|
| 40 |
try:
|
| 41 |
+
dstats.set_label_field(ds_args['label_field'])
|
| 42 |
dstats.load_or_prepare_labels()
|
| 43 |
except:
|
| 44 |
pass
|
|
|
|
| 81 |
|
| 82 |
if all or dataset_args["calculation"] == "lengths":
|
| 83 |
print("\n* Calculating text lengths.")
|
|
|
|
|
|
|
| 84 |
dstats.load_or_prepare_text_lengths()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
print("Done!")
|
| 86 |
|
| 87 |
if all or dataset_args["calculation"] == "labels":
|
|
|
|
| 89 |
print("Warning: You asked for label calculation, but didn't provide "
|
| 90 |
"the labels field name. Assuming it is 'label'...")
|
| 91 |
dstats.set_label_field("label")
|
| 92 |
+
else:
|
| 93 |
print("\n* Calculating label distribution.")
|
| 94 |
dstats.load_or_prepare_labels()
|
| 95 |
fig_label_html = pjoin(dstats.cache_path, "labels_fig.html")
|
|
|
|
| 185 |
"calculation": calculation,
|
| 186 |
"cache_dir": out_dir,
|
| 187 |
}
|
| 188 |
+
load_or_prepare(dataset_args, use_cache=use_cache)
|
| 189 |
|
| 190 |
|
| 191 |
def main():
|
|
|
|
| 267 |
args = parser.parse_args()
|
| 268 |
print("Proceeding with the following arguments:")
|
| 269 |
print(args)
|
| 270 |
+
# run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi
|
| 271 |
get_text_label_df(
|
| 272 |
args.dataset,
|
| 273 |
args.config,
|