File size: 5,856 Bytes
2d9f4b6
191fe1d
72d4212
 
bdf16fd
 
fef2666
 
 
bdf16fd
 
72d4212
bdf16fd
 
 
72d4212
 
 
 
 
 
9cd917b
72d4212
 
 
 
 
 
bdf16fd
72d4212
bdf16fd
 
 
72d4212
 
 
 
 
 
 
 
 
 
 
bdf16fd
 
 
72d4212
 
 
 
 
 
 
 
 
bdf16fd
 
 
 
72d4212
9cd917b
bdf16fd
 
 
 
 
72d4212
9cd917b
bdf16fd
 
 
 
 
 
72d4212
bdf16fd
72d4212
 
bdf16fd
 
 
2d9f4b6
bdf16fd
d09a395
eedb16d
72d4212
 
9cd917b
191fe1d
bdf16fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72d4212
bdf16fd
 
191fe1d
bdf16fd
72d4212
bdf16fd
 
 
 
 
72d4212
 
 
bdf16fd
 
 
72d4212
 
bdf16fd
 
 
 
72d4212
 
bdf16fd
72d4212
 
 
bdf16fd
 
 
72d4212
fef2666
bdf16fd
 
 
191fe1d
bdf16fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import spaces
import os
import gradio as gr
import torch
import subprocess
import importlib, site
import warnings
import logging
from huggingface_hub import hf_hub_download
from kandinsky import get_T2V_pipeline
from PIL import Image

# ============================================================
# 1️⃣ FlashAttention setup
# ============================================================
try:
    print("Attempting to download and install FlashAttention wheel...")
    flash_attention_wheel = hf_hub_download(
        repo_id="rahul7star/flash-attn-3",
        repo_type="model",
        filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
    )
    subprocess.run(["pip", "install", flash_attention_wheel], check=True)
    site.addsitedir(site.getsitepackages()[0])
    importlib.invalidate_caches()
    print("βœ… FlashAttention installed successfully.")
except Exception as e:
    print(f"⚠️ Could not install FlashAttention: {e}")
    print("Continuing without FlashAttention...")

# ============================================================
# 2️⃣ Torch + logging config
# ============================================================
warnings.filterwarnings("ignore")
logging.getLogger("torch").setLevel(logging.ERROR)
torch._logging.set_logs(
    dynamo=logging.ERROR,
    dynamic=logging.ERROR,
    aot=logging.ERROR,
    inductor=logging.ERROR,
    guards=False,
    recompiles=False,
)

# ============================================================
# 3️⃣ Ensure models are downloaded
# ============================================================
if not os.path.exists("./models_downloaded.marker"):
    print("πŸ“¦ Models not found. Running download_models.py...")
    subprocess.run(["python", "download_models.py"], check=True)
    with open("./models_downloaded.marker", "w") as f:
        f.write("done")
    print("βœ… Models downloaded successfully.")
else:
    print("βœ… Models already downloaded (marker found).")

# ============================================================
# 4️⃣ Load pipeline to CUDA (like Wan example)
# ============================================================
print("πŸ”§ Loading Kandinsky 5.0 T2V pipeline to CUDA...")
try:
    pipe = get_T2V_pipeline(
        device_map={
            "dit": "cuda:0",
            "vae": "cuda:0",
            "text_embedder": "cuda:0",
        },
        conf_path="./configs/config_5s_sft.yaml",
    )

    # Explicitly move all components to CUDA
    if hasattr(pipe, "to"):
        pipe.to("cuda")
    print("βœ… Pipeline successfully loaded and moved to CUDA.")

except Exception as e:
    print(f"❌ Pipeline load failed: {e}")
    pipe = None

# ============================================================
# 5️⃣ Generation function
# ============================================================
@spaces.GPU(duration = 40)
def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler):
    print(f"❌ Pipeline load failed: {prompt}")
    
    if pipe is None:
        return None, "❌ Pipeline not initialized."

    try:
        output_path = f"/tmp/{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}"

        if mode == "image":
            print(f"πŸ–ΌοΈ Generating image: {prompt}")
            pipe(
                prompt,
                time_length=0,
                width=width,
                height=height,
                save_path=output_path,
            )
            return output_path, f"βœ… Image saved: {output_path}"

        elif mode == "video":
            print(f"🎬 Generating {duration}s video: {prompt}")
            pipe(
                prompt,
                time_length=duration,
                width=width,
                height=height,
                num_steps=steps,
                guidance_weight=guidance,
                scheduler_scale=scheduler,
                save_path=output_path,
            )
            return output_path, f"βœ… Video saved: {output_path}"

    except torch.cuda.OutOfMemoryError:
        return None, "⚠️ CUDA OOM β€” try smaller size or shorter duration."
    except Exception as e:
        return None, f"❌ Error during generation: {e}"

# ============================================================
# 6️⃣ Gradio UI
# ============================================================
with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V Lite (CUDA)") as demo:
    gr.Markdown("## 🎞️ Kandinsky 5.0 β€” Text & Image to Video Generator")

    with gr.Row():
        with gr.Column(scale=2):
            mode = gr.Radio(["video", "image"], value="video", label="Mode")
            prompt = gr.Textbox(label="Prompt", value="A dog in red boots")
            duration = gr.Slider(1, 10, step=1, value=5, label="Video Duration (seconds)")
            width = gr.Radio([512, 768], value=768, label="Width (px)")
            height = gr.Radio([512, 768], value=512, label="Height (px)")
            steps = gr.Slider(10, 50, step=5, value=25, label="Sampling Steps")
            guidance = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Weight")
            scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale")
            btn = gr.Button("πŸš€ Generate", variant="primary")

        with gr.Column(scale=3):
            output_display = gr.Video(label="Generated Output (Video/Image)")
            status = gr.Markdown()

    btn.click(
        fn=generate_output,
        inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler],
        outputs=[output_display, status],
    )

# ============================================================
# 7️⃣ Launch app normally (no GPU decorator)
# ============================================================
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)