Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torch.utils.data import Dataset | |
| from torchvision.datasets.utils import download_url | |
| from .constants import COCO_ROOT, FLICKR_ROOT | |
| from .utils import AverageMeter | |
| def pre_caption(caption,max_words=50): | |
| caption = re.sub( | |
| r"([.!\"()*#:;~])", | |
| ' ', | |
| caption.lower(), | |
| ) | |
| caption = re.sub( | |
| r"\s{2,}", | |
| ' ', | |
| caption, | |
| ) | |
| caption = caption.rstrip('\n') | |
| caption = caption.strip(' ') | |
| #truncate caption | |
| caption_words = caption.split(' ') | |
| if len(caption_words)>max_words: | |
| caption = ' '.join(caption_words[:max_words]) | |
| return caption | |
| class COCO_Retrieval(Dataset): | |
| def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test", | |
| image_perturb_fn=None, download=False): | |
| """ | |
| COCO Retrieval Dataset. | |
| image_preprocess: image preprocessing function | |
| root_dir: The directory of the coco dataset. This directory should contain test2014 files. | |
| max_words: Cropping the caption to max_words. | |
| split: 'val' or 'test' | |
| image_perturb_fn: image perturbation function for patch permutation experiments. | |
| download: Whether to download the dataset if it does not exist. | |
| """ | |
| self.root_dir = root_dir | |
| if not os.path.exists(root_dir): | |
| print("Directory for COCO could not be found!") | |
| if download: | |
| print("Downloading COCO now.") | |
| self.download() | |
| else: | |
| raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.") | |
| urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', | |
| 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} | |
| filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} | |
| download_url(urls[split],root_dir) | |
| self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r')) | |
| self.image_preprocess = image_preprocess | |
| self.image_perturb_fn = image_perturb_fn | |
| self.image_root = root_dir | |
| self.text = [] | |
| self.image = [] | |
| self.txt2img = {} | |
| self.img2txt = {} | |
| txt_id = 0 | |
| for img_id, ann in enumerate(self.annotation): | |
| self.image.append(ann['image']) | |
| self.img2txt[img_id] = [] | |
| for i, caption in enumerate(ann['caption']): | |
| self.text.append(pre_caption(caption,max_words)) | |
| self.img2txt[img_id].append(txt_id) | |
| self.txt2img[txt_id] = img_id | |
| txt_id += 1 | |
| def __len__(self): | |
| return len(self.annotation) | |
| def __getitem__(self, index): | |
| image_path = os.path.join(self.image_root, self.annotation[index]['image']) | |
| image = Image.open(image_path).convert('RGB') | |
| if self.image_preprocess is not None: | |
| image = self.image_preprocess(image) | |
| if self.image_perturb_fn is not None: | |
| image = self.image_perturb_fn(image) | |
| return {"image": image, "idx": index} | |
| def download(self): | |
| import subprocess | |
| os.makedirs(self.root_dir, exist_ok=True) | |
| #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir) | |
| #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir) | |
| subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir) | |
| subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir) | |
| subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir) | |
| subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir) | |
| def evaluate_scores(self, scores): | |
| if isinstance(scores, tuple): | |
| scores_i2t = scores[0] | |
| scores_t2i = scores[1].T # Make it N_ims x N_text | |
| else: | |
| scores_t2i = scores | |
| scores_i2t = scores | |
| print(f"COCO results across {scores_i2t.shape} samples. ") | |
| prec_at_1 = AverageMeter() | |
| prec_at_5 = AverageMeter() | |
| # Text retrieval | |
| tqdm_iterator = tqdm(range(len(self.img2txt))) | |
| for i in tqdm_iterator: | |
| top5_captions = np.argsort(scores_i2t[i])[-5:] | |
| true_captions = self.img2txt[i] | |
| prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0) | |
| prec_at_5.update(len(set(true_captions) & set(top5_captions))>0) | |
| tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}") | |
| # Image Retrieval | |
| image_prec_at_1 = AverageMeter() | |
| image_prec_at_5 = AverageMeter() | |
| tqdm_iterator = tqdm(range(len(self.txt2img))) | |
| for i in tqdm_iterator: | |
| top5_images = np.argsort(scores_t2i[:, i])[-5:] | |
| true_image = self.txt2img[i] | |
| image_prec_at_1.update(true_image in top5_images[-1:]) | |
| image_prec_at_5.update(true_image in top5_images) | |
| tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}") | |
| records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}] | |
| return records | |
| class Flickr30k_Retrieval(Dataset): | |
| def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30, | |
| image_perturb_fn=None, *args, **kwargs): | |
| ''' | |
| Flickr30k dataset for retrieval. | |
| image_preprocess: image preprocessing function | |
| root_dir: The directory of the coco dataset. This directory should contain test2014 files. | |
| max_words: Cropping the caption to max_words. | |
| split: 'val' or 'test' | |
| image_perturb_fn: image perturbation function for patch permutation experiments. | |
| download: Whether to download the dataset if it does not exist. | |
| ''' | |
| urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', | |
| 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} | |
| filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} | |
| if not os.path.exists(root_dir): | |
| print("Directory for Flickr30k could not be found!") | |
| flickr_url = "https://forms.illinois.edu/sec/229675" | |
| raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.") | |
| download_url(urls[split],root_dir) | |
| self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r')) | |
| self.image_preprocess = image_preprocess | |
| self.image_perturb_fn = image_perturb_fn | |
| self.root_dir = root_dir | |
| self.text = [] | |
| self.image = [] | |
| self.txt2img = {} | |
| self.img2txt = {} | |
| txt_id = 0 | |
| for img_id, ann in enumerate(self.annotation): | |
| self.image.append(ann['image']) | |
| self.img2txt[img_id] = [] | |
| for i, caption in enumerate(ann['caption']): | |
| self.text.append(pre_caption(caption,max_words)) | |
| self.img2txt[img_id].append(txt_id) | |
| self.txt2img[txt_id] = img_id | |
| txt_id += 1 | |
| def __len__(self): | |
| return len(self.annotation) | |
| def __getitem__(self, index): | |
| image_path = os.path.join(self.root_dir, self.annotation[index]['image']) | |
| image = Image.open(image_path).convert('RGB') | |
| if self.image_preprocess is not None: | |
| image = self.image_preprocess(image) | |
| if self.image_perturb_fn is not None: | |
| image = self.image_perturb_fn(image) | |
| return {"image": image, "idx": index} | |
| def evaluate_scores(self, scores): | |
| if isinstance(scores, tuple): | |
| scores_i2t = scores[0] | |
| scores_t2i = scores[1].T # Make it N_ims x N_text | |
| else: | |
| scores_t2i = scores | |
| scores_i2t = scores | |
| print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ") | |
| prec_at_1 = AverageMeter() | |
| prec_at_5 = AverageMeter() | |
| # Text retrieval | |
| tqdm_iterator = tqdm(range(len(self.img2txt))) | |
| for i in tqdm_iterator: | |
| top5_captions = np.argsort(scores_i2t[i])[-5:] | |
| true_captions = self.img2txt[i] | |
| prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0) | |
| prec_at_5.update(len(set(true_captions) & set(top5_captions))>0) | |
| tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}") | |
| # Image Retrieval | |
| image_prec_at_1 = AverageMeter() | |
| image_prec_at_5 = AverageMeter() | |
| tqdm_iterator = tqdm(range(len(self.txt2img))) | |
| for i in tqdm_iterator: | |
| top5_images = np.argsort(scores_t2i[:, i])[-5:] | |
| true_image = self.txt2img[i] | |
| image_prec_at_1.update(true_image in top5_images[-1:]) | |
| image_prec_at_5.update(true_image in top5_images) | |
| tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}") | |
| records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}] | |
| return records | |
| def download(self): | |
| raise NotImplementedError("Flickr30k dataset is not available for download.") | |
| def get_coco_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"): | |
| dataset = COCO_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, | |
| download=download) | |
| return dataset | |
| def get_flickr30k_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"): | |
| dataset = Flickr30k_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, | |
| download=download) | |
| return dataset | |