diff --git a/.gitignore b/.gitignore index 7d5c8833df2de7411771e3b3588bfe7180cfd019..8b73bce0495bebe68e851e5551bdf5a7c05b87dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ *inkpunk-diffusion-v1.ckpt +*instruct-pix2pix* *sd-v1-5-inpainting.ckpt *zalando-hd-resized.zip - +*finetuned_weights.safetensors +*maskfree_finetuned_weights.safetensors # *viton-hd-dataset.zip viton-hd-dataset/ checkpoints/ diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..44475ab8ce3d4510a8f1e3ace822b08f4cf3b202 --- /dev/null +++ b/app.py @@ -0,0 +1,430 @@ +import os +import torch +import gradio as gr +from PIL import Image +import numpy as np +from typing import Optional + +# Import your custom modules +from load_model import preload_models_from_standard_weights +from utils import to_pil_image + +import inspect +import os +from typing import Union + +import PIL +import numpy as np +import torch +import tqdm +from diffusers.utils.torch_utils import randn_tensor + +from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings) +from ddpm import DDPMSampler + + +class CatVTONPix2PixPipeline: + def __init__( + self, + weight_dtype=torch.float32, + device='cuda', + compile=False, + skip_safety_check=True, + use_tf32=True, + models={}, + ): + self.device = device + self.weight_dtype = weight_dtype + self.skip_safety_check = skip_safety_check + self.models = models + + self.generator = torch.Generator(device=device) + self.noise_scheduler = DDPMSampler(generator=self.generator) + self.encoder= models.get('encoder', None) + self.decoder= models.get('decoder', None) + self.unet=models.get('diffusion', None) + + # Enable TF32 for faster training on Ampere GPUs + if use_tf32: + torch.set_float32_matmul_precision("high") + torch.backends.cuda.matmul.allow_tf32 = True + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, torch.Tensor], + condition_image: Union[PIL.Image.Image, torch.Tensor], + num_inference_steps: int = 50, + guidance_scale: float = 2.5, + height: int = 1024, + width: int = 768, + generator=None, + eta=1.0, + **kwargs + ): + concat_dim = -1 # FIXME: y axis concat + # Prepare inputs to Tensor + image, condition_image = check_inputs_maskfree(image, condition_image, width, height) + + # Ensure consistent dtype for all tensors + image = prepare_image(image).to(self.device, dtype=self.weight_dtype) + condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype) + + # Encode the image + image_latent = compute_vae_encodings(image, self.encoder) + condition_latent = compute_vae_encodings(condition_image, self.encoder) + + del image, condition_image + + # Concatenate latents + condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim) + + # Prepare noise + latents = randn_tensor( + condition_latent_concat.shape, + generator=generator, + device=condition_latent_concat.device, + dtype=self.weight_dtype, + ) + + # Prepare timesteps + self.noise_scheduler.set_inference_timesteps(num_inference_steps) + timesteps = self.noise_scheduler.timesteps + latents = self.noise_scheduler.add_noise(latents, timesteps[0]) + + # Classifier-Free Guidance + if do_classifier_free_guidance := (guidance_scale > 1.0): + condition_latent_concat = torch.cat( + [ + torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim), + condition_latent_concat, + ] + ) + + num_warmup_steps = 0 # For simple DDPM, no warmup needed + with tqdm.tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents) + + # prepare the input for the inpainting model + p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1) + + # predict the noise residual + timestep = t.repeat(p2p_latent_model_input.shape[0]) + time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype) + + noise_pred = self.unet( + p2p_latent_model_input, + time_embedding + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.noise_scheduler.step( + t, latents, noise_pred + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps): + progress_bar.update() + + # Decode the final latents + latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0] + image = self.decoder(latents.to(self.device, dtype=self.weight_dtype)) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + image = numpy_to_pil(image) + + return image + +def load_models(): + try: + print("🚀 Starting model loading process...") + + # Check CUDA availability + cuda_available = torch.cuda.is_available() + print(f"CUDA available: {cuda_available}") + if cuda_available: + print(f"CUDA device: {torch.cuda.get_device_name()}") + free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) + print(f"Available CUDA memory: {free_memory / 1e9:.2f} GB") + + device = "cuda" if cuda_available else "cpu" + + # Check if model files exist + ckpt_path = "instruct-pix2pix-00-22000.ckpt" + finetune_path = "maskfree_finetuned_weights.safetensors" + + if not os.path.exists(ckpt_path): + print(f"❌ Checkpoint file not found: {ckpt_path}") + return None, None + + if not os.path.exists(finetune_path): + print(f"❌ Finetune weights file not found: {finetune_path}") + return None, None + + print("📦 Loading models from weights...") + + models = preload_models_from_standard_weights( + ckpt_path=ckpt_path, + device=device, + finetune_weights_path=finetune_path + ) + + if not models: + print("❌ Failed to load models") + return None, None + + # Convert all models to consistent dtype to avoid mixed precision issues + weight_dtype = torch.float32 # Use float32 to avoid dtype mismatch + print(f"Converting models to {weight_dtype}...") + + # Ensure all models use the same dtype + for model_name, model in models.items(): + if model is not None: + try: + model = model.to(dtype=weight_dtype) + models[model_name] = model + print(f"✅ {model_name} converted to {weight_dtype}") + except Exception as e: + print(f"⚠️ Could not convert {model_name} to {weight_dtype}: {e}") + + print("🔧 Initializing pipeline...") + + pipeline = CatVTONPix2PixPipeline( + weight_dtype=weight_dtype, + device=device, + skip_safety_check=True, + models=models, + ) + + print("✅ Models and pipeline loaded successfully!") + return models, pipeline + + except Exception as e: + print(f"❌ Error in load_models: {e}") + import traceback + traceback.print_exc() + return None, None + +def person_example_fn(image_path): + """Handle person image examples""" + if image_path: + return image_path + return None + +def create_demo(pipeline=None): + """Create the Gradio interface""" + + def submit_function_p2p( + person_image_path: Optional[str], + cloth_image_path: Optional[str], + num_inference_steps: int = 50, + guidance_scale: float = 2.5, + seed: int = 42, + ) -> Optional[Image.Image]: + """Process virtual try-on inference""" + + try: + if not person_image_path or not cloth_image_path: + gr.Warning("Please upload both person and cloth images!") + return None + + if not os.path.exists(person_image_path): + gr.Error("Person image file not found!") + return None + + if not os.path.exists(cloth_image_path): + gr.Error("Cloth image file not found!") + return None + + if pipeline is None: + gr.Error("Models not loaded! Please restart the application.") + return None + + # Load images + try: + person_image = Image.open(person_image_path).convert('RGB') + cloth_image = Image.open(cloth_image_path).convert('RGB') + except Exception as e: + gr.Error(f"Error loading images: {str(e)}") + return None + + # Set up generator + generator = torch.Generator(device=pipeline.device) + if seed != -1: + generator.manual_seed(seed) + + print("🔄 Processing virtual try-on...") + + # Run inference + with torch.no_grad(): + results = pipeline( + person_image, + cloth_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + height=512, + width=384, + generator=generator, + ) + + # Process results + if isinstance(results, list) and len(results) > 0: + result = results[0] + else: + result = results + + return result + + except Exception as e: + print(f"❌ Error in submit_function_p2p: {e}") + import traceback + traceback.print_exc() + gr.Error(f"Error during inference: {str(e)}") + return None + + # Custom CSS for better styling + css = """ + .gradio-container { + max-width: 1200px !important; + } + .image-container { + max-height: 600px; + } + """ + + with gr.Blocks(css=css, title="Virtual Try-On") as demo: + gr.HTML(""" +
+

🧥 Virtual Try-On with CatVTON

+

Upload a person image and a clothing item to see how they look together!

+
+ """) + + with gr.Tab("Mask-Free Virtual Try-On"): + with gr.Row(): + with gr.Column(scale=1, min_width=350): + with gr.Row(): + image_path_p2p = gr.Image( + type="filepath", + interactive=True, + visible=False, + ) + person_image_p2p = gr.Image( + interactive=True, + label="Person Image", + type="filepath", + elem_classes=["image-container"] + ) + + with gr.Row(): + cloth_image_p2p = gr.Image( + interactive=True, + label="Clothing Image", + type="filepath", + elem_classes=["image-container"] + ) + + submit_p2p = gr.Button("✨ Generate Try-On", variant="primary", size="lg") + + gr.Markdown( + '
⚠️ Click only once and wait for processing!
' + ) + + with gr.Accordion("🔧 Advanced Options", open=False): + num_inference_steps_p2p = gr.Slider( + label="Inference Steps", + minimum=10, + maximum=100, + step=5, + value=50, + info="More steps = better quality but slower" + ) + guidance_scale_p2p = gr.Slider( + label="Guidance Scale", + minimum=0.0, + maximum=7.5, + step=0.5, + value=2.5, + info="Higher values = stronger conditioning" + ) + seed_p2p = gr.Slider( + label="Seed", + minimum=-1, + maximum=10000, + step=1, + value=42, + info="Use -1 for random seed" + ) + + with gr.Column(scale=2, min_width=500): + result_image_p2p = gr.Image( + interactive=False, + label="Result (Person | Clothing | Generated)", + elem_classes=["image-container"] + ) + + gr.Markdown(""" + ### 📋 Instructions: + 1. Upload a **person image** (front-facing works best) + 2. Upload a **clothing item** you want to try on + 3. Adjust advanced settings if needed + 4. Click "Generate Try-On" and wait + + ### 💡 Tips: + - Use clear, high-resolution images + - Person should be facing forward + - Clothing items work best when laid flat or on a model + - Try different seeds if you're not satisfied with results + """) + + # Event handlers + image_path_p2p.change( + person_example_fn, + inputs=image_path_p2p, + outputs=person_image_p2p + ) + + submit_p2p.click( + submit_function_p2p, + inputs=[ + person_image_p2p, + cloth_image_p2p, + num_inference_steps_p2p, + guidance_scale_p2p, + seed_p2p, + ], + outputs=result_image_p2p, + ) + + return demo + +def app_gradio(): + """Main application function""" + + # Load models at startup + print("🚀 Loading models...") + models, pipeline = load_models() + if not models or not pipeline: + print("❌ Failed to load models. Please check your model files.") + return + + # Create and launch demo + demo = create_demo(pipeline=pipeline) + demo.launch( + share=True, + show_error=True, + server_name="0.0.0.0", + server_port=7860 + ) + +if __name__ == "__main__": + app_gradio() \ No newline at end of file diff --git a/load_model.py b/load_model.py index a78123c3afff8e45e8cbb34643c85bf837a09788..e49bf97e2f66c968eec2ad12aad737cb7f69cd8f 100644 --- a/load_model.py +++ b/load_model.py @@ -78,7 +78,12 @@ def load_finetuned_attention_weights(finetune_weights_path, diffusion, device): def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None): # CatVTON parameters + # in_channels: 8 for instruct-pix2pix (masked free), 9 for sd-v1-5-inpainting (masked based) in_channels = 9 + + if 'maskfree' in finetune_weights_path or 'mask_free' in finetune_weights_path: + in_channels = 8 + out_channels = 4 state_dict=model_converter.load_from_standard_weights(ckpt_path, device) diff --git a/mask-based-output/vitonhd-512/unpaired/00654_00.jpg b/mask-based-output/vitonhd-512/unpaired/00654_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..426a7db299da61f2b3ef5f63ededfefb4d7d11d5 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/00654_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/01265_00.jpg b/mask-based-output/vitonhd-512/unpaired/01265_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bb9fd57dc6a06337999739f38422e062eea07f40 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/01265_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/01985_00.jpg b/mask-based-output/vitonhd-512/unpaired/01985_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..203e13a647c5b24539d1204a3a23a84c01d2e7cb Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/01985_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/02023_00.jpg b/mask-based-output/vitonhd-512/unpaired/02023_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c2e344617f59b0b4cdaa75792ad8bf94c2e09b8 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/02023_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/02532_00.jpg b/mask-based-output/vitonhd-512/unpaired/02532_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d9ab08223986e133915cab6bfb57f042d2355fe Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/02532_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/02944_00.jpg b/mask-based-output/vitonhd-512/unpaired/02944_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b876388b36613dc98ab504dfddbe188c212a9435 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/02944_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/03191_00.jpg b/mask-based-output/vitonhd-512/unpaired/03191_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33ff2fa436b793b1da60b85a9c6f0b7f811ecd65 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/03191_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/03921_00.jpg b/mask-based-output/vitonhd-512/unpaired/03921_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9a3c3dd63d7fdc642a01491fdae77e3dc799e0c8 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/03921_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/05006_00.jpg b/mask-based-output/vitonhd-512/unpaired/05006_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d0ede99acdf085dfc0f40c0d230bd8ff8e2bd5e Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/05006_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/05378_00.jpg b/mask-based-output/vitonhd-512/unpaired/05378_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..02152c7544054dc8a9835352af01d6694db9e154 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/05378_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/07342_00.jpg b/mask-based-output/vitonhd-512/unpaired/07342_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1dcb08eb5102e82f00371556d617aa8e2a6c5c3c Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/07342_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/08088_00.jpg b/mask-based-output/vitonhd-512/unpaired/08088_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9123c591eed312567598f03a3ff7812ab24832eb Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/08088_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/08239_00.jpg b/mask-based-output/vitonhd-512/unpaired/08239_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94dee6120137d9ad8ce7934aedcecc0edfa5ec1d Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/08239_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/08650_00.jpg b/mask-based-output/vitonhd-512/unpaired/08650_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6b7dc2c9653cf81c54c593ef7ca60685ab91fbe7 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/08650_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/08839_00.jpg b/mask-based-output/vitonhd-512/unpaired/08839_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..877f8b46e6a0b13e1eddc2e291346493e13712fa Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/08839_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/11085_00.jpg b/mask-based-output/vitonhd-512/unpaired/11085_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..094449ed1949592950abacd98f28894c11097e35 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/11085_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/12345_00.jpg b/mask-based-output/vitonhd-512/unpaired/12345_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..acf474fdfc13b130e59efdf4bfc6db59a24c7fd0 Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/12345_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/12419_00.jpg b/mask-based-output/vitonhd-512/unpaired/12419_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d72b620505d34f870d389756819b639d25d3d4e Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/12419_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/12562_00.jpg b/mask-based-output/vitonhd-512/unpaired/12562_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6b05a57420b6935c89cca0ba4383ee67f62893a Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/12562_00.jpg differ diff --git a/mask-based-output/vitonhd-512/unpaired/14651_00.jpg b/mask-based-output/vitonhd-512/unpaired/14651_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac8736d4a918a60de597152c50c64c2e48bfb90a Binary files /dev/null and b/mask-based-output/vitonhd-512/unpaired/14651_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/00654_00.jpg b/mask-free-output/vitonhd-512/unpaired/00654_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0257c01dcb4c1b2c43f4304fb1bf83cd171209e Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/00654_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/01265_00.jpg b/mask-free-output/vitonhd-512/unpaired/01265_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6925069276f205361044cebe1bd2e8902478ee51 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/01265_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/01985_00.jpg b/mask-free-output/vitonhd-512/unpaired/01985_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..048f1cf92344730db8f4f49f60c99cf92de707c9 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/01985_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/02023_00.jpg b/mask-free-output/vitonhd-512/unpaired/02023_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4540be4c51feb07d15436e3786a7025d6bbe2fea Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/02023_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/02532_00.jpg b/mask-free-output/vitonhd-512/unpaired/02532_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ed1c6e7de0df52ce017ef466450c3966eb97814c Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/02532_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/02944_00.jpg b/mask-free-output/vitonhd-512/unpaired/02944_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d9dc65abdc38eb363483fef47610939c280a817 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/02944_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/03191_00.jpg b/mask-free-output/vitonhd-512/unpaired/03191_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a82ce367105012bf0234b14b96dfd7c7a9a45e0 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/03191_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/03921_00.jpg b/mask-free-output/vitonhd-512/unpaired/03921_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9ac62a3e26cc9951aeec90f831814ac6e2ad0caf Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/03921_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/05006_00.jpg b/mask-free-output/vitonhd-512/unpaired/05006_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4c71c960b70e1e1e87115fcd8ef848cf22875f80 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/05006_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/05378_00.jpg b/mask-free-output/vitonhd-512/unpaired/05378_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b15f1a8b2a2d6ded57ada390df7e850f31e2d46d Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/05378_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/07342_00.jpg b/mask-free-output/vitonhd-512/unpaired/07342_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3e40ed8260ac7c922ebc85380b4b1092b965753 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/07342_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/08088_00.jpg b/mask-free-output/vitonhd-512/unpaired/08088_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb49ffaa155c938ce5000e6d7433aef3d0096cfc Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/08088_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/08239_00.jpg b/mask-free-output/vitonhd-512/unpaired/08239_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c38883abe653dc52c7d93fe6ec0e53db2eb790c5 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/08239_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/08650_00.jpg b/mask-free-output/vitonhd-512/unpaired/08650_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..095f733d4ac8bad22c6deb984f8f2223d397d4d0 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/08650_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/08839_00.jpg b/mask-free-output/vitonhd-512/unpaired/08839_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a7760ea59219e19f3eb36374a88f4cf06ceaf85b Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/08839_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/11085_00.jpg b/mask-free-output/vitonhd-512/unpaired/11085_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4a6bab62aa0728cfb95262949678abee14686fee Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/11085_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/12345_00.jpg b/mask-free-output/vitonhd-512/unpaired/12345_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d7ec5289c266b5fbb4f9353a25ba1a8084801bac Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/12345_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/12419_00.jpg b/mask-free-output/vitonhd-512/unpaired/12419_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cdbe728f7e377c160829578fdde31b4627f2250d Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/12419_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/12562_00.jpg b/mask-free-output/vitonhd-512/unpaired/12562_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..62533abdc0f4a83a91a9ecf50ac7e1fc9f0e008d Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/12562_00.jpg differ diff --git a/mask-free-output/vitonhd-512/unpaired/14651_00.jpg b/mask-free-output/vitonhd-512/unpaired/14651_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24c49def22c6c399e4594debf56310936bebfad0 Binary files /dev/null and b/mask-free-output/vitonhd-512/unpaired/14651_00.jpg differ diff --git a/sample_inference.ipynb b/mask_based_inference.ipynb similarity index 90% rename from sample_inference.ipynb rename to mask_based_inference.ipynb index 0f6d011cfacc0e4e6a29eb3c817b1977b0d8fd3f..cbdca6f4a4499a7901e6d4b257cc20e8ecbd255c 100644 --- a/sample_inference.ipynb +++ b/mask_based_inference.ipynb @@ -28,6 +28,76 @@ { "cell_type": "code", "execution_count": 2, + "id": "24bd99d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded finetuned weights from finetuned_weights.safetensors\n", + "Loading 0.in_proj.weight\n", + "Loading 0.out_proj.weight\n", + "Loading 0.out_proj.bias\n", + "Loading 8.in_proj.weight\n", + "Loading 8.out_proj.weight\n", + "Loading 8.out_proj.bias\n", + "Loading 16.in_proj.weight\n", + "Loading 16.out_proj.weight\n", + "Loading 16.out_proj.bias\n", + "Loading 24.in_proj.weight\n", + "Loading 24.out_proj.weight\n", + "Loading 24.out_proj.bias\n", + "Loading 32.in_proj.weight\n", + "Loading 32.out_proj.weight\n", + "Loading 32.out_proj.bias\n", + "Loading 40.in_proj.weight\n", + "Loading 40.out_proj.weight\n", + "Loading 40.out_proj.bias\n", + "Loading 48.in_proj.weight\n", + "Loading 48.out_proj.weight\n", + "Loading 48.out_proj.bias\n", + "Loading 56.in_proj.weight\n", + "Loading 56.out_proj.weight\n", + "Loading 56.out_proj.bias\n", + "Loading 64.in_proj.weight\n", + "Loading 64.out_proj.weight\n", + "Loading 64.out_proj.bias\n", + "Loading 72.in_proj.weight\n", + "Loading 72.out_proj.weight\n", + "Loading 72.out_proj.bias\n", + "Loading 80.in_proj.weight\n", + "Loading 80.out_proj.weight\n", + "Loading 80.out_proj.bias\n", + "Loading 88.in_proj.weight\n", + "Loading 88.out_proj.weight\n", + "Loading 88.out_proj.bias\n", + "Loading 96.in_proj.weight\n", + "Loading 96.out_proj.weight\n", + "Loading 96.out_proj.bias\n", + "Loading 104.in_proj.weight\n", + "Loading 104.out_proj.weight\n", + "Loading 104.out_proj.bias\n", + "Loading 112.in_proj.weight\n", + "Loading 112.out_proj.weight\n", + "Loading 112.out_proj.bias\n", + "Loading 120.in_proj.weight\n", + "Loading 120.out_proj.weight\n", + "Loading 120.out_proj.bias\n", + "\n", + "Attention module weights loaded from {finetune_weights_path} successfully.\n" + ] + } + ], + "source": [ + "import load_model\n", + "\n", + "models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"finetuned_weights.safetensors\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "bab24c29", "metadata": {}, "outputs": [ @@ -183,77 +253,7 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "a069151e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded finetuned weights from finetuned_weights.safetensors\n", - "Loading 0.in_proj.weight\n", - "Loading 0.out_proj.weight\n", - "Loading 0.out_proj.bias\n", - "Loading 8.in_proj.weight\n", - "Loading 8.out_proj.weight\n", - "Loading 8.out_proj.bias\n", - "Loading 16.in_proj.weight\n", - "Loading 16.out_proj.weight\n", - "Loading 16.out_proj.bias\n", - "Loading 24.in_proj.weight\n", - "Loading 24.out_proj.weight\n", - "Loading 24.out_proj.bias\n", - "Loading 32.in_proj.weight\n", - "Loading 32.out_proj.weight\n", - "Loading 32.out_proj.bias\n", - "Loading 40.in_proj.weight\n", - "Loading 40.out_proj.weight\n", - "Loading 40.out_proj.bias\n", - "Loading 48.in_proj.weight\n", - "Loading 48.out_proj.weight\n", - "Loading 48.out_proj.bias\n", - "Loading 56.in_proj.weight\n", - "Loading 56.out_proj.weight\n", - "Loading 56.out_proj.bias\n", - "Loading 64.in_proj.weight\n", - "Loading 64.out_proj.weight\n", - "Loading 64.out_proj.bias\n", - "Loading 72.in_proj.weight\n", - "Loading 72.out_proj.weight\n", - "Loading 72.out_proj.bias\n", - "Loading 80.in_proj.weight\n", - "Loading 80.out_proj.weight\n", - "Loading 80.out_proj.bias\n", - "Loading 88.in_proj.weight\n", - "Loading 88.out_proj.weight\n", - "Loading 88.out_proj.bias\n", - "Loading 96.in_proj.weight\n", - "Loading 96.out_proj.weight\n", - "Loading 96.out_proj.bias\n", - "Loading 104.in_proj.weight\n", - "Loading 104.out_proj.weight\n", - "Loading 104.out_proj.bias\n", - "Loading 112.in_proj.weight\n", - "Loading 112.out_proj.weight\n", - "Loading 112.out_proj.bias\n", - "Loading 120.in_proj.weight\n", - "Loading 120.out_proj.weight\n", - "Loading 120.out_proj.bias\n", - "\n", - "Attention module weights loaded from {finetune_weights_path} successfully.\n" - ] - } - ], - "source": [ - "import load_model\n", - "\n", - "models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"finetuned_weights.safetensors\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "a729bf46", "metadata": {}, "outputs": [ @@ -268,15 +268,27 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 50/50 [00:11<00:00, 4.48it/s]\n", - "100%|██████████| 50/50 [00:10<00:00, 4.55it/s]\n", - "100%|██████████| 50/50 [00:11<00:00, 4.43it/s]\n", - "100%|██████████| 50/50 [00:11<00:00, 4.53it/s]\n", - "100%|██████████| 50/50 [00:11<00:00, 4.53it/s]\n", - "100%|██████████| 50/50 [00:11<00:00, 4.51it/s]\n", - "100%|██████████| 50/50 [00:10<00:00, 4.57it/s]\n", - "100%|██████████| 50/50 [00:11<00:00, 4.51it/s]\n", - " 40%|████ | 8/20 [01:32<02:17, 11.49s/it]" + "100%|██████████| 50/50 [00:07<00:00, 7.04it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.32it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.01it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.82it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.86it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.25it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.24it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.89it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.90it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.02it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.40it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.15it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.79it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.07it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.14it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.32it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.13it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.05it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.06it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.09it/s]\n", + "100%|██████████| 20/20 [02:28<00:00, 7.40s/it]\n" ] } ], @@ -299,7 +311,7 @@ " args.__dict__= {\n", " \"dataset_name\": \"vitonhd\",\n", " \"data_root_path\": \"./sample_dataset\",\n", - " \"output_dir\": \"./output\",\n", + " \"output_dir\": \"./mask-based-output\",\n", " \"seed\": 555,\n", " \"batch_size\": 1,\n", " \"num_inference_steps\": 50,\n", diff --git a/mask_free_inference.ipynb b/mask_free_inference.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..65ca8ee668bfbe980f16f97214aad22b48359e37 --- /dev/null +++ b/mask_free_inference.ipynb @@ -0,0 +1,449 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "6d50f66c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model already downloaded.\n" + ] + } + ], + "source": [ + "# check if the model is downloaded, if not download it\n", + "import os\n", + "if not os.path.exists(\"instruct-pix2pix-00-22000.ckpt\"):\n", + " !wget https://huggingface.co/timbrooks/instruct-pix2pix/resolve/main/instruct-pix2pix-00-22000.ckpt\n", + "else:\n", + " print(\"Model already downloaded.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3598a305", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded finetuned weights from maskfree_finetuned_weights.safetensors\n", + "Loading 0.in_proj.weight\n", + "Loading 0.out_proj.weight\n", + "Loading 0.out_proj.bias\n", + "Loading 8.in_proj.weight\n", + "Loading 8.out_proj.weight\n", + "Loading 8.out_proj.bias\n", + "Loading 16.in_proj.weight\n", + "Loading 16.out_proj.weight\n", + "Loading 16.out_proj.bias\n", + "Loading 24.in_proj.weight\n", + "Loading 24.out_proj.weight\n", + "Loading 24.out_proj.bias\n", + "Loading 32.in_proj.weight\n", + "Loading 32.out_proj.weight\n", + "Loading 32.out_proj.bias\n", + "Loading 40.in_proj.weight\n", + "Loading 40.out_proj.weight\n", + "Loading 40.out_proj.bias\n", + "Loading 48.in_proj.weight\n", + "Loading 48.out_proj.weight\n", + "Loading 48.out_proj.bias\n", + "Loading 56.in_proj.weight\n", + "Loading 56.out_proj.weight\n", + "Loading 56.out_proj.bias\n", + "Loading 64.in_proj.weight\n", + "Loading 64.out_proj.weight\n", + "Loading 64.out_proj.bias\n", + "Loading 72.in_proj.weight\n", + "Loading 72.out_proj.weight\n", + "Loading 72.out_proj.bias\n", + "Loading 80.in_proj.weight\n", + "Loading 80.out_proj.weight\n", + "Loading 80.out_proj.bias\n", + "Loading 88.in_proj.weight\n", + "Loading 88.out_proj.weight\n", + "Loading 88.out_proj.bias\n", + "Loading 96.in_proj.weight\n", + "Loading 96.out_proj.weight\n", + "Loading 96.out_proj.bias\n", + "Loading 104.in_proj.weight\n", + "Loading 104.out_proj.weight\n", + "Loading 104.out_proj.bias\n", + "Loading 112.in_proj.weight\n", + "Loading 112.out_proj.weight\n", + "Loading 112.out_proj.bias\n", + "Loading 120.in_proj.weight\n", + "Loading 120.out_proj.weight\n", + "Loading 120.out_proj.bias\n", + "\n", + "Attention module weights loaded from {finetune_weights_path} successfully.\n" + ] + } + ], + "source": [ + "import load_model\n", + "\n", + "models=load_model.preload_models_from_standard_weights(ckpt_path=\"instruct-pix2pix-00-22000.ckpt\", device=\"cuda\", finetune_weights_path=\"maskfree_finetuned_weights.safetensors\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "78e3d8b9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import inspect\n", + "import os\n", + "from typing import Union\n", + "\n", + "import PIL\n", + "import numpy as np\n", + "import torch\n", + "import tqdm\n", + "from diffusers.utils.torch_utils import randn_tensor\n", + "\n", + "from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings)\n", + "from ddpm import DDPMSampler\n", + "\n", + "class CatVTONPix2PixPipeline:\n", + " def __init__(\n", + " self, \n", + " weight_dtype=torch.float32,\n", + " device='cuda',\n", + " compile=False,\n", + " skip_safety_check=True,\n", + " use_tf32=True,\n", + " models={},\n", + " ):\n", + " self.device = device\n", + " self.weight_dtype = weight_dtype\n", + " self.skip_safety_check = skip_safety_check\n", + " self.models = models\n", + "\n", + " self.generator = torch.Generator(device=device)\n", + " self.noise_scheduler = DDPMSampler(generator=self.generator)\n", + " # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n", + " self.encoder= models.get('encoder', None)\n", + " self.decoder= models.get('decoder', None)\n", + " \n", + " self.unet=models.get('diffusion', None) \n", + " # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n", + " if use_tf32:\n", + " torch.set_float32_matmul_precision(\"high\")\n", + " torch.backends.cuda.matmul.allow_tf32 = True\n", + "\n", + " @torch.no_grad()\n", + " def __call__(\n", + " self, \n", + " image: Union[PIL.Image.Image, torch.Tensor],\n", + " condition_image: Union[PIL.Image.Image, torch.Tensor],\n", + " num_inference_steps: int = 50,\n", + " guidance_scale: float = 2.5,\n", + " height: int = 1024,\n", + " width: int = 768,\n", + " generator=None,\n", + " eta=1.0,\n", + " **kwargs\n", + " ):\n", + " concat_dim = -1 # FIXME: y axis concat\n", + " # Prepare inputs to Tensor\n", + " image, condition_image = check_inputs_maskfree(image, condition_image, width, height)\n", + " \n", + " image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n", + " \n", + " condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n", + " \n", + " # Encode the image\n", + " image_latent = compute_vae_encodings(image, self.encoder)\n", + " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", + " \n", + " del image, condition_image\n", + " # Concatenate latents\n", + " # Concatenate latents\n", + " condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)\n", + " # Prepare noise\n", + " latents = randn_tensor(\n", + " condition_latent_concat.shape,\n", + " generator=generator,\n", + " device=condition_latent_concat.device,\n", + " dtype=self.weight_dtype,\n", + " )\n", + " # Prepare timesteps\n", + " self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n", + " timesteps = self.noise_scheduler.timesteps\n", + " # latents = latents * self.noise_scheduler.init_noise_sigma\n", + " latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n", + " \n", + " # Classifier-Free Guidance\n", + " if do_classifier_free_guidance := (guidance_scale > 1.0):\n", + " condition_latent_concat = torch.cat(\n", + " [\n", + " torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n", + " condition_latent_concat,\n", + " ]\n", + " )\n", + "\n", + " num_warmup_steps = 0 # For simple DDPM, no warmup needed\n", + " with tqdm(total=num_inference_steps) as progress_bar:\n", + " for i, t in enumerate(timesteps):\n", + " # expand the latents if we are doing classifier free guidance\n", + " \n", + " latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n", + "\n", + " # prepare the input for the inpainting model\n", + " \n", + " p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)\n", + " # predict the noise residual\n", + " \n", + " timestep = t.repeat(p2p_latent_model_input.shape[0])\n", + " time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)\n", + "\n", + " noise_pred = self.unet(\n", + " p2p_latent_model_input,\n", + " time_embedding\n", + " )\n", + " # perform guidance\n", + " if do_classifier_free_guidance:\n", + " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", + " noise_pred = noise_pred_uncond + guidance_scale * (\n", + " noise_pred_text - noise_pred_uncond\n", + " )\n", + " # compute the previous noisy sample x_t -> x_t-1\n", + " latents = self.noise_scheduler.step(\n", + " t, latents, noise_pred\n", + " )\n", + " # call the callback, if provided\n", + " if i == len(timesteps) - 1 or (\n", + " (i + 1) > num_warmup_steps\n", + " ):\n", + " progress_bar.update()\n", + "\n", + " # Decode the final latents\n", + " latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n", + " # latents = 1 / self.vae.config.scaling_factor * latents\n", + " # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n", + " image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n", + " image = (image / 2 + 0.5).clamp(0, 1)\n", + " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n", + " image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n", + " image = numpy_to_pil(image)\n", + " \n", + " return image\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5627b2d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset vitonhd loaded, total 20 pairs.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 50/50 [00:07<00:00, 7.12it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.31it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.09it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.98it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.01it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.13it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.28it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 7.13it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.17it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.97it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.17it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.38it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.20it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.92it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.71it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.25it/s]\n", + "100%|██████████| 50/50 [00:06<00:00, 7.49it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.87it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.89it/s]\n", + "100%|██████████| 50/50 [00:07<00:00, 6.92it/s]\n", + "100%|██████████| 20/20 [02:26<00:00, 7.35s/it]\n" + ] + } + ], + "source": [ + "import os\n", + "import torch\n", + "import argparse\n", + "from torch.utils.data import DataLoader\n", + "from VITON_Dataset import VITONHDTestDataset\n", + "from tqdm import tqdm\n", + "from PIL import Image\n", + "\n", + "from utils import to_pil_image\n", + "\n", + "@torch.no_grad()\n", + "def main():\n", + " args=argparse.Namespace()\n", + " args.__dict__= {\n", + " \"dataset_name\": \"vitonhd\",\n", + " \"data_root_path\": \"./sample_dataset\",\n", + " \"output_dir\": \"./mask-free-output\",\n", + " \"seed\": 555,\n", + " \"batch_size\": 1,\n", + " \"num_inference_steps\": 50,\n", + " \"guidance_scale\": 2.5,\n", + " \"width\": 384,\n", + " \"height\": 512,\n", + " \"eval_pair\": False,\n", + " \"concat_eval_results\": True,\n", + " \"allow_tf32\": True,\n", + " \"dataloader_num_workers\": 4,\n", + " \"mixed_precision\": 'no',\n", + " \"concat_axis\": 'y',\n", + " \"enable_condition_noise\": True,\n", + " \"is_train\": False\n", + " }\n", + "\n", + " # Pipeline\n", + " pipeline = CatVTONPix2PixPipeline(\n", + " weight_dtype={\n", + " \"no\": torch.float32,\n", + " \"fp16\": torch.float16,\n", + " \"bf16\": torch.bfloat16,\n", + " }[args.mixed_precision],\n", + " device=\"cuda\",\n", + " skip_safety_check=True,\n", + " models=models,\n", + " )\n", + " # Dataset\n", + " if args.dataset_name == \"vitonhd\":\n", + " dataset = VITONHDTestDataset(args)\n", + " else:\n", + " raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n", + " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n", + " dataloader = DataLoader(\n", + " dataset,\n", + " batch_size=args.batch_size,\n", + " shuffle=False,\n", + " num_workers=args.dataloader_num_workers\n", + " )\n", + " \n", + " # Inference\n", + " generator = torch.Generator(device='cuda').manual_seed(args.seed)\n", + " args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n", + " if not os.path.exists(args.output_dir):\n", + " os.makedirs(args.output_dir)\n", + " \n", + " for batch in tqdm(dataloader):\n", + " person_images = batch['person']\n", + " cloth_images = batch['cloth']\n", + "\n", + " results = pipeline(\n", + " person_images,\n", + " cloth_images,\n", + " num_inference_steps=args.num_inference_steps,\n", + " guidance_scale=args.guidance_scale,\n", + " height=args.height,\n", + " width=args.width,\n", + " generator=generator,\n", + " )\n", + " \n", + " if args.concat_eval_results:\n", + " person_images = to_pil_image(person_images)\n", + " cloth_images = to_pil_image(cloth_images)\n", + " for i, result in enumerate(results):\n", + " person_name = batch['person_name'][i]\n", + " output_path = os.path.join(args.output_dir, person_name)\n", + " if not os.path.exists(os.path.dirname(output_path)):\n", + " os.makedirs(os.path.dirname(output_path))\n", + " if args.concat_eval_results:\n", + " w, h = result.size\n", + " concated_result = Image.new('RGB', (w*3, h))\n", + " concated_result.paste(person_images[i], (0, 0))\n", + " concated_result.paste(cloth_images[i], (w, 0)) \n", + " concated_result.paste(result, (w*2, 0))\n", + " result = concated_result\n", + " result.save(output_path)\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39537851", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22fb6113", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c374cc6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bddce5df", + "metadata": { + "vscode": { + "languageId": "markdown" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "harsh", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/output/vitonhd-512/unpaired/00654_00.jpg b/output/vitonhd-512/unpaired/00654_00.jpg deleted file mode 100644 index beea9eae3f8bc99b552a2d490f250cfedb03bc73..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/00654_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/01265_00.jpg b/output/vitonhd-512/unpaired/01265_00.jpg deleted file mode 100644 index 82c42b72e955cbdc3be36b874aa63a8dff2f0383..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/01265_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/01985_00.jpg b/output/vitonhd-512/unpaired/01985_00.jpg deleted file mode 100644 index 33d1b52d4b9c371e830bec1315acbd1e00cd71af..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/01985_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/02023_00.jpg b/output/vitonhd-512/unpaired/02023_00.jpg deleted file mode 100644 index 4b9c91de6ecfbe21877932bfaacf4b0ad34155d2..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/02023_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/02532_00.jpg b/output/vitonhd-512/unpaired/02532_00.jpg deleted file mode 100644 index d98925454eb55ca11f2346c13b21728004b54083..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/02532_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/02944_00.jpg b/output/vitonhd-512/unpaired/02944_00.jpg deleted file mode 100644 index c72ae4a41746ae43ac5b7699bf3cfa2ff1b5b22e..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/02944_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/03191_00.jpg b/output/vitonhd-512/unpaired/03191_00.jpg deleted file mode 100644 index 0c7132663301f60ceccce0f3345a94ee53d990ce..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/03191_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/03921_00.jpg b/output/vitonhd-512/unpaired/03921_00.jpg deleted file mode 100644 index 40c20bb982c4fad8c456d09967377d3f8aa72e85..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/03921_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/05006_00.jpg b/output/vitonhd-512/unpaired/05006_00.jpg deleted file mode 100644 index f64062f354191a62c470516f008851e3272a70a4..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/05006_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/05378_00.jpg b/output/vitonhd-512/unpaired/05378_00.jpg deleted file mode 100644 index 65f81e92a0434564627b80f8c10d610eace26813..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/05378_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/07342_00.jpg b/output/vitonhd-512/unpaired/07342_00.jpg deleted file mode 100644 index fdfdcfa7c1b694d52102cb0f0fab0301a2d5bd74..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/07342_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/08088_00.jpg b/output/vitonhd-512/unpaired/08088_00.jpg deleted file mode 100644 index 4a904a651d1bc42bf6396499645ef24bfbff145f..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/08088_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/08239_00.jpg b/output/vitonhd-512/unpaired/08239_00.jpg deleted file mode 100644 index 5111327812216399b3d707e6296fa645f217d489..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/08239_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/08650_00.jpg b/output/vitonhd-512/unpaired/08650_00.jpg deleted file mode 100644 index aee30d8cabe59bacd98ed946f6a658c887e70a1f..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/08650_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/08839_00.jpg b/output/vitonhd-512/unpaired/08839_00.jpg deleted file mode 100644 index 21dc8c1fa420d984a8ba461744a39887cf3336cd..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/08839_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/11085_00.jpg b/output/vitonhd-512/unpaired/11085_00.jpg deleted file mode 100644 index 9937e679570c67d1b5f448d078399c0b492553f2..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/11085_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/12345_00.jpg b/output/vitonhd-512/unpaired/12345_00.jpg deleted file mode 100644 index fa9d559d97feedf8492339dee79aa57c2b390646..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/12345_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/12419_00.jpg b/output/vitonhd-512/unpaired/12419_00.jpg deleted file mode 100644 index 06417a7a0bfbaa69f8381deb53febfe11baa7265..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/12419_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/12562_00.jpg b/output/vitonhd-512/unpaired/12562_00.jpg deleted file mode 100644 index dbcfb7a75ab6036f40cba5a5fbf40819f7b54553..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/12562_00.jpg and /dev/null differ diff --git a/output/vitonhd-512/unpaired/14651_00.jpg b/output/vitonhd-512/unpaired/14651_00.jpg deleted file mode 100644 index 40c939137e392a1f3fc8bbb2bc7484b0691065ff..0000000000000000000000000000000000000000 Binary files a/output/vitonhd-512/unpaired/14651_00.jpg and /dev/null differ diff --git a/requirements.txt b/requirements.txt index e2aa599dac9716c009c00e7885d7d2ccfdd4cf81..e1602c3e2746793c384cf9232d445702e3520dd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,90 +1,44 @@ -accelerate==1.9.0 -asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work -beautifulsoup4==4.13.4 -certifi==2025.7.14 -charset-normalizer==3.4.2 -comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work -contourpy==1.3.2 -cycler==0.12.1 -debugpy @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_debugpy_1752827112/work -decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work -diffusers==0.34.0 -exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work -executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work -filelock==3.18.0 -fonttools==4.59.0 -fsspec==2025.7.0 -gdown==5.2.0 -hf-xet==1.1.5 -huggingface-hub==0.33.4 -idna==3.10 -importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work -ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work -ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748711175/work -jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work -Jinja2==3.1.6 -jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work -jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work -kagglehub==0.3.12 -kiwisolver==1.4.8 -MarkupSafe==3.0.2 -matplotlib==3.10.3 -matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work -mpmath==1.3.0 -nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work -networkx==3.4.2 -numpy==2.2.6 -nvidia-cublas-cu12==12.6.4.1 -nvidia-cuda-cupti-cu12==12.6.80 -nvidia-cuda-nvrtc-cu12==12.6.77 -nvidia-cuda-runtime-cu12==12.6.77 -nvidia-cudnn-cu12==9.5.1.17 -nvidia-cufft-cu12==11.3.0.4 -nvidia-cufile-cu12==1.11.1.6 -nvidia-curand-cu12==10.3.7.77 -nvidia-cusolver-cu12==11.7.1.2 -nvidia-cusparse-cu12==12.5.4.2 -nvidia-cusparselt-cu12==0.6.3 -nvidia-nccl-cu12==2.26.2 -nvidia-nvjitlink-cu12==12.6.85 -nvidia-nvtx-cu12==12.6.77 -packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1745345660/work -pandas==2.3.1 -parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work -pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work -pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work -pillow==11.3.0 -platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work -prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work -psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663128538/work -ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f -pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work -Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1750615794071/work -pyparsing==3.2.3 -PySocks==1.7.1 -python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work -pytz==2025.2 -PyYAML==6.0.2 -pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1749898457097/work -regex==2024.11.6 -requests==2.32.4 -safetensors==0.5.3 -six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work -soupsieve==2.7 -stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work -sympy==1.14.0 -tokenizers==0.21.2 -torch==2.7.1 -torchsummary==1.5.1 -torchvision==0.22.1 -tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003301700/work -tqdm==4.67.1 -traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work -transformers==4.53.2 -triton==3.3.1 -typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1751643513/work -tzdata==2025.2 -unzip==1.0.0 -urllib3==2.5.0 -wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work -zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1749421620841/work +usage: conda [-h] [-v] [--no-plugins] [-V] COMMAND ... + +conda is a tool for managing and deploying applications, environments and packages. + +options: + -h, --help Show this help message and exit. + -v, --verbose Can be used multiple times. Once for detailed output, + twice for INFO logging, thrice for DEBUG logging, four + times for TRACE logging. + --no-plugins Disable all plugins that are not built into conda. + -V, --version Show the conda version number and exit. + +commands: + The following built-in and plugins subcommands are available. + + COMMAND + activate Activate a conda environment. + clean Remove unused packages and caches. + commands List all available conda subcommands (including those + from plugins). Generally only used by tab-completion. + compare Compare packages between conda environments. + config Modify configuration values in .condarc. + content-trust Signing and verification tools for Conda + create Create a new conda environment from a list of specified + packages. + deactivate Deactivate the current active conda environment. + doctor Display a health report for your environment. + export Export a given environment + info Display information about current conda install. + init Initialize conda for shell interaction. + install Install a list of packages into a specified conda + environment. + list List installed packages in a conda environment. + notices Retrieve latest channel notifications. + package Create low-level conda packages. (EXPERIMENTAL) + remove (uninstall) + Remove a list of packages from a specified conda + environment. + rename Rename an existing environment. + repoquery Advanced search for repodata. + run Run an executable in a conda environment. + search Search for packages and display associated information + using the MatchSpec format. + update (upgrade) Update conda packages to the latest compatible version. diff --git a/utils.py b/utils.py index 72f397a47368b5313df8f66019c30020a600573b..2ae520dc3d861db006d05a27fcb3ba7409bdd60d 100644 --- a/utils.py +++ b/utils.py @@ -20,7 +20,7 @@ def get_time_embedding(timesteps): timesteps = timesteps.unsqueeze(0) # Shape: (160,) - freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) + freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32, device=timesteps.device) / 160) # Shape: (B, 160) x = timesteps.float()[:, None] * freqs[None] # Shape: (B, 320) @@ -101,6 +101,12 @@ def check_inputs(image, condition_image, mask, width, height): condition_image = resize_and_padding(condition_image, (width, height)) return image, condition_image, mask +def check_inputs_maskfree(image, condition_image, width, height): + if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor): + return image, condition_image + image = resize_and_crop(image, (width, height)) + condition_image = resize_and_padding(condition_image, (width, height)) + return image, condition_image def repaint_result(result, person_image, mask_image): result, person, mask = np.array(result), np.array(person_image), np.array(mask_image)