Spaces:
Running
Running
| import inspect | |
| import os | |
| from typing import Union | |
| import PIL | |
| import numpy as np | |
| import torch | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings) | |
| from ddpm import DDPMSampler | |
| from tqdm import tqdm | |
| class CatVTONPix2PixPipeline: | |
| def __init__( | |
| self, | |
| weight_dtype=torch.float32, | |
| device='cpu', | |
| compile=False, | |
| skip_safety_check=True, | |
| use_tf32=True, | |
| models={}, | |
| ): | |
| self.device = device | |
| self.weight_dtype = weight_dtype | |
| self.skip_safety_check = skip_safety_check | |
| self.models = models | |
| self.generator = torch.Generator(device=device) | |
| self.noise_scheduler = DDPMSampler(generator=self.generator) | |
| # self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype) | |
| self.encoder= models.get('encoder', None) | |
| self.decoder= models.get('decoder', None) | |
| self.unet=models.get('diffusion', None) | |
| # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series). | |
| # if use_tf32: | |
| # torch.set_float32_matmul_precision("high") | |
| # torch.backends.cuda.matmul.allow_tf32 = True | |
| def __call__( | |
| self, | |
| image: Union[PIL.Image.Image, torch.Tensor], | |
| condition_image: Union[PIL.Image.Image, torch.Tensor], | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 2.5, | |
| height: int = 1024, | |
| width: int = 768, | |
| generator=None, | |
| eta=1.0, | |
| **kwargs | |
| ): | |
| concat_dim = -1 # FIXME: y axis concat | |
| # Prepare inputs to Tensor | |
| image, condition_image = check_inputs_maskfree(image, condition_image, width, height) | |
| image = prepare_image(image).to(self.device, dtype=self.weight_dtype) | |
| condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype) | |
| # Encode the image | |
| image_latent = compute_vae_encodings(image, self.encoder) | |
| condition_latent = compute_vae_encodings(condition_image, self.encoder) | |
| del image, condition_image | |
| # Concatenate latents | |
| # Concatenate latents | |
| condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim) | |
| # Prepare noise | |
| latents = randn_tensor( | |
| condition_latent_concat.shape, | |
| generator=generator, | |
| device=condition_latent_concat.device, | |
| dtype=self.weight_dtype, | |
| ) | |
| # Prepare timesteps | |
| self.noise_scheduler.set_inference_timesteps(num_inference_steps) | |
| timesteps = self.noise_scheduler.timesteps | |
| # latents = latents * self.noise_scheduler.init_noise_sigma | |
| latents = self.noise_scheduler.add_noise(latents, timesteps[0]) | |
| # Classifier-Free Guidance | |
| if do_classifier_free_guidance := (guidance_scale > 1.0): | |
| condition_latent_concat = torch.cat( | |
| [ | |
| torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim), | |
| condition_latent_concat, | |
| ] | |
| ) | |
| num_warmup_steps = 0 # For simple DDPM, no warmup needed | |
| with tqdm(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents) | |
| # prepare the input for the inpainting model | |
| p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1) | |
| # predict the noise residual | |
| timestep = t.repeat(p2p_latent_model_input.shape[0]) | |
| time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype) | |
| noise_pred = self.unet( | |
| p2p_latent_model_input, | |
| time_embedding | |
| ) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.noise_scheduler.step( | |
| t, latents, noise_pred | |
| ) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps | |
| ): | |
| progress_bar.update() | |
| # Decode the final latents | |
| latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0] | |
| # latents = 1 / self.vae.config.scaling_factor * latents | |
| # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample | |
| image = self.decoder(latents.to(self.device, dtype=self.weight_dtype)) | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = numpy_to_pil(image) | |
| return image | |