Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ import subprocess
|
|
| 10 |
from typing import Optional
|
| 11 |
|
| 12 |
# ---------- Fast, safe defaults ----------
|
| 13 |
-
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 14 |
os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
|
| 15 |
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
|
| 16 |
|
|
@@ -184,7 +184,9 @@ def ensure_pipe() -> DiffusionPipeline:
|
|
| 184 |
pipe = _build_pipeline_cpu()
|
| 185 |
return pipe
|
| 186 |
|
| 187 |
-
# ----------
|
|
|
|
|
|
|
| 188 |
def _estimate_duration(prompt: str,
|
| 189 |
negative_prompt: str,
|
| 190 |
seed: int,
|
|
@@ -194,17 +196,28 @@ def _estimate_duration(prompt: str,
|
|
| 194 |
steps: int,
|
| 195 |
secret_token: str) -> int:
|
| 196 |
"""
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
"""
|
| 200 |
-
|
| 201 |
px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
# ---------- Public generate (token gate) ----------
|
| 207 |
-
@spaces.GPU(duration=_estimate_duration) #
|
| 208 |
def generate(
|
| 209 |
prompt: str,
|
| 210 |
negative_prompt: str = "",
|
|
@@ -218,7 +231,14 @@ def generate(
|
|
| 218 |
if secret_token != SECRET_TOKEN:
|
| 219 |
raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
|
| 220 |
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
# Clamp user inputs for safety
|
| 224 |
width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
|
|
@@ -239,6 +259,11 @@ def generate(
|
|
| 239 |
log.warning(f"Falling back to CPU: {e}")
|
| 240 |
_p.to("cpu", torch.float32)
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
try:
|
| 243 |
device = "cuda" if moved_to_cuda else "cpu"
|
| 244 |
gen = torch.Generator(device=device)
|
|
@@ -268,7 +293,6 @@ def generate(
|
|
| 268 |
def warmup():
|
| 269 |
"""Performs a minimal inference on CPU to warm up the components."""
|
| 270 |
try:
|
| 271 |
-
# Ensure pipe is loaded, though it should be already by main()
|
| 272 |
_p = ensure_pipe()
|
| 273 |
_ = _p(
|
| 274 |
prompt="minimal warmup",
|
|
@@ -318,11 +342,11 @@ def build_ui() -> gr.Blocks:
|
|
| 318 |
|
| 319 |
# ---------- Launch ----------
|
| 320 |
def main():
|
| 321 |
-
# ---
|
| 322 |
-
log.info("Application starting up. Pre-loading model...")
|
| 323 |
-
ensure_pipe()
|
| 324 |
log.info("Model pre-loaded successfully.")
|
| 325 |
-
|
| 326 |
# --- Optional: Run a single inference on CPU if WARMUP is enabled ---
|
| 327 |
if WARMUP:
|
| 328 |
log.info("Warmup enabled. Running a test inference on CPU.")
|
|
@@ -330,9 +354,8 @@ def main():
|
|
| 330 |
|
| 331 |
# --- Build and launch the Gradio UI ---
|
| 332 |
demo = build_ui()
|
| 333 |
-
# Gradio v5: queue() no longer accepts `concurrency_count`; use per-event limits.
|
| 334 |
demo.queue(max_size=QUEUE_SIZE)
|
| 335 |
-
|
| 336 |
log.info("Starting Gradio server...")
|
| 337 |
demo.launch(
|
| 338 |
server_name="0.0.0.0",
|
|
|
|
| 10 |
from typing import Optional
|
| 11 |
|
| 12 |
# ---------- Fast, safe defaults ----------
|
| 13 |
+
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
|
| 14 |
os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
|
| 15 |
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
|
| 16 |
|
|
|
|
| 184 |
pipe = _build_pipeline_cpu()
|
| 185 |
return pipe
|
| 186 |
|
| 187 |
+
# ---------- Cold-start aware duration estimator ----------
|
| 188 |
+
GPU_COLD = True # first GPU invocation will upload weights & warm kernels
|
| 189 |
+
|
| 190 |
def _estimate_duration(prompt: str,
|
| 191 |
negative_prompt: str,
|
| 192 |
seed: int,
|
|
|
|
| 196 |
steps: int,
|
| 197 |
secret_token: str) -> int:
|
| 198 |
"""
|
| 199 |
+
ZeroGPU runtime budget (seconds).
|
| 200 |
+
Includes:
|
| 201 |
+
- model->GPU transfer + warmup (cold start tax)
|
| 202 |
+
- per-step cost scaled by resolution
|
| 203 |
"""
|
| 204 |
+
# normalize size to 1024x1024 ~= 1.0
|
| 205 |
px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
|
| 206 |
+
|
| 207 |
+
# conservative costs (tuned for SDXL+LCM on H200 slice)
|
| 208 |
+
cold_tax = 22.0 if GPU_COLD else 10.0 # seconds
|
| 209 |
+
step_cost = 1.2 # sec/step at 1024^2
|
| 210 |
+
base = 6.0 # misc overhead
|
| 211 |
+
|
| 212 |
+
est = base + cold_tax + steps * step_cost * max(0.5, px_scale)
|
| 213 |
+
|
| 214 |
+
# floors: bigger images need a higher minimum
|
| 215 |
+
floor = 45 if px_scale >= 1.0 else (30 if px_scale >= 0.5 else 20)
|
| 216 |
+
|
| 217 |
+
return int(min(120, max(floor, est)))
|
| 218 |
|
| 219 |
# ---------- Public generate (token gate) ----------
|
| 220 |
+
@spaces.GPU(duration=_estimate_duration) # ZeroGPU uses this to schedule a GPU window
|
| 221 |
def generate(
|
| 222 |
prompt: str,
|
| 223 |
negative_prompt: str = "",
|
|
|
|
| 231 |
if secret_token != SECRET_TOKEN:
|
| 232 |
raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
|
| 233 |
|
| 234 |
+
# For logs: what window we asked ZeroGPU for (based on current cold/warm state)
|
| 235 |
+
try:
|
| 236 |
+
requested = _estimate_duration(prompt, negative_prompt, seed, width, height, guidance_scale, steps, secret_token)
|
| 237 |
+
log.info(f"ZeroGPU duration requested: {requested}s (cold={GPU_COLD}, size={width}x{height}, steps={steps})")
|
| 238 |
+
except Exception:
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
_p = ensure_pipe() # already built on CPU & cached weights on disk
|
| 242 |
|
| 243 |
# Clamp user inputs for safety
|
| 244 |
width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
|
|
|
|
| 259 |
log.warning(f"Falling back to CPU: {e}")
|
| 260 |
_p.to("cpu", torch.float32)
|
| 261 |
|
| 262 |
+
# mark that we've done our cold GPU upload for this process
|
| 263 |
+
global GPU_COLD
|
| 264 |
+
if moved_to_cuda:
|
| 265 |
+
GPU_COLD = False
|
| 266 |
+
|
| 267 |
try:
|
| 268 |
device = "cuda" if moved_to_cuda else "cpu"
|
| 269 |
gen = torch.Generator(device=device)
|
|
|
|
| 293 |
def warmup():
|
| 294 |
"""Performs a minimal inference on CPU to warm up the components."""
|
| 295 |
try:
|
|
|
|
| 296 |
_p = ensure_pipe()
|
| 297 |
_ = _p(
|
| 298 |
prompt="minimal warmup",
|
|
|
|
| 342 |
|
| 343 |
# ---------- Launch ----------
|
| 344 |
def main():
|
| 345 |
+
# --- Pre-load the model on startup (downloads happen here, not in GPU window) ---
|
| 346 |
+
log.info("Application starting up. Pre-loading model on CPU...")
|
| 347 |
+
ensure_pipe()
|
| 348 |
log.info("Model pre-loaded successfully.")
|
| 349 |
+
|
| 350 |
# --- Optional: Run a single inference on CPU if WARMUP is enabled ---
|
| 351 |
if WARMUP:
|
| 352 |
log.info("Warmup enabled. Running a test inference on CPU.")
|
|
|
|
| 354 |
|
| 355 |
# --- Build and launch the Gradio UI ---
|
| 356 |
demo = build_ui()
|
|
|
|
| 357 |
demo.queue(max_size=QUEUE_SIZE)
|
| 358 |
+
|
| 359 |
log.info("Starting Gradio server...")
|
| 360 |
demo.launch(
|
| 361 |
server_name="0.0.0.0",
|