Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Xueyan Zou (xueyan@cs.wisc.edu) | |
| # -------------------------------------------------------- | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| from utils.visualizer import Visualizer | |
| from detectron2.data import MetadataCatalog | |
| t = [] | |
| t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) | |
| transform_ret = transforms.Compose(t) | |
| t = [] | |
| t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) | |
| transform_grd = transforms.Compose(t) | |
| metedata = MetadataCatalog.get('coco_2017_train_panoptic') | |
| def referring_captioning(model, image, texts, inpainting_text, *args, **kwargs): | |
| model_last, model_cap = model | |
| with torch.no_grad(): | |
| image_ori = image | |
| image = transform_grd(image) | |
| width = image.size[0] | |
| height = image.size[1] | |
| image = np.asarray(image) | |
| image_ori_ = image | |
| images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() | |
| texts_input = [[texts.strip() if texts.endswith('.') else (texts + '.')]] | |
| batch_inputs = [{'image': images, 'groundings': {'texts':texts_input}, 'height': height, 'width': width}] | |
| outputs = model_last.model.evaluate_grounding(batch_inputs, None) | |
| grd_mask = (outputs[-1]['grounding_mask'] > 0).float() | |
| grd_mask_ = (1 - F.interpolate(grd_mask[None,], (224, 224), mode='nearest')[0]).bool() | |
| color = [252/255, 91/255, 129/255] | |
| visual = Visualizer(image_ori_, metadata=metedata) | |
| demo = visual.draw_binary_mask(grd_mask.cpu().numpy()[0], color=color, text=texts) | |
| res = demo.get_image() | |
| if (1 - grd_mask_.float()).sum() < 5: | |
| torch.cuda.empty_cache() | |
| return Image.fromarray(res), 'n/a', None | |
| grd_mask_ = grd_mask_ * 0 | |
| image = transform_ret(image_ori) | |
| image_ori = np.asarray(image_ori) | |
| image = np.asarray(image) | |
| images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() | |
| batch_inputs = [{'image': images, 'image_id': 0, 'captioning_mask': grd_mask_}] | |
| token_text = texts.replace('.','') if texts.endswith('.') else texts | |
| token = model_cap.model.sem_seg_head.predictor.lang_encoder.tokenizer.encode(token_text) | |
| token = torch.tensor(token)[None,:-1] | |
| outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={'token': token}) | |
| # outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={}) | |
| text = outputs[-1]['captioning_text'] | |
| torch.cuda.empty_cache() | |
| return Image.fromarray(res), text, None |