ruslanmv commited on
Commit
8888e64
·
1 Parent(s): 3025d81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -17
app.py CHANGED
@@ -10,7 +10,7 @@ import subprocess
10
  from typing import Optional
11
 
12
  # ---------- Fast, safe defaults ----------
13
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
14
  os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
15
  os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
16
 
@@ -184,7 +184,9 @@ def ensure_pipe() -> DiffusionPipeline:
184
  pipe = _build_pipeline_cpu()
185
  return pipe
186
 
187
- # ---------- Duration model for ZeroGPU (match decorated function signature) ----------
 
 
188
  def _estimate_duration(prompt: str,
189
  negative_prompt: str,
190
  seed: int,
@@ -194,17 +196,28 @@ def _estimate_duration(prompt: str,
194
  steps: int,
195
  secret_token: str) -> int:
196
  """
197
- Rough estimate (seconds) to inform ZeroGPU scheduler for better queuing.
198
- Scale by pixel count and steps. Conservative upper bound.
 
 
199
  """
200
- base = 3.0
201
  px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
202
- step_cost = 0.85 # ~0.85s/step @1024^2 (H200 slice; tune as needed)
203
- est = base + steps * step_cost * max(0.5, px_scale)
204
- return int(min(120, max(10, est)))
 
 
 
 
 
 
 
 
 
205
 
206
  # ---------- Public generate (token gate) ----------
207
- @spaces.GPU(duration=_estimate_duration) # <- MUST decorate the function Gradio calls
208
  def generate(
209
  prompt: str,
210
  negative_prompt: str = "",
@@ -218,7 +231,14 @@ def generate(
218
  if secret_token != SECRET_TOKEN:
219
  raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
220
 
221
- _p = ensure_pipe() # This will now return the pre-loaded pipe
 
 
 
 
 
 
 
222
 
223
  # Clamp user inputs for safety
224
  width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
@@ -239,6 +259,11 @@ def generate(
239
  log.warning(f"Falling back to CPU: {e}")
240
  _p.to("cpu", torch.float32)
241
 
 
 
 
 
 
242
  try:
243
  device = "cuda" if moved_to_cuda else "cpu"
244
  gen = torch.Generator(device=device)
@@ -268,7 +293,6 @@ def generate(
268
  def warmup():
269
  """Performs a minimal inference on CPU to warm up the components."""
270
  try:
271
- # Ensure pipe is loaded, though it should be already by main()
272
  _p = ensure_pipe()
273
  _ = _p(
274
  prompt="minimal warmup",
@@ -318,11 +342,11 @@ def build_ui() -> gr.Blocks:
318
 
319
  # ---------- Launch ----------
320
  def main():
321
- # --- FIX: Pre-load the model on startup ---
322
- log.info("Application starting up. Pre-loading model...")
323
- ensure_pipe() # This will download and build the pipeline on the CPU
324
  log.info("Model pre-loaded successfully.")
325
-
326
  # --- Optional: Run a single inference on CPU if WARMUP is enabled ---
327
  if WARMUP:
328
  log.info("Warmup enabled. Running a test inference on CPU.")
@@ -330,9 +354,8 @@ def main():
330
 
331
  # --- Build and launch the Gradio UI ---
332
  demo = build_ui()
333
- # Gradio v5: queue() no longer accepts `concurrency_count`; use per-event limits.
334
  demo.queue(max_size=QUEUE_SIZE)
335
-
336
  log.info("Starting Gradio server...")
337
  demo.launch(
338
  server_name="0.0.0.0",
 
10
  from typing import Optional
11
 
12
  # ---------- Fast, safe defaults ----------
13
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
14
  os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
15
  os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
16
 
 
184
  pipe = _build_pipeline_cpu()
185
  return pipe
186
 
187
+ # ---------- Cold-start aware duration estimator ----------
188
+ GPU_COLD = True # first GPU invocation will upload weights & warm kernels
189
+
190
  def _estimate_duration(prompt: str,
191
  negative_prompt: str,
192
  seed: int,
 
196
  steps: int,
197
  secret_token: str) -> int:
198
  """
199
+ ZeroGPU runtime budget (seconds).
200
+ Includes:
201
+ - model->GPU transfer + warmup (cold start tax)
202
+ - per-step cost scaled by resolution
203
  """
204
+ # normalize size to 1024x1024 ~= 1.0
205
  px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
206
+
207
+ # conservative costs (tuned for SDXL+LCM on H200 slice)
208
+ cold_tax = 22.0 if GPU_COLD else 10.0 # seconds
209
+ step_cost = 1.2 # sec/step at 1024^2
210
+ base = 6.0 # misc overhead
211
+
212
+ est = base + cold_tax + steps * step_cost * max(0.5, px_scale)
213
+
214
+ # floors: bigger images need a higher minimum
215
+ floor = 45 if px_scale >= 1.0 else (30 if px_scale >= 0.5 else 20)
216
+
217
+ return int(min(120, max(floor, est)))
218
 
219
  # ---------- Public generate (token gate) ----------
220
+ @spaces.GPU(duration=_estimate_duration) # ZeroGPU uses this to schedule a GPU window
221
  def generate(
222
  prompt: str,
223
  negative_prompt: str = "",
 
231
  if secret_token != SECRET_TOKEN:
232
  raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
233
 
234
+ # For logs: what window we asked ZeroGPU for (based on current cold/warm state)
235
+ try:
236
+ requested = _estimate_duration(prompt, negative_prompt, seed, width, height, guidance_scale, steps, secret_token)
237
+ log.info(f"ZeroGPU duration requested: {requested}s (cold={GPU_COLD}, size={width}x{height}, steps={steps})")
238
+ except Exception:
239
+ pass
240
+
241
+ _p = ensure_pipe() # already built on CPU & cached weights on disk
242
 
243
  # Clamp user inputs for safety
244
  width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
 
259
  log.warning(f"Falling back to CPU: {e}")
260
  _p.to("cpu", torch.float32)
261
 
262
+ # mark that we've done our cold GPU upload for this process
263
+ global GPU_COLD
264
+ if moved_to_cuda:
265
+ GPU_COLD = False
266
+
267
  try:
268
  device = "cuda" if moved_to_cuda else "cpu"
269
  gen = torch.Generator(device=device)
 
293
  def warmup():
294
  """Performs a minimal inference on CPU to warm up the components."""
295
  try:
 
296
  _p = ensure_pipe()
297
  _ = _p(
298
  prompt="minimal warmup",
 
342
 
343
  # ---------- Launch ----------
344
  def main():
345
+ # --- Pre-load the model on startup (downloads happen here, not in GPU window) ---
346
+ log.info("Application starting up. Pre-loading model on CPU...")
347
+ ensure_pipe()
348
  log.info("Model pre-loaded successfully.")
349
+
350
  # --- Optional: Run a single inference on CPU if WARMUP is enabled ---
351
  if WARMUP:
352
  log.info("Warmup enabled. Running a test inference on CPU.")
 
354
 
355
  # --- Build and launch the Gradio UI ---
356
  demo = build_ui()
 
357
  demo.queue(max_size=QUEUE_SIZE)
358
+
359
  log.info("Starting Gradio server...")
360
  demo.launch(
361
  server_name="0.0.0.0",