Spaces:
Runtime error
Runtime error
| import cv2 | |
| import einops | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import random | |
| import os | |
| import json | |
| import datetime | |
| from huggingface_hub import hf_hub_url, hf_hub_download | |
| from pytorch_lightning import seed_everything | |
| from annotator.util import resize_image, HWC3 | |
| from annotator.OneFormer import OneformerSegmenter | |
| from cldm.model import create_model, load_state_dict | |
| from cldm.ddim_hacked import DDIMSamplerSpaCFG | |
| from ldm.models.autoencoder import DiagonalGaussianDistribution | |
| SEGMENT_MODEL_DICT = { | |
| 'Oneformer': OneformerSegmenter, | |
| } | |
| MASK_MODEL_DICT = { | |
| 'Oneformer': OneformerSegmenter, | |
| } | |
| urls = { | |
| 'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'], | |
| 'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['model_e91.ckpt'] | |
| } | |
| WTS_DICT = { | |
| } | |
| if os.path.exists('checkpoints') == False: | |
| os.mkdir('checkpoints') | |
| for repo in urls: | |
| files = urls[repo] | |
| for file in files: | |
| url = hf_hub_url(repo, file) | |
| name_ckp = url.split('/')[-1] | |
| WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file) | |
| #main model | |
| model = create_model('configs/pair_diff.yaml').cpu() | |
| model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda')) | |
| save_dir = 'results/' | |
| model = model.cuda() | |
| ddim_sampler = DDIMSamplerSpaCFG(model) | |
| save_memory = False | |
| class ImageComp: | |
| def __init__(self, edit_operation): | |
| self.input_img = None | |
| self.input_pmask = None | |
| self.input_segmask = None | |
| self.input_mask = None | |
| self.input_points = [] | |
| self.input_scale = 1 | |
| self.ref_img = None | |
| self.ref_pmask = None | |
| self.ref_segmask = None | |
| self.ref_mask = None | |
| self.ref_points = [] | |
| self.ref_scale = 1 | |
| self.multi_modal = False | |
| self.H = None | |
| self.W = None | |
| self.kernel = np.ones((5, 5), np.uint8) | |
| self.edit_operation = edit_operation | |
| self.init_segmentation_model() | |
| os.makedirs(save_dir, exist_ok=True) | |
| self.base_prompt = 'A picture of {}' | |
| def init_segmentation_model(self, mask_model='Oneformer', segment_model='Oneformer'): | |
| self.segment_model_name = segment_model | |
| self.mask_model_name = mask_model | |
| self.segment_model = SEGMENT_MODEL_DICT[segment_model](WTS_DICT['shi-labs/oneformer_coco_swin_large']) | |
| if mask_model == 'Oneformer' and segment_model == 'Oneformer': | |
| self.mask_model_inp = self.segment_model | |
| self.mask_model_ref = self.segment_model | |
| else: | |
| self.mask_model_inp = MASK_MODEL_DICT[mask_model]() | |
| self.mask_model_ref = MASK_MODEL_DICT[mask_model]() | |
| print(f"Segmentation Models initialized with {mask_model} as mask and {segment_model} as segment") | |
| def init_input_canvas(self, img): | |
| img = HWC3(img) | |
| img = resize_image(img, 512) | |
| if self.segment_model_name == 'Oneformer': | |
| detected_seg = self.segment_model(img, 'semantic') | |
| elif self.segment_model_name == 'SAM': | |
| raise NotImplementedError | |
| if self.mask_model_name == 'Oneformer': | |
| detected_mask = self.mask_model_inp(img, 'panoptic')[0] | |
| elif self.mask_model_name == 'SAM': | |
| detected_mask = self.mask_model_inp(img) | |
| self.input_points = [] | |
| self.input_img = img | |
| self.input_pmask = detected_mask | |
| self.input_segmask = detected_seg | |
| self.H = img.shape[0] | |
| self.W = img.shape[1] | |
| return img | |
| def init_ref_canvas(self, img): | |
| img = HWC3(img) | |
| img = resize_image(img, 512) | |
| if self.segment_model_name == 'Oneformer': | |
| detected_seg = self.segment_model(img, 'semantic') | |
| elif self.segment_model_name == 'SAM': | |
| raise NotImplementedError | |
| if self.mask_model_name == 'Oneformer': | |
| detected_mask = self.mask_model_ref(img, 'panoptic')[0] | |
| elif self.mask_model_name == 'SAM': | |
| detected_mask = self.mask_model_ref(img) | |
| self.ref_points = [] | |
| print("Initialized ref", img.shape) | |
| self.ref_img = img | |
| self.ref_pmask = detected_mask | |
| self.ref_segmask = detected_seg | |
| return img | |
| def select_input_object(self, evt: gr.SelectData): | |
| idx = list(np.array(evt.index) * self.input_scale) | |
| self.input_points.append(idx) | |
| if self.mask_model_name == 'Oneformer': | |
| mask = self._get_mask_from_panoptic(np.array(self.input_points), self.input_pmask) | |
| else: | |
| mask = self.mask_model_inp(self.input_img, self.input_points) | |
| c_ids = self.input_segmask[np.array(self.input_points)[:,1], np.array(self.input_points)[:,0]] | |
| unique_ids, counts = torch.unique(c_ids, return_counts=True) | |
| c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy()) | |
| category = self.segment_model.metadata.stuff_classes[c_id] | |
| # print(self.segment_model.metadata.stuff_classes) | |
| self.input_mask = mask | |
| mask = mask.cpu().numpy() | |
| output = mask[:,:,None] * self.input_img + (1 - mask[:,:,None]) * self.input_img * 0.2 | |
| return output.astype(np.uint8), self.base_prompt.format(category) | |
| def select_ref_object(self, evt: gr.SelectData): | |
| idx = list(np.array(evt.index) * self.ref_scale) | |
| self.ref_points.append(idx) | |
| if self.mask_model_name == 'Oneformer': | |
| mask = self._get_mask_from_panoptic(np.array(self.ref_points), self.ref_pmask) | |
| else: | |
| mask = self.mask_model_ref(self.ref_img, self.ref_points) | |
| c_ids = self.ref_segmask[np.array(self.ref_points)[:,1], np.array(self.ref_points)[:,0]] | |
| unique_ids, counts = torch.unique(c_ids, return_counts=True) | |
| c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy()) | |
| category = self.segment_model.metadata.stuff_classes[c_id] | |
| print("Category of reference object is:", category) | |
| self.ref_mask = mask | |
| mask = mask.cpu().numpy() | |
| output = mask[:,:,None] * self.ref_img + (1 - mask[:,:,None]) * self.ref_img * 0.2 | |
| return output.astype(np.uint8) | |
| def clear_points(self): | |
| self.input_points = [] | |
| self.ref_points = [] | |
| zeros_inp = np.zeros(self.input_img.shape) | |
| zeros_ref = np.zeros(self.ref_img.shape) | |
| return zeros_inp, zeros_ref | |
| def return_input_img(self): | |
| return self.input_img | |
| def _get_mask_from_panoptic(self, points, panoptic_mask): | |
| panoptic_mask_ = panoptic_mask + 1 | |
| ids = panoptic_mask_[points[:,1], points[:,0]] | |
| unique_ids, counts = torch.unique(ids, return_counts=True) | |
| mask_id = unique_ids[torch.argmax(counts)] | |
| final_mask = torch.zeros(panoptic_mask.shape).cuda() | |
| final_mask[panoptic_mask_ == mask_id] = 1 | |
| return final_mask | |
| def _process_mask(self, mask, panoptic_mask, segmask): | |
| obj_class = mask * (segmask + 1) | |
| unique_ids, counts = torch.unique(obj_class, return_counts=True) | |
| obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1 | |
| return mask, obj_class | |
| def _edit_app(self, whole_ref): | |
| """ | |
| Manipulates the panoptic mask of input image to change appearance | |
| """ | |
| input_pmask = self.input_pmask | |
| input_segmask = self.input_segmask | |
| if whole_ref: | |
| reference_mask = torch.ones(self.ref_pmask.shape).cuda() | |
| else: | |
| reference_mask, _ = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask) | |
| edit_mask, _ = self._process_mask(self.input_mask, self.input_pmask, self.input_segmask) | |
| # tmp = cv2.dilate(edit_mask.squeeze().cpu().numpy(), self.kernel, iterations = 2) | |
| # region_mask = torch.tensor(tmp).cuda() | |
| region_mask = edit_mask | |
| ma = torch.max(input_pmask) | |
| input_pmask[edit_mask == 1] = ma + 1 | |
| return reference_mask, input_pmask, input_segmask, region_mask, ma | |
| def _add_object(self, input_mask, dilation_fac): | |
| """ | |
| Manipulates the panooptic mask of input image for adding objects | |
| Args: | |
| input_mask (numpy array): Region where new objects needs to be added | |
| dilation factor (float): Controls edge merging region for adding objects | |
| """ | |
| input_pmask = self.input_pmask | |
| input_segmask = self.input_segmask | |
| reference_mask, obj_class = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask) | |
| tmp = cv2.dilate(input_mask['mask'][:, :, 0], self.kernel, iterations = int(dilation_fac)) | |
| region = torch.tensor(tmp) | |
| region_mask = torch.zeros_like(region).cuda() | |
| region_mask[region > 127] = 1 | |
| mask_ = torch.tensor(input_mask['mask'][:, :, 0]) | |
| edit_mask = torch.zeros_like(mask_).cuda() | |
| edit_mask[mask_ > 127] = 1 | |
| ma = torch.max(input_pmask) | |
| input_pmask[edit_mask == 1] = ma + 1 | |
| print(obj_class) | |
| input_segmask[edit_mask == 1] = obj_class.long() | |
| return reference_mask, input_pmask, input_segmask, region_mask, ma | |
| def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1): | |
| """ | |
| Entry point for all the appearance editing and add objects operations. The function manipulates the | |
| appearance vectors and structure based on user input | |
| Args: | |
| input mask (numpy array): Region in input image which needs to be edited | |
| dilation factor (float): Controls edge merging region for adding objects | |
| whole_ref (bool): Flag for specifying if complete reference image should be used | |
| inter (float): Interpolation of appearance between the reference appearance and the input appearance. | |
| """ | |
| input_img = (self.input_img/127.5 - 1) | |
| input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) | |
| reference_img = (self.ref_img/127.5 - 1) | |
| reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) | |
| if self.edit_operation == 'add_obj': | |
| reference_mask, input_pmask, input_segmask, region_mask, ma = self._add_object(input_mask, dilation_fac) | |
| elif self.edit_operation == 'edit_app': | |
| reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(whole_ref) | |
| #concat featurees | |
| input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) | |
| _, mean_feat_inpt_conc, one_hot_inpt_conc, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True) | |
| reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1) | |
| _, mean_feat_ref_conc, _, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, reference_img, reference_mask, return_all=True) | |
| # if mean_feat_ref.shape[1] > 1: | |
| if isinstance(mean_feat_inpt_conc, list): | |
| appearance_conc = [] | |
| for i in range(len(mean_feat_inpt_conc)): | |
| mean_feat_inpt_conc[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[i][:, ma + 1] + inter*mean_feat_ref_conc[i][:, 1] | |
| splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc[i], one_hot_inpt_conc) | |
| splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc) | |
| splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H//8, self.W//8)) | |
| appearance_conc.append(splatted_feat_conc) | |
| appearance_conc = torch.cat(appearance_conc, dim=1) | |
| else: | |
| print("manipulating") | |
| mean_feat_inpt_conc[:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[:, ma + 1] + inter*mean_feat_ref_conc[:, 1] | |
| splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc, one_hot_inpt_conc) | |
| appearance_conc = torch.nn.functional.normalize(splatted_feat_conc) #l2 normaliz | |
| appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8)) | |
| #cross attention features | |
| _, mean_feat_inpt_ca, one_hot_inpt_ca, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, input_pmask, return_all=True) | |
| _, mean_feat_ref_ca, _, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, reference_img, reference_mask, return_all=True) | |
| # if mean_feat_ref.shape[1] > 1: | |
| if isinstance(mean_feat_inpt_ca, list): | |
| appearance_ca = [] | |
| for i in range(len(mean_feat_inpt_ca)): | |
| mean_feat_inpt_ca[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[i][:, ma + 1] + inter*mean_feat_ref_ca[i][:, 1] | |
| splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca[i], one_hot_inpt_ca) | |
| splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca) | |
| splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H//8, self.W//8)) | |
| appearance_ca.append(splatted_feat_ca) | |
| else: | |
| print("manipulating") | |
| mean_feat_inpt_ca[:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[:, ma + 1] + inter*mean_feat_ref_ca[:, 1] | |
| splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca, one_hot_inpt_ca) | |
| appearance_ca = torch.nn.functional.normalize(splatted_feat_ca) #l2 normaliz | |
| appearance_ca = torch.nn.functional.interpolate(appearance_ca, (self.H//8, self.W//8)) | |
| input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1) | |
| structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8)) | |
| return structure, appearance_conc, appearance_ca, region_mask, input_img | |
| def _edit_obj_var(self, input_mask, ignore_structure): | |
| input_img = (self.input_img/127.5 - 1) | |
| input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) | |
| input_pmask = self.input_pmask | |
| input_segmask = self.input_segmask | |
| ma = torch.max(input_pmask) | |
| mask_ = torch.tensor(input_mask['mask'][:, :, 0]) | |
| edit_mask = torch.zeros_like(mask_).cuda() | |
| edit_mask[mask_ > 127] = 1 | |
| tmp = edit_mask * (input_pmask + ma + 1) | |
| if ignore_structure: | |
| tmp = edit_mask | |
| input_pmask = tmp * edit_mask + (1 - edit_mask) * input_pmask | |
| input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) | |
| mask_ca_feat = self.input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) if ignore_structure else input_pmask | |
| print(torch.unique(mask_ca_feat)) | |
| appearance_conc,_,_,_ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True) | |
| appearance_ca = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, mask_ca_feat) | |
| appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8)) | |
| appearance_ca = [torch.nn.functional.interpolate(ap, (self.H//8, self.W//8)) for ap in appearance_ca] | |
| input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1) | |
| structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8)) | |
| tmp = input_mask['mask'][:, :, 0] | |
| region = torch.tensor(tmp) | |
| mask = torch.zeros_like(region).cuda() | |
| mask[region > 127] = 1 | |
| return structure, appearance_conc, appearance_ca, mask, input_img | |
| def get_caption(self, mask): | |
| """ | |
| Generates the captions based on a set template | |
| Args: | |
| mask (numpy array): Region of image based on which caption needs to be generated | |
| """ | |
| mask = mask['mask'][:, :, 0] | |
| region = torch.tensor(mask).cuda() | |
| mask = torch.zeros_like(region) | |
| mask[region > 127] = 1 | |
| if torch.sum(mask) == 0: | |
| return "" | |
| c_ids = self.input_segmask * mask | |
| unique_ids, counts = torch.unique(c_ids, return_counts=True) | |
| c_id = int(unique_ids[torch.argmax(counts[1:]) + 1].cpu().detach().numpy()) | |
| category = self.segment_model.metadata.stuff_classes[c_id] | |
| return self.base_prompt.format(category) | |
| def save_result(self, input_mask, prompt, a_prompt, n_prompt, | |
| ddim_steps, scale_s, scale_f, scale_t, seed, dilation_fac=1,inter=1, | |
| free_form_obj_var=False, ignore_structure=False): | |
| """ | |
| Saves the current results with all the meta data | |
| """ | |
| meta_data = {} | |
| meta_data['prompt'] = prompt | |
| meta_data['a_prompt'] = a_prompt | |
| meta_data['n_prompt'] = n_prompt | |
| meta_data['seed'] = seed | |
| meta_data['ddim_steps'] = ddim_steps | |
| meta_data['scale_s'] = scale_s | |
| meta_data['scale_f'] = scale_f | |
| meta_data['scale_t'] = scale_t | |
| meta_data['inter'] = inter | |
| meta_data['dilation_fac'] = dilation_fac | |
| meta_data['edit_operation'] = self.edit_operation | |
| uuid = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
| os.makedirs(f'{save_dir}/{uuid}') | |
| with open(f'{save_dir}/{uuid}/meta.json', "w") as outfile: | |
| json.dump(meta_data, outfile) | |
| cv2.imwrite(f'{save_dir}/{uuid}/input.png', self.input_img[:,:,::-1]) | |
| cv2.imwrite(f'{save_dir}/{uuid}/ref.png', self.ref_img[:,:,::-1]) | |
| if self.ref_mask is not None: | |
| cv2.imwrite(f'{save_dir}/{uuid}/ref_mask.png', self.ref_mask.cpu().squeeze().numpy() * 200) | |
| for i in range(len(self.results)): | |
| cv2.imwrite(f'{save_dir}/{uuid}/edit{i}.png', self.results[i][:,:,::-1]) | |
| if self.edit_operation == 'add_obj' or free_form_obj_var: | |
| cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', input_mask['mask'] * 200) | |
| else: | |
| cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', self.input_mask.cpu().squeeze().numpy() * 200) | |
| print("Saved results at", f'{save_dir}/{uuid}') | |
| def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt, | |
| num_samples, ddim_steps, guess_mode, strength, | |
| scale_s, scale_f, scale_t, seed, eta, dilation_fac=1,masking=True,whole_ref=False,inter=1, | |
| free_form_obj_var=False, ignore_structure=False): | |
| print(prompt) | |
| if free_form_obj_var: | |
| print("Free form") | |
| structure, appearance_conc, appearance_ca, mask, img = self._edit_obj_var(input_mask, ignore_structure) | |
| else: | |
| structure, appearance_conc, appearance_ca, mask, img = self._edit(input_mask, ref_mask, dilation_fac=dilation_fac, | |
| whole_ref=whole_ref, inter=inter) | |
| input_pmask = torch.nn.functional.interpolate(self.input_pmask.cuda().unsqueeze(0).unsqueeze(1).float(), (self.H//8, self.W//8)) | |
| input_pmask = input_pmask.to(memory_format=torch.contiguous_format) | |
| if isinstance(appearance_ca, list): | |
| null_appearance_ca = [torch.zeros(a.shape).cuda() for a in appearance_ca] | |
| null_appearance_conc = torch.zeros(appearance_conc.shape).cuda() | |
| null_structure = torch.zeros(structure.shape).cuda() - 1 | |
| null_control = [torch.cat([null_structure, napp, input_pmask * 0], dim=1) for napp in null_appearance_ca] | |
| structure_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in null_appearance_ca] | |
| full_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in appearance_ca] | |
| null_control.append(torch.cat([null_structure, null_appearance_conc, null_structure * 0], dim=1)) | |
| structure_control.append(torch.cat([structure, null_appearance_conc, null_structure], dim=1)) | |
| full_control.append(torch.cat([structure, appearance_conc, input_pmask], dim=1)) | |
| null_control = [torch.cat([nc for _ in range(num_samples)], dim=0) for nc in null_control] | |
| structure_control = [torch.cat([sc for _ in range(num_samples)], dim=0) for sc in structure_control] | |
| full_control = [torch.cat([fc for _ in range(num_samples)], dim=0) for fc in full_control] | |
| #Masking for local edit | |
| if not masking: | |
| mask, x0 = None, None | |
| else: | |
| x0 = model.encode_first_stage(img) | |
| x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number | |
| x0 = x0 * model.scale_factor | |
| mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda() | |
| mask = torch.nn.functional.interpolate(mask.float(), x0.shape[2:]).float() | |
| if seed == -1: | |
| seed = random.randint(0, 65535) | |
| seed_everything(seed) | |
| scale = [scale_s, scale_f, scale_t] | |
| print(scale) | |
| if save_memory: | |
| model.low_vram_shift(is_diffusing=False) | |
| uc_cross = model.get_learned_conditioning([n_prompt] * num_samples) | |
| c_cross = model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples) | |
| cond = {"c_concat": [null_control], "c_crossattn": [c_cross]} | |
| un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]} | |
| un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]} | |
| un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]} | |
| shape = (4, self.H // 8, self.W // 8) | |
| if save_memory: | |
| model.low_vram_shift(is_diffusing=True) | |
| model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 | |
| samples, _ = ddim_sampler.sample(ddim_steps, num_samples, | |
| shape, cond, verbose=False, eta=eta, | |
| unconditional_guidance_scale=scale, mask=mask, x0=x0, | |
| unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ]) | |
| if save_memory: | |
| model.low_vram_shift(is_diffusing=False) | |
| x_samples = (model.decode_first_stage(samples) + 1) * 127.5 | |
| x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8) | |
| results = [x_samples[i] for i in range(num_samples)] | |
| self.results = results | |
| return [] + results |