ruslanmv's picture
Update app.py
6fdbc47
raw
history blame
10.8 kB
# -------------------------------
# AI Fast Image Server (Production)
# -------------------------------
from __future__ import annotations
import os
import sys
import logging
import subprocess
from typing import Optional
# ---------- Early, safe env defaults ----------
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads
os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
# ---------- Logging ----------
logging.basicConfig(
level=os.environ.get("LOG_LEVEL", "INFO").upper(),
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
stream=sys.stdout,
)
log = logging.getLogger("ai-fast-image-server")
# ---------- Config via ENV ----------
# MODEL_BACKEND: sdxl_lcm_unet (heavy), sdxl_lcm_lora (light), ssd1b_lcm_lora (light)
MODEL_BACKEND = os.getenv("MODEL_BACKEND", "sdxl_lcm_lora").lower()
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
DEFAULT_SIZE = int(os.getenv("DEFAULT_SIZE", "1024"))
SECRET_TOKEN = os.getenv("SECRET_TOKEN", "default_secret")
PORT = int(os.getenv("PORT", "7860"))
CONCURRENCY = int(os.getenv("CONCURRENCY", "2"))
QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", "32"))
ENABLE_SSR = os.getenv("ENABLE_SSR", "false").lower() == "true" # SSR can be flaky; default off
# ---------- Imports that require deps ----------
import warnings
warnings.filterwarnings("ignore", message="Can't initialize NVML")
import numpy as np
import torch
from PIL import Image
import gradio as gr
from diffusers import (
DiffusionPipeline,
UNet2DConditionModel,
LCMScheduler,
AutoPipelineForText2Image,
)
# ---------- Version guard: Torch 2.1 + NumPy 2.x is incompatible ----------
try:
_np_major = int(np.__version__.split(".")[0])
if torch.__version__.startswith("2.1") and _np_major >= 2:
raise RuntimeError(
f"Incompatible versions: torch=={torch.__version__} with numpy=={np.__version__}. "
"Pin numpy==1.26.4 or upgrade torch to >=2.3."
)
except Exception as e:
log.error(str(e))
raise
# ---------- Paths ----------
CURRENT_DIR = os.getcwd()
CACHE_DIR = os.path.join(CURRENT_DIR, "cache")
os.makedirs(CACHE_DIR, exist_ok=True)
# ---------- GPU info (logs only) ----------
def print_nvidia_smi() -> None:
try:
proc = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False)
if proc.returncode == 0 and proc.stdout.strip():
log.info("\n" + proc.stdout.strip())
else:
msg = proc.stderr.strip() if proc.stderr else "nvidia-smi not available or returned no output."
log.info(msg)
except FileNotFoundError:
log.info("nvidia-smi not found on PATH.")
print_nvidia_smi()
IS_GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda") if IS_GPU else torch.device("cpu")
DTYPE = torch.float16 if IS_GPU else torch.float32
log.info(f"CUDA available: {IS_GPU} | device={DEVICE} | dtype={DTYPE}")
# ---------- Torch perf knobs ----------
try:
if IS_GPU:
torch.backends.cuda.matmul.allow_tf32 = True # safe perf on Ampere+
torch.set_float32_matmul_precision("high")
except Exception:
pass
# ---------- Helpers ----------
def _variant_kwargs() -> dict:
# use fp16 repo variants only on GPU
return {"variant": "fp16"} if IS_GPU else {}
def _cpu_safety_settings(pipe: DiffusionPipeline) -> None:
# reduce RAM usage and avoid giant VAE allocations on CPU
try:
pipe.enable_vae_tiling()
except Exception:
pass
def _gpu_memory_efficiency(pipe: DiffusionPipeline) -> None:
# enable memory-efficient attention when available
enabled = False
try:
pipe.enable_xformers_memory_efficient_attention()
enabled = True
except Exception:
try:
pipe.enable_attention_slicing("max")
enabled = True
except Exception:
pass
if enabled:
try:
pipe.enable_vae_tiling()
except Exception:
pass
# ---------- Model loading ----------
pipe: Optional[DiffusionPipeline] = None
def load_pipeline() -> DiffusionPipeline:
"""
Load the selected backend with sensible defaults.
- sdxl_lcm_unet: SDXL base + full LCM UNet (heavy, high VRAM)
- sdxl_lcm_lora: SDXL base + LCM-LoRA (light, recommended)
- ssd1b_lcm_lora: SSD-1B + LCM-LoRA (light)
"""
log.info(f"Loading model backend: {MODEL_BACKEND}")
if MODEL_BACKEND == "sdxl_lcm_unet":
# Heavy: downloads ~10 GB UNet; best quality/speed on big GPUs
unet = UNet2DConditionModel.from_pretrained(
"latent-consistency/lcm-sdxl",
torch_dtype=DTYPE,
cache_dir=CACHE_DIR,
**_variant_kwargs(),
)
_pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=unet,
torch_dtype=DTYPE,
cache_dir=CACHE_DIR,
**_variant_kwargs(),
)
elif MODEL_BACKEND == "ssd1b_lcm_lora":
_pipe = AutoPipelineForText2Image.from_pretrained(
"segmind/SSD-1B",
torch_dtype=DTYPE,
cache_dir=CACHE_DIR,
**_variant_kwargs(),
)
_pipe.load_lora_weights("latent-consistency/lcm-lora-ssd-1b")
_pipe.fuse_lora()
else:
# Default & recommended: SDXL + LCM-LoRA (smaller downloads, good quality)
_pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=DTYPE,
cache_dir=CACHE_DIR,
**_variant_kwargs(),
)
_pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
_pipe.fuse_lora()
# Use LCM scheduler
_pipe.scheduler = LCMScheduler.from_config(_pipe.scheduler.config)
# Device & memory efficiency
_pipe.to(DEVICE)
if IS_GPU:
_gpu_memory_efficiency(_pipe)
else:
_cpu_safety_settings(_pipe)
log.info("Pipeline loaded.")
return _pipe
# warmup lazily
def ensure_pipe() -> DiffusionPipeline:
global pipe
if pipe is None:
pipe = load_pipeline()
return pipe
# ---------- HF Spaces GPU decorator (fixes “No @spaces.GPU function detected”) ----------
try:
import spaces # type: ignore
GPU_DECORATOR = spaces.GPU
log.info("`spaces` package detected. GPU-decorating inference function.")
except Exception:
GPU_DECORATOR = lambda f: f # no-op
# ---------- Inference ----------
@gpu_dec := GPU_DECORATOR
def generate_image_internal(
prompt: str,
negative_prompt: str = "",
seed: Optional[int] = 0,
width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE,
guidance_scale: float = 0.0,
num_inference_steps: int = 4,
) -> Image.Image:
_pipe = ensure_pipe()
# Clamp to safe bounds
width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
height = int(np.clip(height, 256, MAX_IMAGE_SIZE))
num_inference_steps = int(np.clip(num_inference_steps, 1, 12))
guidance_scale = float(np.clip(guidance_scale, 0.0, 2.0))
# Deterministic generator
generator = torch.Generator(device=DEVICE)
if seed is not None:
generator = generator.manual_seed(int(seed))
result = _pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale, # LCM prefers low/no guidance
num_inference_steps=num_inference_steps,
generator=generator,
output_type="pil",
)
return result.images[0]
# thin wrapper that enforces the token (kept out of the GPU-decorated function)
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE,
guidance_scale: float = 0.0,
num_inference_steps: int = 4,
secret_token: str = "",
) -> Image.Image:
if secret_token != SECRET_TOKEN:
raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")
return generate_image_internal(
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
)
# ---------- Optional warmup at startup ----------
def warmup():
try:
ensure_pipe()
_ = generate_image_internal(
prompt="A quick warmup prompt, minimal style", seed=42, width=512, height=512, num_inference_steps=2
)
log.info("Warmup complete.")
except Exception as e:
log.warning(f"Warmup skipped or failed: {e}")
if os.getenv("WARMUP", "true").lower() == "true":
# Don't block too long on CPU
if IS_GPU:
warmup()
# ---------- Gradio UI (v5) ----------
def build_ui() -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Image Generator (LCM) — SDXL / SSD-1B")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe the image...")
negative = gr.Textbox(label="Negative Prompt", lines=2, placeholder="(optional)")
with gr.Row():
seed = gr.Number(label="Seed", value=0, precision=0)
width = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Height")
with gr.Row():
guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale")
steps = gr.Slider(1, 12, value=4, step=1, label="Inference steps")
token = gr.Textbox(label="Secret Token", type="password", lines=1)
out = gr.Image(label="Result", height=DEFAULT_SIZE, width=DEFAULT_SIZE)
run = gr.Button("Generate", variant="primary")
inputs = [prompt, negative, seed, width, height, guidance, steps, token]
run.click(fn=generate, inputs=inputs, outputs=out, concurrency_limit=CONCURRENCY)
# Simple health info
gr.Markdown(
f"*Backend:* `{MODEL_BACKEND}`   |   "
f"*Device:* `{DEVICE}`   |   "
f"*dtype:* `{DTYPE}`"
)
return demo
# ---------- Launch ----------
def main():
demo = build_ui()
# Queue for backpressure and concurrency control
demo.queue(max_size=QUEUE_SIZE, concurrency_count=CONCURRENCY)
demo.launch(
server_name="0.0.0.0",
server_port=PORT,
show_api=True,
ssr_mode=ENABLE_SSR, # SSR off by default (can be flaky on Spaces)
share=False,
show_error=True,
)
if __name__ == "__main__":
main()