import os import torch import gradio as gr from PIL import Image import numpy as np from typing import Optional # Import your custom modules from load_model import preload_models_from_standard_weights from utils import to_pil_image from CatVTON_model import CatVTONPix2PixPipeline import os import torch import urllib.request def load_models(): try: print("🚀 Starting model loading process...") # Check and download model files if missing ckpt_path = "instruct-pix2pix-00-22000.ckpt" finetune_path = "maskfree_finetuned_weights.safetensors" if not os.path.exists(ckpt_path): print(f"⬇️ Downloading {ckpt_path}...") url = "https://huggingface.co/timbrooks/instruct-pix2pix/resolve/main/instruct-pix2pix-00-22000.ckpt" urllib.request.urlretrieve(url, ckpt_path) print("✅ Download complete.") else: print("✅ Checkpoint already exists.") if not os.path.exists(finetune_path): print(f"❌ Finetune weights file not found: {finetune_path}") return None, None # Check CUDA availability cuda_available = torch.cuda.is_available() print(f"CUDA available: {cuda_available}") if cuda_available: print(f"CUDA device: {torch.cuda.get_device_name()}") free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) print(f"Available CUDA memory: {free_memory / 1e9:.2f} GB") device = "cuda" if cuda_available else "cpu" print("📦 Loading models from weights...") models = preload_models_from_standard_weights( ckpt_path=ckpt_path, device=device, finetune_weights_path=finetune_path ) if not models: print("❌ Failed to load models") return None, None weight_dtype = torch.float32 print(f"Converting models to {weight_dtype}...") for model_name, model in models.items(): if model is not None: try: model = model.to(dtype=weight_dtype) models[model_name] = model print(f"✅ {model_name} converted to {weight_dtype}") except Exception as e: print(f"⚠️ Could not convert {model_name} to {weight_dtype}: {e}") print("🔧 Initializing pipeline...") pipeline = CatVTONPix2PixPipeline( weight_dtype=weight_dtype, device=device, skip_safety_check=True, models=models, ) print("✅ Models and pipeline loaded successfully!") return models, pipeline except Exception as e: print(f"❌ Error in load_models: {e}") import traceback traceback.print_exc() return None, None def person_example_fn(image_path): """Handle person image examples""" if image_path: return image_path return None def create_demo(pipeline=None): """Create the Gradio interface""" def submit_function_p2p( person_image_path: Optional[str], cloth_image_path: Optional[str], num_inference_steps: int = 50, guidance_scale: float = 2.5, seed: int = 42, ) -> Optional[Image.Image]: """Process virtual try-on inference""" try: if not person_image_path or not cloth_image_path: gr.Warning("Please upload both person and cloth images!") return None if not os.path.exists(person_image_path): gr.Error("Person image file not found!") return None if not os.path.exists(cloth_image_path): gr.Error("Cloth image file not found!") return None if pipeline is None: gr.Error("Models not loaded! Please restart the application.") return None # Load images try: person_image = Image.open(person_image_path).convert('RGB') cloth_image = Image.open(cloth_image_path).convert('RGB') except Exception as e: gr.Error(f"Error loading images: {str(e)}") return None # Set up generator generator = torch.Generator(device=pipeline.device) if seed != -1: generator.manual_seed(seed) print("🔄 Processing virtual try-on...") # Run inference with torch.no_grad(): results = pipeline( person_image, cloth_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=512, width=384, generator=generator, ) # Process results if isinstance(results, list) and len(results) > 0: result = results[0] else: result = results return result except Exception as e: print(f"❌ Error in submit_function_p2p: {e}") import traceback traceback.print_exc() gr.Error(f"Error during inference: {str(e)}") return None # Custom CSS for better styling css = """ .gradio-container { max-width: 1200px !important; } .image-container { max-height: 600px; } """ with gr.Blocks(css=css, title="Virtual Try-On") as demo: gr.HTML("""
Upload a person image and a clothing item to see how they look together!