Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import copy | |
| import random | |
| from contextlib import nullcontext | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from scepter.modules.model.network.ldm import LatentDiffusion | |
| from scepter.modules.model.registry import MODELS | |
| from scepter.modules.utils.config import dict_to_yaml | |
| from scepter.modules.utils.distribute import we | |
| from ..utils.basic_utils import ( | |
| check_list_of_list, | |
| pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor, | |
| to_device, | |
| unpack_tensor_into_imagelist | |
| ) | |
| class TextEmbedding(nn.Module): | |
| def __init__(self, embedding_shape): | |
| super().__init__() | |
| self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) | |
| class LdmACE(LatentDiffusion): | |
| para_dict = LatentDiffusion.para_dict | |
| para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''} | |
| def __init__(self, cfg, logger=None): | |
| super().__init__(cfg, logger=logger) | |
| self.interpolate_func = lambda x: (F.interpolate( | |
| x.unsqueeze(0), | |
| scale_factor=1 / self.size_factor, | |
| mode='nearest-exact') if x is not None else None) | |
| self.text_indentifers = cfg.get('TEXT_IDENTIFIER', []) | |
| self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS', | |
| False) | |
| if self.use_text_pos_embeddings: | |
| self.text_position_embeddings = TextEmbedding( | |
| (10, 4096)).eval().requires_grad_(False) | |
| else: | |
| self.text_position_embeddings = None | |
| self.logger.info(self.model) | |
| def encode_first_stage(self, x, **kwargs): | |
| return [ | |
| self.scale_factor * | |
| self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16)) | |
| for i in x | |
| ] | |
| def decode_first_stage(self, z): | |
| return [ | |
| self.first_stage_model._decode(1. / self.scale_factor * | |
| i.to(torch.float16)) for i in z | |
| ] | |
| def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): | |
| if self.use_text_pos_embeddings and not torch.sum( | |
| self.text_position_embeddings.pos) > 0: | |
| identifier_cont, identifier_cont_mask = getattr( | |
| self.cond_stage_model, 'encode')(self.text_indentifers, | |
| return_mask=True) | |
| self.text_position_embeddings.load_state_dict( | |
| {'pos': identifier_cont[:, 0, :]}) | |
| cont_, cont_mask_ = [], [] | |
| for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): | |
| if isinstance(pp, list): | |
| cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) | |
| cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) | |
| else: | |
| raise NotImplementedError | |
| return cont_, cont_mask_ | |
| def limit_batch_data(self, batch_data_list, log_num): | |
| if log_num and log_num > 0: | |
| batch_data_list_limited = [] | |
| for sub_data in batch_data_list: | |
| if sub_data is not None: | |
| sub_data = sub_data[:log_num] | |
| batch_data_list_limited.append(sub_data) | |
| return batch_data_list_limited | |
| else: | |
| return batch_data_list | |
| def forward_train(self, | |
| edit_image=[], | |
| edit_image_mask=[], | |
| image=None, | |
| image_mask=None, | |
| noise=None, | |
| prompt=[], | |
| **kwargs): | |
| ''' | |
| Args: | |
| edit_image: list of list of edit_image | |
| edit_image_mask: list of list of edit_image_mask | |
| image: target image | |
| image_mask: target image mask | |
| noise: default is None, generate automaticly | |
| prompt: list of list of text | |
| **kwargs: | |
| Returns: | |
| ''' | |
| assert check_list_of_list(prompt) and check_list_of_list( | |
| edit_image) and check_list_of_list(edit_image_mask) | |
| assert len(edit_image) == len(edit_image_mask) == len(prompt) | |
| assert self.cond_stage_model is not None | |
| gc_seg = kwargs.pop('gc_seg', []) | |
| gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0 | |
| context = {} | |
| # process image | |
| image = to_device(image) | |
| x_start = self.encode_first_stage(image, **kwargs) | |
| x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L | |
| n, _, _ = x_start.shape | |
| t = torch.randint(0, self.num_timesteps, (n, ), | |
| device=x_start.device).long() | |
| context['x_shapes'] = x_shapes | |
| # process image mask | |
| image_mask = to_device(image_mask, strict=False) | |
| context['x_mask'] = [self.interpolate_func(i) for i in image_mask | |
| ] if image_mask is not None else [None] * n | |
| # process text | |
| # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): | |
| prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] | |
| try: | |
| cont, cont_mask = getattr(self.cond_stage_model, | |
| 'encode_list')(prompt_, return_mask=True) | |
| except Exception as e: | |
| print(e, prompt_) | |
| cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, | |
| cont_mask) | |
| context['crossattn'] = cont | |
| # process edit image & edit image mask | |
| edit_image = [to_device(i, strict=False) for i in edit_image] | |
| edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] | |
| e_img, e_mask = [], [] | |
| for u, m in zip(edit_image, edit_image_mask): | |
| if m is None: | |
| m = [None] * len(u) if u is not None else [None] | |
| e_img.append( | |
| self.encode_first_stage(u, **kwargs) if u is not None else u) | |
| e_mask.append([ | |
| self.interpolate_func(i) if i is not None else None for i in m | |
| ]) | |
| context['edit'], context['edit_mask'] = e_img, e_mask | |
| # process loss | |
| loss = self.diffusion.loss( | |
| x_0=x_start, | |
| t=t, | |
| noise=noise, | |
| model=self.model, | |
| model_kwargs={ | |
| 'cond': | |
| context, | |
| 'mask': | |
| cont_mask, | |
| 'gc_seg': | |
| gc_seg, | |
| 'text_position_embeddings': | |
| self.text_position_embeddings.pos if hasattr( | |
| self.text_position_embeddings, 'pos') else None | |
| }, | |
| **kwargs) | |
| loss = loss.mean() | |
| ret = {'loss': loss, 'probe_data': {'prompt': prompt}} | |
| return ret | |
| def forward_test(self, | |
| edit_image=[], | |
| edit_image_mask=[], | |
| image=None, | |
| image_mask=None, | |
| prompt=[], | |
| n_prompt=[], | |
| sampler='ddim', | |
| sample_steps=20, | |
| guide_scale=4.5, | |
| guide_rescale=0.5, | |
| log_num=-1, | |
| seed=2024, | |
| **kwargs): | |
| assert check_list_of_list(prompt) and check_list_of_list( | |
| edit_image) and check_list_of_list(edit_image_mask) | |
| assert len(edit_image) == len(edit_image_mask) == len(prompt) | |
| assert self.cond_stage_model is not None | |
| # gc_seg is unused | |
| kwargs.pop('gc_seg', -1) | |
| # prepare data | |
| context, null_context = {}, {} | |
| prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data( | |
| [prompt, n_prompt, image, image_mask, edit_image, edit_image_mask], | |
| log_num) | |
| g = torch.Generator(device=we.device_id) | |
| seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
| g.manual_seed(seed) | |
| n_prompt = copy.deepcopy(prompt) | |
| # only modify the last prompt to be zero | |
| for nn_p_id, nn_p in enumerate(n_prompt): | |
| if isinstance(nn_p, str): | |
| n_prompt[nn_p_id] = [''] | |
| elif isinstance(nn_p, list): | |
| n_prompt[nn_p_id][-1] = '' | |
| else: | |
| raise NotImplementedError | |
| # process image | |
| image = to_device(image) | |
| x = self.encode_first_stage(image, **kwargs) | |
| noise = [ | |
| torch.empty(*i.shape, device=we.device_id).normal_(generator=g) | |
| for i in x | |
| ] | |
| noise, x_shapes = pack_imagelist_into_tensor(noise) | |
| context['x_shapes'] = null_context['x_shapes'] = x_shapes | |
| # process image mask | |
| image_mask = to_device(image_mask, strict=False) | |
| cond_mask = [self.interpolate_func(i) for i in image_mask | |
| ] if image_mask is not None else [None] * len(image) | |
| context['x_mask'] = null_context['x_mask'] = cond_mask | |
| # process text | |
| # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): | |
| prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] | |
| cont, cont_mask = getattr(self.cond_stage_model, | |
| 'encode_list')(prompt_, return_mask=True) | |
| cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, | |
| cont_mask) | |
| null_cont, null_cont_mask = getattr(self.cond_stage_model, | |
| 'encode_list')(n_prompt, | |
| return_mask=True) | |
| null_cont, null_cont_mask = self.cond_stage_embeddings( | |
| prompt, edit_image, null_cont, null_cont_mask) | |
| context['crossattn'] = cont | |
| null_context['crossattn'] = null_cont | |
| # processe edit image & edit image mask | |
| edit_image = [to_device(i, strict=False) for i in edit_image] | |
| edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] | |
| e_img, e_mask = [], [] | |
| for u, m in zip(edit_image, edit_image_mask): | |
| if u is None: | |
| continue | |
| if m is None: | |
| m = [None] * len(u) | |
| e_img.append(self.encode_first_stage(u, **kwargs)) | |
| e_mask.append([self.interpolate_func(i) for i in m]) | |
| null_context['edit'] = context['edit'] = e_img | |
| null_context['edit_mask'] = context['edit_mask'] = e_mask | |
| # process sample | |
| model = self.model_ema if self.use_ema and self.eval_ema else self.model | |
| embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \ | |
| else nullcontext | |
| with embedding_context(): | |
| samples = self.diffusion.sample( | |
| sampler=sampler, | |
| noise=noise, | |
| model=model, | |
| model_kwargs=[{ | |
| 'cond': | |
| context, | |
| 'mask': | |
| cont_mask, | |
| 'text_position_embeddings': | |
| self.text_position_embeddings.pos if hasattr( | |
| self.text_position_embeddings, 'pos') else None | |
| }, { | |
| 'cond': | |
| null_context, | |
| 'mask': | |
| null_cont_mask, | |
| 'text_position_embeddings': | |
| self.text_position_embeddings.pos if hasattr( | |
| self.text_position_embeddings, 'pos') else None | |
| }] if guide_scale is not None and guide_scale > 1 else { | |
| 'cond': | |
| context, | |
| 'mask': | |
| cont_mask, | |
| 'text_position_embeddings': | |
| self.text_position_embeddings.pos if hasattr( | |
| self.text_position_embeddings, 'pos') else None | |
| }, | |
| steps=sample_steps, | |
| guide_scale=guide_scale, | |
| guide_rescale=guide_rescale, | |
| show_progress=True, | |
| **kwargs) | |
| samples = unpack_tensor_into_imagelist(samples, x_shapes) | |
| x_samples = self.decode_first_stage(samples) | |
| outputs = list() | |
| for i in range(len(prompt)): | |
| rec_img = torch.clamp( | |
| (x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255, | |
| min=0.0, | |
| max=1.0) | |
| rec_img = rec_img.squeeze(0) | |
| edit_imgs, edit_img_masks = [], [] | |
| if edit_image is not None and edit_image[i] is not None: | |
| if edit_image_mask[i] is None: | |
| edit_image_mask[i] = [None] * len(edit_image[i]) | |
| for edit_img, edit_mask in zip(edit_image[i], | |
| edit_image_mask[i]): | |
| edit_img = torch.clamp((edit_img + 1.0) / 2.0, | |
| min=0.0, | |
| max=1.0) | |
| edit_imgs.append(edit_img.squeeze(0)) | |
| if edit_mask is None: | |
| edit_mask = torch.ones_like(edit_img[[0], :, :]) | |
| edit_img_masks.append(edit_mask) | |
| one_tup = { | |
| 'reconstruct_image': rec_img, | |
| 'instruction': prompt[i], | |
| 'edit_image': edit_imgs if len(edit_imgs) > 0 else None, | |
| 'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None | |
| } | |
| if image is not None: | |
| if image_mask is None: | |
| image_mask = [None] * len(image) | |
| ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0) | |
| one_tup['target_image'] = ori_img.squeeze(0) | |
| one_tup['target_mask'] = image_mask[i] if image_mask[ | |
| i] is not None else torch.ones_like(ori_img[[0], :, :]) | |
| outputs.append(one_tup) | |
| return outputs | |
| def get_config_template(): | |
| return dict_to_yaml('MODEL', | |
| __class__.__name__, | |
| LdmACE.para_dict, | |
| set_name=True) | |