import spaces import torch from diffusers import DiffusionPipeline, AutoencoderKL from ip_adapter import IPAdapter from PIL import Image import gradio as gr # --- Configuration Constants --- SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" IP_ADAPTER_MODEL_ID = "h94/IP-Adapter-Plus-SDXL" IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sdxl_vit-h.bin" # --- Global Model Instances --- # These will be initialized and compiled during startup pipe_global: DiffusionPipeline = None ip_adapter_global: IPAdapter = None @spaces.GPU(duration=1500) # Allocate maximum time for startup compilation def load_and_compile_models(): """ Loads the SDXL and IP-Adapter models and performs Ahead-of-Time (AoT) compilation of the UNet for performance optimization using ZeroGPU. This function is called once during application startup. """ global pipe_global, ip_adapter_global print("🚀 Starting model loading and compilation...") # 1. Load SDXL base pipeline print(f"Loading SDXL base model: {SDXL_BASE_MODEL_ID}") pipe_global = DiffusionPipeline.from_pretrained( SDXL_BASE_MODEL_ID, torch_dtype=torch.float16, add_watermarker=False, # Disable watermarking for potential speedup variant="fp16" # Use fp16 variant if available for better performance ) # Load VAE separately as recommended for stabilityai models pipe_global.vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) pipe_global.to("cuda") print("SDXL base model loaded and moved to CUDA.") # 2. Load IP-Adapter print(f"Loading IP-Adapter from: {IP_ADAPTER_MODEL_ID}/{IP_ADAPTER_WEIGHT_NAME}") ip_adapter_global = IPAdapter( pipe_global, image_encoder_path=IP_ADAPTER_MODEL_ID, ip_ckpt=IP_ADAPTER_WEIGHT_NAME, device="cuda" ) print("IP-Adapter loaded and integrated into the pipeline.") # 3. Perform AoT compilation for the UNet (main generation component) print("Starting Ahead-of-Time (AoT) compilation for pipe_global.unet with IP-Adapter...") # Prepare dummy inputs for capturing UNet's forward pass. # We need to call a function that internally uses pipe_global.unet # and has IP-Adapter inputs integrated. The `ip_adapter_global.generate` method # is designed for this. We use minimal steps for tracing. dummy_prompt = "a photorealistic image of a beautiful landscape" dummy_ip_image = Image.new('RGB', (224, 224), color = 'red') # IP-Adapter typically uses 224x224 or 256x256 input with spaces.aoti_capture(ip_adapter_global.pipe.unet) as call: # Execute a minimal generation using the IP-Adapter's generate method. # This will trigger the forward pass of `pipe_global.unet` with # all the necessary IP-Adapter embeddings, allowing `aoti_capture` to trace it. _ = ip_adapter_global.generate( prompt=dummy_prompt, images=[dummy_ip_image], # Provide a dummy image to trace the IP-Adapter path height=1024, width=1024, num_inference_steps=2, # Use minimal steps for fast tracing guidance_scale=7.5, num_images_per_prompt=1, output_type="pil", ).images[0] # Export the captured UNet module print("Exporting UNet...") exported_unet = torch.export.export( ip_adapter_global.pipe.unet, args=call.args, kwargs=call.kwargs, ) # Compile the exported UNet module print("Compiling UNet...") compiled_unet = spaces.aoti_compile(exported_unet) print("UNet compilation complete.") # Apply the compiled module back to the pipeline's UNet spaces.aoti_apply(compiled_unet, ip_adapter_global.pipe.unet) print("AoT compiled UNet applied to the pipeline.") print("✅ Models loaded and compiled successfully!") # Call the loading and compilation function once when this module is imported load_and_compile_models() @spaces.GPU(duration=60) # Allocate up to 60 seconds for actual image generation def remix_images( prompt: str, image1: Image.Image | None, image2: Image.Image | None, image3: Image.Image | None ) -> list[Image.Image]: """ Generates images based on a text prompt and up to three input images using SDXL with IP-Adapter. Args: prompt (str): The text prompt for image generation. image1 (PIL.Image.Image | None): The first input image. image2 (PIL.Image.Image | None): The second input image. image3 (PIL.Image.Image | None): The third input image. Returns: list[PIL.Image.Image]: A list of generated images. """ if not prompt: raise gr.Error("Prompt cannot be empty! Please provide a textual description.") # Filter out None images to create a list of valid input images input_images = [img for img in [image1, image2, image3] if img is not None] print(f"Generating image(s) for prompt: '{prompt}'") print(f"Using {len(input_images)} input images for IP-Adapter.") # Call the IP-Adapter's generate method. # The `ip-adapter` library's `generate` method is designed to handle # an empty `images` list by falling back to pure text-to-image generation. generated_images = ip_adapter_global.generate( prompt=prompt, images=input_images, # This can be an empty list height=1024, width=1024, num_inference_steps=30, # Standard number of inference steps guidance_scale=7.5, # Classifier-free guidance scale num_images_per_prompt=1, # Generate one image per request output_type="pil", # Ensure output is PIL Image objects # No seed is used as per requirement ).images return generated_images