Spaces:
Running
Running
File size: 5,565 Bytes
ffb11e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import inspect
import os
from typing import Union
import PIL
import numpy as np
import torch
from diffusers.utils.torch_utils import randn_tensor
from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings)
from ddpm import DDPMSampler
from tqdm import tqdm
class CatVTONPix2PixPipeline:
def __init__(
self,
weight_dtype=torch.float32,
device='cuda',
compile=False,
skip_safety_check=True,
use_tf32=True,
models={},
):
self.device = device
self.weight_dtype = weight_dtype
self.skip_safety_check = skip_safety_check
self.models = models
self.generator = torch.Generator(device=device)
self.noise_scheduler = DDPMSampler(generator=self.generator)
# self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype)
self.encoder= models.get('encoder', None)
self.decoder= models.get('decoder', None)
self.unet=models.get('diffusion', None)
# # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
if use_tf32:
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor],
condition_image: Union[PIL.Image.Image, torch.Tensor],
num_inference_steps: int = 50,
guidance_scale: float = 2.5,
height: int = 1024,
width: int = 768,
generator=None,
eta=1.0,
**kwargs
):
concat_dim = -1 # FIXME: y axis concat
# Prepare inputs to Tensor
image, condition_image = check_inputs_maskfree(image, condition_image, width, height)
image = prepare_image(image).to(self.device, dtype=self.weight_dtype)
condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)
# Encode the image
image_latent = compute_vae_encodings(image, self.encoder)
condition_latent = compute_vae_encodings(condition_image, self.encoder)
del image, condition_image
# Concatenate latents
# Concatenate latents
condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)
# Prepare noise
latents = randn_tensor(
condition_latent_concat.shape,
generator=generator,
device=condition_latent_concat.device,
dtype=self.weight_dtype,
)
# Prepare timesteps
self.noise_scheduler.set_inference_timesteps(num_inference_steps)
timesteps = self.noise_scheduler.timesteps
# latents = latents * self.noise_scheduler.init_noise_sigma
latents = self.noise_scheduler.add_noise(latents, timesteps[0])
# Classifier-Free Guidance
if do_classifier_free_guidance := (guidance_scale > 1.0):
condition_latent_concat = torch.cat(
[
torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
condition_latent_concat,
]
)
num_warmup_steps = 0 # For simple DDPM, no warmup needed
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
# prepare the input for the inpainting model
p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)
# predict the noise residual
timestep = t.repeat(p2p_latent_model_input.shape[0])
time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)
noise_pred = self.unet(
p2p_latent_model_input,
time_embedding
)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.noise_scheduler.step(
t, latents, noise_pred
)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps
):
progress_bar.update()
# Decode the final latents
latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
# latents = 1 / self.vae.config.scaling_factor * latents
# image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample
image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)
return image
|