ruslanmv commited on
Commit
a8d4cbb
·
1 Parent(s): 23c2f20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -10,7 +10,7 @@ import subprocess
10
  from typing import Optional
11
 
12
  # ---------- Fast, safe defaults ----------
13
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
14
  os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
15
  os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
16
 
@@ -126,7 +126,7 @@ def _build_pipeline_cpu() -> DiffusionPipeline:
126
  CPU-only startup environment. We'll move it to CUDA inside the GPU-decorated
127
  function per call and return it to CPU after.
128
  """
129
- log.info(f"Loading model backend: {MODEL_BACKEND}")
130
  if MODEL_BACKEND == "sdxl_lcm_unet":
131
  # SDXL base with LCM UNet (no LoRA required)
132
  unet = UNet2DConditionModel.from_pretrained(
@@ -150,7 +150,7 @@ def _build_pipeline_cpu() -> DiffusionPipeline:
150
  _p.load_lora_weights(
151
  "latent-consistency/lcm-lora-ssd-1b",
152
  adapter_name="lcm",
153
- use_peft_backend=False, # <-- avoid PEFT requirement
154
  )
155
  _p.fuse_lora()
156
  else:
@@ -163,7 +163,7 @@ def _build_pipeline_cpu() -> DiffusionPipeline:
163
  _p.load_lora_weights(
164
  "latent-consistency/lcm-lora-sdxl",
165
  adapter_name="lcm",
166
- use_peft_backend=False, # <-- avoid PEFT requirement
167
  )
168
  _p.fuse_lora()
169
 
@@ -174,10 +174,11 @@ def _build_pipeline_cpu() -> DiffusionPipeline:
174
  except Exception:
175
  pass
176
 
177
- log.info("Pipeline built on CPU.")
178
  return _p
179
 
180
  def ensure_pipe() -> DiffusionPipeline:
 
181
  global pipe
182
  if pipe is None:
183
  pipe = _build_pipeline_cpu()
@@ -217,7 +218,7 @@ def generate(
217
  if secret_token != SECRET_TOKEN:
218
  raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
219
 
220
- _p = ensure_pipe()
221
 
222
  # Clamp user inputs for safety
223
  width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
@@ -265,9 +266,11 @@ def generate(
265
 
266
  # ---------- Optional warmup (CPU only for ZeroGPU) ----------
267
  def warmup():
 
268
  try:
269
- ensure_pipe()
270
- _ = pipe(
 
271
  prompt="minimal warmup",
272
  width=256,
273
  height=256,
@@ -276,13 +279,10 @@ def warmup():
276
  generator=torch.Generator(device="cpu").manual_seed(1),
277
  output_type="pil",
278
  ).images[0]
279
- log.info("CPU warmup complete.")
280
  except Exception as e:
281
  log.warning(f"Warmup skipped or failed: {e}")
282
 
283
- if WARMUP:
284
- warmup()
285
-
286
  # ---------- Gradio UI (v5) ----------
287
  def build_ui() -> gr.Blocks:
288
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -318,9 +318,22 @@ def build_ui() -> gr.Blocks:
318
 
319
  # ---------- Launch ----------
320
  def main():
 
 
 
 
 
 
 
 
 
 
 
321
  demo = build_ui()
322
  # Gradio v5: queue() no longer accepts `concurrency_count`; use per-event limits.
323
  demo.queue(max_size=QUEUE_SIZE)
 
 
324
  demo.launch(
325
  server_name="0.0.0.0",
326
  server_port=PORT,
 
10
  from typing import Optional
11
 
12
  # ---------- Fast, safe defaults ----------
13
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
14
  os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
15
  os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
16
 
 
126
  CPU-only startup environment. We'll move it to CUDA inside the GPU-decorated
127
  function per call and return it to CPU after.
128
  """
129
+ log.info(f"Building pipeline for model backend: {MODEL_BACKEND}")
130
  if MODEL_BACKEND == "sdxl_lcm_unet":
131
  # SDXL base with LCM UNet (no LoRA required)
132
  unet = UNet2DConditionModel.from_pretrained(
 
150
  _p.load_lora_weights(
151
  "latent-consistency/lcm-lora-ssd-1b",
152
  adapter_name="lcm",
153
+ use_peft_backend=False, # <-- avoid PEFT requirement
154
  )
155
  _p.fuse_lora()
156
  else:
 
163
  _p.load_lora_weights(
164
  "latent-consistency/lcm-lora-sdxl",
165
  adapter_name="lcm",
166
+ use_peft_backend=False, # <-- avoid PEFT requirement
167
  )
168
  _p.fuse_lora()
169
 
 
174
  except Exception:
175
  pass
176
 
177
+ log.info("Pipeline built successfully on CPU.")
178
  return _p
179
 
180
  def ensure_pipe() -> DiffusionPipeline:
181
+ """Initializes and returns the global pipeline object."""
182
  global pipe
183
  if pipe is None:
184
  pipe = _build_pipeline_cpu()
 
218
  if secret_token != SECRET_TOKEN:
219
  raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
220
 
221
+ _p = ensure_pipe() # This will now return the pre-loaded pipe
222
 
223
  # Clamp user inputs for safety
224
  width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
 
266
 
267
  # ---------- Optional warmup (CPU only for ZeroGPU) ----------
268
  def warmup():
269
+ """Performs a minimal inference on CPU to warm up the components."""
270
  try:
271
+ # Ensure pipe is loaded, though it should be already by main()
272
+ _p = ensure_pipe()
273
+ _ = _p(
274
  prompt="minimal warmup",
275
  width=256,
276
  height=256,
 
279
  generator=torch.Generator(device="cpu").manual_seed(1),
280
  output_type="pil",
281
  ).images[0]
282
+ log.info("CPU warmup inference complete.")
283
  except Exception as e:
284
  log.warning(f"Warmup skipped or failed: {e}")
285
 
 
 
 
286
  # ---------- Gradio UI (v5) ----------
287
  def build_ui() -> gr.Blocks:
288
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
318
 
319
  # ---------- Launch ----------
320
  def main():
321
+ # --- FIX: Pre-load the model on startup ---
322
+ log.info("Application starting up. Pre-loading model...")
323
+ ensure_pipe() # This will download and build the pipeline on the CPU
324
+ log.info("Model pre-loaded successfully.")
325
+
326
+ # --- Optional: Run a single inference on CPU if WARMUP is enabled ---
327
+ if WARMUP:
328
+ log.info("Warmup enabled. Running a test inference on CPU.")
329
+ warmup()
330
+
331
+ # --- Build and launch the Gradio UI ---
332
  demo = build_ui()
333
  # Gradio v5: queue() no longer accepts `concurrency_count`; use per-event limits.
334
  demo.queue(max_size=QUEUE_SIZE)
335
+
336
+ log.info("Starting Gradio server...")
337
  demo.launch(
338
  server_name="0.0.0.0",
339
  server_port=PORT,