app-pzeyhe-14 / models.py
Gertie01's picture
Update models.py (#2)
4315394 verified
import spaces
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from ip_adapter import IPAdapter
from PIL import Image
import gradio as gr
# --- Configuration Constants ---
SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
IP_ADAPTER_MODEL_ID = "h94/IP-Adapter-Plus-SDXL"
IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sdxl_vit-h.bin"
# --- Global Model Instances ---
# These will be initialized and compiled during startup
pipe_global: DiffusionPipeline = None
ip_adapter_global: IPAdapter = None
@spaces.GPU(duration=1500) # Allocate maximum time for startup compilation
def load_and_compile_models():
"""
Loads the SDXL and IP-Adapter models and performs Ahead-of-Time (AoT) compilation
of the UNet for performance optimization using ZeroGPU.
This function is called once during application startup.
"""
global pipe_global, ip_adapter_global
print("πŸš€ Starting model loading and compilation...")
# 1. Load SDXL base pipeline
print(f"Loading SDXL base model: {SDXL_BASE_MODEL_ID}")
pipe_global = DiffusionPipeline.from_pretrained(
SDXL_BASE_MODEL_ID,
torch_dtype=torch.float16,
add_watermarker=False, # Disable watermarking for potential speedup
variant="fp16" # Use fp16 variant if available for better performance
)
# Load VAE separately as recommended for stabilityai models
pipe_global.vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16,
)
pipe_global.to("cuda")
print("SDXL base model loaded and moved to CUDA.")
# 2. Load IP-Adapter
print(f"Loading IP-Adapter from: {IP_ADAPTER_MODEL_ID}/{IP_ADAPTER_WEIGHT_NAME}")
ip_adapter_global = IPAdapter(
pipe_global,
image_encoder_path=IP_ADAPTER_MODEL_ID,
ip_ckpt=IP_ADAPTER_WEIGHT_NAME,
device="cuda"
)
print("IP-Adapter loaded and integrated into the pipeline.")
# 3. Perform AoT compilation for the UNet (main generation component)
print("Starting Ahead-of-Time (AoT) compilation for pipe_global.unet with IP-Adapter...")
# Prepare dummy inputs for capturing UNet's forward pass.
# We need to call a function that internally uses pipe_global.unet
# and has IP-Adapter inputs integrated. The `ip_adapter_global.generate` method
# is designed for this. We use minimal steps for tracing.
dummy_prompt = "a photorealistic image of a beautiful landscape"
dummy_ip_image = Image.new('RGB', (224, 224), color = 'red') # IP-Adapter typically uses 224x224 or 256x256 input
with spaces.aoti_capture(ip_adapter_global.pipe.unet) as call:
# Execute a minimal generation using the IP-Adapter's generate method.
# This will trigger the forward pass of `pipe_global.unet` with
# all the necessary IP-Adapter embeddings, allowing `aoti_capture` to trace it.
_ = ip_adapter_global.generate(
prompt=dummy_prompt,
images=[dummy_ip_image], # Provide a dummy image to trace the IP-Adapter path
height=1024, width=1024,
num_inference_steps=2, # Use minimal steps for fast tracing
guidance_scale=7.5,
num_images_per_prompt=1,
output_type="pil",
).images[0]
# Export the captured UNet module
print("Exporting UNet...")
exported_unet = torch.export.export(
ip_adapter_global.pipe.unet,
args=call.args,
kwargs=call.kwargs,
)
# Compile the exported UNet module
print("Compiling UNet...")
compiled_unet = spaces.aoti_compile(exported_unet)
print("UNet compilation complete.")
# Apply the compiled module back to the pipeline's UNet
spaces.aoti_apply(compiled_unet, ip_adapter_global.pipe.unet)
print("AoT compiled UNet applied to the pipeline.")
print("βœ… Models loaded and compiled successfully!")
# Call the loading and compilation function once when this module is imported
load_and_compile_models()
@spaces.GPU(duration=60) # Allocate up to 60 seconds for actual image generation
def remix_images(
prompt: str,
image1: Image.Image | None,
image2: Image.Image | None,
image3: Image.Image | None
) -> list[Image.Image]:
"""
Generates images based on a text prompt and up to three input images using SDXL with IP-Adapter.
Args:
prompt (str): The text prompt for image generation.
image1 (PIL.Image.Image | None): The first input image.
image2 (PIL.Image.Image | None): The second input image.
image3 (PIL.Image.Image | None): The third input image.
Returns:
list[PIL.Image.Image]: A list of generated images.
"""
if not prompt:
raise gr.Error("Prompt cannot be empty! Please provide a textual description.")
# Filter out None images to create a list of valid input images
input_images = [img for img in [image1, image2, image3] if img is not None]
print(f"Generating image(s) for prompt: '{prompt}'")
print(f"Using {len(input_images)} input images for IP-Adapter.")
# Call the IP-Adapter's generate method.
# The `ip-adapter` library's `generate` method is designed to handle
# an empty `images` list by falling back to pure text-to-image generation.
generated_images = ip_adapter_global.generate(
prompt=prompt,
images=input_images, # This can be an empty list
height=1024, width=1024,
num_inference_steps=30, # Standard number of inference steps
guidance_scale=7.5, # Classifier-free guidance scale
num_images_per_prompt=1, # Generate one image per request
output_type="pil", # Ensure output is PIL Image objects
# No seed is used as per requirement
).images
return generated_images