rahul7star commited on
Commit
9cd917b
Β·
verified Β·
1 Parent(s): c3c58ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -100
app.py CHANGED
@@ -1,137 +1,244 @@
1
- import spaces
2
-
3
  import os
 
 
4
  import warnings
5
  import logging
6
- import time
7
  import tempfile
 
8
  from pathlib import Path
 
9
 
10
- import gradio as gr
11
  import torch
 
 
12
  from huggingface_hub import hf_hub_download
13
 
14
- # GPU management for Hugging Face Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # ==========================================================
17
- # 1️⃣ Install FlashAttention dynamically
18
  # ==========================================================
19
- try:
20
- print("Attempting to download and install FlashAttention wheel...")
21
- flash_attention_wheel = hf_hub_download(
22
- repo_id="rahul7star/flash-attn-3",
23
- repo_type="model",
24
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
25
- )
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- os.system(f"pip install {flash_attention_wheel}")
28
 
29
- import importlib, site
30
- site.addsitedir(site.getsitepackages()[0])
31
- importlib.invalidate_caches()
 
 
 
 
 
32
 
33
- print("βœ… FlashAttention installed successfully.")
34
- except Exception as e:
35
- print(f"⚠️ Could not install FlashAttention: {e}")
36
- print("Continuing without FlashAttention...")
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  # ==========================================================
40
- # 2️⃣ Kandinsky Import & Setup
41
  # ==========================================================
42
- warnings.filterwarnings("ignore")
43
- logging.getLogger("torch").setLevel(logging.ERROR)
44
-
45
- from kandinsky import get_T2V_pipeline
 
 
 
 
 
 
 
46
 
47
- # Preload model (config path should exist in the repo)
48
- CONFIG_PATH = "./configs/config_5s_sft.yaml"
49
 
50
- # Load pipeline on GPU
51
- print("πŸ”„ Loading Kandinsky T2V pipeline...")
52
- pipe = get_T2V_pipeline(
53
- device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"},
54
- conf_path=CONFIG_PATH,
55
- offload=False,
56
- magcache=False,
57
- )
58
- print("βœ… Kandinsky T2V pipeline loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  # ==========================================================
62
- # 3️⃣ Generation Function
63
  # ==========================================================
64
- @spaces.GPU(duration = 30)
65
- def generate_video(prompt, negative_prompt, image=None, width=768, height=512, duration=5, steps=None, guidance=None, scheduler_scale=5.0, expand_prompt=1):
66
- """Generate a video using Kandinsky 5 T2V pipeline"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
- if (width, height) not in [(512, 512), (512, 768), (768, 512)]:
69
- raise ValueError(f"Unsupported resolution: ({width}x{height}). Supported: 512x512, 512x768, 768x512")
70
-
71
- output_path = Path(tempfile.gettempdir()) / f"kandinsky_{int(time.time())}.mp4"
72
-
73
- start = time.perf_counter()
74
-
75
- # Run pipeline (image optional)
76
- result = pipe(
77
- prompt,
78
- time_length=duration,
79
- width=width,
80
- height=height,
81
- num_steps=steps,
82
- guidance_weight=guidance,
83
- scheduler_scale=scheduler_scale,
84
- expand_prompts=expand_prompt,
85
- save_path=str(output_path),
86
- image=image,
87
  )
88
-
89
- elapsed = time.perf_counter() - start
90
- print(f"βœ… Generated video in {elapsed:.2f}s: {output_path}")
91
- return str(output_path)
92
  except Exception as e:
93
- print(f"❌ Generation failed: {e}")
94
- return None
95
 
96
 
97
  # ==========================================================
98
- # 4️⃣ Gradio UI
99
  # ==========================================================
100
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
- gr.Markdown("## 🎬 Kandinsky 5.0 T2V Lite β€” Text & Image to Video Generator")
102
-
103
- with gr.Row():
104
- with gr.Column(scale=2):
105
- prompt = gr.Textbox(label="Prompt", placeholder="A dog in a red hat running in the snow")
106
- negative_prompt = gr.Textbox(
107
- label="Negative Prompt",
108
- value="Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
109
- )
110
- image = gr.Image(label="Optional Input Image", type="filepath")
111
-
112
- with gr.Row():
113
- width = gr.Radio(choices=[512, 768], value=768, label="Width")
114
- height = gr.Radio(choices=[512, 768], value=512, label="Height")
115
-
116
- duration = gr.Slider(1, 10, value=5, step=1, label="Video Duration (seconds)")
117
- steps = gr.Slider(1, 50, value=None, step=1, label="Sampling Steps (optional)")
118
- guidance = gr.Slider(1.0, 10.0, value=None, step=0.5, label="Guidance Weight (optional)")
119
- scheduler_scale = gr.Slider(1.0, 10.0, value=5.0, step=0.5, label="Scheduler Scale")
120
- expand_prompt = gr.Checkbox(value=True, label="Expand Prompt")
121
-
122
- generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")
123
-
124
- with gr.Column(scale=1):
125
- video_output = gr.Video(label="Generated Video")
 
 
 
 
 
 
 
 
126
 
127
- generate_btn.click(
128
- fn=generate_video,
129
- inputs=[prompt, negative_prompt, image, width, height, duration, steps, guidance, scheduler_scale, expand_prompt],
130
- outputs=[video_output]
131
- )
132
 
133
  # ==========================================================
134
- # 5️⃣ Launch
135
  # ==========================================================
136
  if __name__ == "__main__":
137
- demo.launch()
 
 
 
 
 
1
+ # app.py
 
2
  import os
3
+ import sys
4
+ import time
5
  import warnings
6
  import logging
 
7
  import tempfile
8
+ import shutil
9
  from pathlib import Path
10
+ import subprocess
11
 
 
12
  import torch
13
+ import gradio as gr
14
+ import spaces
15
  from huggingface_hub import hf_hub_download
16
 
17
+ # ==========================================================
18
+ # Helper: shell command runner
19
+ # ==========================================================
20
+ def sh(cmd, check=True):
21
+ print(f"RUN: {cmd}")
22
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
23
+ if result.stdout:
24
+ print(result.stdout)
25
+ if result.stderr:
26
+ print(result.stderr, file=sys.stderr)
27
+ if check and result.returncode != 0:
28
+ raise subprocess.CalledProcessError(result.returncode, cmd)
29
+ return result.returncode, result.stdout
30
+
31
 
32
  # ==========================================================
33
+ # Install FlashAttention (best effort)
34
  # ==========================================================
35
+ def try_install_flash_attention():
36
+ try:
37
+ print("Attempting to download and install FlashAttention wheel...")
38
+ flash_attention_wheel = hf_hub_download(
39
+ repo_id="rahul7star/flash-attn-3",
40
+ repo_type="model",
41
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
42
+ )
43
+ sh(f"pip install {flash_attention_wheel}")
44
+ import importlib, site
45
+ site.addsitedir(site.getsitepackages()[0])
46
+ importlib.invalidate_caches()
47
+ print("βœ… FlashAttention installed successfully.")
48
+ return True
49
+ except Exception as e:
50
+ print(f"⚠️ Could not install FlashAttention: {e}")
51
+ print("Continuing without FlashAttention...")
52
+ return False
53
 
 
54
 
55
+ # ==========================================================
56
+ # Ensure models downloaded before UI starts
57
+ # ==========================================================
58
+ def ensure_models_downloaded(marker_file=".models_ready"):
59
+ marker = Path(marker_file)
60
+ if marker.exists():
61
+ print("βœ… Models already downloaded (marker found).")
62
+ return True
63
 
64
+ if not Path("download_models.py").exists():
65
+ print("❌ download_models.py not found. Please include it in repo.")
66
+ return False
67
+
68
+ print("πŸ“¦ Running download_models.py to fetch model artifacts...")
69
+ try:
70
+ sh(f"{sys.executable} download_models.py", check=True)
71
+ marker.write_text("ok")
72
+ print("βœ… Model download complete. Marker created.")
73
+ return True
74
+ except Exception as e:
75
+ print(f"❌ Model download failed: {e}")
76
+ return False
77
 
78
 
79
  # ==========================================================
80
+ # Import Kandinsky pipeline (after models ready)
81
  # ==========================================================
82
+ def disable_warnings():
83
+ warnings.filterwarnings("ignore")
84
+ logging.getLogger("torch").setLevel(logging.ERROR)
85
+ torch._logging.set_logs(
86
+ dynamo=logging.ERROR,
87
+ dynamic=logging.ERROR,
88
+ aot=logging.ERROR,
89
+ inductor=logging.ERROR,
90
+ guards=False,
91
+ recompiles=False
92
+ )
93
 
 
 
94
 
95
+ # ==========================================================
96
+ # Video generation merged from test.py
97
+ # ==========================================================
98
+ def generate_kandinsky_video(prompt, negative_prompt, width, height, video_duration,
99
+ expand_prompt, sample_steps, guidance_weight,
100
+ scheduler_scale, output_filename, offload=False, magcache=False):
101
+ from kandinsky import get_T2V_pipeline
102
+
103
+ # validate resolution
104
+ if (width, height) not in [(512, 512), (512, 768), (768, 512)]:
105
+ raise ValueError(f"Unsupported video size ({width}x{height}). Use 512x512, 512x768, or 768x512.")
106
+
107
+ # locate config
108
+ conf_path = Path("configs/config_5s_sft.yaml")
109
+ if not conf_path.exists():
110
+ # fallback if config somewhere in subfolder
111
+ candidates = list(Path(".").rglob("config_5s_sft.yaml"))
112
+ if not candidates:
113
+ raise FileNotFoundError("config_5s_sft.yaml not found.")
114
+ conf_path = candidates[0]
115
+
116
+ print(f"πŸ”§ Using config: {conf_path}")
117
+
118
+ # load pipeline
119
+ print("πŸ”„ Loading Kandinsky 5.0 T2V pipeline...")
120
+ pipe = get_T2V_pipeline(
121
+ device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"},
122
+ conf_path=str(conf_path),
123
+ offload=offload,
124
+ magcache=magcache,
125
+ )
126
+ print("βœ… Pipeline loaded successfully.")
127
+
128
+ # generate
129
+ start_time = time.perf_counter()
130
+ x = pipe(
131
+ prompt,
132
+ time_length=video_duration,
133
+ width=width,
134
+ height=height,
135
+ num_steps=sample_steps,
136
+ guidance_weight=guidance_weight,
137
+ scheduler_scale=scheduler_scale,
138
+ expand_prompts=expand_prompt,
139
+ save_path=output_filename,
140
+ )
141
+ elapsed = time.perf_counter() - start_time
142
+ print(f"βœ… Video generated: {output_filename} (took {elapsed:.2f}s)")
143
+ return output_filename
144
 
145
 
146
  # ==========================================================
147
+ # Gradio callback
148
  # ==========================================================
149
+ @spaces.GPU(duration=60)
150
+ def generate(prompt, negative_prompt, width, height, duration, expand_prompt, sample_steps,
151
+ guidance_weight, scheduler_scale, install_flash, force_download):
152
+ logs = []
153
+ disable_warnings()
154
+
155
+ # optional flash install
156
+ if install_flash:
157
+ ok = try_install_flash_attention()
158
+ logs.append(f"FlashAttention install: {'OK' if ok else 'FAILED'}")
159
+
160
+ # ensure models ready
161
+ if force_download:
162
+ marker = Path(".models_ready")
163
+ if marker.exists():
164
+ marker.unlink()
165
+ logs.append("Removed model marker for re-download.")
166
+
167
+ ok_models = ensure_models_downloaded()
168
+ logs.append(f"Models ready: {ok_models}")
169
+
170
+ if not ok_models:
171
+ return None, "\n".join(logs)
172
+
173
+ # run generation
174
  try:
175
+ output_file = Path("outputs")
176
+ output_file.mkdir(exist_ok=True)
177
+ out_path = output_file / f"kandinsky_{int(time.time())}.mp4"
178
+ result_path = generate_kandinsky_video(
179
+ prompt=prompt,
180
+ negative_prompt=negative_prompt,
181
+ width=int(width),
182
+ height=int(height),
183
+ video_duration=int(duration),
184
+ expand_prompt=int(expand_prompt),
185
+ sample_steps=int(sample_steps) if sample_steps else None,
186
+ guidance_weight=float(guidance_weight) if guidance_weight else None,
187
+ scheduler_scale=float(scheduler_scale),
188
+ output_filename=str(out_path),
 
 
 
 
 
189
  )
190
+ logs.append(f"Video generated at {result_path}")
191
+ return str(result_path), "\n".join(logs)
 
 
192
  except Exception as e:
193
+ logs.append(f"❌ Generation failed: {e}")
194
+ return None, "\n".join(logs)
195
 
196
 
197
  # ==========================================================
198
+ # Gradio UI
199
  # ==========================================================
200
+ def build_ui():
201
+ with gr.Blocks(title="Kandinsky 5.0 β€” Text/Image to Video", theme=gr.themes.Soft()) as demo:
202
+ gr.Markdown("# 🎬 Kandinsky 5.0 β€” Text + Image β†’ Video Generator (Spaces GPU)")
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=3):
206
+ prompt = gr.Textbox(label="Prompt", value="A dog in a red hat", placeholder="Describe your scene")
207
+ negative_prompt = gr.Textbox(
208
+ label="Negative Prompt",
209
+ value="Static, 2D cartoon, worst quality, deformed"
210
+ )
211
+ width = gr.Radio([512, 768], value=768, label="Width")
212
+ height = gr.Radio([512, 768], value=512, label="Height")
213
+ duration = gr.Slider(1, 10, value=5, step=1, label="Video duration (seconds)")
214
+ expand_prompt = gr.Checkbox(label="Expand Prompt", value=True)
215
+ sample_steps = gr.Number(label="Sample Steps (optional)", value=None)
216
+ guidance_weight = gr.Number(label="Guidance Weight (optional)", value=None)
217
+ scheduler_scale = gr.Number(label="Scheduler Scale", value=5.0)
218
+ install_flash = gr.Checkbox(label="Install FlashAttention (optional)", value=False)
219
+ force_download = gr.Checkbox(label="Force re-download models", value=False)
220
+
221
+ generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")
222
+ status = gr.Textbox(label="Logs / Status", interactive=False, lines=10)
223
+ with gr.Column(scale=2):
224
+ out_video = gr.Video(label="Generated Video")
225
+
226
+ generate_btn.click(
227
+ fn=generate,
228
+ inputs=[prompt, negative_prompt, width, height, duration, expand_prompt,
229
+ sample_steps, guidance_weight, scheduler_scale,
230
+ install_flash, force_download],
231
+ outputs=[out_video, status]
232
+ )
233
+ return demo
234
 
 
 
 
 
 
235
 
236
  # ==========================================================
237
+ # Main entrypoint
238
  # ==========================================================
239
  if __name__ == "__main__":
240
+ print("πŸš€ Starting Kandinsky T2V Gradio App...")
241
+ Path("outputs").mkdir(exist_ok=True)
242
+ ensure_models_downloaded()
243
+ demo = build_ui()
244
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))