ruslanmv's picture
Update app.py
23c2f20
raw
history blame
11.7 kB
# -------------------------------
# AI Fast Image Server — ZeroGPU Ready (Gradio 5)
# -------------------------------
from __future__ import annotations
import os
import sys
import logging
import subprocess
from typing import Optional
# ---------- Fast, safe defaults ----------
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
# ---------- Logging ----------
logging.basicConfig(
level=os.environ.get("LOG_LEVEL", "INFO").upper(),
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
stream=sys.stdout,
)
log = logging.getLogger("ai-fast-image-server")
# ---------- Config via ENV ----------
# MODEL_BACKEND: "sdxl_lcm_lora" (default), "sdxl_lcm_unet" (heavy), "ssd1b_lcm_lora" (light)
MODEL_BACKEND = os.getenv("MODEL_BACKEND", "sdxl_lcm_lora").lower()
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
DEFAULT_SIZE = int(os.getenv("DEFAULT_SIZE", "1024"))
SECRET_TOKEN = os.getenv("SECRET_TOKEN", "default_secret")
PORT = int(os.getenv("PORT", "7860"))
CONCURRENCY = int(os.getenv("CONCURRENCY", "2"))
QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", "32"))
ENABLE_SSR = os.getenv("ENABLE_SSR", "false").lower() == "true" # SSR off by default for stability
WARMUP = os.getenv("WARMUP", "false").lower() == "true" # default False for ZeroGPU
# ============================================================
# Import `spaces` BEFORE any CUDA-related libs (torch/diffusers)
# ============================================================
try:
import spaces # real decorator on HF Spaces
except ImportError:
# Local/dev fallback: no-op decorator so app still runs without ZeroGPU
class _DummySpaces:
def GPU(self, *args, **kwargs):
def _wrap(f):
return f
return _wrap
spaces = _DummySpaces()
# ---------- Third-party imports (safe to import after `spaces`) ----------
import warnings
warnings.filterwarnings("ignore", message="Can't initialize NVML")
import numpy as np
import torch
from PIL import Image
import gradio as gr
from diffusers import (
DiffusionPipeline,
UNet2DConditionModel,
LCMScheduler,
AutoPipelineForText2Image,
)
# ---------- Version guard: Torch 2.1 + NumPy 2.x is incompatible ----------
try:
_np_major = int(np.__version__.split(".")[0])
if torch.__version__.startswith("2.1") and _np_major >= 2:
raise RuntimeError(
f"Incompatible versions: torch=={torch.__version__} with numpy=={np.__version__}. "
"Pin numpy==1.26.4 or upgrade torch to >=2.3."
)
except Exception as e:
log.error(str(e))
raise
# ---------- Paths ----------
CURRENT_DIR = os.getcwd()
CACHE_DIR = os.path.join(CURRENT_DIR, "cache")
os.makedirs(CACHE_DIR, exist_ok=True)
# ---------- GPU info (logs only) ----------
def print_nvidia_smi() -> None:
try:
proc = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False)
if proc.returncode == 0 and proc.stdout.strip():
log.info("\n" + proc.stdout.strip())
else:
msg = proc.stderr.strip() if proc.stderr else "nvidia-smi not available or returned no output."
log.info(msg)
except FileNotFoundError:
log.info("nvidia-smi not found on PATH.")
print_nvidia_smi()
# ---------- Global pipeline handle (kept on CPU between calls) ----------
pipe: Optional[DiffusionPipeline] = None
def _gpu_mem_efficiency(p: DiffusionPipeline) -> None:
"""Enable memory-efficient attention and VAE tiling where possible."""
enabled = False
try:
p.enable_xformers_memory_efficient_attention()
enabled = True
except Exception:
try:
p.enable_attention_slicing("max")
enabled = True
except Exception:
pass
try:
p.enable_vae_tiling()
except Exception:
pass
if enabled:
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
pass
def _build_pipeline_cpu() -> DiffusionPipeline:
"""
Build the pipeline on CPU with float32 to keep it stable in ZeroGPU's
CPU-only startup environment. We'll move it to CUDA inside the GPU-decorated
function per call and return it to CPU after.
"""
log.info(f"Loading model backend: {MODEL_BACKEND}")
if MODEL_BACKEND == "sdxl_lcm_unet":
# SDXL base with LCM UNet (no LoRA required)
unet = UNet2DConditionModel.from_pretrained(
"latent-consistency/lcm-sdxl",
torch_dtype=torch.float32,
cache_dir=CACHE_DIR,
)
_p = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=unet,
torch_dtype=torch.float32,
cache_dir=CACHE_DIR,
)
elif MODEL_BACKEND == "ssd1b_lcm_lora":
# SSD-1B with LCM-LoRA (Diffusers backend; no PEFT required)
_p = AutoPipelineForText2Image.from_pretrained(
"segmind/SSD-1B",
torch_dtype=torch.float32,
cache_dir=CACHE_DIR,
)
_p.load_lora_weights(
"latent-consistency/lcm-lora-ssd-1b",
adapter_name="lcm",
use_peft_backend=False, # <-- avoid PEFT requirement
)
_p.fuse_lora()
else:
# Default: SDXL + LCM-LoRA (smaller download, great speed/quality)
_p = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32,
cache_dir=CACHE_DIR,
)
_p.load_lora_weights(
"latent-consistency/lcm-lora-sdxl",
adapter_name="lcm",
use_peft_backend=False, # <-- avoid PEFT requirement
)
_p.fuse_lora()
_p.scheduler = LCMScheduler.from_config(_p.scheduler.config)
_p.to("cpu", torch.float32)
try:
_p.enable_vae_tiling()
except Exception:
pass
log.info("Pipeline built on CPU.")
return _p
def ensure_pipe() -> DiffusionPipeline:
global pipe
if pipe is None:
pipe = _build_pipeline_cpu()
return pipe
# ---------- Duration model for ZeroGPU (match decorated function signature) ----------
def _estimate_duration(prompt: str,
negative_prompt: str,
seed: int,
width: int,
height: int,
guidance_scale: float,
steps: int,
secret_token: str) -> int:
"""
Rough estimate (seconds) to inform ZeroGPU scheduler for better queuing.
Scale by pixel count and steps. Conservative upper bound.
"""
base = 3.0
px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)
step_cost = 0.85 # ~0.85s/step @1024^2 (H200 slice; tune as needed)
est = base + steps * step_cost * max(0.5, px_scale)
return int(min(120, max(10, est)))
# ---------- Public generate (token gate) ----------
@spaces.GPU(duration=_estimate_duration) # <- MUST decorate the function Gradio calls
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE,
guidance_scale: float = 0.0,
steps: int = 4,
secret_token: str = "",
) -> Image.Image:
if secret_token != SECRET_TOKEN:
raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
_p = ensure_pipe()
# Clamp user inputs for safety
width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
height = int(np.clip(height, 256, MAX_IMAGE_SIZE))
steps = int(np.clip(steps, 1, 12))
guidance_scale = float(np.clip(guidance_scale, 0.0, 2.0))
# Try to use CUDA when available (ZeroGPU will make it available inside this call)
moved_to_cuda = False
try:
if torch.cuda.is_available():
_p.to("cuda", torch.float16)
_gpu_mem_efficiency(_p)
moved_to_cuda = True
else:
_p.to("cpu", torch.float32)
except Exception as e:
log.warning(f"Falling back to CPU: {e}")
_p.to("cpu", torch.float32)
try:
device = "cuda" if moved_to_cuda else "cpu"
gen = torch.Generator(device=device)
if seed is not None:
gen = gen.manual_seed(int(seed))
out = _p(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=gen,
output_type="pil",
)
return out.images[0]
finally:
# Return model to CPU so the GPU can be released immediately after call
try:
_p.to("cpu", torch.float32)
_p.enable_vae_tiling()
except Exception:
pass
# ---------- Optional warmup (CPU only for ZeroGPU) ----------
def warmup():
try:
ensure_pipe()
_ = pipe(
prompt="minimal warmup",
width=256,
height=256,
guidance_scale=0.0,
num_inference_steps=1,
generator=torch.Generator(device="cpu").manual_seed(1),
output_type="pil",
).images[0]
log.info("CPU warmup complete.")
except Exception as e:
log.warning(f"Warmup skipped or failed: {e}")
if WARMUP:
warmup()
# ---------- Gradio UI (v5) ----------
def build_ui() -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Image Generator (LCM) — SDXL / SSD-1B (ZeroGPU Ready)")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe the image…")
negative = gr.Textbox(label="Negative Prompt", lines=2, placeholder="(optional)")
with gr.Row():
seed = gr.Number(label="Seed", value=0, precision=0)
width = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Height")
with gr.Row():
guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale")
steps = gr.Slider(1, 12, value=4, step=1, label="Inference steps")
token = gr.Textbox(label="Secret Token", type="password", lines=1)
out = gr.Image(label="Result", height=DEFAULT_SIZE, width=DEFAULT_SIZE)
run = gr.Button("Generate", variant="primary")
inputs = [prompt, negative, seed, width, height, guidance, steps, token]
# Per-event concurrency control (Gradio v5)
run.click(fn=generate, inputs=inputs, outputs=out, concurrency_limit=CONCURRENCY)
gr.Markdown(
f"*Backend:* `{MODEL_BACKEND}` &nbsp; | &nbsp; "
f"*ZeroGPU:* `@spaces.GPU` enabled &nbsp; | &nbsp; "
f"*Max size:* {MAX_IMAGE_SIZE}px"
)
return demo
# ---------- Launch ----------
def main():
demo = build_ui()
# Gradio v5: queue() no longer accepts `concurrency_count`; use per-event limits.
demo.queue(max_size=QUEUE_SIZE)
demo.launch(
server_name="0.0.0.0",
server_port=PORT,
show_api=True,
ssr_mode=ENABLE_SSR, # Off by default; enable with ENABLE_SSR=true if needed
share=False,
show_error=True,
)
if __name__ == "__main__":
main()