Spaces:
Paused
Paused
File size: 5,856 Bytes
2d9f4b6 191fe1d 72d4212 bdf16fd fef2666 bdf16fd 72d4212 bdf16fd 72d4212 9cd917b 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 9cd917b bdf16fd 72d4212 9cd917b bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 2d9f4b6 bdf16fd d09a395 eedb16d 72d4212 9cd917b 191fe1d bdf16fd 72d4212 bdf16fd 191fe1d bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 bdf16fd 72d4212 fef2666 bdf16fd 191fe1d bdf16fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import spaces
import os
import gradio as gr
import torch
import subprocess
import importlib, site
import warnings
import logging
from huggingface_hub import hf_hub_download
from kandinsky import get_T2V_pipeline
from PIL import Image
# ============================================================
# 1οΈβ£ FlashAttention setup
# ============================================================
try:
print("Attempting to download and install FlashAttention wheel...")
flash_attention_wheel = hf_hub_download(
repo_id="rahul7star/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
subprocess.run(["pip", "install", flash_attention_wheel], check=True)
site.addsitedir(site.getsitepackages()[0])
importlib.invalidate_caches()
print("β
FlashAttention installed successfully.")
except Exception as e:
print(f"β οΈ Could not install FlashAttention: {e}")
print("Continuing without FlashAttention...")
# ============================================================
# 2οΈβ£ Torch + logging config
# ============================================================
warnings.filterwarnings("ignore")
logging.getLogger("torch").setLevel(logging.ERROR)
torch._logging.set_logs(
dynamo=logging.ERROR,
dynamic=logging.ERROR,
aot=logging.ERROR,
inductor=logging.ERROR,
guards=False,
recompiles=False,
)
# ============================================================
# 3οΈβ£ Ensure models are downloaded
# ============================================================
if not os.path.exists("./models_downloaded.marker"):
print("π¦ Models not found. Running download_models.py...")
subprocess.run(["python", "download_models.py"], check=True)
with open("./models_downloaded.marker", "w") as f:
f.write("done")
print("β
Models downloaded successfully.")
else:
print("β
Models already downloaded (marker found).")
# ============================================================
# 4οΈβ£ Load pipeline to CUDA (like Wan example)
# ============================================================
print("π§ Loading Kandinsky 5.0 T2V pipeline to CUDA...")
try:
pipe = get_T2V_pipeline(
device_map={
"dit": "cuda:0",
"vae": "cuda:0",
"text_embedder": "cuda:0",
},
conf_path="./configs/config_5s_sft.yaml",
)
# Explicitly move all components to CUDA
if hasattr(pipe, "to"):
pipe.to("cuda")
print("β
Pipeline successfully loaded and moved to CUDA.")
except Exception as e:
print(f"β Pipeline load failed: {e}")
pipe = None
# ============================================================
# 5οΈβ£ Generation function
# ============================================================
@spaces.GPU(duration = 40)
def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler):
print(f"β Pipeline load failed: {prompt}")
if pipe is None:
return None, "β Pipeline not initialized."
try:
output_path = f"/tmp/{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}"
if mode == "image":
print(f"πΌοΈ Generating image: {prompt}")
pipe(
prompt,
time_length=0,
width=width,
height=height,
save_path=output_path,
)
return output_path, f"β
Image saved: {output_path}"
elif mode == "video":
print(f"π¬ Generating {duration}s video: {prompt}")
pipe(
prompt,
time_length=duration,
width=width,
height=height,
num_steps=steps,
guidance_weight=guidance,
scheduler_scale=scheduler,
save_path=output_path,
)
return output_path, f"β
Video saved: {output_path}"
except torch.cuda.OutOfMemoryError:
return None, "β οΈ CUDA OOM β try smaller size or shorter duration."
except Exception as e:
return None, f"β Error during generation: {e}"
# ============================================================
# 6οΈβ£ Gradio UI
# ============================================================
with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V Lite (CUDA)") as demo:
gr.Markdown("## ποΈ Kandinsky 5.0 β Text & Image to Video Generator")
with gr.Row():
with gr.Column(scale=2):
mode = gr.Radio(["video", "image"], value="video", label="Mode")
prompt = gr.Textbox(label="Prompt", value="A dog in red boots")
duration = gr.Slider(1, 10, step=1, value=5, label="Video Duration (seconds)")
width = gr.Radio([512, 768], value=768, label="Width (px)")
height = gr.Radio([512, 768], value=512, label="Height (px)")
steps = gr.Slider(10, 50, step=5, value=25, label="Sampling Steps")
guidance = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Weight")
scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale")
btn = gr.Button("π Generate", variant="primary")
with gr.Column(scale=3):
output_display = gr.Video(label="Generated Output (Video/Image)")
status = gr.Markdown()
btn.click(
fn=generate_output,
inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler],
outputs=[output_display, status],
)
# ============================================================
# 7οΈβ£ Launch app normally (no GPU decorator)
# ============================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|