Kandinsky / app.py
rahul7star's picture
Update app.py
d09a395 verified
import spaces
import os
import gradio as gr
import torch
import subprocess
import importlib, site
import warnings
import logging
from huggingface_hub import hf_hub_download
from kandinsky import get_T2V_pipeline
from PIL import Image
# ============================================================
# 1️⃣ FlashAttention setup
# ============================================================
try:
print("Attempting to download and install FlashAttention wheel...")
flash_attention_wheel = hf_hub_download(
repo_id="rahul7star/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
subprocess.run(["pip", "install", flash_attention_wheel], check=True)
site.addsitedir(site.getsitepackages()[0])
importlib.invalidate_caches()
print("βœ… FlashAttention installed successfully.")
except Exception as e:
print(f"⚠️ Could not install FlashAttention: {e}")
print("Continuing without FlashAttention...")
# ============================================================
# 2️⃣ Torch + logging config
# ============================================================
warnings.filterwarnings("ignore")
logging.getLogger("torch").setLevel(logging.ERROR)
torch._logging.set_logs(
dynamo=logging.ERROR,
dynamic=logging.ERROR,
aot=logging.ERROR,
inductor=logging.ERROR,
guards=False,
recompiles=False,
)
# ============================================================
# 3️⃣ Ensure models are downloaded
# ============================================================
if not os.path.exists("./models_downloaded.marker"):
print("πŸ“¦ Models not found. Running download_models.py...")
subprocess.run(["python", "download_models.py"], check=True)
with open("./models_downloaded.marker", "w") as f:
f.write("done")
print("βœ… Models downloaded successfully.")
else:
print("βœ… Models already downloaded (marker found).")
# ============================================================
# 4️⃣ Load pipeline to CUDA (like Wan example)
# ============================================================
print("πŸ”§ Loading Kandinsky 5.0 T2V pipeline to CUDA...")
try:
pipe = get_T2V_pipeline(
device_map={
"dit": "cuda:0",
"vae": "cuda:0",
"text_embedder": "cuda:0",
},
conf_path="./configs/config_5s_sft.yaml",
)
# Explicitly move all components to CUDA
if hasattr(pipe, "to"):
pipe.to("cuda")
print("βœ… Pipeline successfully loaded and moved to CUDA.")
except Exception as e:
print(f"❌ Pipeline load failed: {e}")
pipe = None
# ============================================================
# 5️⃣ Generation function
# ============================================================
@spaces.GPU(duration = 40)
def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler):
print(f"❌ Pipeline load failed: {prompt}")
if pipe is None:
return None, "❌ Pipeline not initialized."
try:
output_path = f"/tmp/{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}"
if mode == "image":
print(f"πŸ–ΌοΈ Generating image: {prompt}")
pipe(
prompt,
time_length=0,
width=width,
height=height,
save_path=output_path,
)
return output_path, f"βœ… Image saved: {output_path}"
elif mode == "video":
print(f"🎬 Generating {duration}s video: {prompt}")
pipe(
prompt,
time_length=duration,
width=width,
height=height,
num_steps=steps,
guidance_weight=guidance,
scheduler_scale=scheduler,
save_path=output_path,
)
return output_path, f"βœ… Video saved: {output_path}"
except torch.cuda.OutOfMemoryError:
return None, "⚠️ CUDA OOM β€” try smaller size or shorter duration."
except Exception as e:
return None, f"❌ Error during generation: {e}"
# ============================================================
# 6️⃣ Gradio UI
# ============================================================
with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V Lite (CUDA)") as demo:
gr.Markdown("## 🎞️ Kandinsky 5.0 β€” Text & Image to Video Generator")
with gr.Row():
with gr.Column(scale=2):
mode = gr.Radio(["video", "image"], value="video", label="Mode")
prompt = gr.Textbox(label="Prompt", value="A dog in red boots")
duration = gr.Slider(1, 10, step=1, value=5, label="Video Duration (seconds)")
width = gr.Radio([512, 768], value=768, label="Width (px)")
height = gr.Radio([512, 768], value=512, label="Height (px)")
steps = gr.Slider(10, 50, step=5, value=25, label="Sampling Steps")
guidance = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Weight")
scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale")
btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=3):
output_display = gr.Video(label="Generated Output (Video/Image)")
status = gr.Markdown()
btn.click(
fn=generate_output,
inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler],
outputs=[output_display, status],
)
# ============================================================
# 7️⃣ Launch app normally (no GPU decorator)
# ============================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)