import spaces import os import gradio as gr import torch import subprocess import importlib, site import warnings import logging from huggingface_hub import hf_hub_download from kandinsky import get_T2V_pipeline from PIL import Image # ============================================================ # 1️⃣ FlashAttention setup # ============================================================ try: print("Attempting to download and install FlashAttention wheel...") flash_attention_wheel = hf_hub_download( repo_id="rahul7star/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) subprocess.run(["pip", "install", flash_attention_wheel], check=True) site.addsitedir(site.getsitepackages()[0]) importlib.invalidate_caches() print("✅ FlashAttention installed successfully.") except Exception as e: print(f"⚠️ Could not install FlashAttention: {e}") print("Continuing without FlashAttention...") # ============================================================ # 2️⃣ Torch + logging config # ============================================================ warnings.filterwarnings("ignore") logging.getLogger("torch").setLevel(logging.ERROR) torch._logging.set_logs( dynamo=logging.ERROR, dynamic=logging.ERROR, aot=logging.ERROR, inductor=logging.ERROR, guards=False, recompiles=False, ) # ============================================================ # 3️⃣ Ensure models are downloaded # ============================================================ if not os.path.exists("./models_downloaded.marker"): print("📦 Models not found. Running download_models.py...") subprocess.run(["python", "download_models.py"], check=True) with open("./models_downloaded.marker", "w") as f: f.write("done") print("✅ Models downloaded successfully.") else: print("✅ Models already downloaded (marker found).") # ============================================================ # 4️⃣ Load pipeline to CUDA (like Wan example) # ============================================================ print("🔧 Loading Kandinsky 5.0 T2V pipeline to CUDA...") try: pipe = get_T2V_pipeline( device_map={ "dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0", }, conf_path="./configs/config_5s_sft.yaml", ) # Explicitly move all components to CUDA if hasattr(pipe, "to"): pipe.to("cuda") print("✅ Pipeline successfully loaded and moved to CUDA.") except Exception as e: print(f"❌ Pipeline load failed: {e}") pipe = None # ============================================================ # 5️⃣ Generation function # ============================================================ @spaces.GPU(duration = 40) def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler): print(f"❌ Pipeline load failed: {prompt}") if pipe is None: return None, "❌ Pipeline not initialized." try: output_path = f"/tmp/{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}" if mode == "image": print(f"🖼️ Generating image: {prompt}") pipe( prompt, time_length=0, width=width, height=height, save_path=output_path, ) return output_path, f"✅ Image saved: {output_path}" elif mode == "video": print(f"🎬 Generating {duration}s video: {prompt}") pipe( prompt, time_length=duration, width=width, height=height, num_steps=steps, guidance_weight=guidance, scheduler_scale=scheduler, save_path=output_path, ) return output_path, f"✅ Video saved: {output_path}" except torch.cuda.OutOfMemoryError: return None, "⚠️ CUDA OOM — try smaller size or shorter duration." except Exception as e: return None, f"❌ Error during generation: {e}" # ============================================================ # 6️⃣ Gradio UI # ============================================================ with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V Lite (CUDA)") as demo: gr.Markdown("## 🎞️ Kandinsky 5.0 — Text & Image to Video Generator") with gr.Row(): with gr.Column(scale=2): mode = gr.Radio(["video", "image"], value="video", label="Mode") prompt = gr.Textbox(label="Prompt", value="A dog in red boots") duration = gr.Slider(1, 10, step=1, value=5, label="Video Duration (seconds)") width = gr.Radio([512, 768], value=768, label="Width (px)") height = gr.Radio([512, 768], value=512, label="Height (px)") steps = gr.Slider(10, 50, step=5, value=25, label="Sampling Steps") guidance = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Weight") scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale") btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(scale=3): output_display = gr.Video(label="Generated Output (Video/Image)") status = gr.Markdown() btn.click( fn=generate_output, inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler], outputs=[output_display, status], ) # ============================================================ # 7️⃣ Launch app normally (no GPU decorator) # ============================================================ if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)