Kandinsky / app_torch.py
rahul7star's picture
Update app_torch.py
1a00eda verified
# app.py
import os
import sys
import subprocess
import importlib
import site
import warnings
import logging
import time
from pathlib import Path
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
import spaces
import time
import time, random
# ---------------------------
# Environment flags (reduce fusion/compilation) β€” set early
# ---------------------------
# These help avoid some torchinductor/flash-attn fusion issues that provoke guard errors.
os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1")
os.environ.setdefault("TORCHINDUCTOR_FUSION", "0")
os.environ.setdefault("USE_FLASH_ATTENTION", "0")
# Some environments check this; safe to set
os.environ.setdefault("XLA_IGNORE_ENV_VARS", "1")
# ---------------------------
# FlashAttention install (best-effort)
# ---------------------------
def try_install_flash_attention():
try:
print("Attempting to download and install FlashAttention wheel...")
wheel = hf_hub_download(
repo_id="rahul7star/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
subprocess.run([sys.executable, "-m", "pip", "install", wheel], check=True)
# refresh site-packages
site.addsitedir(site.getsitepackages()[0])
importlib.invalidate_caches()
print("βœ… FlashAttention installed.")
return True
except Exception as e:
print(f"⚠️ FlashAttention install failed: {e}")
return False
# ---------------------------
# Torch logging / warnings
# ---------------------------
warnings.filterwarnings("ignore")
logging.getLogger("torch").setLevel(logging.ERROR)
# reduce torch verbose logging
try:
torch._logging.set_logs(
dynamo=logging.ERROR,
dynamic=logging.ERROR,
aot=logging.ERROR,
inductor=logging.ERROR,
guards=False,
recompiles=False
)
except Exception:
pass
# Make Dynamo tolerant initially (we'll disable if it fails)
try:
import torch._dynamo as _dynamo
_dynamo.config.suppress_errors = True
_dynamo.config.cache_size_limit = 0 # avoid large guard caches
except Exception:
_dynamo = None
# ---------------------------
# Download models if needed
# ---------------------------
def ensure_models_downloaded(marker_file=".models_ready"):
marker = Path(marker_file)
if marker.exists():
print("Models already downloaded (marker found).")
return True
if not Path("download_models.py").exists():
print("download_models.py not found in repo.")
return False
try:
print("Running download_models.py ...")
subprocess.run([sys.executable, "download_models.py"], check=True)
marker.write_text("ok")
print("Models download finished.")
return True
except Exception as e:
print("Model download failed:", e)
return False
# ---------------------------
# Load Kandinsky pipeline with smart Dynamo handling
# ---------------------------
def load_pipeline(conf_path="./configs/config_5s_sft.yaml", move_to_cuda_if_available=True):
"""
Attempt to load the pipeline normally. If Dynamo/guard errors are raised,
disable torch._dynamo and reload in eager mode.
Returns pipeline or raises.
"""
from kandinsky import get_T2V_pipeline # import inside function to respect env changes
def _do_load():
print("Loading pipeline with device_map pointing to cuda if available...")
device_map = None
if torch.cuda.is_available():
# let the pipeline place modules onto CUDA by device_map
device_map = {"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"}
else:
device_map = "cpu"
pipe = get_T2V_pipeline(device_map=device_map, conf_path=conf_path, offload=False, magcache=False)
# If pipeline has .to and CUDA is available, move it
if move_to_cuda_if_available and torch.cuda.is_available() and hasattr(pipe, "to"):
try:
pipe.to("cuda")
except Exception as e:
# fallback: ignore and continue (some pipelines handle own device_map)
print("Warning while moving pipeline to CUDA:", e)
return pipe
try:
# Try normal load first (Dynamo may be enabled but we've suppressed errors)
pipe = _do_load()
print("Pipeline loaded successfully (initial try).")
return pipe
except Exception as e:
# Detect Dynamo/guard-related signatures and fallback
msg = str(e).lower()
if "dynamo" in msg or "guard" in msg or "attributeerror" in msg or "caught" in msg:
print("⚠️ Dynamo/guard-related error detected while loading pipeline:", e)
# Disable torch dynamo and try again
try:
if _dynamo is not None:
print("Disabling torch._dynamo and retrying load in eager mode...")
_dynamo.disable()
else:
print("torch._dynamo not available; proceeding to retry load.")
except Exception as ex_disable:
print("Error disabling torch._dynamo:", ex_disable)
# Retry load
try:
pipe = _do_load()
print("Pipeline loaded successfully after disabling torch._dynamo.")
return pipe
except Exception as e2:
print("Failed to load pipeline even after disabling torch._dynamo:", e2)
raise
else:
# Not obviously a Dynamo issue β€” re-raise
raise
# ---------------------------
# Startup sequence
# ---------------------------
print("=== startup: installing optional FlashAttention (best-effort) ===")
try_install_flash_attention()
print("=== startup: ensuring models ===")
if not ensure_models_downloaded():
print("Models not available; app may fail at inference. Proceeding anyway.")
print("=== startup: loading pipeline (smart) ===")
pipe = None
try:
pipe = load_pipeline(conf_path="./configs/config_5s_sft.yaml", move_to_cuda_if_available=True)
except Exception as e:
print("Pipeline load ultimately failed:", e)
pipe = None
# ---------------------------
# Helper: ensure pipeline is on CUDA at generation time
# ---------------------------
def ensure_pipe_on_cuda(pipeline):
if pipeline is None:
raise RuntimeError("Pipeline is None")
# If CUDA not available, raise early
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available on this machine")
# If pipeline supports .to, move it
if hasattr(pipeline, "to"):
try:
pipeline.to("cuda")
except Exception as e:
# Some pipelines use device_map placement β€” ignore move failure
print("Warning: pipeline.to('cuda') raised:", e)
# ---------------------------
# Generation function (runs on GPU when used)
# ---------------------------
@spaces.GPU(duration=60)
def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler):
"""
This generation function assumes the pipeline is already loaded (pipe variable).
It will raise a helpful error if the pipeline wasn't loaded at startup.
"""
print(prompt)
if pipe is None:
return None, "❌ Pipeline not initialized at startup. Check logs."
# Ensure CUDA available and pipeline on CUDA
if not torch.cuda.is_available():
return None, "❌ CUDA not available on this host."
try:
# If dynamo is still enabled and we suspect it can cause trouble during forward,
# run inference inside a context where dynamo is disabled to be safe.
try:
if _dynamo is not None:
_dynamo.disable()
except Exception:
pass
out_name = f"/tmp/{int(time.time())}_{random.randint(100,999)}.{'mp4' if mode == 'video' else 'png'}"
if mode == "image":
pipe(prompt, time_length=0, width=width, height=height, save_path=out_name)
return out_name, f"βœ… Image saved to {out_name}"
# video path
pipe(prompt,
time_length=duration,
width=width,
height=height,
num_steps=steps if steps else None,
guidance_weight=guidance if guidance else None,
scheduler_scale=scheduler if scheduler else None,
save_path=out_name)
return out_name, f"βœ… Video saved to {out_name}"
except torch.cuda.OutOfMemoryError:
return None, "⚠️ CUDA OOM β€” try reducing resolution/duration/steps."
except Exception as e:
return None, f"❌ Generation error: {e}"
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V (robust load)") as demo:
gr.Markdown("## Kandinsky 5.0 β€” Robust pipeline loader (smart Dynamo fallback)")
with gr.Row():
with gr.Column(scale=2):
mode = gr.Radio(["video", "image"], value="video", label="Mode")
prompt = gr.Textbox(label="Prompt", value="A dog in red boots")
duration = gr.Slider(1, 10, step=1, value=2, label="Duration (s)")
width = gr.Radio([512, 768], value=768, label="Width")
height = gr.Radio([512, 768], value=512, label="Height")
steps = gr.Slider(4, 50, step=1, value=10, label="Sampling Steps")
guidance = gr.Slider(0.0, 20.0, step=0.5, value=8.0, label="Guidance Weight")
scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=3):
out_video = gr.Video(label="Output")
status = gr.Textbox(label="Status", lines=6)
btn.click(fn=generate_output,
inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler],
outputs=[out_video, status])
# ---------------------------
# Launch
# ---------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))