Spaces:
Running
Running
| import torch | |
| import os | |
| import random | |
| import argparse | |
| from pathlib import Path | |
| from typing import Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torch.optim import AdamW | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from VITON_Dataset import VITONHDTestDataset | |
| # Import your custom modules | |
| from load_model import preload_models_from_standard_weights | |
| from ddpm import DDPMSampler | |
| from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image, save_debug_visualization, compute_vae_encodings | |
| from diffusers.utils.torch_utils import randn_tensor | |
| class CatVTONTrainer: | |
| """Simplified CatVTON Training Class with PEFT, CFG and DREAM support""" | |
| def __init__( | |
| self, | |
| models: Dict[str, nn.Module], | |
| train_dataloader: DataLoader, | |
| val_dataloader: Optional[DataLoader] = None, | |
| device: str = "cuda", | |
| learning_rate: float = 1e-5, | |
| num_epochs: int = 50, | |
| save_steps: int = 1000, | |
| output_dir: str = "./checkpoints", | |
| cfg_dropout_prob: float = 0.1, | |
| max_grad_norm: float = 1.0, | |
| use_peft: bool = True, | |
| dream_lambda: float = 10.0, | |
| resume_from_checkpoint: Optional[str] = None, | |
| use_mixed_precision: bool = True, | |
| height=512, | |
| width=384, | |
| ): | |
| self.training = True | |
| self.models = models | |
| self.train_dataloader = train_dataloader | |
| self.val_dataloader = val_dataloader | |
| self.device = device | |
| self.learning_rate = learning_rate | |
| self.num_epochs = num_epochs | |
| self.save_steps = save_steps | |
| self.output_dir = Path(output_dir) | |
| self.cfg_dropout_prob = cfg_dropout_prob | |
| self.max_grad_norm = max_grad_norm | |
| self.use_peft = use_peft | |
| self.dream_lambda = dream_lambda | |
| self.use_mixed_precision = use_mixed_precision | |
| self.height = height | |
| self.width = width | |
| self.generator = torch.Generator(device=device) | |
| # Create output directory | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Setup mixed precision training | |
| if self.use_mixed_precision: | |
| self.scaler = torch.cuda.amp.GradScaler() | |
| self.weight_dtype = torch.float16 if use_mixed_precision else torch.float32 | |
| # Initialize scheduler and sampler | |
| self.scheduler = DDPMSampler(self.generator, num_training_steps=1000) | |
| # Resume from checkpoint if provided | |
| self.global_step = 0 | |
| self.current_epoch = 0 | |
| # Setup models and optimizers | |
| self._setup_training() | |
| if resume_from_checkpoint: | |
| self._load_checkpoint(resume_from_checkpoint) | |
| self.encoder = self.models.get('encoder', None) | |
| self.decoder = self.models.get('decoder', None) | |
| self.diffusion = self.models.get('diffusion', None) | |
| def _setup_training(self): | |
| """Setup models for training with PEFT""" | |
| # Move models to device | |
| for name, model in self.models.items(): | |
| model.to(self.device) | |
| # Freeze all parameters first | |
| for model in self.models.values(): | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Enable training for specific layers based on PEFT strategy | |
| if self.use_peft: | |
| self._enable_peft_training() | |
| else: | |
| # Enable full training for diffusion model | |
| for param in self.diffusion.parameters(): | |
| param.requires_grad = True | |
| # Collect trainable parameters | |
| trainable_params = [] | |
| total_params = 0 | |
| trainable_count = 0 | |
| for name, model in self.models.items(): | |
| for param_name, param in model.named_parameters(): | |
| total_params += param.numel() | |
| if param.requires_grad: | |
| trainable_params.append(param) | |
| trainable_count += param.numel() | |
| print(f"Total parameters: {total_params:,}") | |
| print(f"Trainable parameters: {trainable_count:,} ({trainable_count/total_params*100:.2f}%)") | |
| # Setup optimizer - AdamW as per paper | |
| self.optimizer = AdamW( | |
| trainable_params, | |
| lr=self.learning_rate, | |
| betas=(0.9, 0.999), | |
| weight_decay=1e-2, | |
| eps=1e-8 | |
| ) | |
| # Setup learning rate scheduler (constant) | |
| self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
| self.optimizer, lr_lambda=lambda epoch: 1.0 | |
| ) | |
| def _enable_peft_training(self): | |
| """Enable PEFT training - only self-attention layers""" | |
| print("Enabling PEFT training (self-attention layers only)") | |
| unet = self.models['diffusion'].unet | |
| # Enable attention layers in encoders and decoders | |
| for layers in [unet.encoders, unet.decoders]: | |
| for layer in layers: | |
| for module_idx, module in enumerate(layer): | |
| for name, param in module.named_parameters(): | |
| if 'attention_1' in name: | |
| param.requires_grad = True | |
| # Enable attention layers in bottleneck | |
| for layer in unet.bottleneck: | |
| for name, param in layer.named_parameters(): | |
| if 'attention_1' in name: | |
| param.requires_grad = True | |
| def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| """Compute MSE loss for denoising with DREAM strategy""" | |
| person_images = batch['person'].to(self.device) | |
| cloth_images = batch['cloth'].to(self.device) | |
| masks = batch['mask'].to(self.device) | |
| batch_size = person_images.shape[0] | |
| concat_dim = -2 # y axis concat | |
| # Prepare inputs | |
| image, condition_image, mask = check_inputs(person_images, cloth_images, masks, self.width, self.height) | |
| image = prepare_image(person_images).to(self.device, dtype=self.weight_dtype) | |
| condition_image = prepare_image(cloth_images).to(self.device, dtype=self.weight_dtype) | |
| mask = prepare_mask_image(masks).to(self.device, dtype=self.weight_dtype) | |
| # Mask image | |
| masked_image = image * (mask < 0.5) | |
| with torch.cuda.amp.autocast(enabled=self.use_mixed_precision): | |
| # VAE encoding | |
| masked_latent = compute_vae_encodings(masked_image, self.encoder) | |
| person_latent = compute_vae_encodings(person_images, self.encoder) | |
| condition_latent = compute_vae_encodings(condition_image, self.encoder) | |
| mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest") | |
| del image, mask, condition_image | |
| # Apply CFG dropout to garment latent (10% chance) | |
| if self.training and random.random() < self.cfg_dropout_prob: | |
| condition_latent = torch.zeros_like(condition_latent) | |
| # Concatenate latents | |
| input_latents = torch.cat([masked_latent, condition_latent], dim=concat_dim) | |
| mask_input = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim) | |
| target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim) | |
| noise = randn_tensor( | |
| target_latents.shape, | |
| generator=self.generator, | |
| device=target_latents.device, | |
| dtype=self.weight_dtype, | |
| ) | |
| # timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item() | |
| # timesteps = torch.tensor(timesteps, device=self.device) | |
| # timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype) | |
| timesteps = torch.randint(1, 1000, size=(batch_size,)) | |
| timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype) | |
| # Add noise to latents | |
| noisy_latents = self.scheduler.add_noise(target_latents, timesteps, noise) | |
| # UNet(zt ⊙ Mi ⊙ Xi) where ⊙ is channel concatenation | |
| unet_input = torch.cat([ | |
| input_latents, # Xi | |
| mask_input, # Mi | |
| noisy_latents, # zt | |
| ], dim=1).to(self.device, dtype=self.weight_dtype) # Channel dimension | |
| # DREAM strategy implementation | |
| if self.dream_lambda > 0: | |
| # Get initial noise prediction | |
| with torch.no_grad(): | |
| epsilon_theta = self.diffusion( | |
| unet_input, | |
| timesteps_embedding | |
| ) | |
| # DREAM noise combination: ε + λ*εθ | |
| dream_noise = noise + self.dream_lambda * epsilon_theta | |
| # Create new noisy latents with DREAM noise | |
| dream_noisy_latents = self.scheduler.add_noise(target_latents, timesteps, dream_noise) | |
| dream_unet_input = torch.cat([ | |
| input_latents, | |
| mask_input, | |
| dream_noisy_latents | |
| ], dim=1).to(self.device, dtype=self.weight_dtype) | |
| predicted_noise = self.diffusion( | |
| dream_unet_input, | |
| timesteps_embedding | |
| ) | |
| # DREAM loss: |(ε + λεθ) - εθ(ẑt, t)|² | |
| loss = F.mse_loss(predicted_noise, dream_noise) | |
| else: | |
| # Standard training without DREAM | |
| predicted_noise = self.diffusion( | |
| unet_input, | |
| timesteps_embedding, | |
| ) | |
| # Standard MSE loss | |
| loss = F.mse_loss(predicted_noise, noise) | |
| if self.global_step % 1000 == 0: | |
| save_debug_visualization( | |
| person_images=person_images, | |
| cloth_images=cloth_images, | |
| masks=masks, | |
| masked_image=masked_image, | |
| noisy_latents=noisy_latents, | |
| predicted_noise=predicted_noise, | |
| target_latents=target_latents, | |
| decoder=self.decoder, | |
| global_step=self.global_step, | |
| output_dir=self.output_dir, | |
| device=self.device | |
| ) | |
| return loss | |
| def train_epoch(self) -> float: | |
| """Train for one epoch - simplified version""" | |
| self.diffusion.train() | |
| total_loss = 0.0 | |
| num_batches = len(self.train_dataloader) | |
| # progress_bar = tqdm(self.train_dataloader, desc=f"Epoch {self.current_epoch+1}") | |
| for step, batch in enumerate(self.train_dataloader): | |
| # Zero gradients | |
| self.optimizer.zero_grad() | |
| # Forward pass with mixed precision | |
| if self.use_mixed_precision: | |
| with torch.cuda.amp.autocast(): | |
| loss = self.compute_loss(batch) | |
| # Backward pass with scaling | |
| self.scaler.scale(loss).backward() | |
| # Gradient clipping and optimizer step | |
| self.scaler.unscale_(self.optimizer) | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in self.diffusion.parameters() if p.requires_grad], | |
| self.max_grad_norm | |
| ) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| else: | |
| loss = self.compute_loss(batch) | |
| loss.backward() | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in self.diffusion.parameters() if p.requires_grad], | |
| self.max_grad_norm | |
| ) | |
| # Optimizer step | |
| self.optimizer.step() | |
| # Update learning rate | |
| self.lr_scheduler.step() | |
| self.global_step += 1 | |
| total_loss += loss.item() | |
| # Update progress bar | |
| # progress_bar.set_postfix({ | |
| # 'loss': loss.item(), | |
| # 'lr': self.optimizer.param_groups[0]['lr'], | |
| # 'step': self.global_step | |
| # }) | |
| # Save checkpoint based on steps | |
| if self.global_step % self.save_steps == 0: | |
| self._save_checkpoint() | |
| # Clear cache periodically to prevent OOM | |
| if step % 50 == 0: | |
| torch.cuda.empty_cache() | |
| return total_loss / num_batches | |
| def train(self): | |
| """Main training loop - simplified version""" | |
| print(f"Starting training for {self.num_epochs} epochs") | |
| print(f"Total training batches per epoch: {len(self.train_dataloader)}") | |
| print(f"Using DREAM with lambda = {self.dream_lambda}") | |
| print(f"Mixed precision: {self.use_mixed_precision}") | |
| for epoch in range(self.current_epoch, self.num_epochs): | |
| self.current_epoch = epoch | |
| # Train one epoch | |
| train_loss = self.train_epoch() | |
| print(f"Epoch {epoch+1}/{self.num_epochs} - Train Loss: {train_loss:.6f}") | |
| # Save epoch checkpoint | |
| if (epoch + 1) % 5 == 0: # Save every 5 epochs | |
| self._save_checkpoint(epoch_checkpoint=True) | |
| # Clear cache at end of epoch | |
| torch.cuda.empty_cache() | |
| # Save final checkpoint | |
| self._save_checkpoint(is_final=True) | |
| print("Training completed!") | |
| def _save_checkpoint(self, is_best: bool = False, epoch_checkpoint: bool = False, is_final: bool = False): | |
| """Save model checkpoint""" | |
| checkpoint = { | |
| 'global_step': self.global_step, | |
| 'current_epoch': self.current_epoch, | |
| 'diffusion_state_dict': self.diffusion.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'lr_scheduler_state_dict': self.lr_scheduler.state_dict(), | |
| 'dream_lambda': self.dream_lambda, | |
| 'use_peft': self.use_peft, | |
| } | |
| if self.use_mixed_precision: | |
| checkpoint['scaler_state_dict'] = self.scaler.state_dict() | |
| if is_final: | |
| checkpoint_path = self.output_dir / "final_model.pth" | |
| elif is_best: | |
| checkpoint_path = self.output_dir / "best_model.pth" | |
| elif epoch_checkpoint: | |
| checkpoint_path = self.output_dir / f"checkpoint_epoch_{self.current_epoch+1}.pth" | |
| else: | |
| checkpoint_path = self.output_dir / f"checkpoint_step_{self.global_step}.pth" | |
| torch.save(checkpoint, checkpoint_path) | |
| print(f"Checkpoint saved: {checkpoint_path}") | |
| def _load_checkpoint(self, checkpoint_path: str): | |
| """Load model checkpoint""" | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.global_step = checkpoint['global_step'] | |
| self.current_epoch = checkpoint['current_epoch'] | |
| self.dream_lambda = checkpoint.get('dream_lambda', 10.0) | |
| self.models['diffusion'].load_state_dict(checkpoint['diffusion_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) | |
| if self.use_mixed_precision and 'scaler_state_dict' in checkpoint: | |
| self.scaler.load_state_dict(checkpoint['scaler_state_dict']) | |
| print(f"Checkpoint loaded: {checkpoint_path}") | |
| print(f"Resuming from epoch {self.current_epoch}, step {self.global_step}") | |
| def create_dataloaders(args) -> DataLoader: | |
| """Create training dataloader""" | |
| if args.dataset_name == "vitonhd": | |
| dataset = VITONHDTestDataset(args) | |
| else: | |
| raise ValueError(f"Invalid dataset name {args.dataset_name}.") | |
| print(f"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=8, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| prefetch_factor=2 | |
| ) | |
| return dataloader | |
| def main(): | |
| args = argparse.Namespace() | |
| args.__dict__ = { | |
| "base_model_path": "sd-v1-5-inpainting.ckpt", | |
| "dataset_name": "vitonhd", | |
| "data_root_path": "./viton-hd-dataset", | |
| "output_dir": "./checkpoints", | |
| "resume_from_checkpoint": "./checkpoints/checkpoint_step_50000.pth", | |
| "seed": 42, | |
| "batch_size": 2, | |
| "width": 384, | |
| "height": 384, | |
| "repaint": True, | |
| "eval_pair": True, | |
| "concat_eval_results": True, | |
| "concat_axis": 'y', | |
| "device": "cuda", | |
| "num_epochs": 50, | |
| "learning_rate": 1e-5, | |
| "max_grad_norm": 1.0, | |
| "cfg_dropout_prob": 0.1, | |
| "dream_lambda": 10.0, | |
| "use_peft": True, | |
| "use_mixed_precision": True, | |
| "save_steps": 10000, | |
| "is_train": True | |
| } | |
| # Set random seeds | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| random.seed(args.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(args.seed) | |
| # Optimize CUDA settings | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| print("-"*100) | |
| # Load pretrained models | |
| print("Loading pretrained models...") | |
| models = preload_models_from_standard_weights(args.base_model_path, args.device) | |
| print("Models loaded successfully.") | |
| print("-"*100) | |
| # Create dataloader | |
| print("Creating dataloader...") | |
| train_dataloader = create_dataloaders(args) | |
| print(f"Training for {args.num_epochs} epochs") | |
| print(f"Batches per epoch: {len(train_dataloader)}") | |
| print("-"*100) | |
| # Initialize trainer | |
| print("Initializing trainer...") | |
| trainer = CatVTONTrainer( | |
| models=models, | |
| train_dataloader=train_dataloader, | |
| device=args.device, | |
| learning_rate=args.learning_rate, | |
| num_epochs=args.num_epochs, | |
| save_steps=args.save_steps, | |
| output_dir=args.output_dir, | |
| cfg_dropout_prob=args.cfg_dropout_prob, | |
| max_grad_norm=args.max_grad_norm, | |
| use_peft=args.use_peft, | |
| dream_lambda=args.dream_lambda, | |
| resume_from_checkpoint=args.resume_from_checkpoint, | |
| use_mixed_precision=args.use_mixed_precision, | |
| height=args.height, | |
| width=args.width | |
| ) | |
| # Start training | |
| print("Starting training...") | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() |