{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "237f5cbf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model already downloaded.\n" ] } ], "source": [ "# !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", "\n", "# check if the model is downloaded, if not download it\n", "import os\n", "if not os.path.exists(\"sd-v1-5-inpainting.ckpt\"):\n", " !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", "else:\n", " print(\"Model already downloaded.\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "24bd99d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded finetuned weights from finetuned_weights.safetensors\n", "Loading 0.in_proj.weight\n", "Loading 0.out_proj.weight\n", "Loading 0.out_proj.bias\n", "Loading 8.in_proj.weight\n", "Loading 8.out_proj.weight\n", "Loading 8.out_proj.bias\n", "Loading 16.in_proj.weight\n", "Loading 16.out_proj.weight\n", "Loading 16.out_proj.bias\n", "Loading 24.in_proj.weight\n", "Loading 24.out_proj.weight\n", "Loading 24.out_proj.bias\n", "Loading 32.in_proj.weight\n", "Loading 32.out_proj.weight\n", "Loading 32.out_proj.bias\n", "Loading 40.in_proj.weight\n", "Loading 40.out_proj.weight\n", "Loading 40.out_proj.bias\n", "Loading 48.in_proj.weight\n", "Loading 48.out_proj.weight\n", "Loading 48.out_proj.bias\n", "Loading 56.in_proj.weight\n", "Loading 56.out_proj.weight\n", "Loading 56.out_proj.bias\n", "Loading 64.in_proj.weight\n", "Loading 64.out_proj.weight\n", "Loading 64.out_proj.bias\n", "Loading 72.in_proj.weight\n", "Loading 72.out_proj.weight\n", "Loading 72.out_proj.bias\n", "Loading 80.in_proj.weight\n", "Loading 80.out_proj.weight\n", "Loading 80.out_proj.bias\n", "Loading 88.in_proj.weight\n", "Loading 88.out_proj.weight\n", "Loading 88.out_proj.bias\n", "Loading 96.in_proj.weight\n", "Loading 96.out_proj.weight\n", "Loading 96.out_proj.bias\n", "Loading 104.in_proj.weight\n", "Loading 104.out_proj.weight\n", "Loading 104.out_proj.bias\n", "Loading 112.in_proj.weight\n", "Loading 112.out_proj.weight\n", "Loading 112.out_proj.bias\n", "Loading 120.in_proj.weight\n", "Loading 120.out_proj.weight\n", "Loading 120.out_proj.bias\n", "\n", "Attention module weights loaded from {finetune_weights_path} successfully.\n" ] } ], "source": [ "import load_model\n", "\n", "models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"finetuned_weights.safetensors\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "bab24c29", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import inspect\n", "import os\n", "from typing import Union\n", "\n", "import PIL\n", "import numpy as np\n", "import torch\n", "import tqdm\n", "from diffusers.utils.torch_utils import randn_tensor\n", "\n", "from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n", " prepare_mask_image, compute_vae_encodings)\n", "from ddpm import DDPMSampler\n", "\n", "class CatVTONPipeline:\n", " def __init__(\n", " self, \n", " weight_dtype=torch.float32,\n", " device='cuda',\n", " compile=False,\n", " skip_safety_check=True,\n", " use_tf32=True,\n", " models={},\n", " ):\n", " self.device = device\n", " self.weight_dtype = weight_dtype\n", " self.skip_safety_check = skip_safety_check\n", " self.models = models\n", "\n", " self.generator = torch.Generator(device=device)\n", " self.noise_scheduler = DDPMSampler(generator=self.generator)\n", " # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n", " self.encoder= models.get('encoder', None)\n", " self.decoder= models.get('decoder', None)\n", " \n", " self.unet=models.get('diffusion', None) \n", " # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n", " if use_tf32:\n", " torch.set_float32_matmul_precision(\"high\")\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", " @torch.no_grad()\n", " def __call__(\n", " self, \n", " image: Union[PIL.Image.Image, torch.Tensor],\n", " condition_image: Union[PIL.Image.Image, torch.Tensor],\n", " mask: Union[PIL.Image.Image, torch.Tensor],\n", " num_inference_steps: int = 50,\n", " guidance_scale: float = 2.5,\n", " height: int = 1024,\n", " width: int = 768,\n", " generator=None,\n", " eta=1.0,\n", " **kwargs\n", " ):\n", " concat_dim = -2 # FIXME: y axis concat\n", " # Prepare inputs to Tensor\n", " image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)\n", " image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n", " condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n", " mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)\n", " # Mask image\n", " masked_image = image * (mask < 0.5)\n", " # VAE encoding\n", " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n", " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n", " del image, mask, condition_image\n", " # Concatenate latents\n", " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", " # Prepare noise\n", " latents = randn_tensor(\n", " masked_latent_concat.shape,\n", " generator=generator,\n", " device=masked_latent_concat.device,\n", " dtype=self.weight_dtype,\n", " )\n", " # Prepare timesteps\n", " self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n", " timesteps = self.noise_scheduler.timesteps\n", " # latents = latents * self.noise_scheduler.init_noise_sigma\n", " latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n", " \n", " # Classifier-Free Guidance\n", " if do_classifier_free_guidance := (guidance_scale > 1.0):\n", " masked_latent_concat = torch.cat(\n", " [\n", " torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n", " masked_latent_concat,\n", " ]\n", " )\n", " mask_latent_concat = torch.cat([mask_latent_concat] * 2)\n", "\n", " num_warmup_steps = 0 # For simple DDPM, no warmup needed\n", " with tqdm(total=num_inference_steps) as progress_bar:\n", " for i, t in enumerate(timesteps):\n", " # expand the latents if we are doing classifier free guidance\n", " non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n", " # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)\n", " # prepare the input for the inpainting model\n", " inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)\n", " # predict the noise residual\n", " \n", " timestep = t.repeat(inpainting_latent_model_input.shape[0])\n", " time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)\n", "\n", " noise_pred = self.unet(\n", " inpainting_latent_model_input,\n", " time_embedding\n", " )\n", " # perform guidance\n", " if do_classifier_free_guidance:\n", " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", " noise_pred = noise_pred_uncond + guidance_scale * (\n", " noise_pred_text - noise_pred_uncond\n", " )\n", " # compute the previous noisy sample x_t -> x_t-1\n", " latents = self.noise_scheduler.step(\n", " t, latents, noise_pred\n", " )\n", " # call the callback, if provided\n", " if i == len(timesteps) - 1 or (\n", " (i + 1) > num_warmup_steps\n", " ):\n", " progress_bar.update()\n", "\n", " # Decode the final latents\n", " latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n", " # latents = 1 / self.vae.config.scaling_factor * latents\n", " # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n", " image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n", " image = (image / 2 + 0.5).clamp(0, 1)\n", " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n", " image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n", " image = numpy_to_pil(image)\n", " \n", " return image\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "a729bf46", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset vitonhd loaded, total 20 pairs.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:07<00:00, 7.04it/s]\n", "100%|██████████| 50/50 [00:06<00:00, 7.32it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.01it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 6.82it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 6.86it/s]\n", "100%|██████████| 50/50 [00:06<00:00, 7.25it/s]\n", "100%|██████████| 50/50 [00:06<00:00, 7.24it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 6.89it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 6.90it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.02it/s]\n", "100%|██████████| 50/50 [00:06<00:00, 7.40it/s]\n", "100%|██████████| 50/50 [00:06<00:00, 7.15it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 6.79it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.07it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.14it/s]\n", "100%|██████████| 50/50 [00:06<00:00, 7.32it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.13it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.05it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.06it/s]\n", "100%|██████████| 50/50 [00:07<00:00, 7.09it/s]\n", "100%|██████████| 20/20 [02:28<00:00, 7.40s/it]\n" ] } ], "source": [ "import os\n", "import numpy as np\n", "import torch\n", "import argparse\n", "from torch.utils.data import Dataset, DataLoader\n", "from VITON_Dataset import VITONHDTestDataset\n", "from diffusers.image_processor import VaeImageProcessor\n", "from tqdm import tqdm\n", "from PIL import Image, ImageFilter\n", "\n", "from utils import repaint, to_pil_image\n", "\n", "@torch.no_grad()\n", "def main():\n", " args=argparse.Namespace()\n", " args.__dict__= {\n", " \"dataset_name\": \"vitonhd\",\n", " \"data_root_path\": \"./sample_dataset\",\n", " \"output_dir\": \"./mask-based-output\",\n", " \"seed\": 555,\n", " \"batch_size\": 1,\n", " \"num_inference_steps\": 50,\n", " \"guidance_scale\": 2.5,\n", " \"width\": 384,\n", " \"height\": 512,\n", " \"repaint\": True,\n", " \"eval_pair\": False,\n", " \"concat_eval_results\": True,\n", " \"allow_tf32\": True,\n", " \"dataloader_num_workers\": 4,\n", " \"mixed_precision\": 'no',\n", " \"concat_axis\": 'y',\n", " \"enable_condition_noise\": True,\n", " \"is_train\": False\n", " }\n", "\n", " # Pipeline\n", " pipeline = CatVTONPipeline(\n", " weight_dtype={\n", " \"no\": torch.float32,\n", " \"fp16\": torch.float16,\n", " \"bf16\": torch.bfloat16,\n", " }[args.mixed_precision],\n", " device=\"cuda\",\n", " skip_safety_check=True,\n", " models=models,\n", " )\n", " # Dataset\n", " if args.dataset_name == \"vitonhd\":\n", " dataset = VITONHDTestDataset(args)\n", " else:\n", " raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n", " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n", " dataloader = DataLoader(\n", " dataset,\n", " batch_size=args.batch_size,\n", " shuffle=False,\n", " num_workers=args.dataloader_num_workers\n", " )\n", " \n", " # Inference\n", " generator = torch.Generator(device='cuda').manual_seed(args.seed)\n", " args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n", " if not os.path.exists(args.output_dir):\n", " os.makedirs(args.output_dir)\n", " \n", " for batch in tqdm(dataloader):\n", " person_images = batch['person']\n", " cloth_images = batch['cloth']\n", " masks = batch['mask']\n", "\n", " results = pipeline(\n", " person_images,\n", " cloth_images,\n", " masks,\n", " num_inference_steps=args.num_inference_steps,\n", " guidance_scale=args.guidance_scale,\n", " height=args.height,\n", " width=args.width,\n", " generator=generator,\n", " )\n", " \n", " if args.concat_eval_results or args.repaint:\n", " person_images = to_pil_image(person_images)\n", " cloth_images = to_pil_image(cloth_images)\n", " masks = to_pil_image(masks)\n", " for i, result in enumerate(results):\n", " person_name = batch['person_name'][i]\n", " output_path = os.path.join(args.output_dir, person_name)\n", " if not os.path.exists(os.path.dirname(output_path)):\n", " os.makedirs(os.path.dirname(output_path))\n", " if args.repaint:\n", " person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']\n", " person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)\n", " mask = Image.open(mask_path).resize(result.size, Image.NEAREST)\n", " result = repaint(person_image, mask, result)\n", " if args.concat_eval_results:\n", " w, h = result.size\n", " concated_result = Image.new('RGB', (w*3, h))\n", " concated_result.paste(person_images[i], (0, 0))\n", " concated_result.paste(cloth_images[i], (w, 0)) \n", " concated_result.paste(result, (w*2, 0))\n", " result = concated_result\n", " result.save(output_path)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": null, "id": "55d88911", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "harsh", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }