# app.py import os import sys import subprocess import importlib import site import warnings import logging import time from pathlib import Path import gradio as gr import torch from huggingface_hub import hf_hub_download from PIL import Image import spaces import time import time, random # --------------------------- # Environment flags (reduce fusion/compilation) — set early # --------------------------- # These help avoid some torchinductor/flash-attn fusion issues that provoke guard errors. os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1") os.environ.setdefault("TORCHINDUCTOR_FUSION", "0") os.environ.setdefault("USE_FLASH_ATTENTION", "0") # Some environments check this; safe to set os.environ.setdefault("XLA_IGNORE_ENV_VARS", "1") # --------------------------- # FlashAttention install (best-effort) # --------------------------- def try_install_flash_attention(): try: print("Attempting to download and install FlashAttention wheel...") wheel = hf_hub_download( repo_id="rahul7star/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) subprocess.run([sys.executable, "-m", "pip", "install", wheel], check=True) # refresh site-packages site.addsitedir(site.getsitepackages()[0]) importlib.invalidate_caches() print("✅ FlashAttention installed.") return True except Exception as e: print(f"⚠️ FlashAttention install failed: {e}") return False # --------------------------- # Torch logging / warnings # --------------------------- warnings.filterwarnings("ignore") logging.getLogger("torch").setLevel(logging.ERROR) # reduce torch verbose logging try: torch._logging.set_logs( dynamo=logging.ERROR, dynamic=logging.ERROR, aot=logging.ERROR, inductor=logging.ERROR, guards=False, recompiles=False ) except Exception: pass # Make Dynamo tolerant initially (we'll disable if it fails) try: import torch._dynamo as _dynamo _dynamo.config.suppress_errors = True _dynamo.config.cache_size_limit = 0 # avoid large guard caches except Exception: _dynamo = None # --------------------------- # Download models if needed # --------------------------- def ensure_models_downloaded(marker_file=".models_ready"): marker = Path(marker_file) if marker.exists(): print("Models already downloaded (marker found).") return True if not Path("download_models.py").exists(): print("download_models.py not found in repo.") return False try: print("Running download_models.py ...") subprocess.run([sys.executable, "download_models.py"], check=True) marker.write_text("ok") print("Models download finished.") return True except Exception as e: print("Model download failed:", e) return False # --------------------------- # Load Kandinsky pipeline with smart Dynamo handling # --------------------------- def load_pipeline(conf_path="./configs/config_5s_sft.yaml", move_to_cuda_if_available=True): """ Attempt to load the pipeline normally. If Dynamo/guard errors are raised, disable torch._dynamo and reload in eager mode. Returns pipeline or raises. """ from kandinsky import get_T2V_pipeline # import inside function to respect env changes def _do_load(): print("Loading pipeline with device_map pointing to cuda if available...") device_map = None if torch.cuda.is_available(): # let the pipeline place modules onto CUDA by device_map device_map = {"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"} else: device_map = "cpu" pipe = get_T2V_pipeline(device_map=device_map, conf_path=conf_path, offload=False, magcache=False) # If pipeline has .to and CUDA is available, move it if move_to_cuda_if_available and torch.cuda.is_available() and hasattr(pipe, "to"): try: pipe.to("cuda") except Exception as e: # fallback: ignore and continue (some pipelines handle own device_map) print("Warning while moving pipeline to CUDA:", e) return pipe try: # Try normal load first (Dynamo may be enabled but we've suppressed errors) pipe = _do_load() print("Pipeline loaded successfully (initial try).") return pipe except Exception as e: # Detect Dynamo/guard-related signatures and fallback msg = str(e).lower() if "dynamo" in msg or "guard" in msg or "attributeerror" in msg or "caught" in msg: print("⚠️ Dynamo/guard-related error detected while loading pipeline:", e) # Disable torch dynamo and try again try: if _dynamo is not None: print("Disabling torch._dynamo and retrying load in eager mode...") _dynamo.disable() else: print("torch._dynamo not available; proceeding to retry load.") except Exception as ex_disable: print("Error disabling torch._dynamo:", ex_disable) # Retry load try: pipe = _do_load() print("Pipeline loaded successfully after disabling torch._dynamo.") return pipe except Exception as e2: print("Failed to load pipeline even after disabling torch._dynamo:", e2) raise else: # Not obviously a Dynamo issue — re-raise raise # --------------------------- # Startup sequence # --------------------------- print("=== startup: installing optional FlashAttention (best-effort) ===") try_install_flash_attention() print("=== startup: ensuring models ===") if not ensure_models_downloaded(): print("Models not available; app may fail at inference. Proceeding anyway.") print("=== startup: loading pipeline (smart) ===") pipe = None try: pipe = load_pipeline(conf_path="./configs/config_5s_sft.yaml", move_to_cuda_if_available=True) except Exception as e: print("Pipeline load ultimately failed:", e) pipe = None # --------------------------- # Helper: ensure pipeline is on CUDA at generation time # --------------------------- def ensure_pipe_on_cuda(pipeline): if pipeline is None: raise RuntimeError("Pipeline is None") # If CUDA not available, raise early if not torch.cuda.is_available(): raise RuntimeError("CUDA not available on this machine") # If pipeline supports .to, move it if hasattr(pipeline, "to"): try: pipeline.to("cuda") except Exception as e: # Some pipelines use device_map placement — ignore move failure print("Warning: pipeline.to('cuda') raised:", e) # --------------------------- # Generation function (runs on GPU when used) # --------------------------- @spaces.GPU(duration=60) def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler): """ This generation function assumes the pipeline is already loaded (pipe variable). It will raise a helpful error if the pipeline wasn't loaded at startup. """ print(prompt) if pipe is None: return None, "❌ Pipeline not initialized at startup. Check logs." # Ensure CUDA available and pipeline on CUDA if not torch.cuda.is_available(): return None, "❌ CUDA not available on this host." try: # If dynamo is still enabled and we suspect it can cause trouble during forward, # run inference inside a context where dynamo is disabled to be safe. try: if _dynamo is not None: _dynamo.disable() except Exception: pass out_name = f"/tmp/{int(time.time())}_{random.randint(100,999)}.{'mp4' if mode == 'video' else 'png'}" if mode == "image": pipe(prompt, time_length=0, width=width, height=height, save_path=out_name) return out_name, f"✅ Image saved to {out_name}" # video path pipe(prompt, time_length=duration, width=width, height=height, num_steps=steps if steps else None, guidance_weight=guidance if guidance else None, scheduler_scale=scheduler if scheduler else None, save_path=out_name) return out_name, f"✅ Video saved to {out_name}" except torch.cuda.OutOfMemoryError: return None, "⚠️ CUDA OOM — try reducing resolution/duration/steps." except Exception as e: return None, f"❌ Generation error: {e}" # --------------------------- # Gradio UI # --------------------------- with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V (robust load)") as demo: gr.Markdown("## Kandinsky 5.0 — Robust pipeline loader (smart Dynamo fallback)") with gr.Row(): with gr.Column(scale=2): mode = gr.Radio(["video", "image"], value="video", label="Mode") prompt = gr.Textbox(label="Prompt", value="A dog in red boots") duration = gr.Slider(1, 10, step=1, value=2, label="Duration (s)") width = gr.Radio([512, 768], value=768, label="Width") height = gr.Radio([512, 768], value=512, label="Height") steps = gr.Slider(4, 50, step=1, value=10, label="Sampling Steps") guidance = gr.Slider(0.0, 20.0, step=0.5, value=8.0, label="Guidance Weight") scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale") btn = gr.Button("Generate", variant="primary") with gr.Column(scale=3): out_video = gr.Video(label="Output") status = gr.Textbox(label="Status", lines=6) btn.click(fn=generate_output, inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler], outputs=[out_video, status]) # --------------------------- # Launch # --------------------------- if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))