| import argparse | |
| import torch | |
| import glob | |
| import os | |
| import numpy as np | |
| class MMapIndexDataset(): | |
| def __init__(self, datapath): | |
| self.idxfp = np.load(datapath + '.npy', mmap_mode='r') | |
| self.binfp = np.memmap(datapath + '.bin', dtype='long', mode='r') | |
| def __len__(self): | |
| return self.idxfp.shape[0] | |
| def __getitem__(self, idx): | |
| return self.binfp[self.idxfp[idx, 0]:self.idxfp[idx, 1]] | |
| def convert_py_to_npy(input_tensor, bin_out, idx_out): | |
| idx = torch.empty(len(input_tensor), 2, dtype=torch.long) | |
| start = 0 | |
| for i, input in enumerate(input_tensor): | |
| idx[i] = torch.tensor([start, start + len(input)]) | |
| start += len(input) | |
| np.save(idx_out, idx) | |
| binfp = np.memmap(bin_out, dtype='long', mode='w+', shape=(start)) | |
| start = 0 | |
| for i, input in enumerate(input_tensor): | |
| for j, idx in enumerate(input): | |
| binfp[start + j] = idx | |
| start += len(input) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Text infilling.") | |
| parser.add_argument('--data_path', type=str, | |
| default='/cognitive_comp/gaoxinyu/data/wudao') | |
| args = parser.parse_args() | |
| process_key = [ | |
| 'incorrect_input_ids_list', | |
| 'label_ids_list', | |
| 'target_ids_list', | |
| ] | |
| if os.path.exists(args.data_path): | |
| print(f'''Loading data from {args.data_path}''') | |
| data_dict = torch.load(args.data_path) | |
| for k in process_key: | |
| bin_out = ('_' + k + '.bin').join(args.data_path.rsplit('.pt', 1)) | |
| idx_out = ('_' + k).join(args.data_path.rsplit('.pt', 1)) | |
| convert_py_to_npy(data_dict[k], bin_out, idx_out) | |
| else: | |
| print( | |
| f'Please create the synthetic datafile {args.data_path} with create_synthetic_data.py.') | |