Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| from diffusers import DiffusionPipeline, AutoencoderKL | |
| from ip_adapter import IPAdapter | |
| from PIL import Image | |
| import gradio as gr | |
| # --- Configuration Constants --- | |
| SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
| IP_ADAPTER_MODEL_ID = "h94/IP-Adapter-Plus-SDXL" | |
| IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sdxl_vit-h.bin" | |
| # --- Global Model Instances --- | |
| # These will be initialized and compiled during startup | |
| pipe_global: DiffusionPipeline = None | |
| ip_adapter_global: IPAdapter = None | |
| # Allocate maximum time for startup compilation | |
| def load_and_compile_models(): | |
| """ | |
| Loads the SDXL and IP-Adapter models and performs Ahead-of-Time (AoT) compilation | |
| of the UNet for performance optimization using ZeroGPU. | |
| This function is called once during application startup. | |
| """ | |
| global pipe_global, ip_adapter_global | |
| print("π Starting model loading and compilation...") | |
| # 1. Load SDXL base pipeline | |
| print(f"Loading SDXL base model: {SDXL_BASE_MODEL_ID}") | |
| pipe_global = DiffusionPipeline.from_pretrained( | |
| SDXL_BASE_MODEL_ID, | |
| torch_dtype=torch.float16, | |
| add_watermarker=False, # Disable watermarking for potential speedup | |
| variant="fp16" # Use fp16 variant if available for better performance | |
| ) | |
| # Load VAE separately as recommended for stabilityai models | |
| pipe_global.vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, | |
| ) | |
| pipe_global.to("cuda") | |
| print("SDXL base model loaded and moved to CUDA.") | |
| # 2. Load IP-Adapter | |
| print(f"Loading IP-Adapter from: {IP_ADAPTER_MODEL_ID}/{IP_ADAPTER_WEIGHT_NAME}") | |
| ip_adapter_global = IPAdapter( | |
| pipe_global, | |
| image_encoder_path=IP_ADAPTER_MODEL_ID, | |
| ip_ckpt=IP_ADAPTER_WEIGHT_NAME, | |
| device="cuda" | |
| ) | |
| print("IP-Adapter loaded and integrated into the pipeline.") | |
| # 3. Perform AoT compilation for the UNet (main generation component) | |
| print("Starting Ahead-of-Time (AoT) compilation for pipe_global.unet with IP-Adapter...") | |
| # Prepare dummy inputs for capturing UNet's forward pass. | |
| # We need to call a function that internally uses pipe_global.unet | |
| # and has IP-Adapter inputs integrated. The `ip_adapter_global.generate` method | |
| # is designed for this. We use minimal steps for tracing. | |
| dummy_prompt = "a photorealistic image of a beautiful landscape" | |
| dummy_ip_image = Image.new('RGB', (224, 224), color = 'red') # IP-Adapter typically uses 224x224 or 256x256 input | |
| with spaces.aoti_capture(ip_adapter_global.pipe.unet) as call: | |
| # Execute a minimal generation using the IP-Adapter's generate method. | |
| # This will trigger the forward pass of `pipe_global.unet` with | |
| # all the necessary IP-Adapter embeddings, allowing `aoti_capture` to trace it. | |
| _ = ip_adapter_global.generate( | |
| prompt=dummy_prompt, | |
| images=[dummy_ip_image], # Provide a dummy image to trace the IP-Adapter path | |
| height=1024, width=1024, | |
| num_inference_steps=2, # Use minimal steps for fast tracing | |
| guidance_scale=7.5, | |
| num_images_per_prompt=1, | |
| output_type="pil", | |
| ).images[0] | |
| # Export the captured UNet module | |
| print("Exporting UNet...") | |
| exported_unet = torch.export.export( | |
| ip_adapter_global.pipe.unet, | |
| args=call.args, | |
| kwargs=call.kwargs, | |
| ) | |
| # Compile the exported UNet module | |
| print("Compiling UNet...") | |
| compiled_unet = spaces.aoti_compile(exported_unet) | |
| print("UNet compilation complete.") | |
| # Apply the compiled module back to the pipeline's UNet | |
| spaces.aoti_apply(compiled_unet, ip_adapter_global.pipe.unet) | |
| print("AoT compiled UNet applied to the pipeline.") | |
| print("β Models loaded and compiled successfully!") | |
| # Call the loading and compilation function once when this module is imported | |
| load_and_compile_models() | |
| # Allocate up to 60 seconds for actual image generation | |
| def remix_images( | |
| prompt: str, | |
| image1: Image.Image | None, | |
| image2: Image.Image | None, | |
| image3: Image.Image | None | |
| ) -> list[Image.Image]: | |
| """ | |
| Generates images based on a text prompt and up to three input images using SDXL with IP-Adapter. | |
| Args: | |
| prompt (str): The text prompt for image generation. | |
| image1 (PIL.Image.Image | None): The first input image. | |
| image2 (PIL.Image.Image | None): The second input image. | |
| image3 (PIL.Image.Image | None): The third input image. | |
| Returns: | |
| list[PIL.Image.Image]: A list of generated images. | |
| """ | |
| if not prompt: | |
| raise gr.Error("Prompt cannot be empty! Please provide a textual description.") | |
| # Filter out None images to create a list of valid input images | |
| input_images = [img for img in [image1, image2, image3] if img is not None] | |
| print(f"Generating image(s) for prompt: '{prompt}'") | |
| print(f"Using {len(input_images)} input images for IP-Adapter.") | |
| # Call the IP-Adapter's generate method. | |
| # The `ip-adapter` library's `generate` method is designed to handle | |
| # an empty `images` list by falling back to pure text-to-image generation. | |
| generated_images = ip_adapter_global.generate( | |
| prompt=prompt, | |
| images=input_images, # This can be an empty list | |
| height=1024, width=1024, | |
| num_inference_steps=30, # Standard number of inference steps | |
| guidance_scale=7.5, # Classifier-free guidance scale | |
| num_images_per_prompt=1, # Generate one image per request | |
| output_type="pil", # Ensure output is PIL Image objects | |
| # No seed is used as per requirement | |
| ).images | |
| return generated_images |