In [1]:
# !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt

# check if the model is downloaded, if not download it
import os
if not os.path.exists("sd-v1-5-inpainting.ckpt"):
 !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt
else:
 print("Model already downloaded.")

Model already downloaded.


In [2]:
import load_model

models=load_model.preload_models_from_standard_weights(ckpt_path="sd-v1-5-inpainting.ckpt", device="cuda", finetune_weights_path="finetuned_weights.safetensors")

Loaded finetuned weights from finetuned_weights.safetensors
Loading 0.in_proj.weight
Loading 0.out_proj.weight
Loading 0.out_proj.bias
Loading 8.in_proj.weight
Loading 8.out_proj.weight
Loading 8.out_proj.bias
Loading 16.in_proj.weight
Loading 16.out_proj.weight
Loading 16.out_proj.bias
Loading 24.in_proj.weight
Loading 24.out_proj.weight
Loading 24.out_proj.bias
Loading 32.in_proj.weight
Loading 32.out_proj.weight
Loading 32.out_proj.bias
Loading 40.in_proj.weight
Loading 40.out_proj.weight
Loading 40.out_proj.bias
Loading 48.in_proj.weight
Loading 48.out_proj.weight
Loading 48.out_proj.bias
Loading 56.in_proj.weight
Loading 56.out_proj.weight
Loading 56.out_proj.bias
Loading 64.in_proj.weight
Loading 64.out_proj.weight
Loading 64.out_proj.bias
Loading 72.in_proj.weight
Loading 72.out_proj.weight
Loading 72.out_proj.bias
Loading 80.in_proj.weight
Loading 80.out_proj.weight
Loading 80.out_proj.bias
Loading 88.in_proj.weight
Loading 88.out_proj.weight
Loading 88.out_proj.bias
Loading 96

In [3]:
import inspect
import os
from typing import Union

import PIL
import numpy as np
import torch
import tqdm
from diffusers.utils.torch_utils import randn_tensor

from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,
 prepare_mask_image, compute_vae_encodings)
from ddpm import DDPMSampler

class CatVTONPipeline:
 def __init__(
 self, 
 weight_dtype=torch.float32,
 device='cuda',
 compile=False,
 skip_safety_check=True,
 use_tf32=True,
 models={},
 ):
 self.device = device
 self.weight_dtype = weight_dtype
 self.skip_safety_check = skip_safety_check
 self.models = models

 self.generator = torch.Generator(device=device)
 self.noise_scheduler = DDPMSampler(generator=self.generator)
 # self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype)
 self.encoder= models.get('encoder', None)
 self.decoder= models.get('decoder', None)
 
 self.unet=models.get('diffusion', None) 
 # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
 if use_tf32:
 torch.set_float32_matmul_precision("high")
 torch.backends.cuda.matmul.allow_tf32 = True

 @torch.no_grad()
 def __call__(
 self, 
 image: Union[PIL.Image.Image, torch.Tensor],
 condition_image: Union[PIL.Image.Image, torch.Tensor],
 mask: Union[PIL.Image.Image, torch.Tensor],
 num_inference_steps: int = 50,
 guidance_scale: float = 2.5,
 height: int = 1024,
 width: int = 768,
 generator=None,
 eta=1.0,
 **kwargs
 ):
 concat_dim = -2 # FIXME: y axis concat
 # Prepare inputs to Tensor
 image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)
 image = prepare_image(image).to(self.device, dtype=self.weight_dtype)
 condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)
 mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)
 # Mask image
 masked_image = image * (mask < 0.5)
 # VAE encoding
 masked_latent = compute_vae_encodings(masked_image, self.encoder)
 condition_latent = compute_vae_encodings(condition_image, self.encoder)
 mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
 del image, mask, condition_image
 # Concatenate latents
 masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)
 mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)
 # Prepare noise
 latents = randn_tensor(
 masked_latent_concat.shape,
 generator=generator,
 device=masked_latent_concat.device,
 dtype=self.weight_dtype,
 )
 # Prepare timesteps
 self.noise_scheduler.set_inference_timesteps(num_inference_steps)
 timesteps = self.noise_scheduler.timesteps
 # latents = latents * self.noise_scheduler.init_noise_sigma
 latents = self.noise_scheduler.add_noise(latents, timesteps[0])
 
 # Classifier-Free Guidance
 if do_classifier_free_guidance := (guidance_scale > 1.0):
 masked_latent_concat = torch.cat(
 [
 torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
 masked_latent_concat,
 ]
 )
 mask_latent_concat = torch.cat([mask_latent_concat] * 2)

 num_warmup_steps = 0 # For simple DDPM, no warmup needed
 with tqdm(total=num_inference_steps) as progress_bar:
 for i, t in enumerate(timesteps):
 # expand the latents if we are doing classifier free guidance
 non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
 # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)
 # prepare the input for the inpainting model
 inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)
 # predict the noise residual
 
 timestep = t.repeat(inpainting_latent_model_input.shape[0])
 time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)

 noise_pred = self.unet(
 inpainting_latent_model_input,
 time_embedding
 )
 # perform guidance
 if do_classifier_free_guidance:
 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 noise_pred = noise_pred_uncond + guidance_scale * (
 noise_pred_text - noise_pred_uncond
 )
 # compute the previous noisy sample x_t -> x_t-1
 latents = self.noise_scheduler.step(
 t, latents, noise_pred
 )
 # call the callback, if provided
 if i == len(timesteps) - 1 or (
 (i + 1) > num_warmup_steps
 ):
 progress_bar.update()

 # Decode the final latents
 latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
 # latents = 1 / self.vae.config.scaling_factor * latents
 # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample
 image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))
 image = (image / 2 + 0.5).clamp(0, 1)
 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
 image = numpy_to_pil(image)
 
 return image


 from .autonotebook import tqdm as notebook_tqdm


In [4]:
import os
import numpy as np
import torch
import argparse
from torch.utils.data import Dataset, DataLoader
from VITON_Dataset import VITONHDTestDataset
from diffusers.image_processor import VaeImageProcessor
from tqdm import tqdm
from PIL import Image, ImageFilter

from utils import repaint, to_pil_image

@torch.no_grad()
def main():
 args=argparse.Namespace()
 args.__dict__= {
 "dataset_name": "vitonhd",
 "data_root_path": "./sample_dataset",
 "output_dir": "./mask-based-output",
 "seed": 555,
 "batch_size": 1,
 "num_inference_steps": 50,
 "guidance_scale": 2.5,
 "width": 384,
 "height": 512,
 "repaint": True,
 "eval_pair": False,
 "concat_eval_results": True,
 "allow_tf32": True,
 "dataloader_num_workers": 4,
 "mixed_precision": 'no',
 "concat_axis": 'y',
 "enable_condition_noise": True,
 "is_train": False
 }

 # Pipeline
 pipeline = CatVTONPipeline(
 weight_dtype={
 "no": torch.float32,
 "fp16": torch.float16,
 "bf16": torch.bfloat16,
 }[args.mixed_precision],
 device="cuda",
 skip_safety_check=True,
 models=models,
 )
 # Dataset
 if args.dataset_name == "vitonhd":
 dataset = VITONHDTestDataset(args)
 else:
 raise ValueError(f"Invalid dataset name {args.dataset}.")
 print(f"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.")
 dataloader = DataLoader(
 dataset,
 batch_size=args.batch_size,
 shuffle=False,
 num_workers=args.dataloader_num_workers
 )
 
 # Inference
 generator = torch.Generator(device='cuda').manual_seed(args.seed)
 args.output_dir = os.path.join(args.output_dir, f"{args.dataset_name}-{args.height}", "paired" if args.eval_pair else "unpaired")
 if not os.path.exists(args.output_dir):
 os.makedirs(args.output_dir)
 
 for batch in tqdm(dataloader):
 person_images = batch['person']
 cloth_images = batch['cloth']
 masks = batch['mask']

 results = pipeline(
 person_images,
 cloth_images,
 masks,
 num_inference_steps=args.num_inference_steps,
 guidance_scale=args.guidance_scale,
 height=args.height,
 width=args.width,
 generator=generator,
 )
 
 if args.concat_eval_results or args.repaint:
 person_images = to_pil_image(person_images)
 cloth_images = to_pil_image(cloth_images)
 masks = to_pil_image(masks)
 for i, result in enumerate(results):
 person_name = batch['person_name'][i]
 output_path = os.path.join(args.output_dir, person_name)
 if not os.path.exists(os.path.dirname(output_path)):
 os.makedirs(os.path.dirname(output_path))
 if args.repaint:
 person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']
 person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)
 mask = Image.open(mask_path).resize(result.size, Image.NEAREST)
 result = repaint(person_image, mask, result)
 if args.concat_eval_results:
 w, h = result.size
 concated_result = Image.new('RGB', (w*3, h))
 concated_result.paste(person_images[i], (0, 0))
 concated_result.paste(cloth_images[i], (w, 0)) 
 concated_result.paste(result, (w*2, 0))
 result = concated_result
 result.save(output_path)

if __name__ == "__main__":
 main()

Dataset vitonhd loaded, total 20 pairs.


100%|██████████| 50/50 [00:07<00:00, 7.04it/s]
100%|██████████| 50/50 [00:06<00:00, 7.32it/s]
100%|██████████| 50/50 [00:07<00:00, 7.01it/s]
100%|██████████| 50/50 [00:07<00:00, 6.82it/s]
100%|██████████| 50/50 [00:07<00:00, 6.86it/s]
100%|██████████| 50/50 [00:06<00:00, 7.25it/s]
100%|██████████| 50/50 [00:06<00:00, 7.24it/s]
100%|██████████| 50/50 [00:07<00:00, 6.89it/s]
100%|██████████| 50/50 [00:07<00:00, 6.90it/s]
100%|██████████| 50/50 [00:07<00:00, 7.02it/s]
100%|██████████| 50/50 [00:06<00:00, 7.40it/s]
100%|██████████| 50/50 [00:06<00:00, 7.15it/s]
100%|██████████| 50/50 [00:07<00:00, 6.79it/s]
100%|██████████| 50/50 [00:07<00:00, 7.07it/s]
100%|██████████| 50/50 [00:07<00:00, 7.14it/s]
100%|██████████| 50/50 [00:06<00:00, 7.32it/s]
100%|██████████| 50/50 [00:07<00:00, 7.13it/s]
100%|██████████| 50/50 [00:07<00:00, 7.05it/s]
100%|██████████| 50/50 [00:07<00:00, 7.06it/s]
100%|██████████| 50/50 [00:07<00:00, 7.09it/s]
100%|██████████| 20/20 [02:28<00:00, 7.40s/it]
