Spaces:
Runtime error
Runtime error
| import argparse | |
| import base64 | |
| import json | |
| import os | |
| import tarfile | |
| import uuid | |
| import zipfile | |
| import time | |
| import braceexpand | |
| import webdataset as wds | |
| from tqdm import tqdm | |
| from tqdm.contrib.concurrent import process_map | |
| arg_parser = argparse.ArgumentParser() | |
| arg_parser.add_argument("--output_dir", type=str) | |
| arg_parser.add_argument( | |
| "--image_shards", | |
| type=str, | |
| help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar", | |
| ) | |
| arg_parser.add_argument( | |
| "--doc_shards", | |
| type=str, | |
| help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip", | |
| ) | |
| arg_parser.add_argument( | |
| "--thread", | |
| type=int, | |
| default=128, | |
| ) | |
| args = arg_parser.parse_args() | |
| def get_txt_to_filename_dict(image_shards, disable_tqdm=False): | |
| txt_to_filename_dict = {} | |
| dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json") | |
| for data in tqdm(dataset, disable=disable_tqdm): | |
| txt = data[0].split(".")[0] | |
| txt_to_filename_dict[txt] = data[1]['key'] | |
| return txt_to_filename_dict | |
| def single_thread(args): | |
| i = args["i"] | |
| output_dir = args["output_dir"] | |
| doc_shards = args["doc_shards"] | |
| image_shards = args["image_shards"] | |
| if i == 0: | |
| tqdm.write(f"output_dir: {output_dir}") | |
| tqdm.write(f"doc_shards: {doc_shards[:5]}") | |
| tqdm.write(f"image_shards: {image_shards[:5]}") | |
| with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink: | |
| sink.verbose = False | |
| for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)): | |
| # txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0)) | |
| # image_tar = tarfile.open(image_shard) | |
| # Open the ZIP archive and extract the JSON file | |
| with zipfile.ZipFile(doc_shard, "r") as zip_file: | |
| # Assumes the JSON file is the first file in the archive | |
| json_filename = zip_file.namelist()[0] | |
| with zip_file.open(json_filename, "r") as json_file: | |
| pbar = tqdm(json_file, disable=True) | |
| total_num = 0 | |
| exist_num = 0 | |
| for sample_data in pbar: | |
| # get image names from json | |
| sample_data = json.loads(sample_data) | |
| image_info = sample_data["image_info"] | |
| image_names = [image["image_name"] for image in image_info] | |
| # Add each image to the tar file | |
| for img_idx, image_name in enumerate(image_names): | |
| total_num += 1 | |
| try: | |
| image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg") | |
| # convert to base64 | |
| image_bytes = image.read() | |
| image_base64 = base64.b64encode(image_bytes).decode("utf-8") | |
| exist_num += 1 | |
| except: | |
| tqdm.write(f"{image_name.split('.')[0]}") | |
| image_base64 = "null" | |
| sample_data["image_info"][img_idx][ | |
| "image_base64" | |
| ] = image_base64 | |
| key_str = uuid.uuid4().hex | |
| sink.write({"__key__": key_str, "json": sample_data}) | |
| pbar.set_description(f"{exist_num/total_num:.2f}") | |
| # image_tar.close() | |
| def main(): | |
| timestamp = int(time.time()) | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True) | |
| tasks = [] | |
| for i in range(args.thread): | |
| thread_dir = os.path.join(args.output_dir, str(timestamp), str(i)) | |
| os.makedirs(thread_dir, exist_ok=True) | |
| tasks.append({ | |
| "i": i, | |
| "output_dir": thread_dir, | |
| "doc_shards": [], | |
| "image_shards": [], | |
| }) | |
| doc_shards = list(braceexpand.braceexpand(args.doc_shards)) | |
| image_shards = list(braceexpand.braceexpand(args.image_shards)) | |
| assert len(doc_shards) == len( | |
| image_shards | |
| ), "Each doc shards must have a corresponding image shard" | |
| for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)): | |
| tasks[i % args.thread]["doc_shards"].append(doc_shard) | |
| tasks[i % args.thread]["image_shards"].append(image_shard) | |
| # assert len(tasks) == args.thread | |
| # process_map(single_thread, tasks, max_workers=args.thread, disable=True) | |
| single_thread(tasks[0]) | |
| if __name__ == "__main__": | |
| main() | |