Spaces:
Runtime error
Runtime error
| import datasets | |
| import os | |
| from tqdm import tqdm | |
| import webdataset as wds | |
| import json | |
| DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/the_pile/all/train" | |
| OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile" | |
| SAMPLE_PER_SHARD = 100000 | |
| if __name__ == "__main__": | |
| os.makedirs(OUT_DIR) | |
| print("load dataset...") | |
| pile = datasets.load_from_disk(DATASET_ROOT) | |
| total_num = pile.num_rows | |
| print("total num:", total_num) | |
| num = 0 | |
| pbar = tqdm(total=total_num) | |
| with wds.ShardWriter(OUT_DIR+"/%05d.tar", maxcount=SAMPLE_PER_SHARD, encoder=False) as sink: | |
| for sample in pile.iter(4096): | |
| for text, meta in zip(sample["text"], sample["meta"]): | |
| pbar.update(1) | |
| if meta.get("pile_set_name", None) == "Github": | |
| continue | |
| num += 1 | |
| sink.write({ | |
| '__key__': str(num), | |
| 'txt': text.encode("utf-8"), | |
| 'json': json.dumps(meta, indent=4).encode("utf-8"), | |
| }) | |
| print(f"{num} out of {total_num} is written") | |