File size: 13,216 Bytes
6fdbc47
35220ff
6fdbc47
4ebc629
6fdbc47
4e09337
6fdbc47
 
4e09337
35220ff
92717ee
74942a4
8888e64
c2fbaa7
6fdbc47
 
 
 
 
 
 
 
 
a1cb500
6fdbc47
74942a4
6fdbc47
 
 
 
 
 
 
74942a4
 
6fdbc47
c2fbaa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1cb500
 
 
6fdbc47
d2a27fd
4e09337
6fdbc47
 
 
 
 
 
 
d2a27fd
6fdbc47
 
 
 
 
 
 
 
 
 
 
4e09337
6fdbc47
 
 
 
92717ee
6fdbc47
 
4e09337
6fdbc47
a1cb500
6fdbc47
4e09337
6fdbc47
 
4e09337
6fdbc47
4e09337
 
 
74942a4
 
6fdbc47
74942a4
 
c2fbaa7
6fdbc47
74942a4
6fdbc47
 
 
74942a4
6fdbc47
 
 
74942a4
 
 
 
6fdbc47
 
74942a4
 
6fdbc47
 
 
74942a4
6fdbc47
74942a4
 
 
6fdbc47
a8d4cbb
6fdbc47
23c2f20
6fdbc47
 
74942a4
6fdbc47
 
74942a4
6fdbc47
 
74942a4
6fdbc47
 
 
23c2f20
74942a4
6fdbc47
74942a4
6fdbc47
 
23c2f20
 
 
a8d4cbb
23c2f20
74942a4
6fdbc47
74942a4
 
6fdbc47
74942a4
6fdbc47
 
23c2f20
 
 
a8d4cbb
23c2f20
74942a4
6fdbc47
74942a4
 
 
35220ff
74942a4
 
6fdbc47
a8d4cbb
74942a4
6fdbc47
 
a8d4cbb
6fdbc47
 
74942a4
6fdbc47
 
8888e64
 
 
35220ff
 
 
 
 
 
20ea5a4
 
74942a4
8888e64
 
 
 
74942a4
8888e64
74942a4
8888e64
 
 
 
 
 
 
 
 
 
 
 
74942a4
20ea5a4
8888e64
20ea5a4
92717ee
20ea5a4
 
 
 
 
 
 
4e09337
9273088
 
 
20ea5a4
 
 
8888e64
 
 
 
 
 
 
 
d2a27fd
20ea5a4
 
 
 
 
 
 
 
74942a4
20ea5a4
 
 
 
 
 
 
 
 
74942a4
8888e64
 
 
 
20ea5a4
 
 
74942a4
 
 
 
 
 
 
 
35220ff
74942a4
 
 
 
 
 
20ea5a4
74942a4
 
 
 
 
92717ee
35220ff
6fdbc47
a8d4cbb
a1cb500
a8d4cbb
 
74942a4
 
 
 
 
 
 
 
a8d4cbb
6fdbc47
 
 
 
 
 
74942a4
6fdbc47
 
74942a4
6fdbc47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35220ff
6fdbc47
 
4e09337
6fdbc47
74942a4
 
92717ee
6fdbc47
 
 
 
8888e64
 
 
a8d4cbb
8888e64
a8d4cbb
 
 
 
 
 
6fdbc47
35220ff
8888e64
a8d4cbb
6fdbc47
 
 
 
35220ff
6fdbc47
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# -------------------------------
# AI Fast Image Server — ZeroGPU Ready (Gradio 5)
# -------------------------------

from __future__ import annotations
import os
import sys
import logging
import subprocess
from typing import Optional

# ---------- Fast, safe defaults ----------
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")   # faster model downloads
os.environ.setdefault("DEEPSPEED_DISABLE_NVML", "1")      # silence NVML in headless envs
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")

# ---------- Logging ----------
logging.basicConfig(
    level=os.environ.get("LOG_LEVEL", "INFO").upper(),
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    stream=sys.stdout,
)
log = logging.getLogger("ai-fast-image-server")

# ---------- Config via ENV ----------
# MODEL_BACKEND: "sdxl_lcm_lora" (default), "sdxl_lcm_unet" (heavy), "ssd1b_lcm_lora" (light)
MODEL_BACKEND = os.getenv("MODEL_BACKEND", "sdxl_lcm_lora").lower()
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
DEFAULT_SIZE = int(os.getenv("DEFAULT_SIZE", "1024"))
SECRET_TOKEN = os.getenv("SECRET_TOKEN", "default_secret")
PORT = int(os.getenv("PORT", "7860"))
CONCURRENCY = int(os.getenv("CONCURRENCY", "2"))
QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", "32"))
ENABLE_SSR = os.getenv("ENABLE_SSR", "false").lower() == "true"  # SSR off by default for stability
WARMUP = os.getenv("WARMUP", "false").lower() == "true"          # default False for ZeroGPU

# ============================================================
# Import `spaces` BEFORE any CUDA-related libs (torch/diffusers)
# ============================================================
try:
    import spaces  # real decorator on HF Spaces
except ImportError:
    # Local/dev fallback: no-op decorator so app still runs without ZeroGPU
    class _DummySpaces:
        def GPU(self, *args, **kwargs):
            def _wrap(f):
                return f
            return _wrap
    spaces = _DummySpaces()

# ---------- Third-party imports (safe to import after `spaces`) ----------
import warnings
warnings.filterwarnings("ignore", message="Can't initialize NVML")

import numpy as np
import torch
from PIL import Image
import gradio as gr
from diffusers import (
    DiffusionPipeline,
    UNet2DConditionModel,
    LCMScheduler,
    AutoPipelineForText2Image,
)

# ---------- Version guard: Torch 2.1 + NumPy 2.x is incompatible ----------
try:
    _np_major = int(np.__version__.split(".")[0])
    if torch.__version__.startswith("2.1") and _np_major >= 2:
        raise RuntimeError(
            f"Incompatible versions: torch=={torch.__version__} with numpy=={np.__version__}. "
            "Pin numpy==1.26.4 or upgrade torch to >=2.3."
        )
except Exception as e:
    log.error(str(e))
    raise

# ---------- Paths ----------
CURRENT_DIR = os.getcwd()
CACHE_DIR = os.path.join(CURRENT_DIR, "cache")
os.makedirs(CACHE_DIR, exist_ok=True)

# ---------- GPU info (logs only) ----------
def print_nvidia_smi() -> None:
    try:
        proc = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False)
        if proc.returncode == 0 and proc.stdout.strip():
            log.info("\n" + proc.stdout.strip())
        else:
            msg = proc.stderr.strip() if proc.stderr else "nvidia-smi not available or returned no output."
            log.info(msg)
    except FileNotFoundError:
        log.info("nvidia-smi not found on PATH.")

print_nvidia_smi()

# ---------- Global pipeline handle (kept on CPU between calls) ----------
pipe: Optional[DiffusionPipeline] = None

def _gpu_mem_efficiency(p: DiffusionPipeline) -> None:
    """Enable memory-efficient attention and VAE tiling where possible."""
    enabled = False
    try:
        p.enable_xformers_memory_efficient_attention()
        enabled = True
    except Exception:
        try:
            p.enable_attention_slicing("max")
            enabled = True
        except Exception:
            pass
    try:
        p.enable_vae_tiling()
    except Exception:
        pass
    if enabled:
        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass

def _build_pipeline_cpu() -> DiffusionPipeline:
    """
    Build the pipeline on CPU with float32 to keep it stable in ZeroGPU's
    CPU-only startup environment. We'll move it to CUDA inside the GPU-decorated
    function per call and return it to CPU after.
    """
    log.info(f"Building pipeline for model backend: {MODEL_BACKEND}")
    if MODEL_BACKEND == "sdxl_lcm_unet":
        # SDXL base with LCM UNet (no LoRA required)
        unet = UNet2DConditionModel.from_pretrained(
            "latent-consistency/lcm-sdxl",
            torch_dtype=torch.float32,
            cache_dir=CACHE_DIR,
        )
        _p = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            unet=unet,
            torch_dtype=torch.float32,
            cache_dir=CACHE_DIR,
        )
    elif MODEL_BACKEND == "ssd1b_lcm_lora":
        # SSD-1B with LCM-LoRA (Diffusers backend; no PEFT required)
        _p = AutoPipelineForText2Image.from_pretrained(
            "segmind/SSD-1B",
            torch_dtype=torch.float32,
            cache_dir=CACHE_DIR,
        )
        _p.load_lora_weights(
            "latent-consistency/lcm-lora-ssd-1b",
            adapter_name="lcm",
            use_peft_backend=False,  # <-- avoid PEFT requirement
        )
        _p.fuse_lora()
    else:
        # Default: SDXL + LCM-LoRA (smaller download, great speed/quality)
        _p = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float32,
            cache_dir=CACHE_DIR,
        )
        _p.load_lora_weights(
            "latent-consistency/lcm-lora-sdxl",
            adapter_name="lcm",
            use_peft_backend=False,  # <-- avoid PEFT requirement
        )
        _p.fuse_lora()

    _p.scheduler = LCMScheduler.from_config(_p.scheduler.config)
    _p.to("cpu", torch.float32)
    try:
        _p.enable_vae_tiling()
    except Exception:
        pass

    log.info("Pipeline built successfully on CPU.")
    return _p

def ensure_pipe() -> DiffusionPipeline:
    """Initializes and returns the global pipeline object."""
    global pipe
    if pipe is None:
        pipe = _build_pipeline_cpu()
    return pipe

# ---------- Cold-start aware duration estimator ----------
GPU_COLD = True  # first GPU invocation will upload weights & warm kernels

def _estimate_duration(prompt: str,
                       negative_prompt: str,
                       seed: int,
                       width: int,
                       height: int,
                       guidance_scale: float,
                       steps: int,
                       secret_token: str) -> int:
    """
    ZeroGPU runtime budget (seconds).
    Includes:
      - model->GPU transfer + warmup (cold start tax)
      - per-step cost scaled by resolution
    """
    # normalize size to 1024x1024 ~= 1.0
    px_scale = (max(256, width) * max(256, height)) / (1024 * 1024)

    # conservative costs (tuned for SDXL+LCM on H200 slice)
    cold_tax = 22.0 if GPU_COLD else 10.0   # seconds
    step_cost = 1.2                         # sec/step at 1024^2
    base = 6.0                              # misc overhead

    est = base + cold_tax + steps * step_cost * max(0.5, px_scale)

    # floors: bigger images need a higher minimum
    floor = 45 if px_scale >= 1.0 else (30 if px_scale >= 0.5 else 20)

    return int(min(120, max(floor, est)))

# ---------- Public generate (token gate) ----------
@spaces.GPU(duration=_estimate_duration)  # ZeroGPU uses this to schedule a GPU window
def generate(
    prompt: str,
    negative_prompt: str = "",
    seed: int = 0,
    width: int = DEFAULT_SIZE,
    height: int = DEFAULT_SIZE,
    guidance_scale: float = 0.0,
    steps: int = 4,
    secret_token: str = "",
) -> Image.Image:
    # Declare global BEFORE any reference or assignment to GPU_COLD
    global GPU_COLD

    if secret_token != SECRET_TOKEN:
        raise gr.Error("Invalid secret token. Set SECRET_TOKEN or pass the correct token.")

    # For logs: what window we asked ZeroGPU for (based on current cold/warm state)
    try:
        requested = _estimate_duration(prompt, negative_prompt, seed, width, height, guidance_scale, steps, secret_token)
        log.info(f"ZeroGPU duration requested: {requested}s (cold={GPU_COLD}, size={width}x{height}, steps={steps})")
    except Exception:
        pass

    _p = ensure_pipe()  # already built on CPU & cached weights on disk

    # Clamp user inputs for safety
    width = int(np.clip(width, 256, MAX_IMAGE_SIZE))
    height = int(np.clip(height, 256, MAX_IMAGE_SIZE))
    steps = int(np.clip(steps, 1, 12))
    guidance_scale = float(np.clip(guidance_scale, 0.0, 2.0))

    # Try to use CUDA when available (ZeroGPU will make it available inside this call)
    moved_to_cuda = False
    try:
        if torch.cuda.is_available():
            _p.to("cuda", torch.float16)
            _gpu_mem_efficiency(_p)
            moved_to_cuda = True
        else:
            _p.to("cpu", torch.float32)
    except Exception as e:
        log.warning(f"Falling back to CPU: {e}")
        _p.to("cpu", torch.float32)

    # mark that we've done our cold GPU upload for this process
    if moved_to_cuda:
        GPU_COLD = False

    try:
        device = "cuda" if moved_to_cuda else "cpu"
        gen = torch.Generator(device=device)
        if seed is not None:
            gen = gen.manual_seed(int(seed))

        out = _p(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=gen,
            output_type="pil",
        )
        return out.images[0]
    finally:
        # Return model to CPU so the GPU can be released immediately after call
        try:
            _p.to("cpu", torch.float32)
            _p.enable_vae_tiling()
        except Exception:
            pass

# ---------- Optional warmup (CPU only for ZeroGPU) ----------
def warmup():
    """Performs a minimal inference on CPU to warm up the components."""
    try:
        _p = ensure_pipe()
        _ = _p(
            prompt="minimal warmup",
            width=256,
            height=256,
            guidance_scale=0.0,
            num_inference_steps=1,
            generator=torch.Generator(device="cpu").manual_seed(1),
            output_type="pil",
        ).images[0]
        log.info("CPU warmup inference complete.")
    except Exception as e:
        log.warning(f"Warmup skipped or failed: {e}")

# ---------- Gradio UI (v5) ----------
def build_ui() -> gr.Blocks:
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("## Image Generator (LCM) — SDXL / SSD-1B (ZeroGPU Ready)")

        with gr.Row():
            prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe the image…")
            negative = gr.Textbox(label="Negative Prompt", lines=2, placeholder="(optional)")

        with gr.Row():
            seed = gr.Number(label="Seed", value=0, precision=0)
            width = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Width")
            height = gr.Slider(256, MAX_IMAGE_SIZE, value=DEFAULT_SIZE, step=32, label="Height")

        with gr.Row():
            guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale")
            steps = gr.Slider(1, 12, value=4, step=1, label="Inference steps")
            token = gr.Textbox(label="Secret Token", type="password", lines=1)

        out = gr.Image(label="Result", height=DEFAULT_SIZE, width=DEFAULT_SIZE)
        run = gr.Button("Generate", variant="primary")

        inputs = [prompt, negative, seed, width, height, guidance, steps, token]
        # Per-event concurrency control (Gradio v5)
        run.click(fn=generate, inputs=inputs, outputs=out, concurrency_limit=CONCURRENCY)

        gr.Markdown(
            f"*Backend:* `{MODEL_BACKEND}` &nbsp; | &nbsp; "
            f"*ZeroGPU:* `@spaces.GPU` enabled &nbsp; | &nbsp; "
            f"*Max size:* {MAX_IMAGE_SIZE}px"
        )
    return demo

# ---------- Launch ----------
def main():
    # --- Pre-load the model on startup (downloads happen here, not in GPU window) ---
    log.info("Application starting up. Pre-loading model on CPU...")
    ensure_pipe()
    log.info("Model pre-loaded successfully.")

    # --- Optional: Run a single inference on CPU if WARMUP is enabled ---
    if WARMUP:
        log.info("Warmup enabled. Running a test inference on CPU.")
        warmup()

    # --- Build and launch the Gradio UI ---
    demo = build_ui()
    demo.queue(max_size=QUEUE_SIZE)

    log.info("Starting Gradio server...")
    demo.launch(
        server_name="0.0.0.0",
        server_port=PORT,
        show_api=True,
        ssr_mode=ENABLE_SSR,  # Off by default; enable with ENABLE_SSR=true if needed
        share=False,
        show_error=True,
    )

if __name__ == "__main__":
    main()