Spaces:
Runtime error
Runtime error
File size: 5,820 Bytes
8645d6f 4315394 8645d6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import spaces
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from ip_adapter import IPAdapter
from PIL import Image
import gradio as gr
# --- Configuration Constants ---
SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
IP_ADAPTER_MODEL_ID = "h94/IP-Adapter-Plus-SDXL"
IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sdxl_vit-h.bin"
# --- Global Model Instances ---
# These will be initialized and compiled during startup
pipe_global: DiffusionPipeline = None
ip_adapter_global: IPAdapter = None
@spaces.GPU(duration=1500) # Allocate maximum time for startup compilation
def load_and_compile_models():
"""
Loads the SDXL and IP-Adapter models and performs Ahead-of-Time (AoT) compilation
of the UNet for performance optimization using ZeroGPU.
This function is called once during application startup.
"""
global pipe_global, ip_adapter_global
print("π Starting model loading and compilation...")
# 1. Load SDXL base pipeline
print(f"Loading SDXL base model: {SDXL_BASE_MODEL_ID}")
pipe_global = DiffusionPipeline.from_pretrained(
SDXL_BASE_MODEL_ID,
torch_dtype=torch.float16,
add_watermarker=False, # Disable watermarking for potential speedup
variant="fp16" # Use fp16 variant if available for better performance
)
# Load VAE separately as recommended for stabilityai models
pipe_global.vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16,
)
pipe_global.to("cuda")
print("SDXL base model loaded and moved to CUDA.")
# 2. Load IP-Adapter
print(f"Loading IP-Adapter from: {IP_ADAPTER_MODEL_ID}/{IP_ADAPTER_WEIGHT_NAME}")
ip_adapter_global = IPAdapter(
pipe_global,
image_encoder_path=IP_ADAPTER_MODEL_ID,
ip_ckpt=IP_ADAPTER_WEIGHT_NAME,
device="cuda"
)
print("IP-Adapter loaded and integrated into the pipeline.")
# 3. Perform AoT compilation for the UNet (main generation component)
print("Starting Ahead-of-Time (AoT) compilation for pipe_global.unet with IP-Adapter...")
# Prepare dummy inputs for capturing UNet's forward pass.
# We need to call a function that internally uses pipe_global.unet
# and has IP-Adapter inputs integrated. The `ip_adapter_global.generate` method
# is designed for this. We use minimal steps for tracing.
dummy_prompt = "a photorealistic image of a beautiful landscape"
dummy_ip_image = Image.new('RGB', (224, 224), color = 'red') # IP-Adapter typically uses 224x224 or 256x256 input
with spaces.aoti_capture(ip_adapter_global.pipe.unet) as call:
# Execute a minimal generation using the IP-Adapter's generate method.
# This will trigger the forward pass of `pipe_global.unet` with
# all the necessary IP-Adapter embeddings, allowing `aoti_capture` to trace it.
_ = ip_adapter_global.generate(
prompt=dummy_prompt,
images=[dummy_ip_image], # Provide a dummy image to trace the IP-Adapter path
height=1024, width=1024,
num_inference_steps=2, # Use minimal steps for fast tracing
guidance_scale=7.5,
num_images_per_prompt=1,
output_type="pil",
).images[0]
# Export the captured UNet module
print("Exporting UNet...")
exported_unet = torch.export.export(
ip_adapter_global.pipe.unet,
args=call.args,
kwargs=call.kwargs,
)
# Compile the exported UNet module
print("Compiling UNet...")
compiled_unet = spaces.aoti_compile(exported_unet)
print("UNet compilation complete.")
# Apply the compiled module back to the pipeline's UNet
spaces.aoti_apply(compiled_unet, ip_adapter_global.pipe.unet)
print("AoT compiled UNet applied to the pipeline.")
print("β
Models loaded and compiled successfully!")
# Call the loading and compilation function once when this module is imported
load_and_compile_models()
@spaces.GPU(duration=60) # Allocate up to 60 seconds for actual image generation
def remix_images(
prompt: str,
image1: Image.Image | None,
image2: Image.Image | None,
image3: Image.Image | None
) -> list[Image.Image]:
"""
Generates images based on a text prompt and up to three input images using SDXL with IP-Adapter.
Args:
prompt (str): The text prompt for image generation.
image1 (PIL.Image.Image | None): The first input image.
image2 (PIL.Image.Image | None): The second input image.
image3 (PIL.Image.Image | None): The third input image.
Returns:
list[PIL.Image.Image]: A list of generated images.
"""
if not prompt:
raise gr.Error("Prompt cannot be empty! Please provide a textual description.")
# Filter out None images to create a list of valid input images
input_images = [img for img in [image1, image2, image3] if img is not None]
print(f"Generating image(s) for prompt: '{prompt}'")
print(f"Using {len(input_images)} input images for IP-Adapter.")
# Call the IP-Adapter's generate method.
# The `ip-adapter` library's `generate` method is designed to handle
# an empty `images` list by falling back to pure text-to-image generation.
generated_images = ip_adapter_global.generate(
prompt=prompt,
images=input_images, # This can be an empty list
height=1024, width=1024,
num_inference_steps=30, # Standard number of inference steps
guidance_scale=7.5, # Classifier-free guidance scale
num_images_per_prompt=1, # Generate one image per request
output_type="pil", # Ensure output is PIL Image objects
# No seed is used as per requirement
).images
return generated_images |