Spaces:
Running
on
Zero
Running
on
Zero
| # ------------------------------- | |
| # AI Fast Image Server (Production) | |
| # ------------------------------- | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import logging | |
| import subprocess | |
| from typing import Optional | |
| # ---------- Early, safe env defaults ---------- | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster model downloads | |
| os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1") # silence NVML in headless envs | |
| os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") | |
| # ---------- Logging ---------- | |
| logging.basicConfig( | |
| level=os.environ.get("LOG_LEVEL", "INFO").upper(), | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("ai-fast-image-server") | |
| # ---------- Config via ENV ---------- | |
| # MODEL_BACKEND: sdxl_lcm_unet (heavy), sdxl_lcm_lora (light), ssd1b_lcm_lora (light) | |
| MODEL_BACKEND = os.getenv("MODEL_BACKEND", "sdxl_lcm_lora").lower() | |
| MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024")) | |
| DEFAULT_SIZE = int(os.getenv("DEFAULT_SIZE", "1024")) | |
| SECRET_TOKEN = os.getenv("SECRET_TOKEN", "default_secret") | |
| PORT = int(os.getenv("PORT", "7860")) | |
| CONCURRENCY = int(os.getenv("CONCURRENCY", "2")) | |
| QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", "32")) | |
| ENABLE_SSR = os.getenv("ENABLE_SSR", "false").lower() == "true" # SSR can be flaky; default off | |
| # ---------- Imports that require deps ---------- | |
| import warnings | |
| warnings.filterwarnings("ignore", message="Can't initialize NVML") | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from diffusers import ( | |
| DiffusionPipeline, | |
| UNet2DConditionModel, | |
| LCMScheduler, | |
| AutoPipelineForText2Image, | |
| ) | |
| # ---------- Version guard: Torch 2.1 + NumPy 2.x is incompatible ---------- | |
| try: | |
| _np_major = int(np.__version__.split(".")[0]) | |
| if torch.__version__.startswith("2.1") and _np_major >= 2: | |
| raise RuntimeError( | |
| f"Incompatible versions: torch=={torch.__version__} with numpy=={np.__version__}. " | |
| "Pin numpy==1.26.4 or upgrade torch to >=2.3." | |
| ) | |
| except Exception as e: | |
| log.error(str(e)) | |
| raise | |
| # ---------- Paths ---------- | |
| CURRENT_DIR = os.getcwd() | |
| CACHE_DIR = os.path.join(CURRENT_DIR, "cache") | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # ---------- GPU info (logs only) ---------- | |
| def print_nvidia_smi() -> None: | |
| try: | |
| proc = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False) | |
| if proc.returncode == 0 and proc.stdout.strip(): | |
| log.info("\n" + proc.stdout.strip()) | |
| else: | |
| msg = proc.stderr.strip() if proc.stderr else "nvidia-smi not available or returned no output." | |
| log.info(msg) | |
| except FileNotFoundError: | |
| log.info("nvidia-smi not found on PATH.") | |
| print_nvidia_smi() | |
| IS_GPU = torch.cuda.is_available() | |
| DEVICE = torch.device("cuda") if IS_GPU else torch.device("cpu") | |
| DTYPE = torch.float16 if IS_GPU else torch.float32 | |
| log.info(f"CUDA available: {IS_GPU} | device={DEVICE} | dtype={DTYPE}") | |
| # ---------- Torch perf knobs ---------- | |
| try: | |
| if IS_GPU: | |
| torch.backends.cuda.matmul.allow_tf32 = True # safe perf on Ampere+ | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| # ---------- Helpers ---------- | |
| def _variant_kwargs() -> dict: | |
| # use fp16 repo variants only on GPU | |
| return {"variant": "fp16"} if IS_GPU else {} | |
| def _cpu_safety_settings(pipe: DiffusionPipeline) -> None: | |
| # reduce RAM usage and avoid giant VAE allocations on CPU | |
| try: | |
| pipe.enable_vae_tiling() | |
| except Exception: | |
| pass | |
| def _gpu_memory_efficiency(pipe: DiffusionPipeline) -> None: | |
| # enable memory-efficient attention when available | |
| enabled = False | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| enabled = True | |
| except Exception: | |
| try: | |
| pipe.enable_attention_slicing("max") | |
| enabled = True | |
| except Exception: | |
| pass | |
| if enabled: | |
| try: | |
| pipe.enable_vae_tiling() | |
| except Exception: | |
| pass | |
| # ---------- Model loading ---------- | |
| pipe: Optional[DiffusionPipeline] = None | |
| def load_pipeline() -> DiffusionPipeline: | |
| """ | |
| Load the selected backend with sensible defaults. | |
| - sdxl_lcm_unet: SDXL base + full LCM UNet (heavy, high VRAM) | |
| - sdxl_lcm_lora: SDXL base + LCM-LoRA (light, recommended) | |
| - ssd1b_lcm_lora: SSD-1B + LCM-LoRA (light) | |
| """ | |
| log.info(f"Loading model backend: {MODEL_BACKEND}") | |
| if MODEL_BACKEND == "sdxl_lcm_unet": | |
| # Heavy: downloads ~10 GB UNet; best quality/speed on big GPUs | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "latent-consistency/lcm-sdxl", | |
| torch_dtype=DTYPE, | |
| cache_dir=CACHE_DIR, | |
| **_variant_kwargs(), | |
| ) | |
| _pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| unet=unet, | |
| torch_dtype=DTYPE, | |
| cache_dir=CACHE_DIR, | |
| **_variant_kwargs(), | |
| ) | |
| elif MODEL_BACKEND == "ssd1b_lcm_lora": | |
| _pipe = AutoPipelineForText2Image.from_pretrained( | |
| "segmind/SSD-1B", | |
| torch_dtype=DTYPE, | |
| cache_dir=CACHE_DIR, | |
| **_variant_kwargs(), | |
| ) | |
| _pipe.load_lora_weights("latent-consistency/lcm-lora-ssd-1b") | |
| _pipe.fuse_lora() | |
| else: | |
| # Default & recommended: SDXL + LCM-LoRA (smaller downloads, good quality) | |
| _pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=DTYPE, | |
| cache_dir=CACHE_DIR, | |
| **_variant_kwargs(), | |
| ) | |
| _pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") | |
| _pipe.fuse_lora() | |
| # Use LCM scheduler | |
| _pipe.scheduler = LCMScheduler.from_config(_pipe.scheduler.config) | |
| # Device & memory efficiency | |
| _pipe.to(DEVICE) | |
| if IS_GPU: | |
| _gpu_memory_efficiency(_pipe) | |
| else: | |
| _cpu_safety_settings(_pipe) | |
| log.info("Pipeline loaded.") | |
| return _pipe | |
| # warmup lazily | |
| def ensure_pipe() -> DiffusionPipeline: | |
| global pipe | |
| if pipe is None: | |
| pipe = load_pipeline() | |
| return pipe | |
| # ---------- HF Spaces GPU decorator (fixes “No @spaces.GPU function detected”) ---------- | |
| try: | |
| import spaces # type: ignore | |
| GPU_DECORATOR = spaces.GPU | |
| log.info("`spaces` package detected. GPU-decorating inference function.") | |
| except Exception: | |
| GPU_DECORATOR = lambda f: f # no-op | |
| # ---------- Inference ---------- | |
| def generate_image_internal( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| seed: Optional[int] = 0, | |
| width: int = DEFAULT_SIZE, | |
| height: int = DEFAULT_SIZE, | |
| guidance_scale: float = 0.0, | |
| num_inference_steps: int = 4, | |
| ) -> Image.Image: | |
| _pipe = ensure_pipe() | |
| # Clamp to safe bounds | |
| width = int(np.clip(width, 256, MAX_IMAGE_SIZE)) | |
| height = int(np.clip(height, 256, MAX_IMAGE_SIZE)) | |
| num_inference_steps = int(np.clip(num_inference_steps, 1, 12)) | |
| guidance_scale = float(np.clip(guidance_scale, 0.0, 2.0)) | |
| # Deterministic generator | |
| generator = torch.Generator(device=DEVICE) | |
| if seed is not None: | |
| generator = generator.manual_seed(int(seed)) | |
| result = _pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, # LCM prefers low/no guidance | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| output_type="pil", | |
| ) | |
| return result.images[0] | |
| # thin wrapper that enforces the token (kept out of the GPU-decorated function) | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| seed: int = 0, | |
| width: int = DEFAULT_SIZE, | |
| height: int = DEFAULT_SIZE, | |
| guidance_scale: float = 0.0, | |
| num_inference_steps: int = 4, | |
| secret_token: str = "", | |
| ) -> Image.Image: | |
| if secret_token != SECRET_TOKEN: | |
| raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.") | |
| return generate_image_internal( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| seed=seed, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| # ---------- Optional warmup at startup ---------- | |
| def warmup(): | |
| try: | |
| ensure_pipe() | |
| _ = generate_image_internal( | |
| prompt="A quick warmup prompt, minimal style", seed=42, width=512, height=512, num_inference_steps=2 | |
| ) | |
| log.info("Warmup complete.") | |
| except Exception as e: | |
| log.warning(f"Warmup skipped or failed: {e}") | |
| if os.getenv("WARMUP", "true").lower() == "true": | |
| # Don't block too long on CPU | |
| if IS_GPU: | |
| warmup() | |
| # ---------- Gradio UI (v5) ---------- | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## Image Generator (LCM) — SDXL / SSD-1B") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe the image...") | |
| negative = gr.Textbox(label="Negative Prompt", lines=2, placeholder="(optional)") | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| width = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Width") | |
| height = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Height") | |
| with gr.Row(): | |
| guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale") | |
| steps = gr.Slider(1, 12, value=4, step=1, label="Inference steps") | |
| token = gr.Textbox(label="Secret Token", type="password", lines=1) | |
| out = gr.Image(label="Result", height=DEFAULT_SIZE, width=DEFAULT_SIZE) | |
| run = gr.Button("Generate", variant="primary") | |
| inputs = [prompt, negative, seed, width, height, guidance, steps, token] | |
| run.click(fn=generate, inputs=inputs, outputs=out, concurrency_limit=CONCURRENCY) | |
| # Simple health info | |
| gr.Markdown( | |
| f"*Backend:* `{MODEL_BACKEND}` | " | |
| f"*Device:* `{DEVICE}` | " | |
| f"*dtype:* `{DTYPE}`" | |
| ) | |
| return demo | |
| # ---------- Launch ---------- | |
| def main(): | |
| demo = build_ui() | |
| # Queue for backpressure and concurrency control | |
| demo.queue(max_size=QUEUE_SIZE, concurrency_count=CONCURRENCY) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=PORT, | |
| show_api=True, | |
| ssr_mode=ENABLE_SSR, # SSR off by default (can be flaky on Spaces) | |
| share=False, | |
| show_error=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |