ruslanmv commited on
Commit
35220ff
·
1 Parent(s): 74942a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -33
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # -------------------------------
2
- # AI Fast Image Server — ZeroGPU Ready
3
  # -------------------------------
4
 
5
  from __future__ import annotations
@@ -7,7 +7,7 @@ import os
7
  import sys
8
  import logging
9
  import subprocess
10
- from typing import Optional, Callable
11
 
12
  # ---------- Fast, safe defaults ----------
13
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
@@ -55,7 +55,6 @@ try:
55
  except Exception:
56
  class _DummySpaces:
57
  def GPU(self, *args, **kwargs):
58
- # identity decorator if not on Spaces
59
  def _wrap(f):
60
  return f
61
  return _wrap
@@ -112,17 +111,12 @@ def _gpu_mem_efficiency(p: DiffusionPipeline) -> None:
112
  except Exception:
113
  pass
114
  if enabled:
115
- # faster matmul on Ampere+
116
  try:
117
  torch.backends.cuda.matmul.allow_tf32 = True
118
  torch.set_float32_matmul_precision("high")
119
  except Exception:
120
  pass
121
 
122
- def _variant_kwargs() -> dict:
123
- # Use fp16 repo variants only when on GPU (avoid oddities on CPU)
124
- return {"variant": "fp16"}
125
-
126
  def _build_pipeline_cpu() -> DiffusionPipeline:
127
  """
128
  Build the pipeline on CPU with float32 to keep it stable in ZeroGPU's
@@ -131,12 +125,10 @@ def _build_pipeline_cpu() -> DiffusionPipeline:
131
  """
132
  log.info(f"Loading model backend: {MODEL_BACKEND}")
133
  if MODEL_BACKEND == "sdxl_lcm_unet":
134
- # Heavy: full LCM UNet (~10GB). Use only if you have big VRAM.
135
  unet = UNet2DConditionModel.from_pretrained(
136
  "latent-consistency/lcm-sdxl",
137
  torch_dtype=torch.float32,
138
  cache_dir=CACHE_DIR,
139
- # no variant on CPU
140
  )
141
  _p = DiffusionPipeline.from_pretrained(
142
  "stabilityai/stable-diffusion-xl-base-1.0",
@@ -162,13 +154,10 @@ def _build_pipeline_cpu() -> DiffusionPipeline:
162
  _p.load_lora_weights("latent-consistency/lcm-lora-sdxl")
163
  _p.fuse_lora()
164
 
165
- # Use LCM scheduler
166
  _p.scheduler = LCMScheduler.from_config(_p.scheduler.config)
167
-
168
- # Stay on CPU by default (ZeroGPU will give us CUDA only during calls)
169
  _p.to("cpu", torch.float32)
170
  try:
171
- _p.enable_vae_tiling() # also fine on CPU
172
  except Exception:
173
  pass
174
 
@@ -181,23 +170,26 @@ def ensure_pipe() -> DiffusionPipeline:
181
  pipe = _build_pipeline_cpu()
182
  return pipe
183
 
184
- # ---------- Duration model for ZeroGPU ----------
185
- def _estimate_duration(prompt: str, negative_prompt: str, seed: int,
186
- width: int, height: int, guidance_scale: float, steps: int,
187
- secret_token: str) -> int:
 
 
 
 
188
  """
189
  Rough estimate (seconds) to inform ZeroGPU scheduler for better queuing.
190
  Scale by pixel count and steps. Conservative upper bound.
191
  """
192
- base = 3.0 # pipeline dispatch + overhead
193
  px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
194
  step_cost = 0.85 # ~0.85s/step @1024^2 (H200 slice; tune as needed)
195
  est = base + steps * step_cost * max(0.5, px_scale)
196
- # Clamp between 10 and 120 seconds
197
  return int(min(120, max(10, est)))
198
 
199
  # ---------- GPU-decorated inference (Spaces detects this) ----------
200
- @spaces.GPU(duration=_estimate_duration) # dynamic duration; no-op outside Spaces
201
  def _generate_gpu_call(
202
  prompt: str,
203
  negative_prompt: str,
@@ -212,19 +204,15 @@ def _generate_gpu_call(
212
  start and back to CPU at the end so that it remains usable when GPU is released.
213
  """
214
  _p = ensure_pipe()
215
-
216
- # Move to CUDA with half precision (safe with LCM)
217
  _p.to("cuda", torch.float16)
218
  _gpu_mem_efficiency(_p)
219
 
220
  try:
221
- # Clamp inputs
222
  width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
223
  height = int(np.clip(height, 256, MAX_IMAGE_SIZE))
224
  steps = int(np.clip(steps, 1, 12))
225
  guidance_scale = float(np.clip(guidance_scale, 0.0, 2.0))
226
 
227
- # Deterministic generator on CUDA
228
  gen = torch.Generator(device="cuda")
229
  if seed is not None:
230
  gen = gen.manual_seed(int(seed))
@@ -234,21 +222,20 @@ def _generate_gpu_call(
234
  negative_prompt=negative_prompt,
235
  width=width,
236
  height=height,
237
- guidance_scale=guidance_scale, # LCM prefers low guidance
238
  num_inference_steps=steps,
239
  generator=gen,
240
  output_type="pil",
241
  )
242
  return out.images[0]
243
  finally:
244
- # Always return pipeline to CPU so next non-GPU context is safe
245
  try:
246
  _p.to("cpu", torch.float32)
247
  _p.enable_vae_tiling()
248
  except Exception:
249
  pass
250
 
251
- # ---------- Public generate (token gate kept outside GPU context) ----------
252
  def generate(
253
  prompt: str,
254
  negative_prompt: str = "",
@@ -261,7 +248,6 @@ def generate(
261
  ) -> Image.Image:
262
  if secret_token != SECRET_TOKEN:
263
  raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
264
-
265
  return _generate_gpu_call(
266
  prompt=prompt,
267
  negative_prompt=negative_prompt,
@@ -272,11 +258,10 @@ def generate(
272
  steps=num_inference_steps,
273
  )
274
 
275
- # ---------- Optional warmup (CPU only by default for ZeroGPU) ----------
276
  def warmup():
277
  try:
278
  ensure_pipe()
279
- # Tiny CPU warmup to load weights into RAM/cache
280
  _ = pipe(
281
  prompt="minimal warmup",
282
  width=256,
@@ -316,6 +301,7 @@ def build_ui() -> gr.Blocks:
316
  run = gr.Button("Generate", variant="primary")
317
 
318
  inputs = [prompt, negative, seed, width, height, guidance, steps, token]
 
319
  run.click(fn=generate, inputs=inputs, outputs=out, concurrency_limit=CONCURRENCY)
320
 
321
  gr.Markdown(
@@ -328,12 +314,13 @@ def build_ui() -> gr.Blocks:
328
  # ---------- Launch ----------
329
  def main():
330
  demo = build_ui()
331
- demo.queue(max_size=QUEUE_SIZE, concurrency_count=CONCURRENCY)
 
332
  demo.launch(
333
  server_name="0.0.0.0",
334
  server_port=PORT,
335
  show_api=True,
336
- ssr_mode=ENABLE_SSR, # Off by default; turn on with ENABLE_SSR=true if needed
337
  share=False,
338
  show_error=True,
339
  )
 
1
  # -------------------------------
2
+ # AI Fast Image Server — ZeroGPU Ready (Gradio 5)
3
  # -------------------------------
4
 
5
  from __future__ import annotations
 
7
  import sys
8
  import logging
9
  import subprocess
10
+ from typing import Optional
11
 
12
  # ---------- Fast, safe defaults ----------
13
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
 
55
  except Exception:
56
  class _DummySpaces:
57
  def GPU(self, *args, **kwargs):
 
58
  def _wrap(f):
59
  return f
60
  return _wrap
 
111
  except Exception:
112
  pass
113
  if enabled:
 
114
  try:
115
  torch.backends.cuda.matmul.allow_tf32 = True
116
  torch.set_float32_matmul_precision("high")
117
  except Exception:
118
  pass
119
 
 
 
 
 
120
  def _build_pipeline_cpu() -> DiffusionPipeline:
121
  """
122
  Build the pipeline on CPU with float32 to keep it stable in ZeroGPU's
 
125
  """
126
  log.info(f"Loading model backend: {MODEL_BACKEND}")
127
  if MODEL_BACKEND == "sdxl_lcm_unet":
 
128
  unet = UNet2DConditionModel.from_pretrained(
129
  "latent-consistency/lcm-sdxl",
130
  torch_dtype=torch.float32,
131
  cache_dir=CACHE_DIR,
 
132
  )
133
  _p = DiffusionPipeline.from_pretrained(
134
  "stabilityai/stable-diffusion-xl-base-1.0",
 
154
  _p.load_lora_weights("latent-consistency/lcm-lora-sdxl")
155
  _p.fuse_lora()
156
 
 
157
  _p.scheduler = LCMScheduler.from_config(_p.scheduler.config)
 
 
158
  _p.to("cpu", torch.float32)
159
  try:
160
+ _p.enable_vae_tiling()
161
  except Exception:
162
  pass
163
 
 
170
  pipe = _build_pipeline_cpu()
171
  return pipe
172
 
173
+ # ---------- Duration model for ZeroGPU (match decorated function signature) ----------
174
+ def _estimate_duration(prompt: str,
175
+ negative_prompt: str,
176
+ seed: int,
177
+ width: int,
178
+ height: int,
179
+ guidance_scale: float,
180
+ steps: int) -> int:
181
  """
182
  Rough estimate (seconds) to inform ZeroGPU scheduler for better queuing.
183
  Scale by pixel count and steps. Conservative upper bound.
184
  """
185
+ base = 3.0
186
  px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
187
  step_cost = 0.85 # ~0.85s/step @1024^2 (H200 slice; tune as needed)
188
  est = base + steps * step_cost * max(0.5, px_scale)
 
189
  return int(min(120, max(10, est)))
190
 
191
  # ---------- GPU-decorated inference (Spaces detects this) ----------
192
+ @spaces.GPU(duration=_estimate_duration) # no-op outside Spaces
193
  def _generate_gpu_call(
194
  prompt: str,
195
  negative_prompt: str,
 
204
  start and back to CPU at the end so that it remains usable when GPU is released.
205
  """
206
  _p = ensure_pipe()
 
 
207
  _p.to("cuda", torch.float16)
208
  _gpu_mem_efficiency(_p)
209
 
210
  try:
 
211
  width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
212
  height = int(np.clip(height, 256, MAX_IMAGE_SIZE))
213
  steps = int(np.clip(steps, 1, 12))
214
  guidance_scale = float(np.clip(guidance_scale, 0.0, 2.0))
215
 
 
216
  gen = torch.Generator(device="cuda")
217
  if seed is not None:
218
  gen = gen.manual_seed(int(seed))
 
222
  negative_prompt=negative_prompt,
223
  width=width,
224
  height=height,
225
+ guidance_scale=guidance_scale,
226
  num_inference_steps=steps,
227
  generator=gen,
228
  output_type="pil",
229
  )
230
  return out.images[0]
231
  finally:
 
232
  try:
233
  _p.to("cpu", torch.float32)
234
  _p.enable_vae_tiling()
235
  except Exception:
236
  pass
237
 
238
+ # ---------- Public generate (token gate) ----------
239
  def generate(
240
  prompt: str,
241
  negative_prompt: str = "",
 
248
  ) -> Image.Image:
249
  if secret_token != SECRET_TOKEN:
250
  raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
 
251
  return _generate_gpu_call(
252
  prompt=prompt,
253
  negative_prompt=negative_prompt,
 
258
  steps=num_inference_steps,
259
  )
260
 
261
+ # ---------- Optional warmup (CPU only for ZeroGPU) ----------
262
  def warmup():
263
  try:
264
  ensure_pipe()
 
265
  _ = pipe(
266
  prompt="minimal warmup",
267
  width=256,
 
301
  run = gr.Button("Generate", variant="primary")
302
 
303
  inputs = [prompt, negative, seed, width, height, guidance, steps, token]
304
+ # Per-event concurrency control (Gradio v5)
305
  run.click(fn=generate, inputs=inputs, outputs=out, concurrency_limit=CONCURRENCY)
306
 
307
  gr.Markdown(
 
314
  # ---------- Launch ----------
315
  def main():
316
  demo = build_ui()
317
+ # Gradio v5: queue() no longer accepts `concurrency_count`; use per-event limits.
318
+ demo.queue(max_size=QUEUE_SIZE)
319
  demo.launch(
320
  server_name="0.0.0.0",
321
  server_port=PORT,
322
  show_api=True,
323
+ ssr_mode=ENABLE_SSR, # Off by default; enable with ENABLE_SSR=true if needed
324
  share=False,
325
  show_error=True,
326
  )