{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "81e4a1db", "metadata": {}, "outputs": [], "source": [ "# !git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git" ] }, { "cell_type": "code", "execution_count": 2, "id": "9c89e320", "metadata": {}, "outputs": [], "source": [ "# cd stable-diffusion/" ] }, { "cell_type": "code", "execution_count": 3, "id": "ff8b706c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model already downloaded.\n" ] } ], "source": [ "# check if the model is downloaded, if not download it\n", "import os\n", "if not os.path.exists(\"sd-v1-5-inpainting.ckpt\"):\n", " !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", "else:\n", " print(\"Model already downloaded.\")\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "53095103", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoints directory already exists.\n" ] } ], "source": [ "# make output and checkpoints directories if they don't exist\n", "import os\n", "if not os.path.exists(\"checkpoints\"):\n", " os.makedirs(\"checkpoints\")\n", "else:\n", " print(\"Checkpoints directory already exists.\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "d8978b25", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "VITON-HD dataset already exists.\n", "Zip file does not exist, nothing to remove.\n" ] } ], "source": [ "import os\n", "if not os.path.exists(\"viton-hd-dataset\"):\n", " !curl -L -u harshkesherwani:7695128b407febc869a6f5b2cb0cbf26\\\n", " -o /home/mahesh/harsh/stable-diffusion/viton-hd-dataset.zip\\\n", " https://www.kaggle.com/api/v1/datasets/download/harshkesherwani/viton-hd-dataset\n", " \n", " import zipfile\n", " with zipfile.ZipFile('viton-hd-dataset.zip', 'r') as zip_ref:\n", " zip_ref.extractall('viton-hd-dataset')\n", " \n", " print(\"VITON-HD dataset downloaded and extracted.\")\n", "else:\n", " print(\"VITON-HD dataset already exists.\")\n", " \n", "import os\n", "if os.path.exists(\"viton-hd-dataset.zip\"):\n", " os.remove(\"viton-hd-dataset.zip\")\n", " print(\"Removed the zip file after extraction.\")\n", "else:\n", " print(\"Zip file does not exist, nothing to remove.\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3aea80d9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------------------------------------------\n", "Loading pretrained models...\n", "Models loaded successfully.\n", "----------------------------------------------------------------------------------------------------\n", "Creating dataloader...\n", "Dataset vitonhd loaded, total 11647 pairs.\n", "Training for 50 epochs\n", "Batches per epoch: 5824\n", "----------------------------------------------------------------------------------------------------\n", "Initializing trainer...\n", "Enabling PEFT training (self-attention layers only)\n", "Total parameters: 899,226,667\n", "Trainable parameters: 49,574,080 (5.51%)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_1669505/646906096.py:71: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", " self.scaler = torch.cuda.amp.GradScaler()\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint loaded: ./checkpoints/checkpoint_step_40000.pth\n", "Resuming from epoch 12, step 40000\n", "Starting training...\n", "Starting training for 50 epochs\n", "Total training batches per epoch: 5824\n", "Using DREAM with lambda = 0\n", "Mixed precision: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 13: 0%| | 0/5824 [00:00 520\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[8], line 517\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;66;03m# Start training\u001b[39;00m\n\u001b[1;32m 516\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStarting training...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 517\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[8], line 353\u001b[0m, in \u001b[0;36mCatVTONTrainer.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcurrent_epoch \u001b[38;5;241m=\u001b[39m epoch\n\u001b[1;32m 352\u001b[0m \u001b[38;5;66;03m# Train one epoch\u001b[39;00m\n\u001b[0;32m--> 353\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_epochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - Train Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 357\u001b[0m \u001b[38;5;66;03m# Save epoch checkpoint\u001b[39;00m\n", "Cell \u001b[0;32mIn[8], line 292\u001b[0m, in \u001b[0;36mCatVTONTrainer.train_epoch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mixed_precision:\n\u001b[1;32m 291\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mamp\u001b[38;5;241m.\u001b[39mautocast():\n\u001b[0;32m--> 292\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;66;03m# Backward pass with scaling\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward()\n", "Cell \u001b[0;32mIn[8], line 211\u001b[0m, in \u001b[0;36mCatVTONTrainer.compute_loss\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;66;03m# timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\u001b[39;00m\n\u001b[1;32m 208\u001b[0m \u001b[38;5;66;03m# timesteps = torch.tensor(timesteps, device=self.device)\u001b[39;00m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;66;03m# timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\u001b[39;00m\n\u001b[1;32m 210\u001b[0m timesteps \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1000\u001b[39m, size\u001b[38;5;241m=\u001b[39m(batch_size,))\n\u001b[0;32m--> 211\u001b[0m timesteps_embedding \u001b[38;5;241m=\u001b[39m \u001b[43mget_time_embedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_dtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;66;03m# Add noise to latents\u001b[39;00m\n\u001b[1;32m 214\u001b[0m noisy_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscheduler\u001b[38;5;241m.\u001b[39madd_noise(target_latents, timesteps, noise)\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "import random\n", "import argparse\n", "from pathlib import Path\n", "from typing import Dict, Optional\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torch.optim import AdamW\n", "\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm import tqdm\n", "from VITON_Dataset import VITONHDTestDataset\n", "\n", "# Import your custom modules\n", "from load_model import preload_models_from_standard_weights\n", "from ddpm import DDPMSampler\n", "from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image, compute_vae_encodings, save_debug_visualization\n", "from diffusers.utils.torch_utils import randn_tensor\n", "\n", "class CatVTONTrainer:\n", " \"\"\"Simplified CatVTON Training Class with PEFT, CFG and DREAM support\"\"\"\n", " \n", " def __init__(\n", " self,\n", " models: Dict[str, nn.Module],\n", " train_dataloader: DataLoader,\n", " val_dataloader: Optional[DataLoader] = None,\n", " device: str = \"cuda\",\n", " learning_rate: float = 1e-5,\n", " num_epochs: int = 50,\n", " save_steps: int = 1000,\n", " output_dir: str = \"./checkpoints\",\n", " cfg_dropout_prob: float = 0.1,\n", " max_grad_norm: float = 1.0,\n", " use_peft: bool = True,\n", " dream_lambda: float = 10.0,\n", " resume_from_checkpoint: Optional[str] = None,\n", " use_mixed_precision: bool = True,\n", " height=512,\n", " width=384,\n", " ):\n", " self.training = True\n", " self.models = models\n", " self.train_dataloader = train_dataloader\n", " self.val_dataloader = val_dataloader\n", " self.device = device\n", " self.learning_rate = learning_rate\n", " self.num_epochs = num_epochs\n", " self.save_steps = save_steps\n", " self.output_dir = Path(output_dir)\n", " self.cfg_dropout_prob = cfg_dropout_prob\n", " self.max_grad_norm = max_grad_norm\n", " self.use_peft = use_peft\n", " self.dream_lambda = dream_lambda\n", " self.use_mixed_precision = use_mixed_precision\n", " self.height = height\n", " self.width = width\n", " self.generator = torch.Generator(device=device)\n", " \n", " # Create output directory\n", " self.output_dir.mkdir(parents=True, exist_ok=True)\n", "\n", " # Setup mixed precision training\n", " if self.use_mixed_precision:\n", " self.scaler = torch.cuda.amp.GradScaler()\n", "\n", " self.weight_dtype = torch.float16 if use_mixed_precision else torch.float32\n", " \n", " # Initialize scheduler and sampler\n", " self.scheduler = DDPMSampler(self.generator, num_training_steps=1000)\n", "\n", " # Resume from checkpoint if provided\n", " self.global_step = 0\n", " self.current_epoch = 0\n", " \n", " # Setup models and optimizers\n", " self._setup_training()\n", " \n", " if resume_from_checkpoint:\n", " self._load_checkpoint(resume_from_checkpoint)\n", " \n", " \n", " \n", " self.encoder = self.models.get('encoder', None)\n", " self.decoder = self.models.get('decoder', None)\n", " self.diffusion = self.models.get('diffusion', None)\n", " \n", " def _setup_training(self):\n", " \"\"\"Setup models for training with PEFT\"\"\"\n", " # Move models to device\n", " for name, model in self.models.items():\n", " model.to(self.device)\n", " \n", " # Freeze all parameters first\n", " for model in self.models.values():\n", " for param in model.parameters():\n", " param.requires_grad = False\n", " \n", " # Enable training for specific layers based on PEFT strategy\n", " if self.use_peft:\n", " self._enable_peft_training()\n", " else:\n", " # Enable full training for diffusion model\n", " for param in self.diffusion.parameters():\n", " param.requires_grad = True\n", " \n", " # Collect trainable parameters\n", " trainable_params = []\n", " total_params = 0\n", " trainable_count = 0\n", " \n", " for name, model in self.models.items():\n", " for param_name, param in model.named_parameters():\n", " total_params += param.numel()\n", " if param.requires_grad:\n", " trainable_params.append(param)\n", " trainable_count += param.numel()\n", "\n", " print(f\"Total parameters: {total_params:,}\")\n", " print(f\"Trainable parameters: {trainable_count:,} ({trainable_count/total_params*100:.2f}%)\")\n", " \n", " # Setup optimizer - AdamW as per paper\n", " self.optimizer = AdamW(\n", " trainable_params,\n", " lr=self.learning_rate,\n", " betas=(0.9, 0.999),\n", " weight_decay=1e-2,\n", " eps=1e-8\n", " )\n", " \n", " # Setup learning rate scheduler (constant)\n", " self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n", " self.optimizer, lr_lambda=lambda epoch: 1.0\n", " )\n", " \n", " def _enable_peft_training(self):\n", " \"\"\"Enable PEFT training - only self-attention layers\"\"\"\n", " print(\"Enabling PEFT training (self-attention layers only)\")\n", " \n", " unet = self.models['diffusion'].unet\n", " \n", " # Enable attention layers in encoders and decoders\n", " for layers in [unet.encoders, unet.decoders]:\n", " for layer in layers:\n", " for module_idx, module in enumerate(layer):\n", " for name, param in module.named_parameters():\n", " if 'attention_1' in name:\n", " param.requires_grad = True\n", " \n", " # Enable attention layers in bottleneck\n", " for layer in unet.bottleneck:\n", " for name, param in layer.named_parameters():\n", " if 'attention_1' in name:\n", " param.requires_grad = True\n", " \n", " def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:\n", " \"\"\"Compute MSE loss for denoising with DREAM strategy\"\"\"\n", " person_images = batch['person'].to(self.device)\n", " cloth_images = batch['cloth'].to(self.device)\n", " masks = batch['mask'].to(self.device)\n", " \n", " batch_size = person_images.shape[0]\n", "\n", " concat_dim = -2 # y axis concat\n", " \n", " # Prepare inputs\n", " image, condition_image, mask = check_inputs(person_images, cloth_images, masks, self.width, self.height)\n", " image = prepare_image(person_images).to(self.device, dtype=self.weight_dtype)\n", " condition_image = prepare_image(cloth_images).to(self.device, dtype=self.weight_dtype)\n", " mask = prepare_mask_image(masks).to(self.device, dtype=self.weight_dtype)\n", " \n", " # Mask image\n", " masked_image = image * (mask < 0.5)\n", "\n", " with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):\n", " # VAE encoding\n", " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n", " person_latent = compute_vae_encodings(person_images, self.encoder)\n", " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n", " \n", " \n", " del image, mask, condition_image\n", " \n", " # Apply CFG dropout to garment latent (10% chance)\n", " if self.training and random.random() < self.cfg_dropout_prob:\n", " condition_latent = torch.zeros_like(condition_latent)\n", " \n", " # Concatenate latents\n", " input_latents = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", " mask_input = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", " target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim)\n", "\n", " noise = randn_tensor(\n", " target_latents.shape,\n", " generator=self.generator,\n", " device=target_latents.device,\n", " dtype=self.weight_dtype,\n", " )\n", "\n", " # timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\n", " # timesteps = torch.tensor(timesteps, device=self.device)\n", " # timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n", " timesteps = torch.randint(1, 1000, size=(batch_size,))\n", " timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n", "\n", " # Add noise to latents\n", " noisy_latents = self.scheduler.add_noise(target_latents, timesteps, noise)\n", "\n", " # UNet(zt ⊙ Mi ⊙ Xi) where ⊙ is channel concatenation\n", " unet_input = torch.cat([\n", " input_latents, # Xi\n", " mask_input, # Mi\n", " noisy_latents, # zt\n", " ], dim=1).to(self.device, dtype=self.weight_dtype) # Channel dimension\n", " \n", "\n", " # DREAM strategy implementation\n", " if self.dream_lambda > 0:\n", " # Get initial noise prediction\n", " with torch.no_grad():\n", " epsilon_theta = self.diffusion(\n", " unet_input,\n", " timesteps_embedding\n", " )\n", " \n", " # DREAM noise combination: ε + λ*εθ\n", " dream_noise = noise + self.dream_lambda * epsilon_theta\n", " \n", " # Create new noisy latents with DREAM noise\n", " dream_noisy_latents = self.scheduler.add_noise(target_latents, timesteps, dream_noise)\n", "\n", " dream_unet_input = torch.cat([\n", " input_latents, \n", " mask_input,\n", " dream_noisy_latents\n", " ], dim=1).to(self.device, dtype=self.weight_dtype)\n", "\n", " predicted_noise = self.diffusion(\n", " dream_unet_input,\n", " timesteps_embedding\n", " )\n", " # DREAM loss: |(ε + λεθ) - εθ(ẑt, t)|²\n", " loss = F.mse_loss(predicted_noise, dream_noise)\n", " else:\n", " # Standard training without DREAM\n", " predicted_noise = self.diffusion(\n", " unet_input,\n", " timesteps_embedding,\n", " )\n", " \n", " # Standard MSE loss\n", " loss = F.mse_loss(predicted_noise, noise)\n", " \n", " if self.global_step % 1000 == 0:\n", " save_debug_visualization(\n", " person_images=person_images,\n", " cloth_images=cloth_images, \n", " masks=masks,\n", " masked_image=masked_image,\n", " noisy_latents=noisy_latents,\n", " predicted_noise=predicted_noise,\n", " target_latents=target_latents,\n", " decoder=self.decoder,\n", " global_step=self.global_step,\n", " output_dir=self.output_dir,\n", " device=self.device\n", " )\n", " return loss\n", " \n", " def train_epoch(self) -> float:\n", " \"\"\"Train for one epoch - simplified version\"\"\"\n", " self.diffusion.train()\n", " total_loss = 0.0\n", " num_batches = len(self.train_dataloader)\n", " \n", " progress_bar = tqdm(self.train_dataloader, desc=f\"Epoch {self.current_epoch+1}\")\n", " \n", " for step, batch in enumerate(progress_bar):\n", " # Zero gradients\n", " self.optimizer.zero_grad()\n", " \n", " # Forward pass with mixed precision\n", " if self.use_mixed_precision:\n", " with torch.cuda.amp.autocast():\n", " loss = self.compute_loss(batch)\n", " \n", " # Backward pass with scaling\n", " self.scaler.scale(loss).backward()\n", " \n", " # Gradient clipping and optimizer step\n", " self.scaler.unscale_(self.optimizer)\n", " torch.nn.utils.clip_grad_norm_(\n", " [p for p in self.diffusion.parameters() if p.requires_grad],\n", " self.max_grad_norm\n", " )\n", " \n", " self.scaler.step(self.optimizer)\n", " self.scaler.update()\n", " else:\n", " loss = self.compute_loss(batch)\n", " loss.backward()\n", " \n", " # Gradient clipping\n", " torch.nn.utils.clip_grad_norm_(\n", " [p for p in self.diffusion.parameters() if p.requires_grad],\n", " self.max_grad_norm\n", " )\n", " \n", " # Optimizer step\n", " self.optimizer.step()\n", " \n", " # Update learning rate\n", " self.lr_scheduler.step()\n", " self.global_step += 1\n", " \n", " total_loss += loss.item()\n", " \n", " # Update progress bar\n", " progress_bar.set_postfix({\n", " 'loss': loss.item(),\n", " 'lr': self.optimizer.param_groups[0]['lr'],\n", " 'step': self.global_step\n", " })\n", " \n", " # Save checkpoint based on steps\n", " if self.global_step % self.save_steps == 0:\n", " self._save_checkpoint()\n", " \n", " # Clear cache periodically to prevent OOM\n", " if step % 50 == 0:\n", " torch.cuda.empty_cache()\n", " \n", " return total_loss / num_batches\n", " \n", " def train(self):\n", " \"\"\"Main training loop - simplified version\"\"\"\n", " print(f\"Starting training for {self.num_epochs} epochs\")\n", " print(f\"Total training batches per epoch: {len(self.train_dataloader)}\")\n", " print(f\"Using DREAM with lambda = {self.dream_lambda}\")\n", " print(f\"Mixed precision: {self.use_mixed_precision}\")\n", " \n", " for epoch in range(self.current_epoch, self.num_epochs):\n", " self.current_epoch = epoch\n", " \n", " # Train one epoch\n", " train_loss = self.train_epoch()\n", " \n", " print(f\"Epoch {epoch+1}/{self.num_epochs} - Train Loss: {train_loss:.6f}\")\n", " \n", " # Save epoch checkpoint\n", " if (epoch + 1) % 5 == 0: # Save every 5 epochs\n", " self._save_checkpoint(epoch_checkpoint=True)\n", " \n", " # Clear cache at end of epoch\n", " torch.cuda.empty_cache()\n", " \n", " # Save final checkpoint\n", " self._save_checkpoint(is_final=True)\n", " print(\"Training completed!\")\n", " \n", " def _save_checkpoint(self, is_best: bool = False, epoch_checkpoint: bool = False, is_final: bool = False):\n", " \"\"\"Save model checkpoint\"\"\"\n", " checkpoint = {\n", " 'global_step': self.global_step,\n", " 'current_epoch': self.current_epoch,\n", " 'diffusion_state_dict': self.diffusion.state_dict(),\n", " 'optimizer_state_dict': self.optimizer.state_dict(),\n", " 'lr_scheduler_state_dict': self.lr_scheduler.state_dict(),\n", " 'dream_lambda': self.dream_lambda,\n", " 'use_peft': self.use_peft,\n", " }\n", " \n", " if self.use_mixed_precision:\n", " checkpoint['scaler_state_dict'] = self.scaler.state_dict()\n", " \n", " if is_final:\n", " checkpoint_path = self.output_dir / \"final_model.pth\"\n", " elif is_best:\n", " checkpoint_path = self.output_dir / \"best_model.pth\"\n", " elif epoch_checkpoint:\n", " checkpoint_path = self.output_dir / f\"checkpoint_epoch_{self.current_epoch+1}.pth\"\n", " else:\n", " checkpoint_path = self.output_dir / f\"checkpoint_step_{self.global_step}.pth\"\n", " \n", " torch.save(checkpoint, checkpoint_path)\n", " print(f\"Checkpoint saved: {checkpoint_path}\")\n", " \n", " def _load_checkpoint(self, checkpoint_path: str):\n", " \"\"\"Load model checkpoint\"\"\"\n", " checkpoint = torch.load(checkpoint_path, map_location=self.device)\n", " \n", " self.global_step = checkpoint['global_step']\n", " self.current_epoch = checkpoint['current_epoch']\n", " self.dream_lambda = checkpoint.get('dream_lambda', 10.0)\n", " \n", " self.models['diffusion'].load_state_dict(checkpoint['diffusion_state_dict'])\n", " self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])\n", " \n", " if self.use_mixed_precision and 'scaler_state_dict' in checkpoint:\n", " self.scaler.load_state_dict(checkpoint['scaler_state_dict'])\n", " \n", " print(f\"Checkpoint loaded: {checkpoint_path}\")\n", " print(f\"Resuming from epoch {self.current_epoch}, step {self.global_step}\")\n", "\n", "\n", "def create_dataloaders(args) -> DataLoader:\n", " \"\"\"Create training dataloader\"\"\"\n", " if args.dataset_name == \"vitonhd\":\n", " dataset = VITONHDTestDataset(args)\n", " else:\n", " raise ValueError(f\"Invalid dataset name {args.dataset_name}.\")\n", " \n", " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n", " \n", " dataloader = DataLoader(\n", " dataset,\n", " batch_size=args.batch_size,\n", " shuffle=True,\n", " num_workers=8,\n", " pin_memory=True,\n", " persistent_workers=True,\n", " prefetch_factor=2\n", " )\n", " \n", " return dataloader\n", "\n", "\n", "def main():\n", " args = argparse.Namespace()\n", " args.__dict__ = {\n", " \"base_model_path\": \"sd-v1-5-inpainting.ckpt\",\n", " \"dataset_name\": \"vitonhd\",\n", " \"data_root_path\": \"./viton-hd-dataset\",\n", " \"output_dir\": \"./checkpoints\",\n", " \"resume_from_checkpoint\": \"./checkpoints/checkpoint_step_40000.pth\",\n", " \"seed\": 42,\n", " \"batch_size\": 2,\n", " \"width\": 384,\n", " \"height\": 384,\n", " \"repaint\": True,\n", " \"eval_pair\": True,\n", " \"concat_eval_results\": True,\n", " \"concat_axis\": 'y',\n", " \"device\": \"cuda\",\n", " \"num_epochs\": 50, \n", " \"learning_rate\": 1e-5,\n", " \"max_grad_norm\": 1.0,\n", " \"cfg_dropout_prob\": 0.1,\n", " \"dream_lambda\": 10.0,\n", " \"use_peft\": True,\n", " \"use_mixed_precision\": True,\n", " \"save_steps\": 10000,\n", " \"is_train\": True\n", " }\n", " \n", " # Set random seeds\n", " torch.manual_seed(args.seed)\n", " np.random.seed(args.seed)\n", " random.seed(args.seed)\n", " if torch.cuda.is_available():\n", " torch.cuda.manual_seed_all(args.seed)\n", " \n", " # Optimize CUDA settings\n", " torch.backends.cudnn.benchmark = True\n", " torch.backends.cuda.matmul.allow_tf32 = True \n", " torch.backends.cudnn.allow_tf32 = True \n", " torch.set_float32_matmul_precision(\"high\")\n", " \n", " print(\"-\"*100)\n", "\n", " # Load pretrained models\n", " print(\"Loading pretrained models...\")\n", " models = preload_models_from_standard_weights(args.base_model_path, args.device)\n", " print(\"Models loaded successfully.\")\n", " \n", " print(\"-\"*100)\n", " \n", " # Create dataloader\n", " print(\"Creating dataloader...\")\n", " train_dataloader = create_dataloaders(args)\n", " \n", " print(f\"Training for {args.num_epochs} epochs\")\n", " print(f\"Batches per epoch: {len(train_dataloader)}\")\n", " \n", " print(\"-\"*100)\n", " \n", " # Initialize trainer\n", " print(\"Initializing trainer...\") \n", " trainer = CatVTONTrainer(\n", " models=models,\n", " train_dataloader=train_dataloader,\n", " device=args.device,\n", " learning_rate=args.learning_rate,\n", " num_epochs=args.num_epochs,\n", " save_steps=args.save_steps,\n", " output_dir=args.output_dir,\n", " cfg_dropout_prob=args.cfg_dropout_prob,\n", " max_grad_norm=args.max_grad_norm,\n", " use_peft=args.use_peft,\n", " dream_lambda=args.dream_lambda,\n", " resume_from_checkpoint=args.resume_from_checkpoint,\n", " use_mixed_precision=args.use_mixed_precision,\n", " height=args.height,\n", " width=args.width\n", " )\n", " \n", " # Start training\n", " print(\"Starting training...\")\n", " trainer.train() \n", "\n", "if __name__ == \"__main__\":\n", " main()" ] } ], "metadata": { "kernelspec": { "display_name": "harsh", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }