Spaces:
Paused
Paused
| import spaces | |
| import os | |
| import gradio as gr | |
| import torch | |
| import subprocess | |
| import importlib, site | |
| import warnings | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| from kandinsky import get_T2V_pipeline | |
| from PIL import Image | |
| # ============================================================ | |
| # 1οΈβ£ FlashAttention setup | |
| # ============================================================ | |
| try: | |
| print("Attempting to download and install FlashAttention wheel...") | |
| flash_attention_wheel = hf_hub_download( | |
| repo_id="rahul7star/flash-attn-3", | |
| repo_type="model", | |
| filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", | |
| ) | |
| subprocess.run(["pip", "install", flash_attention_wheel], check=True) | |
| site.addsitedir(site.getsitepackages()[0]) | |
| importlib.invalidate_caches() | |
| print("β FlashAttention installed successfully.") | |
| except Exception as e: | |
| print(f"β οΈ Could not install FlashAttention: {e}") | |
| print("Continuing without FlashAttention...") | |
| # ============================================================ | |
| # 2οΈβ£ Torch + logging config | |
| # ============================================================ | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("torch").setLevel(logging.ERROR) | |
| torch._logging.set_logs( | |
| dynamo=logging.ERROR, | |
| dynamic=logging.ERROR, | |
| aot=logging.ERROR, | |
| inductor=logging.ERROR, | |
| guards=False, | |
| recompiles=False, | |
| ) | |
| # ============================================================ | |
| # 3οΈβ£ Ensure models are downloaded | |
| # ============================================================ | |
| if not os.path.exists("./models_downloaded.marker"): | |
| print("π¦ Models not found. Running download_models.py...") | |
| subprocess.run(["python", "download_models.py"], check=True) | |
| with open("./models_downloaded.marker", "w") as f: | |
| f.write("done") | |
| print("β Models downloaded successfully.") | |
| else: | |
| print("β Models already downloaded (marker found).") | |
| # ============================================================ | |
| # 4οΈβ£ Load pipeline to CUDA (like Wan example) | |
| # ============================================================ | |
| print("π§ Loading Kandinsky 5.0 T2V pipeline to CUDA...") | |
| try: | |
| pipe = get_T2V_pipeline( | |
| device_map={ | |
| "dit": "cuda:0", | |
| "vae": "cuda:0", | |
| "text_embedder": "cuda:0", | |
| }, | |
| conf_path="./configs/config_5s_sft.yaml", | |
| ) | |
| # Explicitly move all components to CUDA | |
| if hasattr(pipe, "to"): | |
| pipe.to("cuda") | |
| print("β Pipeline successfully loaded and moved to CUDA.") | |
| except Exception as e: | |
| print(f"β Pipeline load failed: {e}") | |
| pipe = None | |
| # ============================================================ | |
| # 5οΈβ£ Generation function | |
| # ============================================================ | |
| def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler): | |
| print(f"β Pipeline load failed: {prompt}") | |
| if pipe is None: | |
| return None, "β Pipeline not initialized." | |
| try: | |
| output_path = f"/tmp/{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}" | |
| if mode == "image": | |
| print(f"πΌοΈ Generating image: {prompt}") | |
| pipe( | |
| prompt, | |
| time_length=0, | |
| width=width, | |
| height=height, | |
| save_path=output_path, | |
| ) | |
| return output_path, f"β Image saved: {output_path}" | |
| elif mode == "video": | |
| print(f"π¬ Generating {duration}s video: {prompt}") | |
| pipe( | |
| prompt, | |
| time_length=duration, | |
| width=width, | |
| height=height, | |
| num_steps=steps, | |
| guidance_weight=guidance, | |
| scheduler_scale=scheduler, | |
| save_path=output_path, | |
| ) | |
| return output_path, f"β Video saved: {output_path}" | |
| except torch.cuda.OutOfMemoryError: | |
| return None, "β οΈ CUDA OOM β try smaller size or shorter duration." | |
| except Exception as e: | |
| return None, f"β Error during generation: {e}" | |
| # ============================================================ | |
| # 6οΈβ£ Gradio UI | |
| # ============================================================ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V Lite (CUDA)") as demo: | |
| gr.Markdown("## ποΈ Kandinsky 5.0 β Text & Image to Video Generator") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| mode = gr.Radio(["video", "image"], value="video", label="Mode") | |
| prompt = gr.Textbox(label="Prompt", value="A dog in red boots") | |
| duration = gr.Slider(1, 10, step=1, value=5, label="Video Duration (seconds)") | |
| width = gr.Radio([512, 768], value=768, label="Width (px)") | |
| height = gr.Radio([512, 768], value=512, label="Height (px)") | |
| steps = gr.Slider(10, 50, step=5, value=25, label="Sampling Steps") | |
| guidance = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Weight") | |
| scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale") | |
| btn = gr.Button("π Generate", variant="primary") | |
| with gr.Column(scale=3): | |
| output_display = gr.Video(label="Generated Output (Video/Image)") | |
| status = gr.Markdown() | |
| btn.click( | |
| fn=generate_output, | |
| inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler], | |
| outputs=[output_display, status], | |
| ) | |
| # ============================================================ | |
| # 7οΈβ£ Launch app normally (no GPU decorator) | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |