Gertie01 commited on
Commit
1ba3abf
·
verified ·
1 Parent(s): e5b2acc

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ To create an image remixer with Gradio, where users can drag up to three images and provide a text prompt to generate a new image, we'll use a `diffusers` pipeline. Specifically, we'll use `stabilityai/stable-diffusion-xl-base-1.0` for text-to-image generation and `stabilityai/stable-diffusion-xl-refiner-1.0` for image-to-image generation.
2
+
3
+ Crucially, we will implement **ZeroGPU Ahead-of-Time (AoT) compilation** for the UNet components of both diffusion models. This is mandatory for GPU-dependent diffusion models on Hugging Face Spaces for significant performance improvements.
4
+
5
+ The application will be structured into multiple files for better organization:
6
+
7
+ * `app.py`: The main Gradio interface.
8
+ * `models.py`: Handles model loading, AoT compilation, and the core remixing logic.
9
+ * `requirements.txt`: Lists all Python dependencies.
10
+
11
+ ---
12
+
13
+ ### `app.py`
14
+ ```python
15
+ import gradio as gr
16
+ from models import remix_images_inference # Import the inference function from models.py
17
+
18
+ def main():
19
+ with gr.Blocks(title="Image Remix with SDXL") as demo:
20
+ gr.HTML(
21
+ """
22
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
23
+ <h1 style="font-weight: 900; font-size: 2.5em; margin-bottom: 0.5em;">
24
+ Image Remix with SDXL
25
+ </h1>
26
+ <p style="margin-bottom: 1em; font-size: 1.1em; color: #555;">
27
+ Drag and drop up to three images and provide a text prompt to remix them using Stable Diffusion XL.
28
+ If the first image is provided, it will be used as a base for image-to-image generation.
29
+ </p>
30
+ <p style="font-size: 0.9em; color: #777;">
31
+ Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: #4CAF50; text-decoration: none;">anycoder</a>
32
+ </p>
33
+ </div>
34
+ """
35
+ )
36
+
37
+ with gr.Row():
38
+ with gr.Column(scale=1):
39
+ with gr.Group():
40
+ image_input_1 = gr.Image(
41
+ label="Image Input 1 (Base for Img2Img)",
42
+ type="pil",
43
+ height=256,
44
+ width=256,
45
+ image_mode="RGBA",
46
+ drop_threshold_height=200
47
+ )
48
+ image_input_2 = gr.Image(
49
+ label="Image Input 2 (Optional)",
50
+ type="pil",
51
+ height=256,
52
+ width=256,
53
+ image_mode="RGBA",
54
+ drop_threshold_height=200
55
+ )
56
+ image_input_3 = gr.Image(
57
+ label="Image Input 3 (Optional)",
58
+ type="pil",
59
+ height=256,
60
+ width=256,
61
+ image_mode="RGBA",
62
+ drop_threshold_height=200
63
+ )
64
+
65
+ prompt_input = gr.Textbox(
66
+ label="Remix Prompt",
67
+ placeholder="A vibrant abstract painting blending elements of nature and technology",
68
+ lines=2
69
+ )
70
+ remix_button = gr.Button("Remix Images", variant="primary")
71
+
72
+ with gr.Column(scale=2):
73
+ output_image = gr.Image(
74
+ label="Remixed Image",
75
+ type="pil",
76
+ interactive=False,
77
+ height=512,
78
+ width=512
79
+ )
80
+
81
+ # Define the interaction
82
+ remix_button.click(
83
+ fn=remix_images_inference,
84
+ inputs=[image_input_1, image_input_2, image_input_3, prompt_input],
85
+ outputs=output_image,
86
+ queue=True,
87
+ show_progress="full"
88
+ )
89
+
90
+ gr.Examples(
91
+ examples=[
92
+ [
93
+ "https://www.kasandbox.org/programming-images/avatars/spunky-sam-headphones.png",
94
+ None,
95
+ None,
96
+ "a robot wearing headphones, futuristic, cyberpunk art style"
97
+ ],
98
+ [
99
+ "https://gradio-docs-json.s3.us-west-2.amazonaws.com/base.png",
100
+ "https://gradio-docs-json.s3.us-west-2.amazonaws.com/buildings.png",
101
+ None,
102
+ "a serene landscape with ancient ruins, overgrown with lush vegetation, concept art, fantasy"
103
+ ],
104
+ [
105
+ None,
106
+ None,
107
+ None,
108
+ "an astronaut riding a horse on the moon, cinematic, photorealistic"
109
+ ]
110
+ ],
111
+ inputs=[image_input_1, image_input_2, image_input_3, prompt_input],
112
+ outputs=output_image,
113
+ fn=remix_images_inference,
114
+ cache_examples=False, # Cache examples can be set to True if the inference is fast enough
115
+ run_on_click=True
116
+ )
117
+
118
+ demo.launch(enable_monitoring=True)
119
+
120
+ if __name__ == "__main__":
121
+ main()
122
+ ```
123
+
124
+ ### `models.py`
125
+ ```python
126
+ import spaces
127
+ import torch
128
+ from diffusers import DiffusionPipeline, AutoPipelineForImage2Image
129
+ from PIL import Image
130
+ import numpy as np
131
+ import os
132
+ import gradio as gr
133
+
134
+ # Set a cache directory for models if not already set, for smoother experience on Spaces
135
+ os.environ["HF_HOME"] = os.getenv("HF_HOME", "/data/hf_cache")
136
+
137
+ MODEL_ID_TEXT2IMG = "stabilityai/stable-diffusion-xl-base-1.0"
138
+ MODEL_ID_IMG2IMG = "stabilityai/stable-diffusion-xl-refiner-1.0" # SDXL Refiner for img2img
139
+
140
+ # Load models outside the GPU context first. Using fp16 for faster inference.
141
+ print("Loading models (this may take a moment)...")
142
+ pipe_t2i_raw = DiffusionPipeline.from_pretrained(MODEL_ID_TEXT2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
143
+ pipe_i2i_raw = AutoPipelineForImage2Image.from_pretrained(MODEL_ID_IMG2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
144
+ print("Models loaded.")
145
+
146
+ def prepare_unet_dummy_inputs(pipe, resolution=(1024, 1024), batch_size=2, dtype=torch.float16, device="cuda"):
147
+ """
148
+ Prepares dummy inputs for SDXL's UNet for AoT compilation.
149
+ """
150
+ height, width = resolution
151
+ latent_height = height // 8
152
+ latent_width = width // 8
153
+
154
+ dummy_latents = torch.randn(batch_size, pipe.unet.config.in_channels, latent_height, latent_width, device=device, dtype=dtype)
155
+ dummy_timestep = torch.tensor(1.0, device=device, dtype=dtype)
156
+
157
+ dummy_encoder_hidden_states = torch.randn(batch_size, 77, pipe.unet.config.cross_attention_dim, device=device, dtype=dtype)
158
+
159
+ # added_cond_kwargs contains text_embeds and time_ids for SDXL UNet
160
+ dummy_text_embeds = torch.randn(batch_size, pipe.unet.config.addition_embed_type_num_vector_context_tokens, device=device, dtype=dtype)
161
+ dummy_time_ids = torch.randn(batch_size, 6, device=device, dtype=dtype)
162
+
163
+ unet_inputs = {
164
+ "sample": dummy_latents,
165
+ "timestep": dummy_timestep,
166
+ "encoder_hidden_states": dummy_encoder_hidden_states,
167
+ "added_cond_kwargs": {
168
+ "text_embeds": dummy_text_embeds,
169
+ "time_ids": dummy_time_ids,
170
+ }
171
+ }
172
+ return unet_inputs
173
+
174
+ @spaces.GPU(duration=1500) # Use max duration for compilation at startup
175
+ def compile_optimized_models():
176
+ """
177
+ Compiles the UNet components of both text-to-image and image-to-image pipelines
178
+ using Ahead-of-Time (AoT) compilation for performance optimization.
179
+ """
180
+ print("Moving models to CUDA...")
181
+ pipe_t2i_raw.to("cuda")
182
+ pipe_i2i_raw.to("cuda")
183
+ print("Models moved to CUDA.")
184
+
185
+ # Compile UNet for text2img pipeline
186
+ print("Compiling Text2Image UNet...")
187
+ dummy_inputs_t2i = prepare_unet_dummy_inputs(pipe_t2i_raw)
188
+ with spaces.aoti_capture(pipe_t2i_raw.unet, **dummy_inputs_t2i) as call_t2i:
189
+ pass # Inputs are passed directly to aoti_capture for explicit tracing
190
+
191
+ exported_t2i_unet = torch.export.export(pipe_t2i_raw.unet, args=call_t2i.args, kwargs=call_t2i.kwargs)
192
+ compiled_t2i_unet = spaces.aoti_compile(exported_t2i_unet)
193
+ spaces.aoti_apply(compiled_t2i_unet, pipe_t2i_raw.unet)
194
+ print("Text2Image UNet compiled.")
195
+
196
+ # Compile UNet for img2img (refiner) pipeline
197
+ print("Compiling Image2Image UNet...")
198
+ dummy_inputs_i2i = prepare_unet_dummy_inputs(pipe_i2i_raw)
199
+ with spaces.aoti_capture(pipe_i2i_raw.unet, **dummy_inputs_i2i) as call_i2i:
200
+ pass # Inputs are passed directly to aoti_capture for explicit tracing
201
+
202
+ exported_i2i_unet = torch.export.export(pipe_i2i_raw.unet, args=call_i2i.args, kwargs=call_i2i.kwargs)
203
+ compiled_i2i_unet = spaces.aoti_compile(exported_i2i_unet)
204
+ spaces.aoti_apply(compiled_i2i_unet, pipe_i2i_raw.unet)
205
+ print("Image2Image UNet compiled.")
206
+
207
+ # Return the now-compiled pipelines
208
+ return pipe_t2i_raw, pipe_i2i_raw
209
+
210
+ # Execute compilation during startup
211
+ pipe_text2img, pipe_img2img = compile_optimized_models()
212
+
213
+ @spaces.GPU(duration=120) # Allocate GPU for inference, max 120 seconds
214
+ def remix_images_inference(image1: Image.Image | None, image2: Image.Image | None, image3: Image.Image | None, prompt: str) -> Image.Image:
215
+ """
216
+ Remixes images based on a text prompt using a diffusion model.
217
+ If image1 is provided, it uses an image-to-image pipeline for remixing.
218
+ Otherwise, it falls back to a text-to-image pipeline.
219
+
220
+ Args:
221
+ image1 (Image.Image | None): The first input image. If provided, used as base for img2img.
222
+ image2 (Image.Image | None): The second input image (currently influences prompt slightly).
223
+ image3 (Image.Image | None): The third input image (currently influences prompt slightly).
224
+ prompt (str): The text prompt to guide the remixing.
225
+
226
+ Returns:
227
+ Image.Image: The remixed image.
228
+ """
229
+ if not prompt:
230
+ raise gr.Error("Remix prompt cannot be empty!")
231
+
232
+ output_resolution = (1024, 1024) # Fixed resolution for compiled models
233
+
234
+ # Build a more descriptive prompt if additional images are provided,
235
+ # to somewhat acknowledge their presence in the "remix".
236
+ extra_prompt_info = ""
237
+ if image2 is not None and image3 is not None:
238
+ extra_prompt_info = ", incorporating elements from other images"
239
+ elif image2 is not None or image3 is not None:
240
+ extra_prompt_info = ", with subtle influences from another image"
241
+
242
+ full_prompt = f"{prompt}{extra_prompt_info}"
243
+ print(f"Full prompt for generation: {full_prompt}")
244
+
245
+ if image1 is not None:
246
+ # Resize the input image to the target resolution for img2img
247
+ input_image_resized = image1.resize(output_resolution, Image.LANCZOS)
248
+
249
+ print("Performing image-to-image remixing...")
250
+ generated_image = pipe_img2img(
251
+ prompt=full_prompt,
252
+ image=input_image_resized,
253
+ strength=0.75, # High strength allows significant transformation from original image
254
+ guidance_scale=8.0,
255
+ num_inference_steps=50,
256
+ width=output_resolution[0],
257
+ height=output_resolution[1],
258
+ ).images[0]
259
+ else:
260
+ print("Performing text-to-image generation...")
261
+ generated_image = pipe_text2img(
262
+ prompt=full_prompt,
263
+ height=output_resolution[0],
264
+ width=output_resolution[1],
265
+ guidance_scale=8.0,
266
+ num_inference_steps=50
267
+ ).images[0]
268
+
269
+ return generated_image
270
+
271
+ ```
272
+
273
+ ### `requirements.txt`
274
+ ```
275
+ gradio
276
+ torch
277
+ diffusers==0.28.0
278
+ transformers==4.41.2
279
+ accelerate
280
+ Pillow
281
+ xformers
282
+ torchvision
283
+ safetensors
284
+ ```