File size: 5,820 Bytes
8645d6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4315394
8645d6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import spaces
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from ip_adapter import IPAdapter
from PIL import Image
import gradio as gr

# --- Configuration Constants ---
SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
IP_ADAPTER_MODEL_ID = "h94/IP-Adapter-Plus-SDXL"
IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sdxl_vit-h.bin"

# --- Global Model Instances ---
# These will be initialized and compiled during startup
pipe_global: DiffusionPipeline = None
ip_adapter_global: IPAdapter = None

@spaces.GPU(duration=1500)  # Allocate maximum time for startup compilation
def load_and_compile_models():
    """
    Loads the SDXL and IP-Adapter models and performs Ahead-of-Time (AoT) compilation
    of the UNet for performance optimization using ZeroGPU.
    This function is called once during application startup.
    """
    global pipe_global, ip_adapter_global
    
    print("πŸš€ Starting model loading and compilation...")

    # 1. Load SDXL base pipeline
    print(f"Loading SDXL base model: {SDXL_BASE_MODEL_ID}")
    pipe_global = DiffusionPipeline.from_pretrained(
        SDXL_BASE_MODEL_ID,
        torch_dtype=torch.float16,
        add_watermarker=False,  # Disable watermarking for potential speedup
        variant="fp16" # Use fp16 variant if available for better performance
    )
    # Load VAE separately as recommended for stabilityai models
    pipe_global.vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16,
    )
    pipe_global.to("cuda")
    print("SDXL base model loaded and moved to CUDA.")

    # 2. Load IP-Adapter
    print(f"Loading IP-Adapter from: {IP_ADAPTER_MODEL_ID}/{IP_ADAPTER_WEIGHT_NAME}")
    ip_adapter_global = IPAdapter(
        pipe_global,
        image_encoder_path=IP_ADAPTER_MODEL_ID,
        ip_ckpt=IP_ADAPTER_WEIGHT_NAME,
        device="cuda"
    )
    print("IP-Adapter loaded and integrated into the pipeline.")

    # 3. Perform AoT compilation for the UNet (main generation component)
    print("Starting Ahead-of-Time (AoT) compilation for pipe_global.unet with IP-Adapter...")

    # Prepare dummy inputs for capturing UNet's forward pass.
    # We need to call a function that internally uses pipe_global.unet
    # and has IP-Adapter inputs integrated. The `ip_adapter_global.generate` method
    # is designed for this. We use minimal steps for tracing.
    dummy_prompt = "a photorealistic image of a beautiful landscape"
    dummy_ip_image = Image.new('RGB', (224, 224), color = 'red') # IP-Adapter typically uses 224x224 or 256x256 input

    with spaces.aoti_capture(ip_adapter_global.pipe.unet) as call:
        # Execute a minimal generation using the IP-Adapter's generate method.
        # This will trigger the forward pass of `pipe_global.unet` with
        # all the necessary IP-Adapter embeddings, allowing `aoti_capture` to trace it.
        _ = ip_adapter_global.generate(
            prompt=dummy_prompt,
            images=[dummy_ip_image],  # Provide a dummy image to trace the IP-Adapter path
            height=1024, width=1024,
            num_inference_steps=2,  # Use minimal steps for fast tracing
            guidance_scale=7.5,
            num_images_per_prompt=1,
            output_type="pil",
        ).images[0]
    
    # Export the captured UNet module
    print("Exporting UNet...")
    exported_unet = torch.export.export(
        ip_adapter_global.pipe.unet,
        args=call.args,
        kwargs=call.kwargs,
    )
    
    # Compile the exported UNet module
    print("Compiling UNet...")
    compiled_unet = spaces.aoti_compile(exported_unet)
    print("UNet compilation complete.")
    
    # Apply the compiled module back to the pipeline's UNet
    spaces.aoti_apply(compiled_unet, ip_adapter_global.pipe.unet)
    print("AoT compiled UNet applied to the pipeline.")
    print("βœ… Models loaded and compiled successfully!")

# Call the loading and compilation function once when this module is imported
load_and_compile_models()

@spaces.GPU(duration=60)  # Allocate up to 60 seconds for actual image generation
def remix_images(
    prompt: str,
    image1: Image.Image | None,
    image2: Image.Image | None,
    image3: Image.Image | None
) -> list[Image.Image]:
    """
    Generates images based on a text prompt and up to three input images using SDXL with IP-Adapter.

    Args:
        prompt (str): The text prompt for image generation.
        image1 (PIL.Image.Image | None): The first input image.
        image2 (PIL.Image.Image | None): The second input image.
        image3 (PIL.Image.Image | None): The third input image.

    Returns:
        list[PIL.Image.Image]: A list of generated images.
    """
    if not prompt:
        raise gr.Error("Prompt cannot be empty! Please provide a textual description.")

    # Filter out None images to create a list of valid input images
    input_images = [img for img in [image1, image2, image3] if img is not None]

    print(f"Generating image(s) for prompt: '{prompt}'")
    print(f"Using {len(input_images)} input images for IP-Adapter.")

    # Call the IP-Adapter's generate method.
    # The `ip-adapter` library's `generate` method is designed to handle
    # an empty `images` list by falling back to pure text-to-image generation.
    generated_images = ip_adapter_global.generate(
        prompt=prompt,
        images=input_images,  # This can be an empty list
        height=1024, width=1024,
        num_inference_steps=30,  # Standard number of inference steps
        guidance_scale=7.5,      # Classifier-free guidance scale
        num_images_per_prompt=1, # Generate one image per request
        output_type="pil",       # Ensure output is PIL Image objects
        # No seed is used as per requirement
    ).images

    return generated_images