diff --git a/.gitignore b/.gitignore index 4a7c07ef837da1f33e9829dab9c9b42d9fe28f75..7d5c8833df2de7411771e3b3588bfe7180cfd019 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,12 @@ *sd-v1-5-inpainting.ckpt *zalando-hd-resized.zip +# *viton-hd-dataset.zip +viton-hd-dataset/ +checkpoints/ + +*finetuned_weights.safetensors + # Byte-compiled / optimized / DLL files __pycache__/ **/__pycache__/ diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..52cc640f4533906ddecc96feaf43bb3d3aa88522 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "subProcess": true + } + ] +} \ No newline at end of file diff --git a/interface.py b/interface.py deleted file mode 100644 index 66d8dcf358c25ce43d6a95fedeaafb00b09f7f39..0000000000000000000000000000000000000000 --- a/interface.py +++ /dev/null @@ -1,151 +0,0 @@ -import gradio as gr -import torch -from PIL import Image -from transformers import CLIPTokenizer - -# Import your existing model and pipeline modules -import load_model -import pipeline - -# Device Configuration -ALLOW_CUDA = True -ALLOW_MPS = False - -def determine_device(): - if torch.cuda.is_available() and ALLOW_CUDA: - return "cuda" - elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS: - return "mps" - return "cpu" - -DEVICE = determine_device() -print(f"Using device: {DEVICE}") - -# Load tokenizer and models -tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt") -model_file = "inkpunk-diffusion-v1.ckpt" -models = load_model.preload_models_from_standard_weights(model_file, DEVICE) -# models=None - -def generate_image( - prompt, - uncond_prompt="", - do_cfg=True, - cfg_scale=8, - sampler="ddpm", - num_inference_steps=50, - seed=42, - input_image=None, - strength=1.0 -): - """ - Generate an image using the Stable Diffusion pipeline - - Args: - - prompt (str): Text description of the image to generate - - uncond_prompt (str, optional): Negative prompt to guide generation - - do_cfg (bool): Whether to use classifier-free guidance - - cfg_scale (float): Classifier-free guidance scale - - sampler (str): Sampling method - - num_inference_steps (int): Number of denoising steps - - seed (int): Random seed for reproducibility - - input_image (PIL.Image, optional): Input image for image-to-image generation - - strength (float): Strength of image transformation (0-1) - - Returns: - - PIL.Image: Generated image - """ - try: - # Ensure input_image is None if not provided - if input_image is None: - strength = 1.0 - - # Generate the image - output_image = pipeline.generate( - prompt=prompt, - uncond_prompt=uncond_prompt, - input_image=input_image, - strength=strength, - do_cfg=do_cfg, - cfg_scale=cfg_scale, - sampler_name=sampler, - n_inference_steps=num_inference_steps, - seed=seed, - models=models, - device=DEVICE, - idle_device="cuda", - tokenizer=tokenizer, - ) - - # Convert numpy array to PIL Image - return Image.fromarray(output_image) - - except Exception as e: - print(f"Error generating image: {e}") - return None - -def launch_gradio_interface(): - """ - Create and launch Gradio interface for Stable Diffusion - """ - with gr.Blocks(title="Stable Diffusion Image Generator") as demo: - gr.Markdown("# 🎨 Stable Diffusion Image Generator") - - with gr.Row(): - with gr.Column(): - # Text Inputs - prompt = gr.Textbox(label="Prompt", - placeholder="Describe the image you want to generate...") - uncond_prompt = gr.Textbox(label="Negative Prompt (Optional)", - placeholder="Describe what you don't want in the image...") - - # Generation Parameters - with gr.Accordion("Advanced Settings", open=False): - do_cfg = gr.Checkbox(label="Use Classifier-Free Guidance", value=True) - cfg_scale = gr.Slider(minimum=1, maximum=14, value=8, label="CFG Scale") - sampler = gr.Dropdown( - choices=["ddpm", "ddim", "pndm"], # Add more samplers if available - value="ddpm", - label="Sampling Method" - ) - num_inference_steps = gr.Slider( - minimum=10, - maximum=100, - value=50, - label="Number of Inference Steps" - ) - seed = gr.Number(value=42, label="Random Seed") - - # Image-to-Image Section - with gr.Accordion("Image-to-Image", open=False): - input_image = gr.Image(type="pil", label="Input Image (Optional)") - strength = gr.Slider( - minimum=0, - maximum=1, - value=0.8, - label="Image Transformation Strength" - ) - - # Generate Button - generate_btn = gr.Button("Generate Image", variant="primary") - - with gr.Row(): - # Output Image - output_image = gr.Image(label="Generated Image") - - # Connect Button to Generation Function - generate_btn.click( - fn=generate_image, - inputs=[ - prompt, uncond_prompt, do_cfg, cfg_scale, - sampler, num_inference_steps, seed, - input_image, strength - ], - outputs=output_image - ) - - # Launch the interface - demo.launch(server_name="0.0.0.0", server_port=7860) - -if __name__ == "__main__": - launch_gradio_interface() \ No newline at end of file diff --git a/load_model.py b/load_model.py index cbdc7af95bdd44847a669f77cd1bc45e040735c3..a78123c3afff8e45e8cbb34643c85bf837a09788 100644 --- a/load_model.py +++ b/load_model.py @@ -1,11 +1,81 @@ from clip import CLIP from encoder import VAE_Encoder from decoder import VAE_Decoder -from diffusion import Diffusion - +from diffusion import Diffusion, UNET_AttentionBlock +from safetensors.torch import load_file import model_converter import torch +def load_finetuned_attention_weights(finetune_weights_path, diffusion, device): + updated_loaded_data = load_file(finetune_weights_path, device=device) + print(f"Loaded finetuned weights from {finetune_weights_path}") + + unet= diffusion.unet + idx = 0 + # Iterate through the attention layers in the encoders + for layers in unet.encoders: + for layer in layers: + if isinstance(layer, UNET_AttentionBlock): + # Get the parameters from the loaded data for this block + in_proj_weight_key = f"{idx}.in_proj.weight" + out_proj_weight_key = f"{idx}.out_proj.weight" + out_proj_bias_key = f"{idx}.out_proj.bias" + + # Load the weights if they exist in the loaded data + if in_proj_weight_key in updated_loaded_data: + print(f"Loading {in_proj_weight_key}") + layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key]) + if out_proj_weight_key in updated_loaded_data: + print(f"Loading {out_proj_weight_key}") + layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key]) + if out_proj_bias_key in updated_loaded_data: + print(f"Loading {out_proj_bias_key}") + layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key]) + idx += 8 + + # Move to the next attention block index in the loaded data + + + # Iterate through the attention layers in the decoders + for layers in unet.decoders: + for layer in layers: + if isinstance(layer, UNET_AttentionBlock): + in_proj_weight_key = f"{idx}.in_proj.weight" + out_proj_weight_key = f"{idx}.out_proj.weight" + out_proj_bias_key = f"{idx}.out_proj.bias" + + if in_proj_weight_key in updated_loaded_data: + print(f"Loading {in_proj_weight_key}") + layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key]) + if out_proj_weight_key in updated_loaded_data: + print(f"Loading {out_proj_weight_key}") + layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key]) + if out_proj_bias_key in updated_loaded_data: + print(f"Loading {out_proj_bias_key}") + layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key]) + idx += 8 + + + # Iterate through the attention layers in the bottleneck + for layer in unet.bottleneck: + if isinstance(layer, UNET_AttentionBlock): + in_proj_weight_key = f"{idx}.in_proj.weight" + out_proj_weight_key = f"{idx}.out_proj.weight" + out_proj_bias_key = f"{idx}.out_proj.bias" + + if in_proj_weight_key in updated_loaded_data: + print(f"Loading {in_proj_weight_key}") + layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key]) + if out_proj_weight_key in updated_loaded_data: + print(f"Loading {out_proj_weight_key}") + layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key]) + if out_proj_bias_key in updated_loaded_data: + print(f"Loading {out_proj_bias_key}") + layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key]) + idx += 8 + + print("\nAttention module weights loaded from {finetune_weights_path} successfully.") + def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None): # CatVTON parameters in_channels = 9 @@ -14,12 +84,10 @@ def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_pat state_dict=model_converter.load_from_standard_weights(ckpt_path, device) diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device) - + diffusion.load_state_dict(state_dict['diffusion'], strict=True) + if finetune_weights_path != None: - checkpoint = torch.load(finetune_weights_path, map_location=device) - diffusion.load_state_dict(checkpoint['diffusion_state_dict'], strict=True) - else: - diffusion.load_state_dict(state_dict['diffusion'], strict=True) + load_finetuned_attention_weights(finetune_weights_path, diffusion, device) encoder=VAE_Encoder().to(device) encoder.load_state_dict(state_dict['encoder'], strict=True) diff --git a/logs.txt b/logs.txt new file mode 100644 index 0000000000000000000000000000000000000000..dc165a120a0dae19bda48279019baa4f5fff16bf --- /dev/null +++ b/logs.txt @@ -0,0 +1,29 @@ +/home/mahesh/harsh/stable-diffusion/training.py:84: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. + self.scaler = torch.cuda.amp.GradScaler() +---------------------------------------------------------------------------------------------------- +Loading pretrained models... +Models loaded successfully. +---------------------------------------------------------------------------------------------------- +Creating dataloader... +Dataset vitonhd loaded, total 11647 pairs. +Training for 50 epochs +Batches per epoch: 5824 +---------------------------------------------------------------------------------------------------- +Initializing trainer... +Enabling PEFT training (self-attention layers only) +Total parameters: 899,226,667 +Trainable parameters: 49,574,080 (5.51%) +Checkpoint loaded: ./checkpoints/checkpoint_step_50000.pth +Resuming from epoch 13, step 50000 +Starting training... +Starting training for 50 epochs +Total training batches per epoch: 5824 +Using DREAM with lambda = 0 +Mixed precision: True +/home/mahesh/harsh/stable-diffusion/training.py:304: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(): +/home/mahesh/harsh/stable-diffusion/training.py:194: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(enabled=self.use_mixed_precision): +/home/mahesh/harsh/stable-diffusion/utils.py:491: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(enabled=False): + \ No newline at end of file diff --git a/output/vitonhd-512/unpaired/00654_00.jpg b/output/vitonhd-512/unpaired/00654_00.jpg index a7671c5ab2ce1c9eb42844a9f63284b20f1f8628..7e27d95ae152cd3f624677b76b6e2eaf7f75b781 100644 Binary files a/output/vitonhd-512/unpaired/00654_00.jpg and b/output/vitonhd-512/unpaired/00654_00.jpg differ diff --git a/output/vitonhd-512/unpaired/01265_00.jpg b/output/vitonhd-512/unpaired/01265_00.jpg index 4675777a22e8768d4a6a1c8915fddb90515f9f99..3fda22c37d6e23414866bc3da33c9dc41b17ebc6 100644 Binary files a/output/vitonhd-512/unpaired/01265_00.jpg and b/output/vitonhd-512/unpaired/01265_00.jpg differ diff --git a/output/vitonhd-512/unpaired/01985_00.jpg b/output/vitonhd-512/unpaired/01985_00.jpg index 7836c09f2f3c7387839c3a9c72e41941313e5f90..e55704b7621c5ab1c0f5c143cc9fe1014f0ad8c3 100644 Binary files a/output/vitonhd-512/unpaired/01985_00.jpg and b/output/vitonhd-512/unpaired/01985_00.jpg differ diff --git a/output/vitonhd-512/unpaired/02023_00.jpg b/output/vitonhd-512/unpaired/02023_00.jpg index 16fb43aa58c56493cea88110b9ca35e28a353d1e..942d19f37f74abd30467fb33c280c084259106be 100644 Binary files a/output/vitonhd-512/unpaired/02023_00.jpg and b/output/vitonhd-512/unpaired/02023_00.jpg differ diff --git a/output/vitonhd-512/unpaired/02532_00.jpg b/output/vitonhd-512/unpaired/02532_00.jpg index 09cdd883d610b405aba1f5dcb42fc3b5aca410b3..9838fb9cab9d9ed6de355e851bb22f5916bae963 100644 Binary files a/output/vitonhd-512/unpaired/02532_00.jpg and b/output/vitonhd-512/unpaired/02532_00.jpg differ diff --git a/output/vitonhd-512/unpaired/02944_00.jpg b/output/vitonhd-512/unpaired/02944_00.jpg index 808a386ea3d7ab1974d76fd2b4a6f05f3d5657a5..f53fdb026f4227fe00e1cc5e53cc0f8aaea721ba 100644 Binary files a/output/vitonhd-512/unpaired/02944_00.jpg and b/output/vitonhd-512/unpaired/02944_00.jpg differ diff --git a/output/vitonhd-512/unpaired/03191_00.jpg b/output/vitonhd-512/unpaired/03191_00.jpg index 1770687fe2b3caf890ef9199cbf7d73b1df0a025..2c3148d3d9aa7eca91ccedd8a2c5471e234c6b79 100644 Binary files a/output/vitonhd-512/unpaired/03191_00.jpg and b/output/vitonhd-512/unpaired/03191_00.jpg differ diff --git a/output/vitonhd-512/unpaired/03921_00.jpg b/output/vitonhd-512/unpaired/03921_00.jpg index 47295a276462a99a99e76b9db7647e9099b1a0b3..52c09f7d541a9939692fcf0df5d07aca27c1e0c6 100644 Binary files a/output/vitonhd-512/unpaired/03921_00.jpg and b/output/vitonhd-512/unpaired/03921_00.jpg differ diff --git a/output/vitonhd-512/unpaired/05006_00.jpg b/output/vitonhd-512/unpaired/05006_00.jpg index 346fdb42d30fcb23bc6d720d8206f06e71c067b9..754e1c7d497ab7413ea27ef39dd13b13dd414956 100644 Binary files a/output/vitonhd-512/unpaired/05006_00.jpg and b/output/vitonhd-512/unpaired/05006_00.jpg differ diff --git a/output/vitonhd-512/unpaired/05378_00.jpg b/output/vitonhd-512/unpaired/05378_00.jpg index 0f36112940927444e101f1ce0d8b2452943e03c4..c7f501df0478ac4bb2d28e57c46dca457585f307 100644 Binary files a/output/vitonhd-512/unpaired/05378_00.jpg and b/output/vitonhd-512/unpaired/05378_00.jpg differ diff --git a/output/vitonhd-512/unpaired/07342_00.jpg b/output/vitonhd-512/unpaired/07342_00.jpg index cbe8753f54a35e2b507b0aacff1c543e62b85365..13230cde259b03182e2db93c3e092722e99bc4ee 100644 Binary files a/output/vitonhd-512/unpaired/07342_00.jpg and b/output/vitonhd-512/unpaired/07342_00.jpg differ diff --git a/output/vitonhd-512/unpaired/08088_00.jpg b/output/vitonhd-512/unpaired/08088_00.jpg index 066ec2b18dea80babe056cdb85a178ef25ea7557..2742cdecefe10dac2f2d48206477bc8ba414c827 100644 Binary files a/output/vitonhd-512/unpaired/08088_00.jpg and b/output/vitonhd-512/unpaired/08088_00.jpg differ diff --git a/output/vitonhd-512/unpaired/08239_00.jpg b/output/vitonhd-512/unpaired/08239_00.jpg index ddc7217626ad877957835455592bdb119d0d8a8b..84f46a4b9467c27016c75768051eb3464df32dfc 100644 Binary files a/output/vitonhd-512/unpaired/08239_00.jpg and b/output/vitonhd-512/unpaired/08239_00.jpg differ diff --git a/output/vitonhd-512/unpaired/08650_00.jpg b/output/vitonhd-512/unpaired/08650_00.jpg index d80ae16e44107a6ac017b63cc6fc42229b50555a..2fb43ef527b2559095dd51d64e9b0be50ea843c6 100644 Binary files a/output/vitonhd-512/unpaired/08650_00.jpg and b/output/vitonhd-512/unpaired/08650_00.jpg differ diff --git a/output/vitonhd-512/unpaired/08839_00.jpg b/output/vitonhd-512/unpaired/08839_00.jpg index be5a0fa97aa7bc39d6571e8acd0175eec58ccf6d..0b3f443b20f229ef108405e434c88f3bcc6ad2f4 100644 Binary files a/output/vitonhd-512/unpaired/08839_00.jpg and b/output/vitonhd-512/unpaired/08839_00.jpg differ diff --git a/output/vitonhd-512/unpaired/11085_00.jpg b/output/vitonhd-512/unpaired/11085_00.jpg index 959327da21600b92e60fee3264a41a65bdc6346b..ae15133748e12d8245b918c12a0edc7666f05a9a 100644 Binary files a/output/vitonhd-512/unpaired/11085_00.jpg and b/output/vitonhd-512/unpaired/11085_00.jpg differ diff --git a/output/vitonhd-512/unpaired/12345_00.jpg b/output/vitonhd-512/unpaired/12345_00.jpg index 85b29984f61932e80065f09b79946e1320c142c5..7701349fc341ef36ec75974579fe2499d7a8333d 100644 Binary files a/output/vitonhd-512/unpaired/12345_00.jpg and b/output/vitonhd-512/unpaired/12345_00.jpg differ diff --git a/output/vitonhd-512/unpaired/12419_00.jpg b/output/vitonhd-512/unpaired/12419_00.jpg index 49e1ce6972282da408dcea6fc1fa40f7a1cb67f5..46b2c296ee025b4a26359bab04abca2c1f87d2ef 100644 Binary files a/output/vitonhd-512/unpaired/12419_00.jpg and b/output/vitonhd-512/unpaired/12419_00.jpg differ diff --git a/output/vitonhd-512/unpaired/12562_00.jpg b/output/vitonhd-512/unpaired/12562_00.jpg index 0eb8da4cd84c82d19ead5898c1e644ffce38e4a9..221edd200d578170e97f865dde2f5c5245110831 100644 Binary files a/output/vitonhd-512/unpaired/12562_00.jpg and b/output/vitonhd-512/unpaired/12562_00.jpg differ diff --git a/output/vitonhd-512/unpaired/14651_00.jpg b/output/vitonhd-512/unpaired/14651_00.jpg index 67cb3f1d3393928b71a93461ef9264eff182863f..8d6bba121b05fe62e2577cae424f359e465a66ef 100644 Binary files a/output/vitonhd-512/unpaired/14651_00.jpg and b/output/vitonhd-512/unpaired/14651_00.jpg differ diff --git a/pipeline.py b/pipeline.py deleted file mode 100644 index 9d0f07b701aeaa8976bbd92ce8e0c2241e467c42..0000000000000000000000000000000000000000 --- a/pipeline.py +++ /dev/null @@ -1,314 +0,0 @@ -import math -from typing import List, Union -import PIL -import torch -import numpy as np -from tqdm import tqdm -from ddpm import DDPMSampler -from PIL import Image -import load_model -from utils import check_inputs, prepare_image, prepare_mask_image - -WIDTH = 512 -HEIGHT = 512 -LATENTS_WIDTH = WIDTH // 8 -LATENTS_HEIGHT = HEIGHT // 8 - -def repaint_result(result, person_image, mask_image): - result, person, mask = np.array(result), np.array(person_image), np.array(mask_image) - # expand the mask to 3 channels & to 0~1 - mask = np.expand_dims(mask, axis=2) - mask = mask / 255.0 - # mask for result, ~mask for person - result_ = result * mask + person * (1 - mask) - return Image.fromarray(result_.astype(np.uint8)) - -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - -def tensor_to_image(tensor: torch.Tensor): - """ - Converts a torch tensor to PIL Image. - """ - assert tensor.dim() == 3, "Input tensor should be 3-dimensional." - assert tensor.dtype == torch.float32, "Input tensor should be float32." - assert ( - tensor.min() >= 0 and tensor.max() <= 1 - ), "Input tensor should be in range [0, 1]." - tensor = tensor.cpu() - tensor = tensor * 255 - tensor = tensor.permute(1, 2, 0) - tensor = tensor.numpy().astype(np.uint8) - image = Image.fromarray(tensor) - return image - - -def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4): - """ - Concatenates images horizontally and with - """ - widths = [image.size[0] for image in images] - heights = [image.size[1] for image in images] - total_width = cols * max(widths) - total_width += divider * (cols - 1) - # `col` images each row - rows = math.ceil(len(images) / cols) - total_height = max(heights) * rows - # add divider between rows - total_height += divider * (len(heights) // cols - 1) - - # all black image - concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0)) - - x_offset = 0 - y_offset = 0 - for i, image in enumerate(images): - concat_image.paste(image, (x_offset, y_offset)) - x_offset += image.size[0] + divider - if (i + 1) % cols == 0: - x_offset = 0 - y_offset += image.size[1] + divider - - return concat_image - -def compute_vae_encodings(image_tensor, encoder, device): - """Encode image using VAE encoder""" - # Generate random noise for encoding - encoder_noise = torch.randn( - (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8), - device=device, - ) - - # Encode using your custom encoder - latent = encoder(image_tensor, encoder_noise) - return latent - - -def generate( - image: Union[PIL.Image.Image, torch.Tensor], - condition_image: Union[PIL.Image.Image, torch.Tensor], - mask: Union[PIL.Image.Image, torch.Tensor], - num_inference_steps: int = 50, - guidance_scale: float = 2.5, - height: int = 1024, - width: int = 768, - models={}, - sampler_name="ddpm", - seed=None, - device=None, - idle_device=None, - **kwargs -): - with torch.no_grad(): - if idle_device: - to_idle = lambda x: x.to(idle_device) - else: - to_idle = lambda x: x - - # Initialize random number generator according to the seed specified - generator = torch.Generator(device=device) - if seed is None: - generator.seed() - else: - generator.manual_seed(seed) - - concat_dim = -1 # FIXME: y axis concat - - # Prepare inputs to Tensor - image, condition_image, mask = check_inputs(image, condition_image, mask, width, height) - # print(f"Input image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}") - image = prepare_image(image).to(device) - condition_image = prepare_image(condition_image).to(device) - mask = prepare_mask_image(mask).to(device) - - print(f"Prepared image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}") - # Mask image - masked_image = image * (mask < 0.5) - - print(f"Masked image shape: {masked_image.shape}") - - # VAE encoding - encoder = models.get('encoder', None) - if encoder is None: - raise ValueError("Encoder model not found in models dictionary") - - encoder.to(device) - masked_latent = compute_vae_encodings(masked_image, encoder, device) - condition_latent = compute_vae_encodings(condition_image, encoder, device) - to_idle(encoder) - - print(f"Masked latent shape: {masked_latent.shape}, condition latent shape: {condition_latent.shape}") - - # Concatenate latents - masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim) - - print(f"Masked Person latent + garment latent: {masked_latent_concat.shape}") - - mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest") - del image, mask, condition_image - mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim) - - print(f"Mask latent concat shape: {mask_latent_concat.shape}") - - # Initialize latents - latents = torch.randn( - masked_latent_concat.shape, - generator=generator, - device=masked_latent_concat.device, - dtype=masked_latent_concat.dtype - ) - - print(f"Latents shape: {latents.shape}") - - # Prepare timesteps - if sampler_name == "ddpm": - sampler = DDPMSampler(generator) - sampler.set_inference_timesteps(num_inference_steps) - else: - raise ValueError("Unknown sampler value %s. " % sampler_name) - - timesteps = sampler.timesteps - # latents = sampler.add_noise(latents, timesteps[0]) - - # Classifier-Free Guidance - do_classifier_free_guidance = guidance_scale > 1.0 - if do_classifier_free_guidance: - masked_latent_concat = torch.cat( - [ - torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim), - masked_latent_concat, - ] - ) - mask_latent_concat = torch.cat([mask_latent_concat] * 2) - - print(f"Masked latent concat for classifier-free guidance: {masked_latent_concat.shape}, mask latent concat: {mask_latent_concat.shape}") - - - # Denoising loop - Fixed: removed self references and incorrect scheduler calls - num_warmup_steps = 0 # For simple DDPM, no warmup needed - - with tqdm(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents) - - # print(f"Non-inpainting latent model input shape: {non_inpainting_latent_model_input.shape}") - - # prepare the input for the inpainting model - inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1) - - # print(f"Inpainting latent model input shape: {inpainting_latent_model_input.shape}") - - # predict the noise residual - diffusion = models.get('diffusion', None) - if diffusion is None: - raise ValueError("Diffusion model not found in models dictionary") - - diffusion.to(device) - - # Create time embedding for the current timestep - time_embedding = get_time_embedding(t.item()).to(device) - # print(f"Time embedding shape: {time_embedding.shape}") - - if do_classifier_free_guidance: - time_embedding = torch.cat([time_embedding] * 2) - - noise_pred = diffusion( - inpainting_latent_model_input, - time_embedding - ) - - to_idle(diffusion) - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = sampler.step(t, latents, noise_pred) - - # Update progress bar - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps): - progress_bar.update() - - # Decode the final latents - latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0] - - decoder = models.get('decoder', None) - if decoder is None: - raise ValueError("Decoder model not found in models dictionary") - - decoder.to(device) - - image = decoder(latents.to(device)) - # image = rescale(image, (-1, 1), (0, 255), clamp=True) - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - image = numpy_to_pil(image) - - to_idle(decoder) - - return image - - -def rescale(x, old_range, new_range, clamp=False): - old_min, old_max = old_range - new_min, new_max = new_range - x -= old_min - x *= (new_max - new_min) / (old_max - old_min) - x += new_min - if clamp: - x = x.clamp(new_min, new_max) - return x - -def get_time_embedding(timestep): - # Shape: (160,) - freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) - # Shape: (1, 160) - x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None] - # Shape: (1, 160 * 2) -> (1, 320) - return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) - -if __name__ == "__main__": - # Example usage - image = Image.open("person.jpg").convert("RGB") - condition_image = Image.open("image.png").convert("RGB") - mask = Image.open("agnostic_mask.png").convert("L") - - # Load models - models=load_model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda") - - # Generate image - generated_image = generate( - image=image, - condition_image=condition_image, - mask=mask, - num_inference_steps=50, - guidance_scale=2.5, - width=WIDTH, - height=HEIGHT, - models=models, - sampler_name="ddpm", - seed=42, - device="cuda" # or "cpu" - ) - - generated_image[0].save("generated_image.png") - \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c82823951a587ca4e7d5f253d1238f721aaea61f..e2aa599dac9716c009c00e7885d7d2ccfdd4cf81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,53 +1,39 @@ -aiohappyeyeballs==2.6.1 -aiohttp==3.11.18 -aiosignal==1.3.2 -annotated-types==0.7.0 +accelerate==1.9.0 asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work -attrs==25.3.0 -certifi==2025.4.26 +beautifulsoup4==4.13.4 +certifi==2025.7.14 charset-normalizer==3.4.2 -click==8.2.0 comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work contourpy==1.3.2 cycler==0.12.1 -datasets==3.6.0 -debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321233760/work +debugpy @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_debugpy_1752827112/work decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work -dill==0.3.8 -docker-pycreds==0.4.0 +diffusers==0.34.0 exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work filelock==3.18.0 -fonttools==4.58.0 -frozenlist==1.6.0 -fsspec==2025.3.0 -gitdb==4.0.12 -GitPython==3.1.44 -hf-xet==1.1.0 -huggingface-hub==0.31.1 +fonttools==4.59.0 +fsspec==2025.7.0 +gdown==5.2.0 +hf-xet==1.1.5 +huggingface-hub==0.33.4 idna==3.10 -importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work +importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work -ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1745672166/work -ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work -ipywidgets==8.1.7 +ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748711175/work jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work Jinja2==3.1.6 jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work -jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work -jupyterlab_widgets==3.0.15 +jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work +kagglehub==0.3.12 kiwisolver==1.4.8 -lightning==2.5.1.post0 -lightning-utilities==0.14.3 MarkupSafe==3.0.2 matplotlib==3.10.3 matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work mpmath==1.3.0 -multidict==6.4.3 -multiprocess==0.70.16 nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work networkx==3.4.2 -numpy==2.2.5 +numpy==2.2.6 nvidia-cublas-cu12==12.6.4.1 nvidia-cuda-cupti-cu12==12.6.80 nvidia-cuda-nvrtc-cu12==12.6.77 @@ -62,54 +48,43 @@ nvidia-cusparselt-cu12==0.6.3 nvidia-nccl-cu12==2.26.2 nvidia-nvjitlink-cu12==12.6.85 nvidia-nvtx-cu12==12.6.77 -packaging==24.2 -pandas==2.2.3 +packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1745345660/work +pandas==2.3.1 parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work -pillow==11.2.1 +pillow==11.3.0 platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work -propcache==0.3.1 -protobuf==6.30.2 -psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663149797/work +psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663128538/work ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work -pyarrow==20.0.0 -pydantic==2.11.4 -pydantic_core==2.33.2 -Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work +Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1750615794071/work pyparsing==3.2.3 -python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work -pytorch-lightning==2.5.1.post0 +PySocks==1.7.1 +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work pytz==2025.2 PyYAML==6.0.2 -pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245578/work +pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1749898457097/work regex==2024.11.6 -requests==2.32.3 +requests==2.32.4 safetensors==0.5.3 -sentry-sdk==2.27.0 -setproctitle==1.3.6 six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work -smmap==5.0.2 +soupsieve==2.7 stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work sympy==1.14.0 -tokenizers==0.21.1 -torch==2.7.0 -torchmetrics==1.7.1 -torchvision==0.22.0 -tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1732615904614/work +tokenizers==0.21.2 +torch==2.7.1 +torchsummary==1.5.1 +torchvision==0.22.1 +tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003301700/work tqdm==4.67.1 traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work -transformers==4.51.3 -triton==3.3.0 -typing-inspection==0.4.0 -typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work +transformers==4.53.2 +triton==3.3.1 +typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1751643513/work tzdata==2025.2 -urllib3==2.4.0 -wandb==0.19.11 +unzip==1.0.0 +urllib3==2.5.0 wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work -widgetsnbextension==4.0.14 -xxhash==3.5.0 -yarl==1.20.0 -zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work +zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1749421620841/work diff --git a/sample_inference.ipynb b/sample_inference.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..96ba030b2c23de804ecedda41beb1f2a6077756d --- /dev/null +++ b/sample_inference.ipynb @@ -0,0 +1,435 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "237f5cbf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model already downloaded.\n" + ] + } + ], + "source": [ + "# !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", + "\n", + "# check if the model is downloaded, if not download it\n", + "import os\n", + "if not os.path.exists(\"sd-v1-5-inpainting.ckpt\"):\n", + " !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", + "else:\n", + " print(\"Model already downloaded.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bab24c29", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import inspect\n", + "import os\n", + "from typing import Union\n", + "\n", + "import PIL\n", + "import numpy as np\n", + "import torch\n", + "import tqdm\n", + "from diffusers.utils.torch_utils import randn_tensor\n", + "\n", + "from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n", + " prepare_mask_image, compute_vae_encodings)\n", + "from ddpm import DDPMSampler\n", + "\n", + "class CatVTONPipeline:\n", + " def __init__(\n", + " self, \n", + " weight_dtype=torch.float32,\n", + " device='cuda',\n", + " compile=False,\n", + " skip_safety_check=True,\n", + " use_tf32=True,\n", + " models={},\n", + " ):\n", + " self.device = device\n", + " self.weight_dtype = weight_dtype\n", + " self.skip_safety_check = skip_safety_check\n", + " self.models = models\n", + "\n", + " self.generator = torch.Generator(device=device)\n", + " self.noise_scheduler = DDPMSampler(generator=self.generator)\n", + " # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n", + " self.encoder= models.get('encoder', None)\n", + " self.decoder= models.get('decoder', None)\n", + " \n", + " self.unet=models.get('diffusion', None) \n", + " # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n", + " if use_tf32:\n", + " torch.set_float32_matmul_precision(\"high\")\n", + " torch.backends.cuda.matmul.allow_tf32 = True\n", + "\n", + " @torch.no_grad()\n", + " def __call__(\n", + " self, \n", + " image: Union[PIL.Image.Image, torch.Tensor],\n", + " condition_image: Union[PIL.Image.Image, torch.Tensor],\n", + " mask: Union[PIL.Image.Image, torch.Tensor],\n", + " num_inference_steps: int = 50,\n", + " guidance_scale: float = 2.5,\n", + " height: int = 1024,\n", + " width: int = 768,\n", + " generator=None,\n", + " eta=1.0,\n", + " **kwargs\n", + " ):\n", + " concat_dim = -2 # FIXME: y axis concat\n", + " # Prepare inputs to Tensor\n", + " image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)\n", + " image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n", + " condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n", + " mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)\n", + " # Mask image\n", + " masked_image = image * (mask < 0.5)\n", + " # VAE encoding\n", + " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n", + " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", + " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n", + " del image, mask, condition_image\n", + " # Concatenate latents\n", + " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", + " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", + " # Prepare noise\n", + " latents = randn_tensor(\n", + " masked_latent_concat.shape,\n", + " generator=generator,\n", + " device=masked_latent_concat.device,\n", + " dtype=self.weight_dtype,\n", + " )\n", + " # Prepare timesteps\n", + " self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n", + " timesteps = self.noise_scheduler.timesteps\n", + " # latents = latents * self.noise_scheduler.init_noise_sigma\n", + " latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n", + " \n", + " # Classifier-Free Guidance\n", + " if do_classifier_free_guidance := (guidance_scale > 1.0):\n", + " masked_latent_concat = torch.cat(\n", + " [\n", + " torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n", + " masked_latent_concat,\n", + " ]\n", + " )\n", + " mask_latent_concat = torch.cat([mask_latent_concat] * 2)\n", + "\n", + " num_warmup_steps = 0 # For simple DDPM, no warmup needed\n", + " with tqdm(total=num_inference_steps) as progress_bar:\n", + " for i, t in enumerate(timesteps):\n", + " # expand the latents if we are doing classifier free guidance\n", + " non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n", + " # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)\n", + " # prepare the input for the inpainting model\n", + " inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)\n", + " # predict the noise residual\n", + " \n", + " timestep = t.repeat(inpainting_latent_model_input.shape[0])\n", + " time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)\n", + "\n", + " noise_pred = self.unet(\n", + " inpainting_latent_model_input,\n", + " time_embedding\n", + " )\n", + " # perform guidance\n", + " if do_classifier_free_guidance:\n", + " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", + " noise_pred = noise_pred_uncond + guidance_scale * (\n", + " noise_pred_text - noise_pred_uncond\n", + " )\n", + " # compute the previous noisy sample x_t -> x_t-1\n", + " latents = self.noise_scheduler.step(\n", + " t, latents, noise_pred\n", + " )\n", + " # call the callback, if provided\n", + " if i == len(timesteps) - 1 or (\n", + " (i + 1) > num_warmup_steps\n", + " ):\n", + " progress_bar.update()\n", + "\n", + " # Decode the final latents\n", + " latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n", + " # latents = 1 / self.vae.config.scaling_factor * latents\n", + " # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n", + " image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n", + " image = (image / 2 + 0.5).clamp(0, 1)\n", + " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n", + " image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n", + " image = numpy_to_pil(image)\n", + " \n", + " return image\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a069151e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded finetuned weights from finetuned_weights.safetensors\n", + "Loading 0.in_proj.weight\n", + "Loading 0.out_proj.weight\n", + "Loading 0.out_proj.bias\n", + "Loading 8.in_proj.weight\n", + "Loading 8.out_proj.weight\n", + "Loading 8.out_proj.bias\n", + "Loading 16.in_proj.weight\n", + "Loading 16.out_proj.weight\n", + "Loading 16.out_proj.bias\n", + "Loading 24.in_proj.weight\n", + "Loading 24.out_proj.weight\n", + "Loading 24.out_proj.bias\n", + "Loading 32.in_proj.weight\n", + "Loading 32.out_proj.weight\n", + "Loading 32.out_proj.bias\n", + "Loading 40.in_proj.weight\n", + "Loading 40.out_proj.weight\n", + "Loading 40.out_proj.bias\n", + "Loading 48.in_proj.weight\n", + "Loading 48.out_proj.weight\n", + "Loading 48.out_proj.bias\n", + "Loading 56.in_proj.weight\n", + "Loading 56.out_proj.weight\n", + "Loading 56.out_proj.bias\n", + "Loading 64.in_proj.weight\n", + "Loading 64.out_proj.weight\n", + "Loading 64.out_proj.bias\n", + "Loading 72.in_proj.weight\n", + "Loading 72.out_proj.weight\n", + "Loading 72.out_proj.bias\n", + "Loading 80.in_proj.weight\n", + "Loading 80.out_proj.weight\n", + "Loading 80.out_proj.bias\n", + "Loading 88.in_proj.weight\n", + "Loading 88.out_proj.weight\n", + "Loading 88.out_proj.bias\n", + "Loading 96.in_proj.weight\n", + "Loading 96.out_proj.weight\n", + "Loading 96.out_proj.bias\n", + "Loading 104.in_proj.weight\n", + "Loading 104.out_proj.weight\n", + "Loading 104.out_proj.bias\n", + "Loading 112.in_proj.weight\n", + "Loading 112.out_proj.weight\n", + "Loading 112.out_proj.bias\n", + "Loading 120.in_proj.weight\n", + "Loading 120.out_proj.weight\n", + "Loading 120.out_proj.bias\n", + "\n", + "Attention module weights loaded from {finetune_weights_path} successfully.\n" + ] + } + ], + "source": [ + "import load_model\n", + "\n", + "models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"finetuned_weights.safetensors\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a729bf46", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset vitonhd loaded, total 20 pairs.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.39it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.39it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.42it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.44it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.44it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.40it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.43it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.42it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.41it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.43it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.43it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.41it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.40it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.43it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.29it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.46it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.45it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.46it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.47it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:07<00:00, 6.45it/s]\n", + "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20/20 [02:43<00:00, 8.15s/it]\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import torch\n", + "import argparse\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from VITON_Dataset import VITONHDTestDataset\n", + "from diffusers.image_processor import VaeImageProcessor\n", + "from tqdm import tqdm\n", + "from PIL import Image, ImageFilter\n", + "\n", + "from utils import repaint, to_pil_image\n", + "\n", + "@torch.no_grad()\n", + "def main():\n", + " args=argparse.Namespace()\n", + " args.__dict__= {\n", + " \"dataset_name\": \"vitonhd\",\n", + " \"data_root_path\": \"./sample_dataset\",\n", + " \"output_dir\": \"./trained_output\",\n", + " \"seed\": 555,\n", + " \"batch_size\": 1,\n", + " \"num_inference_steps\": 50,\n", + " \"guidance_scale\": 2.5,\n", + " \"width\": 384,\n", + " \"height\": 384,\n", + " \"repaint\": True,\n", + " \"eval_pair\": False,\n", + " \"concat_eval_results\": True,\n", + " \"allow_tf32\": True,\n", + " \"dataloader_num_workers\": 4,\n", + " \"mixed_precision\": 'no',\n", + " \"concat_axis\": 'y',\n", + " \"enable_condition_noise\": True,\n", + " \"is_train\": False\n", + " }\n", + "\n", + " # Pipeline\n", + " pipeline = CatVTONPipeline(\n", + " weight_dtype={\n", + " \"no\": torch.float32,\n", + " \"fp16\": torch.float16,\n", + " \"bf16\": torch.bfloat16,\n", + " }[args.mixed_precision],\n", + " device=\"cuda\",\n", + " skip_safety_check=True,\n", + " models=models,\n", + " )\n", + " # Dataset\n", + " if args.dataset_name == \"vitonhd\":\n", + " dataset = VITONHDTestDataset(args)\n", + " else:\n", + " raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n", + " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n", + " dataloader = DataLoader(\n", + " dataset,\n", + " batch_size=args.batch_size,\n", + " shuffle=False,\n", + " num_workers=args.dataloader_num_workers\n", + " )\n", + " \n", + " # Inference\n", + " generator = torch.Generator(device='cuda').manual_seed(args.seed)\n", + " args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n", + " if not os.path.exists(args.output_dir):\n", + " os.makedirs(args.output_dir)\n", + " \n", + " for batch in tqdm(dataloader):\n", + " person_images = batch['person']\n", + " cloth_images = batch['cloth']\n", + " masks = batch['mask']\n", + "\n", + " results = pipeline(\n", + " person_images,\n", + " cloth_images,\n", + " masks,\n", + " num_inference_steps=args.num_inference_steps,\n", + " guidance_scale=args.guidance_scale,\n", + " height=args.height,\n", + " width=args.width,\n", + " generator=generator,\n", + " )\n", + " \n", + " if args.concat_eval_results or args.repaint:\n", + " person_images = to_pil_image(person_images)\n", + " cloth_images = to_pil_image(cloth_images)\n", + " masks = to_pil_image(masks)\n", + " for i, result in enumerate(results):\n", + " person_name = batch['person_name'][i]\n", + " output_path = os.path.join(args.output_dir, person_name)\n", + " if not os.path.exists(os.path.dirname(output_path)):\n", + " os.makedirs(os.path.dirname(output_path))\n", + " if args.repaint:\n", + " person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']\n", + " person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)\n", + " mask = Image.open(mask_path).resize(result.size, Image.NEAREST)\n", + " result = repaint(person_image, mask, result)\n", + " if args.concat_eval_results:\n", + " w, h = result.size\n", + " concated_result = Image.new('RGB', (w*3, h))\n", + " concated_result.paste(person_images[i], (0, 0))\n", + " concated_result.paste(cloth_images[i], (w, 0)) \n", + " concated_result.paste(result, (w*2, 0))\n", + " result = concated_result\n", + " result.save(output_path)\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55d88911", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "harsh", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index a9e9e02dcc365906bcf51f57bf1c5f25fe6b8fc0..0000000000000000000000000000000000000000 --- a/test.ipynb +++ /dev/null @@ -1,1430 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "6387c9e1", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ca9233f0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'/kaggle/working'" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "3d2f98af", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[0m\u001b[01;34mtest\u001b[0m/ test_pairs.txt \u001b[01;34mtrain\u001b[0m/ train_pairs.txt\n" - ] - } - ], - "source": [ - "ls /kaggle/input/viton-hd-dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "dc0f36f4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Cloning into 'stable-diffusion'...\n", - "remote: Enumerating objects: 156, done.\u001b[K\n", - "remote: Counting objects: 100% (156/156), done.\u001b[K\n", - "remote: Compressing objects: 100% (129/129), done.\u001b[K\n", - "remote: Total 156 (delta 41), reused 141 (delta 27), pack-reused 0 (from 0)\u001b[K\n", - "Receiving objects: 100% (156/156), 9.12 MiB | 41.53 MiB/s, done.\n", - "Resolving deltas: 100% (41/41), done.\n" - ] - } - ], - "source": [ - "!git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "a0bf01ab", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/kaggle/working/stable-diffusion\n" - ] - } - ], - "source": [ - "cd stable-diffusion/" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "1401cd56", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2025-06-15 18:33:59-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", - "Resolving huggingface.co (huggingface.co)... 3.171.171.128, 3.171.171.6, 3.171.171.104, ...\n", - "Connecting to huggingface.co (huggingface.co)|3.171.171.128|:443... connected.\n", - "HTTP request sent, awaiting response... 307 Temporary Redirect\n", - "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n", - "--2025-06-15 18:33:59-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", - "Reusing existing connection to huggingface.co:443.\n", - "HTTP request sent, awaiting response... 302 Found\n", - "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750014781&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDAxNDc4MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=J6qT-n9PY34qz09a9caWcpc8-GaaGi%7EUu6AemCTMk48YsJgF9fjibpdUC-yTeIAJxbF4KxuFDt-5T6tXkQXgDaNakqUiTcxxJKpVNQYG9MlJ%7E3xeXE-WfBpwE9BbXkksCDStzHYqWV5ni5q0t2gPUqfwbmEFdfvZbQPol1oKH1ldWgCa3XusvR%7EUfdcxtci8gCgLXIrbNu7AG2lepj0AqpxkO5hsIBIhqUOTDXG7okdVLhepoAwnmJkc4neFV5LcR1Tt70My-1jdSFExn6c3yMLmWprMm3UMv6h5MyMifZWw4RdBrBWDjm0TPDwVMuwhgKiT6F9WTnUZvl1F0KKXFQ__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n", - "--2025-06-15 18:33:59-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750014781&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDAxNDc4MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=J6qT-n9PY34qz09a9caWcpc8-GaaGi%7EUu6AemCTMk48YsJgF9fjibpdUC-yTeIAJxbF4KxuFDt-5T6tXkQXgDaNakqUiTcxxJKpVNQYG9MlJ%7E3xeXE-WfBpwE9BbXkksCDStzHYqWV5ni5q0t2gPUqfwbmEFdfvZbQPol1oKH1ldWgCa3XusvR%7EUfdcxtci8gCgLXIrbNu7AG2lepj0AqpxkO5hsIBIhqUOTDXG7okdVLhepoAwnmJkc4neFV5LcR1Tt70My-1jdSFExn6c3yMLmWprMm3UMv6h5MyMifZWw4RdBrBWDjm0TPDwVMuwhgKiT6F9WTnUZvl1F0KKXFQ__&Key-Pair-Id=K3RPWS32NSSJCE\n", - "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.160.78.87, 18.160.78.43, 18.160.78.83, ...\n", - "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.160.78.87|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 4265437280 (4.0G) [binary/octet-stream]\n", - "Saving to: β€˜sd-v1-5-inpainting.ckpt’\n", - "\n", - "sd-v1-5-inpainting. 100%[===================>] 3.97G 306MB/s in 13s \n", - "\n", - "2025-06-15 18:34:13 (302 MB/s) - β€˜sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n", - "\n" - ] - } - ], - "source": [ - "!wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7450c55", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2025-06-11 10:33:19-- https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true\n", - "Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.90, 3.163.189.114, ...\n", - "Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n", - "HTTP request sent, awaiting response... 302 Found\n", - "Location: https://cdn-lfs-us-1.hf.co/repos/49/48/4948d897acaa287a14cc261fb60bfdb3ff0e6571ca16a0b5fa38cec3cfebdc34/915df7bf19a33bee36a28d5f9ceaef1e2267c47526f98ca9e4c49e90ae5f0fd0?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1749641599&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MTU5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ5LzQ4LzQ5NDhkODk3YWNhYTI4N2ExNGNjMjYxZmI2MGJmZGIzZmYwZTY1NzFjYTE2YTBiNWZhMzhjZWMzY2ZlYmRjMzQvOTE1ZGY3YmYxOWEzM2JlZTM2YTI4ZDVmOWNlYWVmMWUyMjY3YzQ3NTI2Zjk4Y2E5ZTRjNDllOTBhZTVmMGZkMD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iN3Lw7GVk22rlaKenmmcr3VTvG2wC9AFWTNHUmdS8DOVyKF2fUSnjW3QnGTm6P15luwwy2xs-43aiE22hmdjFm9AOV9v67mBvhUe3Gjp9k2DC-KIY%7ES6YuRPUUMLHSriK2bN6GfVpl6e-XN%7Ew6mEHiyUah9plAkKGidYjfaUXrODQr34siqAmTjDDD8wRyHAbLFiCMB-zUbllG4YjEO-rJkilkVtUEriayspO1uEKe%7EtAjW27n5Te68FqKTX%7Etj77fPDKGNV4p%7EUIvRtPx4jdtb1Mll7ga5C-YMwpNCKDX4bvWDMrnf2NNs9EIouNdjMZdBpPHUH2EpQGfEASUX0eg__&Key-Pair-Id=K24J24Z295AEI9 [following]\n", - "--2025-06-11 10:33:19-- https://cdn-lfs-us-1.hf.co/repos/49/48/4948d897acaa287a14cc261fb60bfdb3ff0e6571ca16a0b5fa38cec3cfebdc34/915df7bf19a33bee36a28d5f9ceaef1e2267c47526f98ca9e4c49e90ae5f0fd0?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1749641599&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MTU5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ5LzQ4LzQ5NDhkODk3YWNhYTI4N2ExNGNjMjYxZmI2MGJmZGIzZmYwZTY1NzFjYTE2YTBiNWZhMzhjZWMzY2ZlYmRjMzQvOTE1ZGY3YmYxOWEzM2JlZTM2YTI4ZDVmOWNlYWVmMWUyMjY3YzQ3NTI2Zjk4Y2E5ZTRjNDllOTBhZTVmMGZkMD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iN3Lw7GVk22rlaKenmmcr3VTvG2wC9AFWTNHUmdS8DOVyKF2fUSnjW3QnGTm6P15luwwy2xs-43aiE22hmdjFm9AOV9v67mBvhUe3Gjp9k2DC-KIY%7ES6YuRPUUMLHSriK2bN6GfVpl6e-XN%7Ew6mEHiyUah9plAkKGidYjfaUXrODQr34siqAmTjDDD8wRyHAbLFiCMB-zUbllG4YjEO-rJkilkVtUEriayspO1uEKe%7EtAjW27n5Te68FqKTX%7Etj77fPDKGNV4p%7EUIvRtPx4jdtb1Mll7ga5C-YMwpNCKDX4bvWDMrnf2NNs9EIouNdjMZdBpPHUH2EpQGfEASUX0eg__&Key-Pair-Id=K24J24Z295AEI9\n", - "Resolving cdn-lfs-us-1.hf.co (cdn-lfs-us-1.hf.co)... 18.238.238.75, 18.238.238.106, 18.238.238.119, ...\n", - "Connecting to cdn-lfs-us-1.hf.co (cdn-lfs-us-1.hf.co)|18.238.238.75|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 198303368 (189M) [binary/octet-stream]\n", - "Saving to: β€˜model.safetensors?download=true’\n", - "\n", - "model.safetensors?d 100%[===================>] 189.12M 298MB/s in 0.6s \n", - "\n", - "2025-06-11 10:33:20 (298 MB/s) - β€˜model.safetensors?download=true’ saved [198303368/198303368]\n", - "\n" - ] - } - ], - "source": [ - "# !wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ca20c487", - "metadata": {}, - "outputs": [], - "source": [ - "# mv 'model.safetensors?download=true' model.safetensors" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "6d0a1287", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "attention.py encoder.py\t model.safetensors sd-v1-5-inpainting.ckpt\n", - "clip.py interface.py\t pipeline.py\t test.ipynb\n", - "ddpm.py merges.txt\t README.md\t vocab.json\n", - "decoder.py model_converter.py requirements.txt\n", - "diffusion.py model.py\t\t sample_dataset\n" - ] - } - ], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8f11470e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/kaggle/working/stable-diffusion/CatVTON\n" - ] - } - ], - "source": [ - "cd .." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "cb794cb3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "app_flux.py eval.py preprocess_agnostic_mask.py \u001b[0m\u001b[01;34mstable-diffusion\u001b[0m/\n", - "app_p2p.py index.html \u001b[01;34m__pycache__\u001b[0m/ utils.py\n", - "app.py inference.py README.md\n", - "\u001b[01;34mdensepose\u001b[0m/ LICENSE requirements.txt\n", - "\u001b[01;34mdetectron2\u001b[0m/ \u001b[01;34mmodel\u001b[0m/ \u001b[01;34mresource\u001b[0m/\n" - ] - } - ], - "source": [ - "ls" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "b6af145b", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import shutil\n", - "\n", - "src_dir = \"./stable-diffusion\"\n", - "dst_dir = \".\"\n", - "\n", - "for filename in os.listdir(src_dir):\n", - " src_path = os.path.join(src_dir, filename)\n", - " dst_path = os.path.join(dst_dir, filename)\n", - " if os.path.isfile(src_path):\n", - " shutil.move(src_path, dst_path)\n", - " elif os.path.isdir(src_path):\n", - " shutil.move(src_path, dst_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "63ee438c", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60598bd3", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 229, - "id": "192a649c", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import gc\n", - "\n", - "# Clear CUDA cache and collect garbage\n", - "torch.cuda.empty_cache()\n", - "gc.collect()\n", - "\n", - "# Delete all user-defined variables except for built-ins and modules\n", - "for var in list(globals()):\n", - " if not var.startswith(\"__\") and var not in [\"torch\", \"gc\"]:\n", - " del globals()[var]\n", - "\n", - "gc.collect()\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "code", - "execution_count": 245, - "id": "a3a4a5dc", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import gc\n", - "\n", - "# Clear CUDA cache and collect garbage\n", - "torch.cuda.empty_cache()\n", - "gc.collect()\n", - "\n", - "# Delete all user-defined variables except for built-ins and modules\n", - "for var_name in list(globals()):\n", - " if not var_name.startswith(\"__\") and var_name not in [\"torch\", \"gc\"]:\n", - " del globals()[var_name]\n", - "\n", - "gc.collect()\n", - "torch.cuda.empty_cache()\n", - "\n", - "import tensorflow as tf\n", - "tf.keras.backend.clear_session()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "91ef7a4e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "import gc\n", - "\n", - "torch.cuda.empty_cache() # Release unused GPU memory\n", - "gc.collect() # Run Python garbage collector" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "08f29055", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU memory used: 0.00 MB / 16269.25 MB\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "if torch.cuda.is_available():\n", - " used = torch.cuda.memory_allocated() / 1024 ** 2 # in MB\n", - " total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 2 # in MB\n", - " print(f\"GPU memory used: {used:.2f} MB / {total:.2f} MB\")\n", - "else:\n", - " print(\"CUDA is not available.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 197, - "id": "6fbde810", - "metadata": {}, - "outputs": [], - "source": [ - "# rm -rf output" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "37335c1e", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_vae_encodings(image_tensor, encoder, device=\"cuda\"):\n", - " \"\"\"Encode image using VAE encoder\"\"\"\n", - " # Generate random noise for encoding\n", - " encoder_noise = torch.randn(\n", - " (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),\n", - " device=device,\n", - " )\n", - " \n", - " # Encode using your custom encoder\n", - " latent = encoder(image_tensor, encoder_noise)\n", - " return latent" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "35d98b83", - "metadata": {}, - "outputs": [], - "source": [ - "def get_trainable_module(unet, trainable_module_name):\n", - " if trainable_module_name == \"unet\":\n", - " return unet\n", - " elif trainable_module_name == \"transformer\":\n", - " trainable_modules = torch.nn.ModuleList()\n", - " for blocks in [unet.encoders, unet.bottleneck, unet.decoders]:\n", - " if hasattr(blocks, \"attentions\"):\n", - " trainable_modules.append(blocks.attentions)\n", - " else:\n", - " for block in blocks:\n", - " if hasattr(block, \"attentions\"):\n", - " trainable_modules.append(block.attentions)\n", - " return trainable_modules\n", - " elif trainable_module_name == \"attention\":\n", - " attn_blocks = torch.nn.ModuleList()\n", - " for name, param in unet.named_modules():\n", - " if \"attention_1\" in name:\n", - " attn_blocks.append(param)\n", - " return attn_blocks\n", - " else:\n", - " raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "d7ff094a", - "metadata": {}, - "outputs": [], - "source": [ - "from torch.nn import functional as F\n", - "import torch\n", - "# from flash_attn import flash_attn_func\n", - "\n", - "class SkipAttnProcessor(torch.nn.Module):\n", - " def __init__(self, *args, **kwargs) -> None:\n", - " super().__init__()\n", - "\n", - " def __call__(\n", - " self,\n", - " attn,\n", - " hidden_states,\n", - " encoder_hidden_states=None,\n", - " attention_mask=None,\n", - " temb=None,\n", - " ):\n", - " return hidden_states\n", - "\n", - "class AttnProcessor2_0(torch.nn.Module):\n", - " r\"\"\"\n", - " Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " hidden_size=None,\n", - " cross_attention_dim=None,\n", - " **kwargs\n", - " ):\n", - " super().__init__()\n", - " if not hasattr(F, \"scaled_dot_product_attention\"):\n", - " raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n", - "\n", - " def __call__(\n", - " self,\n", - " attn,\n", - " hidden_states,\n", - " encoder_hidden_states=None,\n", - " attention_mask=None,\n", - " temb=None,\n", - " *args,\n", - " **kwargs,\n", - " ):\n", - " residual = hidden_states\n", - "\n", - " if attn.spatial_norm is not None:\n", - " hidden_states = attn.spatial_norm(hidden_states, temb)\n", - "\n", - " input_ndim = hidden_states.ndim\n", - "\n", - " if input_ndim == 4:\n", - " batch_size, channel, height, width = hidden_states.shape\n", - " hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n", - "\n", - " batch_size, sequence_length, _ = (\n", - " hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n", - " )\n", - "\n", - " if attention_mask is not None:\n", - " attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n", - " # scaled_dot_product_attention expects attention_mask shape to be\n", - " # (batch, heads, source_length, target_length)\n", - " attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n", - "\n", - " if attn.group_norm is not None:\n", - " hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n", - "\n", - " query = attn.to_q(hidden_states)\n", - "\n", - " if encoder_hidden_states is None:\n", - " encoder_hidden_states = hidden_states\n", - " elif attn.norm_cross:\n", - " encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n", - "\n", - " key = attn.to_k(encoder_hidden_states)\n", - " value = attn.to_v(encoder_hidden_states)\n", - "\n", - " inner_dim = key.shape[-1]\n", - " head_dim = inner_dim // attn.heads\n", - "\n", - " query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", - "\n", - " key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", - " value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", - "\n", - " # the output of sdp = (batch, num_heads, seq_len, head_dim)\n", - " # TODO: add support for attn.scale when we move to Torch 2.1\n", - " \n", - " hidden_states = F.scaled_dot_product_attention(\n", - " query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n", - " )\n", - " # hidden_states = flash_attn_func(\n", - " # query, key, value, dropout_p=0.0, causal=False\n", - " # )\n", - "\n", - " hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n", - " hidden_states = hidden_states.to(query.dtype)\n", - "\n", - " # linear proj\n", - " hidden_states = attn.to_out[0](hidden_states)\n", - " # dropout\n", - " hidden_states = attn.to_out[1](hidden_states)\n", - "\n", - " if input_ndim == 4:\n", - " hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n", - "\n", - " if attn.residual_connection:\n", - " hidden_states = hidden_states + residual\n", - "\n", - " hidden_states = hidden_states / attn.rescale_output_factor\n", - "\n", - " return hidden_states\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "84a7fa87", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import json\n", - "import torch\n", - "\n", - "def init_adapter(unet, \n", - " cross_attn_cls=SkipAttnProcessor,\n", - " self_attn_cls=None,\n", - " cross_attn_dim=None, \n", - " **kwargs):\n", - " if cross_attn_dim is None:\n", - " cross_attn_dim = unet.config.cross_attention_dim\n", - " attn_procs = {}\n", - " for name in unet.attn_processors.keys():\n", - " cross_attention_dim = None if name.endswith(\"attn1.processor\") else cross_attn_dim\n", - " if name.startswith(\"mid_block\"):\n", - " hidden_size = unet.config.block_out_channels[-1]\n", - " elif name.startswith(\"up_blocks\"):\n", - " block_id = int(name[len(\"up_blocks.\")])\n", - " hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n", - " elif name.startswith(\"down_blocks\"):\n", - " block_id = int(name[len(\"down_blocks.\")])\n", - " hidden_size = unet.config.block_out_channels[block_id]\n", - " if cross_attention_dim is None:\n", - " if self_attn_cls is not None:\n", - " attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n", - " else:\n", - " # retain the original attn processor\n", - " attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n", - " else:\n", - " attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n", - " \n", - " unet.set_attn_processor(attn_procs)\n", - " adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n", - " return adapter_modules\n", - "\n", - "def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):\n", - " from diffusers import AutoencoderKL\n", - " from transformers import CLIPTextModel, CLIPTokenizer\n", - "\n", - " text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder=\"text_encoder\")\n", - " vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder=\"vae\")\n", - " tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder=\"tokenizer\")\n", - " try:\n", - " unet_folder = os.path.join(diffusion_model_name_or_path, \"unet\")\n", - " unet_configs = json.load(open(os.path.join(unet_folder, \"config.json\"), \"r\"))\n", - " unet = unet_class(**unet_configs)\n", - " unet.load_state_dict(torch.load(os.path.join(unet_folder, \"diffusion_pytorch_model.bin\"), map_location=\"cpu\"), strict=True)\n", - " except:\n", - " unet = None\n", - " return text_encoder, vae, tokenizer, unet\n", - "\n", - "def attn_of_unet(unet):\n", - " attn_blocks = torch.nn.ModuleList()\n", - " for name, param in unet.named_modules():\n", - " if \"attn1\" in name:\n", - " attn_blocks.append(param)\n", - " return attn_blocks\n", - "\n", - "def get_trainable_module(unet, trainable_module_name):\n", - " if trainable_module_name == \"unet\":\n", - " return unet\n", - " elif trainable_module_name == \"transformer\":\n", - " trainable_modules = torch.nn.ModuleList()\n", - " for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:\n", - " if hasattr(blocks, \"attentions\"):\n", - " trainable_modules.append(blocks.attentions)\n", - " else:\n", - " for block in blocks:\n", - " if hasattr(block, \"attentions\"):\n", - " trainable_modules.append(block.attentions)\n", - " return trainable_modules\n", - " elif trainable_module_name == \"attention\":\n", - " attn_blocks = torch.nn.ModuleList()\n", - " for name, param in unet.named_modules():\n", - " if \"attn1\" in name:\n", - " attn_blocks.append(param)\n", - " return attn_blocks\n", - " else:\n", - " raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")\n", - "\n", - " \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6028381d", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-06-15 18:35:15.189276: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "E0000 00:00:1750012515.396602 73 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "E0000 00:00:1750012515.456784 73 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" - ] - } - ], - "source": [ - "import inspect\n", - "import os\n", - "from typing import Union\n", - "\n", - "import PIL\n", - "import numpy as np\n", - "import torch\n", - "import tqdm\n", - "from accelerate import load_checkpoint_in_model\n", - "from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel\n", - "from diffusers.pipelines.stable_diffusion.safety_checker import \\\n", - " StableDiffusionSafetyChecker\n", - "from diffusers.utils.torch_utils import randn_tensor\n", - "from huggingface_hub import snapshot_download\n", - "from transformers import CLIPImageProcessor\n", - "\n", - "from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n", - " prepare_mask_image, resize_and_crop, resize_and_padding)\n", - "from ddpm import DDPMSampler\n", - "\n", - "class CatVTONPipeline:\n", - " def __init__(\n", - " self, \n", - " base_ckpt, \n", - " attn_ckpt, \n", - " attn_ckpt_version=\"mix\",\n", - " weight_dtype=torch.float32,\n", - " device='cuda',\n", - " compile=False,\n", - " skip_safety_check=True,\n", - " use_tf32=True,\n", - " models={},\n", - " ):\n", - " self.device = device\n", - " self.weight_dtype = weight_dtype\n", - " self.skip_safety_check = skip_safety_check\n", - " self.models = models\n", - "\n", - " self.generator = torch.Generator(device=device)\n", - " self.noise_scheduler = DDPMSampler(generator=self.generator)\n", - " # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n", - " self.encoder= models.get('encoder', None)\n", - " self.decoder= models.get('decoder', None)\n", - " if not skip_safety_check:\n", - " self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder=\"feature_extractor\")\n", - " self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(base_ckpt, subfolder=\"safety_checker\").to(device, dtype=weight_dtype)\n", - " self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder=\"unet\").to(device, dtype=weight_dtype)\n", - " # self.unet=models.get('diffusion', None)\n", - " init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor) # Skip Cross-Attention\n", - " self.attn_modules = get_trainable_module(self.unet, \"attention\")\n", - " self.auto_attn_ckpt_load(attn_ckpt, attn_ckpt_version)\n", - " # Pytorch 2.0 Compile\n", - " # if compile:\n", - " # self.unet = torch.compile(self.unet)\n", - " # self.vae = torch.compile(self.vae, mode=\"reduce-overhead\")\n", - " \n", - " # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n", - " if use_tf32:\n", - " torch.set_float32_matmul_precision(\"high\")\n", - " torch.backends.cuda.matmul.allow_tf32 = True\n", - "\n", - " def auto_attn_ckpt_load(self, attn_ckpt, version):\n", - " sub_folder = {\n", - " \"mix\": \"mix-48k-1024\",\n", - " \"vitonhd\": \"vitonhd-16k-512\",\n", - " \"dresscode\": \"dresscode-16k-512\",\n", - " }[version]\n", - " if os.path.exists(attn_ckpt):\n", - " load_checkpoint_in_model(self.attn_modules, os.path.join(attn_ckpt, sub_folder, 'attention'))\n", - " else:\n", - " repo_path = snapshot_download(repo_id=attn_ckpt)\n", - " print(f\"Downloaded {attn_ckpt} to {repo_path}\")\n", - " load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, 'attention'))\n", - " \n", - " def run_safety_checker(self, image):\n", - " if self.safety_checker is None:\n", - " has_nsfw_concept = None\n", - " else:\n", - " safety_checker_input = self.feature_extractor(image, return_tensors=\"pt\").to(self.device)\n", - " image, has_nsfw_concept = self.safety_checker(\n", - " images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)\n", - " )\n", - " return image, has_nsfw_concept\n", - " \n", - " def prepare_extra_step_kwargs(self, generator, eta):\n", - " # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n", - " # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n", - " # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502\n", - " # and should be between [0, 1]\n", - "\n", - " accepts_eta = \"eta\" in set(\n", - " inspect.signature(self.noise_scheduler.step).parameters.keys()\n", - " )\n", - " extra_step_kwargs = {}\n", - " if accepts_eta:\n", - " extra_step_kwargs[\"eta\"] = eta\n", - "\n", - " # check if the scheduler accepts generator\n", - " accepts_generator = \"generator\" in set(\n", - " inspect.signature(self.noise_scheduler.step).parameters.keys()\n", - " )\n", - " if accepts_generator:\n", - " extra_step_kwargs[\"generator\"] = generator\n", - " return extra_step_kwargs\n", - "\n", - " @torch.no_grad()\n", - " def __call__(\n", - " self, \n", - " image: Union[PIL.Image.Image, torch.Tensor],\n", - " condition_image: Union[PIL.Image.Image, torch.Tensor],\n", - " mask: Union[PIL.Image.Image, torch.Tensor],\n", - " num_inference_steps: int = 50,\n", - " guidance_scale: float = 2.5,\n", - " height: int = 1024,\n", - " width: int = 768,\n", - " generator=None,\n", - " eta=1.0,\n", - " **kwargs\n", - " ):\n", - " concat_dim = -2 # FIXME: y axis concat\n", - " # Prepare inputs to Tensor\n", - " image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)\n", - " image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n", - " condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n", - " mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)\n", - " # Mask image\n", - " masked_image = image * (mask < 0.5)\n", - " # VAE encoding\n", - " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n", - " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", - " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n", - " del image, mask, condition_image\n", - " # Concatenate latents\n", - " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", - " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", - " # Prepare noise\n", - " latents = randn_tensor(\n", - " masked_latent_concat.shape,\n", - " generator=generator,\n", - " device=masked_latent_concat.device,\n", - " dtype=self.weight_dtype,\n", - " )\n", - " # Prepare timesteps\n", - " self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n", - " timesteps = self.noise_scheduler.timesteps\n", - " # latents = latents * self.noise_scheduler.init_noise_sigma\n", - " latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n", - " # Classifier-Free Guidance\n", - " if do_classifier_free_guidance := (guidance_scale > 1.0):\n", - " masked_latent_concat = torch.cat(\n", - " [\n", - " torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n", - " masked_latent_concat,\n", - " ]\n", - " )\n", - " mask_latent_concat = torch.cat([mask_latent_concat] * 2)\n", - "\n", - " # Denoising loop\n", - " # extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n", - " # num_warmup_steps = (len(timesteps) - num_inference_steps * self.noise_scheduler.order)\n", - " num_warmup_steps = 0 # For simple DDPM, no warmup needed\n", - " with tqdm(total=num_inference_steps) as progress_bar:\n", - " for i, t in enumerate(timesteps):\n", - " # expand the latents if we are doing classifier free guidance\n", - " non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n", - " # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)\n", - " # prepare the input for the inpainting model\n", - " inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)\n", - " # predict the noise residual\n", - " # time_embedding = get_time_embedding(t.item())\n", - " # time_embedding = time_embedding.repeat(inpainting_latent_model_input.shape[0], 1).to(self.device, dtype=self.weight_dtype)\n", - " noise_pred= self.unet(\n", - " inpainting_latent_model_input,\n", - " # time_embedding\n", - " t.to(self.device),\n", - " encoder_hidden_states=None, # FIXME\n", - " return_dict=False,\n", - " )[0]\n", - " # perform guidance\n", - " if do_classifier_free_guidance:\n", - " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", - " noise_pred = noise_pred_uncond + guidance_scale * (\n", - " noise_pred_text - noise_pred_uncond\n", - " )\n", - " # compute the previous noisy sample x_t -> x_t-1\n", - " latents = self.noise_scheduler.step(\n", - " t, latents, noise_pred\n", - " )\n", - " # call the callback, if provided\n", - " if i == len(timesteps) - 1 or (\n", - " (i + 1) > num_warmup_steps\n", - " ):\n", - " progress_bar.update()\n", - "\n", - " # Decode the final latents\n", - " latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n", - " # latents = 1 / self.vae.config.scaling_factor * latents\n", - " # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n", - " image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n", - " image = (image / 2 + 0.5).clamp(0, 1)\n", - " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n", - " image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n", - " image = numpy_to_pil(image)\n", - " \n", - " # Safety Check\n", - " if not self.skip_safety_check:\n", - " current_script_directory = os.path.dirname(os.path.realpath(__file__))\n", - " nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg')\n", - " nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size)\n", - " image_np = np.array(image)\n", - " _, has_nsfw_concept = self.run_safety_checker(image=image_np)\n", - " for i, not_safe in enumerate(has_nsfw_concept):\n", - " if not_safe:\n", - " image[i] = nsfw_image\n", - " return image\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "94e19198", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "594e5184ce094185bf75cb38118c1867", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "config.json: 0%| | 0.00/748 [00:00=\"\n", - " \" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the\"\n", - " \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n", - " ),\n", - " )\n", - "\n", - " parser.add_argument(\n", - " \"--concat_axis\",\n", - " type=str,\n", - " choices=[\"x\", \"y\", 'random'],\n", - " default=\"y\",\n", - " help=\"The axis to concat the cloth feature, select from ['x', 'y', 'random'].\",\n", - " )\n", - " parser.add_argument(\n", - " \"--enable_condition_noise\",\n", - " action=\"store_true\",\n", - " default=True,\n", - " help=\"Whether or not to enable condition noise.\",\n", - " )\n", - " \n", - " args = parser.parse_args()\n", - " env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n", - " if env_local_rank != -1 and env_local_rank != args.local_rank:\n", - " args.local_rank = env_local_rank\n", - "\n", - " return args\n", - "\n", - "@torch.no_grad()\n", - "def main():\n", - " # args = parse_args()\n", - "\n", - " # Replace with your actual data root and output directory paths\n", - " # !CUDA_VISIBLE_DEVICES=0 python inference.py \\\n", - " # --dataset vitonhd \\\n", - " # --data_root_path /kaggle/input/viton-hd-dataset \\\n", - " # --output_dir ./output \\\n", - " # --dataloader_num_workers 8 \\\n", - " # --batch_size 8 \\\n", - " # --seed 555 \\\n", - " # --mixed_precision no \\\n", - " # --allow_tf32 \\\n", - " # --repaint \\\n", - " # --eval_pair\n", - " \n", - " args=argparse.Namespace()\n", - " args.__dict__= {\n", - " \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n", - " \"resume_path\": \"zhengchong/CatVTON\",\n", - " \"dataset_name\": \"vitonhd\",\n", - " # \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n", - " \"data_root_path\": \"/kaggle/working/stable-diffusion/sample_dataset\",\n", - " \"output_dir\": \"./output\",\n", - " \"seed\": 555,\n", - " \"batch_size\": 2,\n", - " \"num_inference_steps\": 50,\n", - " \"guidance_scale\": 2.5,\n", - " \"width\": 384,\n", - " \"height\": 512,\n", - " \"repaint\": True,\n", - " \"eval_pair\": False,\n", - " \"concat_eval_results\": True,\n", - " \"allow_tf32\": True,\n", - " \"dataloader_num_workers\": 4,\n", - " \"mixed_precision\": 'no',\n", - " \"concat_axis\": 'y',\n", - " \"enable_condition_noise\": True,\n", - " \"is_train\": False\n", - " }\n", - "\n", - " models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=None)\n", - "\n", - " # Pipeline\n", - " pipeline = CatVTONPipeline(\n", - " attn_ckpt_version=args.dataset_name,\n", - " attn_ckpt=args.resume_path,\n", - " base_ckpt=args.base_model_path,\n", - " weight_dtype={\n", - " \"no\": torch.float32,\n", - " \"fp16\": torch.float16,\n", - " \"bf16\": torch.bfloat16,\n", - " }[args.mixed_precision],\n", - " device=\"cuda\",\n", - " skip_safety_check=True,\n", - " models=models,\n", - " )\n", - " # Dataset\n", - " if args.dataset_name == \"vitonhd\":\n", - " dataset = VITONHDTestDataset(args)\n", - " else:\n", - " raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n", - " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n", - " dataloader = DataLoader(\n", - " dataset,\n", - " batch_size=args.batch_size,\n", - " shuffle=False,\n", - " num_workers=args.dataloader_num_workers\n", - " )\n", - " # Inference\n", - " generator = torch.Generator(device='cuda').manual_seed(args.seed)\n", - " args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n", - " if not os.path.exists(args.output_dir):\n", - " os.makedirs(args.output_dir)\n", - " \n", - " for batch in tqdm(dataloader):\n", - " person_images = batch['person']\n", - " cloth_images = batch['cloth']\n", - " masks = batch['mask']\n", - "\n", - " results = pipeline(\n", - " person_images,\n", - " cloth_images,\n", - " masks,\n", - " num_inference_steps=args.num_inference_steps,\n", - " guidance_scale=args.guidance_scale,\n", - " height=args.height,\n", - " width=args.width,\n", - " generator=generator,\n", - " )\n", - " \n", - " if args.concat_eval_results or args.repaint:\n", - " person_images = to_pil_image(person_images)\n", - " cloth_images = to_pil_image(cloth_images)\n", - " masks = to_pil_image(masks)\n", - " for i, result in enumerate(results):\n", - " person_name = batch['person_name'][i]\n", - " output_path = os.path.join(args.output_dir, person_name)\n", - " if not os.path.exists(os.path.dirname(output_path)):\n", - " os.makedirs(os.path.dirname(output_path))\n", - " if args.repaint:\n", - " person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']\n", - " person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)\n", - " mask = Image.open(mask_path).resize(result.size, Image.NEAREST)\n", - " result = repaint(person_image, mask, result)\n", - " if args.concat_eval_results:\n", - " w, h = result.size\n", - " concated_result = Image.new('RGB', (w*3, h))\n", - " concated_result.paste(person_images[i], (0, 0))\n", - " concated_result.paste(cloth_images[i], (w, 0)) \n", - " concated_result.paste(result, (w*2, 0))\n", - " result = concated_result\n", - " result.save(output_path)\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c2d9f98", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "143d0ef9", - "metadata": {}, - "outputs": [], - "source": [ - "# rm -rf output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e417edb7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c86c58d", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/trained_output/vitonhd-384/unpaired/00654_00.jpg b/trained_output/vitonhd-384/unpaired/00654_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af72ec3eab502993cec65a9315c6ef4e8c202051 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/00654_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/01265_00.jpg b/trained_output/vitonhd-384/unpaired/01265_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..93455e95a7318fbfa8060381b794a0db1e126395 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/01265_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/01985_00.jpg b/trained_output/vitonhd-384/unpaired/01985_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..54deed3e859becf96ab7c903bd707fbeec008b4e Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/01985_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/02023_00.jpg b/trained_output/vitonhd-384/unpaired/02023_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c9fe685911d5690d1546d17347135a3ffe11e3a Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/02023_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/02532_00.jpg b/trained_output/vitonhd-384/unpaired/02532_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..900e40acf86778ddeaa179ce058cb18d8dd2ffed Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/02532_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/02944_00.jpg b/trained_output/vitonhd-384/unpaired/02944_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37d9f5778a62187f1f56210d56bed4e4052e3d00 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/02944_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/03191_00.jpg b/trained_output/vitonhd-384/unpaired/03191_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aca0c2744243856b18e96df4417ce1ee50fd0b64 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/03191_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/03921_00.jpg b/trained_output/vitonhd-384/unpaired/03921_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d44da966c33527f373c83a944473e0c72e519d25 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/03921_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/05006_00.jpg b/trained_output/vitonhd-384/unpaired/05006_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..afe8a541cb068f98a8398058338a83333d12c907 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/05006_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/05378_00.jpg b/trained_output/vitonhd-384/unpaired/05378_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3b39dd47947a9726fa8724ba099c264a673c30db Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/05378_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/07342_00.jpg b/trained_output/vitonhd-384/unpaired/07342_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..732332fee96ddab461cb074a5812ef8ddec71468 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/07342_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/08088_00.jpg b/trained_output/vitonhd-384/unpaired/08088_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e0fd840920516a954831ee4a257cbaa3a9beb93e Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/08088_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/08239_00.jpg b/trained_output/vitonhd-384/unpaired/08239_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fcf4fd4d21496ffa2ea92415172244b2a6fd9466 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/08239_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/08650_00.jpg b/trained_output/vitonhd-384/unpaired/08650_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..157223532b98b7e857f5bd1a20bb8a66692fa74a Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/08650_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/08839_00.jpg b/trained_output/vitonhd-384/unpaired/08839_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..638a83fad39dc524704535e1553779fed775772c Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/08839_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/11085_00.jpg b/trained_output/vitonhd-384/unpaired/11085_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f0f7fc4faeecd2dddca85e9b347e4eacd305a152 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/11085_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/12345_00.jpg b/trained_output/vitonhd-384/unpaired/12345_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..22a7f7a33d70e7d1ccced3909259c914a286f2b8 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/12345_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/12419_00.jpg b/trained_output/vitonhd-384/unpaired/12419_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..691dac95b51ffaff13db6c5242cb723c1e23c383 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/12419_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/12562_00.jpg b/trained_output/vitonhd-384/unpaired/12562_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7f8c0272c95ac01a33c5bfcd2bb62655b98d7884 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/12562_00.jpg differ diff --git a/trained_output/vitonhd-384/unpaired/14651_00.jpg b/trained_output/vitonhd-384/unpaired/14651_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..70492475142842becabc0802a4cd2fbfaadcc9b8 Binary files /dev/null and b/trained_output/vitonhd-384/unpaired/14651_00.jpg differ diff --git a/training.ipynb b/training.ipynb index d686758832e35837cd08abd1a1ac7dde605bb2f5..5ec36b2f8f1e97e5dcbd50efeeee878a6ec73643 100644 --- a/training.ipynb +++ b/training.ipynb @@ -5,23 +5,9 @@ "execution_count": 1, "id": "81e4a1db", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Cloning into 'stable-diffusion'...\n", - "remote: Enumerating objects: 184, done.\u001b[K\n", - "remote: Counting objects: 100% (184/184), done.\u001b[K\n", - "remote: Compressing objects: 100% (156/156), done.\u001b[K\n", - "remote: Total 184 (delta 44), reused 165 (delta 26), pack-reused 0 (from 0)\u001b[K\n", - "Receiving objects: 100% (184/184), 9.94 MiB | 37.02 MiB/s, done.\n", - "Resolving deltas: 100% (44/44), done.\n" - ] - } - ], + "outputs": [], "source": [ - "!git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git" + "# !git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git" ] }, { @@ -29,443 +15,370 @@ "execution_count": 2, "id": "9c89e320", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/kaggle/working/stable-diffusion\n" - ] - } - ], + "outputs": [], "source": [ - "cd stable-diffusion/" + "# cd stable-diffusion/" ] }, { "cell_type": "code", "execution_count": 3, - "id": "8b304af3", + "id": "ff8b706c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Already up to date.\n" + "Model already downloaded.\n" ] } ], "source": [ - "!git pull" + "# check if the model is downloaded, if not download it\n", + "import os\n", + "if not os.path.exists(\"sd-v1-5-inpainting.ckpt\"):\n", + " !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", + "else:\n", + " print(\"Model already downloaded.\")\n" ] }, { "cell_type": "code", "execution_count": 4, - "id": "ff8b706c", + "id": "53095103", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "--2025-06-17 08:50:15-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", - "Resolving huggingface.co (huggingface.co)... 3.171.171.104, 3.171.171.128, 3.171.171.6, ...\n", - "Connecting to huggingface.co (huggingface.co)|3.171.171.104|:443... connected.\n", - "HTTP request sent, awaiting response... 307 Temporary Redirect\n", - "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n", - "--2025-06-17 08:50:15-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", - "Reusing existing connection to huggingface.co:443.\n", - "HTTP request sent, awaiting response... 302 Found\n", - "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750153142&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDE1MzE0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=kAea10Cu%7EhNLABWiXI0i%7E5gAtwsQUUM6CIZczAEWsswZur-XllSQvXEoKksmPdojVE654r7s-CxII8r%7EQ52to%7EQMLbjsjw-JmXq4duiq91qz6U5aenByAXSpOO1ihAoCmCkP02e7L5Wcbs%7EhaV26W9Q%7EAfbwyQ1mn9ta%7EHIDiE7AuNuHgkEEA2IP45ao25b9zsaFw6fIUlBy93Meuf82zwzsw8CJPWV9QEwj-oPVeSDyv3ZhfxS3iCgGSYS320Vs7NcK%7EqJxPfttpTHG9m6zAnfxOpWjYVQfre6HnHUt3VHOy4QdDvpyfljgEQoH4LxRBWI%7Ev72YjOJZDEgSPoTi1Q__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n", - "--2025-06-17 08:50:15-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750153142&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDE1MzE0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=kAea10Cu%7EhNLABWiXI0i%7E5gAtwsQUUM6CIZczAEWsswZur-XllSQvXEoKksmPdojVE654r7s-CxII8r%7EQ52to%7EQMLbjsjw-JmXq4duiq91qz6U5aenByAXSpOO1ihAoCmCkP02e7L5Wcbs%7EhaV26W9Q%7EAfbwyQ1mn9ta%7EHIDiE7AuNuHgkEEA2IP45ao25b9zsaFw6fIUlBy93Meuf82zwzsw8CJPWV9QEwj-oPVeSDyv3ZhfxS3iCgGSYS320Vs7NcK%7EqJxPfttpTHG9m6zAnfxOpWjYVQfre6HnHUt3VHOy4QdDvpyfljgEQoH4LxRBWI%7Ev72YjOJZDEgSPoTi1Q__&Key-Pair-Id=K3RPWS32NSSJCE\n", - "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.160.78.83, 18.160.78.87, 18.160.78.43, ...\n", - "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.160.78.83|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 4265437280 (4.0G) [binary/octet-stream]\n", - "Saving to: β€˜sd-v1-5-inpainting.ckpt’\n", - "\n", - "sd-v1-5-inpainting. 100%[===================>] 3.97G 324MB/s in 12s \n", - "\n", - "2025-06-17 08:50:27 (341 MB/s) - β€˜sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n", - "\n" + "Checkpoints directory already exists.\n" ] } ], "source": [ - "!wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt" + "# make output and checkpoints directories if they don't exist\n", + "import os\n", + "if not os.path.exists(\"checkpoints\"):\n", + " os.makedirs(\"checkpoints\")\n", + "else:\n", + " print(\"Checkpoints directory already exists.\")" ] }, { "cell_type": "code", "execution_count": 5, - "id": "4c5198ca", + "id": "d8978b25", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "attention.py interface.py\t README.md\t\t utils.py\n", - "clip.py load_model.py\t requirements.txt\t VITON_Dataset.py\n", - "ddpm.py merges.txt\t sample_dataset\t vocab.json\n", - "decoder.py model_converter.py sd-v1-5-inpainting.ckpt\n", - "diffusion.py output\t\t test.ipynb\n", - "encoder.py pipeline.py\t training.ipynb\n" + "VITON-HD dataset already exists.\n", + "Zip file does not exist, nothing to remove.\n" ] } ], "source": [ - "!ls" + "import os\n", + "if not os.path.exists(\"viton-hd-dataset\"):\n", + " !curl -L -u harshkesherwani:7695128b407febc869a6f5b2cb0cbf26\\\n", + " -o /home/mahesh/harsh/stable-diffusion/viton-hd-dataset.zip\\\n", + " https://www.kaggle.com/api/v1/datasets/download/harshkesherwani/viton-hd-dataset\n", + " \n", + " import zipfile\n", + " with zipfile.ZipFile('viton-hd-dataset.zip', 'r') as zip_ref:\n", + " zip_ref.extractall('viton-hd-dataset')\n", + " \n", + " print(\"VITON-HD dataset downloaded and extracted.\")\n", + "else:\n", + " print(\"VITON-HD dataset already exists.\")\n", + " \n", + "import os\n", + "if os.path.exists(\"viton-hd-dataset.zip\"):\n", + " os.remove(\"viton-hd-dataset.zip\")\n", + " print(\"Removed the zip file after extraction.\")\n", + "else:\n", + " print(\"Zip file does not exist, nothing to remove.\")\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "9041f108", + "id": "3aea80d9", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Found existing installation: gdown 5.2.0\n", - "Uninstalling gdown-5.2.0:\n", - " Successfully uninstalled gdown-5.2.0\n" + "----------------------------------------------------------------------------------------------------\n", + "Loading pretrained models...\n", + "Models loaded successfully.\n", + "----------------------------------------------------------------------------------------------------\n", + "Creating dataloader...\n", + "Dataset vitonhd loaded, total 11647 pairs.\n", + "Training for 50 epochs\n", + "Batches per epoch: 5824\n", + "----------------------------------------------------------------------------------------------------\n", + "Initializing trainer...\n", + "Enabling PEFT training (self-attention layers only)\n", + "Total parameters: 899,226,667\n", + "Trainable parameters: 49,574,080 (5.51%)\n" ] - } - ], - "source": [ - "# !pip uninstall gdown -y" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9c7b968", - "metadata": {}, - "outputs": [ + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1669505/646906096.py:71: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", + " self.scaler = torch.cuda.amp.GradScaler()\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Collecting gdown\n", - " Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)\n", - "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from gdown) (4.13.3)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from gdown) (3.18.0)\n", - "Requirement already satisfied: requests[socks] in /usr/local/lib/python3.11/dist-packages (from gdown) (2.32.3)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from gdown) (4.67.1)\n", - "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->gdown) (2.6)\n", - "Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->gdown) (4.13.1)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (3.4.1)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (3.10)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (2.3.0)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (2025.1.31)\n", - "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (1.7.1)\n", - "Downloading gdown-5.2.0-py3-none-any.whl (18 kB)\n", - "Installing collected packages: gdown\n", - "Successfully installed gdown-5.2.0\n" + "Checkpoint loaded: ./checkpoints/checkpoint_step_40000.pth\n", + "Resuming from epoch 12, step 40000\n", + "Starting training...\n", + "Starting training for 50 epochs\n", + "Total training batches per epoch: 5824\n", + "Using DREAM with lambda = 0\n", + "Mixed precision: True\n" ] - } - ], - "source": [ - "# !pip install -U --no-cache-dir gdown --pre" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2653ceca", - "metadata": {}, - "outputs": [ + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13: 0%| | 0/5824 [00:00\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Release unused GPU memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Run Python garbage collector\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyError\u001b[0m: '_oh'" + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13: 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 3000/5824 [26:32<24:32, 1.92it/s, loss=0.0144, lr=1e-5, step=43001] " ] - } - ], - "source": [ - "import torch\n", - "import gc\n", - "\n", - "torch.cuda.empty_cache() # Release unused GPU memory\n", - "gc.collect() # Run Python garbage collector" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "5a57d765", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import gc\n", - "\n", - "# Clear CUDA cache and collect garbage\n", - "torch.cuda.empty_cache()\n", - "gc.collect()\n", - "\n", - "# Delete all user-defined variables except for built-ins and modules\n", - "for var in list(globals()):\n", - " if not var.startswith(\"__\") and var not in [\"torch\", \"gc\"]:\n", - " del globals()[var]\n", - "\n", - "gc.collect()\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "5957ec57", - "metadata": {}, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.keras.backend.clear_session()" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "796e8ef7", - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", "text": [ - "GPU memory used: 8.12 MB / 16269.25 MB\n" + "Debug visualization saved: checkpoints/debug_viz/debug_step_043000.jpg\n" ] - } - ], - "source": [ - "import torch\n", - "\n", - "if torch.cuda.is_available():\n", - " used = torch.cuda.memory_allocated() / 1024 ** 2 # in MB\n", - " total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 2 # in MB\n", - " print(f\"GPU memory used: {used:.2f} MB / {total:.2f} MB\")\n", - "else:\n", - " print(\"CUDA is not available.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32ed173e", - "metadata": {}, - "outputs": [ + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13: 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 4001/5824 [35:19<18:14, 1.67it/s, loss=0.0233, lr=1e-5, step=44001] " + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Total RAM: 31.35 GB\n", - "Available RAM: 24.16 GB\n" + "Debug visualization saved: checkpoints/debug_viz/debug_step_044000.jpg\n" ] - } - ], - "source": [ - "import psutil\n", - "\n", - "mem = psutil.virtual_memory()\n", - "total_ram = mem.total / (1024 ** 3) # in GB\n", - "available_ram = mem.available / (1024 ** 3) # in GB\n", - "print(f\"Total RAM: {total_ram:.2f} GB\")\n", - "print(f\"Available RAM: {available_ram:.2f} GB\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d13441b5", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "3ce888b6", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_vae_encodings(image_tensor, encoder, device=\"cuda\"):\n", - " \"\"\"Encode image using VAE encoder\"\"\"\n", - " # Generate random noise for encoding\n", - " encoder_noise = torch.randn(\n", - " (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),\n", - " device=device,\n", - " )\n", - " with torch.no_grad(): # VAE encoding doesn't need gradients\n", - " return encoder(image_tensor, encoder_noise)" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "3aea80d9", - "metadata": {}, - "outputs": [ + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13: 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 5000/5824 [44:07<07:16, 1.89it/s, loss=0.0609, lr=1e-5, step=45001] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Debug visualization saved: checkpoints/debug_viz/debug_step_045000.jpg\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5824/5824 [51:31<00:00, 1.88it/s, loss=0.00715, lr=1e-5, step=45824]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13/50 - Train Loss: 0.030487\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 3%|β–Ž | 177/5824 [01:33<56:36, 1.66it/s, loss=0.0409, lr=1e-5, step=46001] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Debug visualization saved: checkpoints/debug_viz/debug_step_046000.jpg\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 20%|β–ˆβ–ˆ | 1177/5824 [10:19<46:38, 1.66it/s, loss=0.00494, lr=1e-5, step=47001]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Debug visualization saved: checkpoints/debug_viz/debug_step_047000.jpg\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 37%|β–ˆβ–ˆβ–ˆβ–‹ | 2177/5824 [19:07<36:55, 1.65it/s, loss=0.0527, lr=1e-5, step=48001] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Debug visualization saved: checkpoints/debug_viz/debug_step_048000.jpg\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 3177/5824 [27:52<26:30, 1.66it/s, loss=0.0266, lr=1e-5, step=49001] " + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Loading pretrained models...\n" + "Debug visualization saved: checkpoints/debug_viz/debug_step_049000.jpg\n" ] }, { - "ename": "OutOfMemoryError", - "evalue": "CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 29.12 MiB is free. Process 3907 has 15.85 GiB memory in use. Of the allocated memory 15.49 GiB is allocated by PyTorch, and 62.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)", + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 4176/5824 [36:39<41:17, 1.50s/it, loss=0.0227, lr=1e-5, step=5e+4] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved: checkpoints/checkpoint_step_50000.pth\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 4177/5824 [36:40<35:22, 1.29s/it, loss=0.0152, lr=1e-5, step=5e+4]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Debug visualization saved: checkpoints/debug_viz/debug_step_050000.jpg\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 4211/5824 [36:58<14:09, 1.90it/s, loss=0.0351, lr=1e-5, step=5e+4] \n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_69/1468414648.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 504\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/tmp/ipykernel_69/1468414648.py\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[0;31m# Load pretrained models\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 468\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Loading pretrained models...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 469\u001b[0;31m \u001b[0mmodels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpreload_models_from_standard_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_model_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 470\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[0;31m# Create dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/kaggle/working/stable-diffusion/load_model.py\u001b[0m in \u001b[0;36mpreload_models_from_standard_weights\u001b[0;34m(ckpt_path, device, finetune_weights_path)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mstate_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel_converter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_from_standard_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mdiffusion\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDiffusion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0min_channels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout_channels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinetune_weights_path\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1338\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1339\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1340\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1342\u001b[0m def register_full_backward_pre_hook(\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 925\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 926\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 927\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 928\u001b[0m \u001b[0mp_should_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 929\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1325\u001b[0m )\n\u001b[0;32m-> 1326\u001b[0;31m return t.to(\n\u001b[0m\u001b[1;32m 1327\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1328\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 29.12 MiB is free. Process 3907 has 15.85 GiB memory in use. Of the allocated memory 15.49 GiB is allocated by PyTorch, and 62.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)" + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 520\u001b[0m\n\u001b[1;32m 517\u001b[0m trainer\u001b[38;5;241m.\u001b[39mtrain() \n\u001b[1;32m 519\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 520\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[8], line 517\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;66;03m# Start training\u001b[39;00m\n\u001b[1;32m 516\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStarting training...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 517\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[8], line 353\u001b[0m, in \u001b[0;36mCatVTONTrainer.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcurrent_epoch \u001b[38;5;241m=\u001b[39m epoch\n\u001b[1;32m 352\u001b[0m \u001b[38;5;66;03m# Train one epoch\u001b[39;00m\n\u001b[0;32m--> 353\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_epochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - Train Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 357\u001b[0m \u001b[38;5;66;03m# Save epoch checkpoint\u001b[39;00m\n", + "Cell \u001b[0;32mIn[8], line 292\u001b[0m, in \u001b[0;36mCatVTONTrainer.train_epoch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mixed_precision:\n\u001b[1;32m 291\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mamp\u001b[38;5;241m.\u001b[39mautocast():\n\u001b[0;32m--> 292\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;66;03m# Backward pass with scaling\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward()\n", + "Cell \u001b[0;32mIn[8], line 211\u001b[0m, in \u001b[0;36mCatVTONTrainer.compute_loss\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;66;03m# timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\u001b[39;00m\n\u001b[1;32m 208\u001b[0m \u001b[38;5;66;03m# timesteps = torch.tensor(timesteps, device=self.device)\u001b[39;00m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;66;03m# timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\u001b[39;00m\n\u001b[1;32m 210\u001b[0m timesteps \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1000\u001b[39m, size\u001b[38;5;241m=\u001b[39m(batch_size,))\n\u001b[0;32m--> 211\u001b[0m timesteps_embedding \u001b[38;5;241m=\u001b[39m \u001b[43mget_time_embedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_dtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;66;03m# Add noise to latents\u001b[39;00m\n\u001b[1;32m 214\u001b[0m noisy_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscheduler\u001b[38;5;241m.\u001b[39madd_noise(target_latents, timesteps, noise)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ - "import os\n", - "import json\n", "import random\n", "import argparse\n", "from pathlib import Path\n", - "from typing import Dict, List, Tuple, Optional\n", + "from typing import Dict, Optional\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "from torch.utils.data import Dataset, DataLoader\n", + "from torch.utils.data import DataLoader\n", "from torch.optim import AdamW\n", - "from torch.optim.lr_scheduler import CosineAnnealingLR\n", "\n", "import numpy as np\n", "from PIL import Image\n", @@ -475,7 +388,7 @@ "# Import your custom modules\n", "from load_model import preload_models_from_standard_weights\n", "from ddpm import DDPMSampler\n", - "from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image\n", + "from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image, compute_vae_encodings, save_debug_visualization\n", "from diffusers.utils.torch_utils import randn_tensor\n", "\n", "class CatVTONTrainer:\n", @@ -533,15 +446,18 @@ " # Resume from checkpoint if provided\n", " self.global_step = 0\n", " self.current_epoch = 0\n", + " \n", + " # Setup models and optimizers\n", + " self._setup_training()\n", + " \n", " if resume_from_checkpoint:\n", " self._load_checkpoint(resume_from_checkpoint)\n", + " \n", + " \n", " \n", " self.encoder = self.models.get('encoder', None)\n", " self.decoder = self.models.get('decoder', None)\n", " self.diffusion = self.models.get('diffusion', None)\n", - "\n", - " # Setup models and optimizers\n", - " self._setup_training()\n", " \n", " def _setup_training(self):\n", " \"\"\"Setup models for training with PEFT\"\"\"\n", @@ -595,7 +511,7 @@ " \"\"\"Enable PEFT training - only self-attention layers\"\"\"\n", " print(\"Enabling PEFT training (self-attention layers only)\")\n", " \n", - " unet = self.diffusion.unet\n", + " unet = self.models['diffusion'].unet\n", " \n", " # Enable attention layers in encoders and decoders\n", " for layers in [unet.encoders, unet.decoders]:\n", @@ -610,19 +526,14 @@ " for name, param in layer.named_parameters():\n", " if 'attention_1' in name:\n", " param.requires_grad = True\n", - " \n", - " def _apply_cfg_dropout(self, garment_latent: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"Apply classifier-free guidance dropout (10% chance)\"\"\"\n", - " if self.training and random.random() < self.cfg_dropout_prob:\n", - " # Replace with zero tensor for unconditional generation\n", - " return torch.zeros_like(garment_latent)\n", - " return garment_latent\n", " \n", " def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:\n", " \"\"\"Compute MSE loss for denoising with DREAM strategy\"\"\"\n", " person_images = batch['person'].to(self.device)\n", " cloth_images = batch['cloth'].to(self.device)\n", " masks = batch['mask'].to(self.device)\n", + " \n", + " batch_size = person_images.shape[0]\n", "\n", " concat_dim = -2 # y axis concat\n", " \n", @@ -642,67 +553,65 @@ " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n", " \n", + " \n", " del image, mask, condition_image\n", - "\n", - " # Apply CFG dropout to garment latent\n", - " condition_latent = self._apply_cfg_dropout(condition_latent)\n", + " \n", + " # Apply CFG dropout to garment latent (10% chance)\n", + " if self.training and random.random() < self.cfg_dropout_prob:\n", + " condition_latent = torch.zeros_like(condition_latent)\n", " \n", " # Concatenate latents\n", - " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", - " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", + " input_latents = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", + " mask_input = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", " target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim)\n", "\n", " noise = randn_tensor(\n", - " masked_latent_concat.shape,\n", + " target_latents.shape,\n", " generator=self.generator,\n", - " device=masked_latent_concat.device,\n", + " device=target_latents.device,\n", " dtype=self.weight_dtype,\n", " )\n", "\n", - " timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\n", - " timesteps = torch.tensor(timesteps, device=self.device)\n", + " # timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\n", + " # timesteps = torch.tensor(timesteps, device=self.device)\n", + " # timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n", + " timesteps = torch.randint(1, 1000, size=(batch_size,))\n", " timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n", "\n", " # Add noise to latents\n", " noisy_latents = self.scheduler.add_noise(target_latents, timesteps, noise)\n", "\n", - " inpainting_latent_model_input = torch.cat([ \n", - " masked_latent_concat,\n", - " mask_latent_concat,\n", - " noisy_latents\n", - " ], dim=1).to(self.device, dtype=self.weight_dtype)\n", + " # UNet(zt βŠ™ Mi βŠ™ Xi) where βŠ™ is channel concatenation\n", + " unet_input = torch.cat([\n", + " input_latents, # Xi\n", + " mask_input, # Mi\n", + " noisy_latents, # zt\n", + " ], dim=1).to(self.device, dtype=self.weight_dtype) # Channel dimension\n", + " \n", "\n", " # DREAM strategy implementation\n", " if self.dream_lambda > 0:\n", " # Get initial noise prediction\n", " with torch.no_grad():\n", " epsilon_theta = self.diffusion(\n", - " inpainting_latent_model_input,\n", + " unet_input,\n", " timesteps_embedding\n", " )\n", " \n", - " # Apply DREAM: zΛ†t = √αt*z0 + √(1-Ξ±t)*(Ξ΅ + Ξ»*Ρθ)\n", - " alphas_cumprod = self.scheduler.alphas_cumprod.to(device=self.device, dtype=self.weight_dtype)\n", - " sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n", - " sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n", - " \n", - " # Reshape for broadcasting\n", - " sqrt_alpha_prod = sqrt_alpha_prod.view(-1, 1, 1, 1)\n", - " sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.view(-1, 1, 1, 1)\n", - " \n", - " # DREAM noise combination\n", + " # DREAM noise combination: Ξ΅ + Ξ»*Ρθ\n", " dream_noise = noise + self.dream_lambda * epsilon_theta\n", + " \n", + " # Create new noisy latents with DREAM noise\n", + " dream_noisy_latents = self.scheduler.add_noise(target_latents, timesteps, dream_noise)\n", "\n", - " dream_noisy_latents = sqrt_alpha_prod * target_latents + sqrt_one_minus_alpha_prod * dream_noise\n", - "\n", - " dream_model_input = torch.cat([\n", - " dream_noisy_latents, \n", - " mask_latent_concat, \n", - " masked_latent_concat\n", - " ], dim=1)\n", + " dream_unet_input = torch.cat([\n", + " input_latents, \n", + " mask_input,\n", + " dream_noisy_latents\n", + " ], dim=1).to(self.device, dtype=self.weight_dtype)\n", "\n", " predicted_noise = self.diffusion(\n", - " dream_model_input,\n", + " dream_unet_input,\n", " timesteps_embedding\n", " )\n", " # DREAM loss: |(Ξ΅ + λΡθ) - Ρθ(αΊ‘t, t)|Β²\n", @@ -710,13 +619,27 @@ " else:\n", " # Standard training without DREAM\n", " predicted_noise = self.diffusion(\n", - " inpainting_latent_model_input,\n", + " unet_input,\n", " timesteps_embedding,\n", " )\n", - "\n", + " \n", " # Standard MSE loss\n", " loss = F.mse_loss(predicted_noise, noise)\n", " \n", + " if self.global_step % 1000 == 0:\n", + " save_debug_visualization(\n", + " person_images=person_images,\n", + " cloth_images=cloth_images, \n", + " masks=masks,\n", + " masked_image=masked_image,\n", + " noisy_latents=noisy_latents,\n", + " predicted_noise=predicted_noise,\n", + " target_latents=target_latents,\n", + " decoder=self.decoder,\n", + " global_step=self.global_step,\n", + " output_dir=self.output_dir,\n", + " device=self.device\n", + " )\n", " return loss\n", " \n", " def train_epoch(self) -> float:\n", @@ -883,13 +806,13 @@ " args.__dict__ = {\n", " \"base_model_path\": \"sd-v1-5-inpainting.ckpt\",\n", " \"dataset_name\": \"vitonhd\",\n", - " \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n", + " \"data_root_path\": \"./viton-hd-dataset\",\n", " \"output_dir\": \"./checkpoints\",\n", - " \"resume_from_checkpoint\": None,\n", + " \"resume_from_checkpoint\": \"./checkpoints/checkpoint_step_40000.pth\",\n", " \"seed\": 42,\n", - " \"batch_size\": 1,\n", + " \"batch_size\": 2,\n", " \"width\": 384,\n", - " \"height\": 512,\n", + " \"height\": 384,\n", " \"repaint\": True,\n", " \"eval_pair\": True,\n", " \"concat_eval_results\": True,\n", @@ -899,10 +822,10 @@ " \"learning_rate\": 1e-5,\n", " \"max_grad_norm\": 1.0,\n", " \"cfg_dropout_prob\": 0.1,\n", - " \"dream_lambda\": 0,\n", + " \"dream_lambda\": 10.0,\n", " \"use_peft\": True,\n", " \"use_mixed_precision\": True,\n", - " \"save_steps\": 1000,\n", + " \"save_steps\": 10000,\n", " \"is_train\": True\n", " }\n", " \n", @@ -918,10 +841,15 @@ " torch.backends.cuda.matmul.allow_tf32 = True \n", " torch.backends.cudnn.allow_tf32 = True \n", " torch.set_float32_matmul_precision(\"high\")\n", + " \n", + " print(\"-\"*100)\n", "\n", " # Load pretrained models\n", " print(\"Loading pretrained models...\")\n", " models = preload_models_from_standard_weights(args.base_model_path, args.device)\n", + " print(\"Models loaded successfully.\")\n", + " \n", + " print(\"-\"*100)\n", " \n", " # Create dataloader\n", " print(\"Creating dataloader...\")\n", @@ -930,6 +858,8 @@ " print(f\"Training for {args.num_epochs} epochs\")\n", " print(f\"Batches per epoch: {len(train_dataloader)}\")\n", " \n", + " print(\"-\"*100)\n", + " \n", " # Initialize trainer\n", " print(\"Initializing trainer...\") \n", " trainer = CatVTONTrainer(\n", @@ -954,31 +884,14 @@ " print(\"Starting training...\")\n", " trainer.train() \n", "\n", - "\n", "if __name__ == \"__main__\":\n", " main()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "77892d6a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b3917d76", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "harsh", "language": "python", "name": "python3" }, @@ -992,7 +905,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.10.18" } }, "nbformat": 4, diff --git a/training.py b/training.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ae334e9aed5565d5c35bc6e13424394ac4483d --- /dev/null +++ b/training.py @@ -0,0 +1,518 @@ +import torch +import os +import random +import argparse +from pathlib import Path +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.optim import AdamW + +import numpy as np +from PIL import Image +from tqdm import tqdm +from VITON_Dataset import VITONHDTestDataset + +# Import your custom modules +from load_model import preload_models_from_standard_weights +from ddpm import DDPMSampler +from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image, save_debug_visualization, compute_vae_encodings +from diffusers.utils.torch_utils import randn_tensor + + +class CatVTONTrainer: + """Simplified CatVTON Training Class with PEFT, CFG and DREAM support""" + + def __init__( + self, + models: Dict[str, nn.Module], + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader] = None, + device: str = "cuda", + learning_rate: float = 1e-5, + num_epochs: int = 50, + save_steps: int = 1000, + output_dir: str = "./checkpoints", + cfg_dropout_prob: float = 0.1, + max_grad_norm: float = 1.0, + use_peft: bool = True, + dream_lambda: float = 10.0, + resume_from_checkpoint: Optional[str] = None, + use_mixed_precision: bool = True, + height=512, + width=384, + ): + self.training = True + self.models = models + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.device = device + self.learning_rate = learning_rate + self.num_epochs = num_epochs + self.save_steps = save_steps + self.output_dir = Path(output_dir) + self.cfg_dropout_prob = cfg_dropout_prob + self.max_grad_norm = max_grad_norm + self.use_peft = use_peft + self.dream_lambda = dream_lambda + self.use_mixed_precision = use_mixed_precision + self.height = height + self.width = width + self.generator = torch.Generator(device=device) + + # Create output directory + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Setup mixed precision training + if self.use_mixed_precision: + self.scaler = torch.cuda.amp.GradScaler() + + self.weight_dtype = torch.float16 if use_mixed_precision else torch.float32 + + # Initialize scheduler and sampler + self.scheduler = DDPMSampler(self.generator, num_training_steps=1000) + + # Resume from checkpoint if provided + self.global_step = 0 + self.current_epoch = 0 + + # Setup models and optimizers + self._setup_training() + + if resume_from_checkpoint: + self._load_checkpoint(resume_from_checkpoint) + + + + self.encoder = self.models.get('encoder', None) + self.decoder = self.models.get('decoder', None) + self.diffusion = self.models.get('diffusion', None) + + def _setup_training(self): + """Setup models for training with PEFT""" + # Move models to device + for name, model in self.models.items(): + model.to(self.device) + + # Freeze all parameters first + for model in self.models.values(): + for param in model.parameters(): + param.requires_grad = False + + # Enable training for specific layers based on PEFT strategy + if self.use_peft: + self._enable_peft_training() + else: + # Enable full training for diffusion model + for param in self.diffusion.parameters(): + param.requires_grad = True + + # Collect trainable parameters + trainable_params = [] + total_params = 0 + trainable_count = 0 + + for name, model in self.models.items(): + for param_name, param in model.named_parameters(): + total_params += param.numel() + if param.requires_grad: + trainable_params.append(param) + trainable_count += param.numel() + + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_count:,} ({trainable_count/total_params*100:.2f}%)") + + # Setup optimizer - AdamW as per paper + self.optimizer = AdamW( + trainable_params, + lr=self.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-8 + ) + + # Setup learning rate scheduler (constant) + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: 1.0 + ) + + def _enable_peft_training(self): + """Enable PEFT training - only self-attention layers""" + print("Enabling PEFT training (self-attention layers only)") + + unet = self.models['diffusion'].unet + + # Enable attention layers in encoders and decoders + for layers in [unet.encoders, unet.decoders]: + for layer in layers: + for module_idx, module in enumerate(layer): + for name, param in module.named_parameters(): + if 'attention_1' in name: + param.requires_grad = True + + # Enable attention layers in bottleneck + for layer in unet.bottleneck: + for name, param in layer.named_parameters(): + if 'attention_1' in name: + param.requires_grad = True + + def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Compute MSE loss for denoising with DREAM strategy""" + person_images = batch['person'].to(self.device) + cloth_images = batch['cloth'].to(self.device) + masks = batch['mask'].to(self.device) + + batch_size = person_images.shape[0] + + concat_dim = -2 # y axis concat + + # Prepare inputs + image, condition_image, mask = check_inputs(person_images, cloth_images, masks, self.width, self.height) + image = prepare_image(person_images).to(self.device, dtype=self.weight_dtype) + condition_image = prepare_image(cloth_images).to(self.device, dtype=self.weight_dtype) + mask = prepare_mask_image(masks).to(self.device, dtype=self.weight_dtype) + + # Mask image + masked_image = image * (mask < 0.5) + + with torch.cuda.amp.autocast(enabled=self.use_mixed_precision): + # VAE encoding + masked_latent = compute_vae_encodings(masked_image, self.encoder) + person_latent = compute_vae_encodings(person_images, self.encoder) + condition_latent = compute_vae_encodings(condition_image, self.encoder) + mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest") + + + del image, mask, condition_image + + # Apply CFG dropout to garment latent (10% chance) + if self.training and random.random() < self.cfg_dropout_prob: + condition_latent = torch.zeros_like(condition_latent) + + # Concatenate latents + input_latents = torch.cat([masked_latent, condition_latent], dim=concat_dim) + mask_input = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim) + target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim) + + noise = randn_tensor( + target_latents.shape, + generator=self.generator, + device=target_latents.device, + dtype=self.weight_dtype, + ) + + # timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item() + # timesteps = torch.tensor(timesteps, device=self.device) + # timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype) + timesteps = torch.randint(1, 1000, size=(batch_size,)) + timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype) + + # Add noise to latents + noisy_latents = self.scheduler.add_noise(target_latents, timesteps, noise) + + # UNet(zt βŠ™ Mi βŠ™ Xi) where βŠ™ is channel concatenation + unet_input = torch.cat([ + input_latents, # Xi + mask_input, # Mi + noisy_latents, # zt + ], dim=1).to(self.device, dtype=self.weight_dtype) # Channel dimension + + + # DREAM strategy implementation + if self.dream_lambda > 0: + # Get initial noise prediction + with torch.no_grad(): + epsilon_theta = self.diffusion( + unet_input, + timesteps_embedding + ) + + # DREAM noise combination: Ξ΅ + Ξ»*Ρθ + dream_noise = noise + self.dream_lambda * epsilon_theta + + # Create new noisy latents with DREAM noise + dream_noisy_latents = self.scheduler.add_noise(target_latents, timesteps, dream_noise) + + dream_unet_input = torch.cat([ + input_latents, + mask_input, + dream_noisy_latents + ], dim=1).to(self.device, dtype=self.weight_dtype) + + predicted_noise = self.diffusion( + dream_unet_input, + timesteps_embedding + ) + # DREAM loss: |(Ξ΅ + λΡθ) - Ρθ(αΊ‘t, t)|Β² + loss = F.mse_loss(predicted_noise, dream_noise) + else: + # Standard training without DREAM + predicted_noise = self.diffusion( + unet_input, + timesteps_embedding, + ) + + # Standard MSE loss + loss = F.mse_loss(predicted_noise, noise) + + if self.global_step % 1000 == 0: + save_debug_visualization( + person_images=person_images, + cloth_images=cloth_images, + masks=masks, + masked_image=masked_image, + noisy_latents=noisy_latents, + predicted_noise=predicted_noise, + target_latents=target_latents, + decoder=self.decoder, + global_step=self.global_step, + output_dir=self.output_dir, + device=self.device + ) + return loss + + def train_epoch(self) -> float: + """Train for one epoch - simplified version""" + self.diffusion.train() + total_loss = 0.0 + num_batches = len(self.train_dataloader) + + # progress_bar = tqdm(self.train_dataloader, desc=f"Epoch {self.current_epoch+1}") + + for step, batch in enumerate(self.train_dataloader): + # Zero gradients + self.optimizer.zero_grad() + + # Forward pass with mixed precision + if self.use_mixed_precision: + with torch.cuda.amp.autocast(): + loss = self.compute_loss(batch) + + # Backward pass with scaling + self.scaler.scale(loss).backward() + + # Gradient clipping and optimizer step + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + [p for p in self.diffusion.parameters() if p.requires_grad], + self.max_grad_norm + ) + + self.scaler.step(self.optimizer) + self.scaler.update() + else: + loss = self.compute_loss(batch) + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_( + [p for p in self.diffusion.parameters() if p.requires_grad], + self.max_grad_norm + ) + + # Optimizer step + self.optimizer.step() + + # Update learning rate + self.lr_scheduler.step() + self.global_step += 1 + + total_loss += loss.item() + + # Update progress bar + # progress_bar.set_postfix({ + # 'loss': loss.item(), + # 'lr': self.optimizer.param_groups[0]['lr'], + # 'step': self.global_step + # }) + + # Save checkpoint based on steps + if self.global_step % self.save_steps == 0: + self._save_checkpoint() + + # Clear cache periodically to prevent OOM + if step % 50 == 0: + torch.cuda.empty_cache() + + return total_loss / num_batches + + def train(self): + """Main training loop - simplified version""" + print(f"Starting training for {self.num_epochs} epochs") + print(f"Total training batches per epoch: {len(self.train_dataloader)}") + print(f"Using DREAM with lambda = {self.dream_lambda}") + print(f"Mixed precision: {self.use_mixed_precision}") + + for epoch in range(self.current_epoch, self.num_epochs): + self.current_epoch = epoch + + # Train one epoch + train_loss = self.train_epoch() + + print(f"Epoch {epoch+1}/{self.num_epochs} - Train Loss: {train_loss:.6f}") + + # Save epoch checkpoint + if (epoch + 1) % 5 == 0: # Save every 5 epochs + self._save_checkpoint(epoch_checkpoint=True) + + # Clear cache at end of epoch + torch.cuda.empty_cache() + + # Save final checkpoint + self._save_checkpoint(is_final=True) + print("Training completed!") + + def _save_checkpoint(self, is_best: bool = False, epoch_checkpoint: bool = False, is_final: bool = False): + """Save model checkpoint""" + checkpoint = { + 'global_step': self.global_step, + 'current_epoch': self.current_epoch, + 'diffusion_state_dict': self.diffusion.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'lr_scheduler_state_dict': self.lr_scheduler.state_dict(), + 'dream_lambda': self.dream_lambda, + 'use_peft': self.use_peft, + } + + if self.use_mixed_precision: + checkpoint['scaler_state_dict'] = self.scaler.state_dict() + + if is_final: + checkpoint_path = self.output_dir / "final_model.pth" + elif is_best: + checkpoint_path = self.output_dir / "best_model.pth" + elif epoch_checkpoint: + checkpoint_path = self.output_dir / f"checkpoint_epoch_{self.current_epoch+1}.pth" + else: + checkpoint_path = self.output_dir / f"checkpoint_step_{self.global_step}.pth" + + torch.save(checkpoint, checkpoint_path) + print(f"Checkpoint saved: {checkpoint_path}") + + def _load_checkpoint(self, checkpoint_path: str): + """Load model checkpoint""" + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + self.global_step = checkpoint['global_step'] + self.current_epoch = checkpoint['current_epoch'] + self.dream_lambda = checkpoint.get('dream_lambda', 10.0) + + self.models['diffusion'].load_state_dict(checkpoint['diffusion_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) + + if self.use_mixed_precision and 'scaler_state_dict' in checkpoint: + self.scaler.load_state_dict(checkpoint['scaler_state_dict']) + + print(f"Checkpoint loaded: {checkpoint_path}") + print(f"Resuming from epoch {self.current_epoch}, step {self.global_step}") + + +def create_dataloaders(args) -> DataLoader: + """Create training dataloader""" + if args.dataset_name == "vitonhd": + dataset = VITONHDTestDataset(args) + else: + raise ValueError(f"Invalid dataset name {args.dataset_name}.") + + print(f"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.") + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=8, + pin_memory=True, + persistent_workers=True, + prefetch_factor=2 + ) + return dataloader + +def main(): + args = argparse.Namespace() + args.__dict__ = { + "base_model_path": "sd-v1-5-inpainting.ckpt", + "dataset_name": "vitonhd", + "data_root_path": "./viton-hd-dataset", + "output_dir": "./checkpoints", + "resume_from_checkpoint": "./checkpoints/checkpoint_step_50000.pth", + "seed": 42, + "batch_size": 2, + "width": 384, + "height": 384, + "repaint": True, + "eval_pair": True, + "concat_eval_results": True, + "concat_axis": 'y', + "device": "cuda", + "num_epochs": 50, + "learning_rate": 1e-5, + "max_grad_norm": 1.0, + "cfg_dropout_prob": 0.1, + "dream_lambda": 10.0, + "use_peft": True, + "use_mixed_precision": True, + "save_steps": 10000, + "is_train": True + } + + # Set random seeds + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + # Optimize CUDA settings + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + + print("-"*100) + + # Load pretrained models + print("Loading pretrained models...") + models = preload_models_from_standard_weights(args.base_model_path, args.device) + print("Models loaded successfully.") + + print("-"*100) + + # Create dataloader + print("Creating dataloader...") + train_dataloader = create_dataloaders(args) + + print(f"Training for {args.num_epochs} epochs") + print(f"Batches per epoch: {len(train_dataloader)}") + + print("-"*100) + + # Initialize trainer + print("Initializing trainer...") + trainer = CatVTONTrainer( + models=models, + train_dataloader=train_dataloader, + device=args.device, + learning_rate=args.learning_rate, + num_epochs=args.num_epochs, + save_steps=args.save_steps, + output_dir=args.output_dir, + cfg_dropout_prob=args.cfg_dropout_prob, + max_grad_norm=args.max_grad_norm, + use_peft=args.use_peft, + dream_lambda=args.dream_lambda, + resume_from_checkpoint=args.resume_from_checkpoint, + use_mixed_precision=args.use_mixed_precision, + height=args.height, + width=args.width + ) + + # Start training + print("Starting training...") + trainer.train() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils.py b/utils.py index 74ef89e3f6e524ab6538fb84a6454529accd6772..72f397a47368b5313df8f66019c30020a600573b 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,6 @@ import os - +import torchvision.transforms as transforms +from PIL import Image import math import PIL import numpy as np @@ -13,15 +14,18 @@ from typing import List, Optional, Tuple, Set from tqdm import tqdm from PIL import Image, ImageFilter -def get_time_embedding(timestep): +def get_time_embedding(timesteps): + # Handle both scalar and batch inputs + if timesteps.dim() == 0: + timesteps = timesteps.unsqueeze(0) + # Shape: (160,) freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) - # Shape: (1, 160) - x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None] - # Shape: (1, 160 * 2) -> (1, 320) + # Shape: (B, 160) + x = timesteps.float()[:, None] * freqs[None] + # Shape: (B, 320) return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) - def repaint(person, mask, result): _, h = result.size kernal_size = h // 50 @@ -48,63 +52,6 @@ def to_pil_image(images): pil_images = [Image.fromarray(image) for image in images] return pil_images - -# Compute DREAM and update latents for diffusion sampling -# def compute_dream_and_update_latents_for_inpaint( -# unet: UNet2DConditionModel, -# noise_scheduler: SchedulerMixin, -# timesteps: torch.Tensor, -# noise: torch.Tensor, -# noisy_latents: torch.Tensor, -# target: torch.Tensor, -# encoder_hidden_states: torch.Tensor, -# dream_detail_preservation: float = 1.0, -# ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: -# """ -# Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. -# DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra -# forward step without gradients. - -# Args: -# `unet`: The state unet to use to make a prediction. -# `noise_scheduler`: The noise scheduler used to add noise for the given timestep. -# `timesteps`: The timesteps for the noise_scheduler to user. -# `noise`: A tensor of noise in the shape of noisy_latents. -# `noisy_latents`: Previously noise latents from the training loop. -# `target`: The ground-truth tensor to predict after eps is removed. -# `encoder_hidden_states`: Text embeddings from the text model. -# `dream_detail_preservation`: A float value that indicates detail preservation level. -# See reference. - -# Returns: -# `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. -# """ -# alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] -# sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - -# # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. -# dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation - -# pred = None # b, 4, h, w -# with torch.no_grad(): -# pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - -# noisy_latents_no_condition = noisy_latents[:, :4] -# _noisy_latents, _target = (None, None) -# if noise_scheduler.config.prediction_type == "epsilon": -# predicted_noise = pred -# delta_noise = (noise - predicted_noise).detach() -# delta_noise.mul_(dream_lambda) -# _noisy_latents = noisy_latents_no_condition.add(sqrt_one_minus_alphas_cumprod * delta_noise) -# _target = target.add(delta_noise) -# elif noise_scheduler.config.prediction_type == "v_prediction": -# raise NotImplementedError("DREAM has not been implemented for v-prediction") -# else: -# raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - -# _noisy_latents = torch.cat([_noisy_latents, noisy_latents[:, 4:]], dim=1) -# return _noisy_latents, _target - # Prepare the input for inpainting model. def prepare_inpainting_input( noisy_latents: torch.Tensor, @@ -132,76 +79,18 @@ def prepare_inpainting_input( return noisy_latents # Compute VAE encodings -def compute_vae_encodings(image: torch.Tensor, vae: torch.nn.Module) -> torch.Tensor: - """ - Args: - images (torch.Tensor): image to be encoded - vae (torch.nn.Module): vae model - - Returns: - torch.Tensor: latent encoding of the image - """ - pixel_values = image.to(memory_format=torch.contiguous_format).float() - pixel_values = pixel_values.to(vae.device, dtype=vae.dtype) - with torch.no_grad(): - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor - return model_input - - -# Init Accelerator -from accelerate import Accelerator, DistributedDataParallelKwargs -from accelerate.utils import ProjectConfiguration - -def init_accelerator(config): - accelerator_project_config = ProjectConfiguration( - project_dir=config.project_name, - logging_dir=os.path.join(config.project_name, "logs"), +def compute_vae_encodings(image_tensor, encoder, device="cuda"): + """Encode image using VAE encoder""" + # Generate random noise for encoding + encoder_noise = torch.randn( + (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8), + device=device, ) - accelerator_ddp_config = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator( - mixed_precision=config.mixed_precision, - log_with=config.report_to, - project_config=accelerator_project_config, - kwargs_handlers=[accelerator_ddp_config], - gradient_accumulation_steps=config.gradient_accumulation_steps, - ) - # Disable AMP for MPS. - if torch.backends.mps.is_available(): - accelerator.native_amp = False - - if accelerator.is_main_process: - accelerator.init_trackers( - project_name=config.project_name, - config={ - "learning_rate": config.learning_rate, - "train_batch_size": config.train_batch_size, - "image_size": f"{config.width}x{config.height}", - }, - ) - - return accelerator - - -def init_weight_dtype(wight_dtype): - return { - "no": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, - }[wight_dtype] - + + # Encode using your custom encoder + latent = encoder(image_tensor, encoder_noise) + return latent -def init_add_item_id(config): - return torch.tensor( - [ - config.height, - config.width * 2, - 0, - 0, - config.height, - config.width * 2, - ] - ).repeat(config.train_batch_size, 1) def check_inputs(image, condition_image, mask, width, height): if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor) and isinstance(mask, torch.Tensor): @@ -312,86 +201,6 @@ def tensor_to_image(tensor: torch.Tensor): image = Image.fromarray(tensor) return image - -def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4): - """ - Concatenates images horizontally and with - """ - widths = [image.size[0] for image in images] - heights = [image.size[1] for image in images] - total_width = cols * max(widths) - total_width += divider * (cols - 1) - # `col` images each row - rows = math.ceil(len(images) / cols) - total_height = max(heights) * rows - # add divider between rows - total_height += divider * (len(heights) // cols - 1) - - # all black image - concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0)) - - x_offset = 0 - y_offset = 0 - for i, image in enumerate(images): - concat_image.paste(image, (x_offset, y_offset)) - x_offset += image.size[0] + divider - if (i + 1) % cols == 0: - x_offset = 0 - y_offset += image.size[1] + divider - - return concat_image - - -def read_prompt_file(prompt_file: str): - if prompt_file is not None and os.path.isfile(prompt_file): - with open(prompt_file, "r") as sample_prompt_file: - sample_prompts = sample_prompt_file.readlines() - sample_prompts = [sample_prompt.strip() for sample_prompt in sample_prompts] - else: - sample_prompts = [] - return sample_prompts - - -def save_tensors_to_npz(tensors: torch.Tensor, paths: List[str]): - assert len(tensors) == len(paths), "Length of tensors and paths should be the same!" - for tensor, path in zip(tensors, paths): - np.savez_compressed(path, latent=tensor.cpu().numpy()) - - -def deepspeed_zero_init_disabled_context_manager(): - """ - returns either a context list that includes one that will disable zero.Init or an empty context list - """ - deepspeed_plugin = ( - AcceleratorState().deepspeed_plugin - if accelerate.state.is_initialized() - else None - ) - if deepspeed_plugin is None: - return [] - - return [deepspeed_plugin.zero3_init_context_manager(enable=False)] - - -def is_xformers_available(): - try: - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - print( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " - "please update xFormers to at least 0.0.17. " - "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - return True - except ImportError: - raise ValueError( - "xformers is not available. Make sure it is installed correctly" - ) - - - def resize_and_crop(image, size): # Crop to size ratio w, h = image.size @@ -426,19 +235,124 @@ def resize_and_padding(image, size): padding.paste(image, ((target_w - new_w) // 2, (target_h - new_h) // 2)) return padding - -def scan_files_in_dir(directory, postfix: Set[str] = None, progress_bar: tqdm = None) -> list: - file_list = [] - progress_bar = tqdm(total=0, desc=f"Scanning", ncols=100) if progress_bar is None else progress_bar - for entry in os.scandir(directory): - if entry.is_file(): - if postfix is None or os.path.splitext(entry.path)[1] in postfix: - file_list.append(entry) - progress_bar.total += 1 - progress_bar.update(1) - elif entry.is_dir(): - file_list += scan_files_in_dir(entry.path, postfix=postfix, progress_bar=progress_bar) - return file_list - -if __name__ == "__main__": - ... \ No newline at end of file +def save_debug_visualization( + person_images, cloth_images, masks, masked_image, + noisy_latents, predicted_noise, target_latents, + decoder, global_step, output_dir, device="cuda" +): + """ + Simple debug visualization function to save training progress images. + + Args: + person_images: Original person images [B, 3, H, W] + cloth_images: Cloth/garment images [B, 3, H, W] + masks: Mask images [B, 1, H, W] + masked_image: Person image with mask applied [B, 3, H, W] + noisy_latents: Noisy latents fed to model [B, C, h, w] + predicted_noise: Model's predicted noise [B, C, h, w] + target_latents: Ground truth latents [B, C, h, w] + decoder: VAE decoder model + global_step: Current training step + output_dir: Directory to save images + device: Device to use + """ + + try: + with torch.no_grad(): + # Take first sample from batch + person_img = person_images[0:1] # [1, 3, H, W] + cloth_img = cloth_images[0:1] + mask_img = masks[0:1] + masked_img = masked_image[0:1] + + # Split concatenated latents if needed (assuming concat on height dim) + if target_latents.shape[-2] > noisy_latents.shape[-2] // 2: + # Latents are concatenated, split them + h = target_latents.shape[-2] // 2 + noisy_person_latent = noisy_latents[0:1, :, :h, :] + predicted_person_latent = (noisy_person_latent - predicted_noise[0:1, :, :h, :]) + target_person_latent = target_latents[0:1, :, :h, :] + else: + noisy_person_latent = noisy_latents[0:1] + predicted_person_latent = (noisy_person_latent - predicted_noise[0:1]) + target_person_latent = target_latents[0:1] + + + # Decode latents to images + with torch.cuda.amp.autocast(enabled=False): + noisy_decoded = decoder(noisy_person_latent.float()) + predicted_decoded = decoder(predicted_person_latent.float()) + target_decoded = decoder(target_person_latent.float()) + + # Convert to PIL images + def tensor_to_pil(tensor): + # tensor: [1, 3, H, W] in range [-1, 1] or [0, 1] + tensor = tensor.squeeze(0) # [3, H, W] + tensor = torch.clamp((tensor + 1.0) / 2.0, 0, 1) # Normalize to [0,1] + tensor = tensor.cpu() + transform = transforms.ToPILImage() + return transform(tensor) + + # Convert mask to PIL (single channel) + def mask_to_pil(tensor): + tensor = tensor.squeeze() # Remove batch and channel dims + tensor = torch.clamp(tensor, 0, 1) + tensor = tensor.cpu() + # Convert to 3-channel for visualization + tensor_3ch = tensor.unsqueeze(0).repeat(3, 1, 1) + transform = transforms.ToPILImage() + return transform(tensor_3ch) + + # Convert all tensors to PIL images + person_pil = tensor_to_pil(person_img) + cloth_pil = tensor_to_pil(cloth_img) + mask_pil = mask_to_pil(mask_img) + masked_pil = tensor_to_pil(masked_img) + noisy_pil = tensor_to_pil(noisy_decoded) + predicted_pil = tensor_to_pil(predicted_decoded) + target_pil = tensor_to_pil(target_decoded) + + # Create labels + labels = ['Person', 'Cloth', 'Mask', 'Masked', 'Noisy', 'Predicted', 'Target'] + images = [person_pil, cloth_pil, mask_pil, masked_pil, noisy_pil, predicted_pil, target_pil] + + # Get dimensions + width, height = person_pil.size + + # Create combined image (horizontal layout) + combined_width = width * len(images) + combined_height = height + 30 # Extra space for labels + + combined_img = Image.new('RGB', (combined_width, combined_height), 'white') + + # Paste images side by side with labels + from PIL import ImageDraw, ImageFont + draw = ImageDraw.Draw(combined_img) + + try: + # Try to use a default font + font = ImageFont.load_default() + except: + font = None + + for i, (img, label) in enumerate(zip(images, labels)): + x_offset = i * width + combined_img.paste(img, (x_offset, 30)) + + # Add label + if font: + draw.text((x_offset + 5, 5), label, fill='black', font=font) + else: + draw.text((x_offset + 5, 5), label, fill='black') + + # Save the combined image + debug_dir = os.path.join(output_dir, 'debug_viz') + os.makedirs(debug_dir, exist_ok=True) + + save_path = os.path.join(debug_dir, f'debug_step_{global_step:06d}.jpg') + combined_img.save(save_path, 'JPEG', quality=95) + + print(f"Debug visualization saved: {save_path}") + + except Exception as e: + print(f"Error in debug visualization: {e}")