rahul7star commited on
Commit
fef2666
Β·
verified Β·
1 Parent(s): f68fbe5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -173
app.py CHANGED
@@ -1,188 +1,136 @@
1
- # app.py
2
- import spaces
3
  import os
4
- import sys
 
5
  import time
6
- import subprocess
7
  import tempfile
8
- import shutil
9
  from pathlib import Path
10
- from huggingface_hub import hf_hub_download
11
  import gradio as gr
 
 
12
 
13
- # ====================================
14
- # Helper utilities
15
- # ====================================
16
 
17
- def sh(cmd, check=True, env=None):
18
- """Shell helper that prints output live."""
19
- print(f"RUN: {cmd}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- completed = subprocess.run(cmd, shell=True, check=check, capture_output=True, text=True, env=env)
22
- print(completed.stdout)
23
- if completed.stderr:
24
- print("ERR:", completed.stderr, file=sys.stderr)
25
- return completed.returncode, completed.stdout
26
- except subprocess.CalledProcessError as e:
27
- print("Command failed:", e, file=sys.stderr)
28
- print(e.stdout)
29
- print(e.stderr, file=sys.stderr)
30
- return e.returncode, e.stdout if hasattr(e, "stdout") else ""
31
-
32
-
33
- # ====================================
34
- # FlashAttention install (startup)
35
- # ====================================
36
-
37
- def try_install_flash_attention():
38
- """Download and install FlashAttention wheel from rahul7star/flash-attn-3 repo."""
39
- try:
40
- print("πŸ”Ή Attempting to install FlashAttention...")
41
- wheel = hf_hub_download(
42
- repo_id="rahul7star/flash-attn-3",
43
- repo_type="model",
44
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
45
  )
46
- print(f"βœ… Wheel downloaded: {wheel}")
47
- sh(f"pip install {wheel}")
48
- import importlib, site
49
- site.addsitedir(site.getsitepackages()[0])
50
- importlib.invalidate_caches()
51
- print("βœ… FlashAttention installed successfully.")
52
- except Exception as e:
53
- print(f"⚠️ Could not install FlashAttention: {e}")
54
- print("Continuing without it...")
55
-
56
-
57
- # ====================================
58
- # Model download (startup)
59
- # ====================================
60
 
61
- def ensure_models_downloaded():
62
- """Run download_models.py once at startup to fetch model weights."""
63
- marker = Path(".models_ready")
64
- if marker.exists():
65
- print("βœ… Models already downloaded (marker found).")
66
- return True
67
-
68
- if not Path("download_models.py").exists():
69
- print("❌ Missing download_models.py in repo. Please include it.")
70
- return False
71
-
72
- print("⬇️ Downloading model weights via download_models.py ...")
73
- try:
74
- rc, _ = sh(f"{sys.executable} download_models.py", check=True)
75
- marker.write_text("ok")
76
- print("βœ… Model download complete.")
77
- return True
78
  except Exception as e:
79
- print(f"❌ Model download failed: {e}")
80
- return False
81
-
82
-
83
- # ====================================
84
- # Inference runner (text/image β†’ video)
85
- # ====================================
86
-
87
- def run_inference(prompt: str, image_path: str | None = None):
88
- """Run test.py with prompt + optional image. Returns path to video."""
89
- workdir = os.getcwd()
90
- out_video = Path(workdir) / "output.mp4"
91
- if out_video.exists():
92
- out_video.unlink(missing_ok=True)
93
-
94
- cmd = [sys.executable, "test.py", "--prompt", f"\"{prompt}\""]
95
- if image_path:
96
- cmd += ["--image_path", f"\"{image_path}\""]
97
-
98
- cmd_str = " ".join(cmd)
99
- print(f"πŸš€ Running inference: {cmd_str}")
100
-
101
- try:
102
- proc = subprocess.run(cmd_str, shell=True, capture_output=True, text=True, check=True)
103
- print(proc.stdout)
104
- if proc.stderr:
105
- print(proc.stderr, file=sys.stderr)
106
- except subprocess.CalledProcessError as e:
107
- print("❌ Inference failed:", e)
108
- print(e.stdout)
109
- print(e.stderr)
110
  return None
111
 
112
- # Find the resulting .mp4
113
- if out_video.exists():
114
- return str(out_video)
115
- vids = sorted(Path(workdir).glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
116
- return str(vids[0]) if vids else None
117
-
118
-
119
- # ====================================
120
- # Gradio callback
121
- # ====================================
122
-
123
- @spaces.GPU(duration=50)
124
- def generate(prompt, image):
125
- """Main Gradio callback for generating video."""
126
- status = []
127
- temp_img_path = None
128
-
129
- if image is not None:
130
- tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
131
- image.save(tmp, format="PNG")
132
- tmp.close()
133
- temp_img_path = tmp.name
134
- status.append(f"πŸ“Έ Saved image: {temp_img_path}")
135
-
136
- try:
137
- video_path = run_inference(prompt, image_path=temp_img_path)
138
- if not video_path:
139
- status.append("❌ No video produced. Check test.py output.")
140
- return None, "\n".join(status)
141
- except Exception as e:
142
- status.append(f"❌ Inference failed: {e}")
143
- return None, "\n".join(status)
144
-
145
- dest_dir = Path("outputs"); dest_dir.mkdir(exist_ok=True)
146
- ts = int(time.time())
147
- dest = dest_dir / f"t2v_output_{ts}.mp4"
148
- shutil.copy(video_path, dest)
149
- status.append(f"βœ… Video generated: {dest}")
150
- return str(dest), "\n".join(status)
151
-
152
-
153
- # ====================================
154
- # UI builder
155
- # ====================================
156
-
157
- def build_ui():
158
- with gr.Blocks(title="Text+Image β†’ Video (Spaces GPU)") as demo:
159
- gr.Markdown("## 🎬 Kandinsky / T2V Video Generator\nProvide a text prompt and optional image to generate short video clips using GPU inference.")
160
- with gr.Row():
161
- with gr.Column(scale=3):
162
- prompt = gr.Textbox(label="Prompt", placeholder="A dog in a red hat, cinematic lighting", value="A dog in a red hat")
163
- image_in = gr.Image(label="Optional input image", type="pil")
164
- generate_btn = gr.Button("πŸŽ₯ Generate Video", variant="primary")
165
- status = gr.Textbox(label="Logs", lines=8)
166
- with gr.Column(scale=2):
167
- out_video = gr.Video(label="Output video")
168
-
169
- generate_btn.click(fn=generate, inputs=[prompt, image_in], outputs=[out_video, status])
170
- return demo
171
-
172
-
173
- # ====================================
174
- # App startup
175
- # ====================================
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  if __name__ == "__main__":
178
- print("πŸš€ Starting Text+Image β†’ Video Gradio App")
179
- print("Python:", sys.executable)
180
- print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES", "(not set)"))
181
-
182
- # Install FlashAttention + download models ONCE at startup
183
- try_install_flash_attention()
184
- ensure_models_downloaded()
185
-
186
- Path("outputs").mkdir(exist_ok=True)
187
- demo = build_ui()
188
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
1
  import os
2
+ import warnings
3
+ import logging
4
  import time
 
5
  import tempfile
 
6
  from pathlib import Path
7
+
8
  import gradio as gr
9
+ import torch
10
+ from huggingface_hub import hf_hub_download
11
 
12
+ # GPU management for Hugging Face Spaces
13
+ import spaces
 
14
 
15
+ # ==========================================================
16
+ # 1️⃣ Install FlashAttention dynamically
17
+ # ==========================================================
18
+ try:
19
+ print("Attempting to download and install FlashAttention wheel...")
20
+ flash_attention_wheel = hf_hub_download(
21
+ repo_id="rahul7star/flash-attn-3",
22
+ repo_type="model",
23
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
24
+ )
25
+
26
+ os.system(f"pip install {flash_attention_wheel}")
27
+
28
+ import importlib, site
29
+ site.addsitedir(site.getsitepackages()[0])
30
+ importlib.invalidate_caches()
31
+
32
+ print("βœ… FlashAttention installed successfully.")
33
+ except Exception as e:
34
+ print(f"⚠️ Could not install FlashAttention: {e}")
35
+ print("Continuing without FlashAttention...")
36
+
37
+
38
+ # ==========================================================
39
+ # 2️⃣ Kandinsky Import & Setup
40
+ # ==========================================================
41
+ warnings.filterwarnings("ignore")
42
+ logging.getLogger("torch").setLevel(logging.ERROR)
43
+
44
+ from kandinsky import get_T2V_pipeline
45
+
46
+ # Preload model (config path should exist in the repo)
47
+ CONFIG_PATH = "./configs/config_5s_sft.yaml"
48
+
49
+ # Load pipeline on GPU
50
+ print("πŸ”„ Loading Kandinsky T2V pipeline...")
51
+ pipe = get_T2V_pipeline(
52
+ device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"},
53
+ conf_path=CONFIG_PATH,
54
+ offload=False,
55
+ magcache=False,
56
+ )
57
+ print("βœ… Kandinsky T2V pipeline loaded successfully.")
58
+
59
+
60
+ # ==========================================================
61
+ # 3️⃣ Generation Function
62
+ # ==========================================================
63
+ @spaces.GPU
64
+ def generate_video(prompt, negative_prompt, image=None, width=768, height=512, duration=5, steps=None, guidance=None, scheduler_scale=5.0, expand_prompt=1):
65
+ """Generate a video using Kandinsky 5 T2V pipeline"""
66
  try:
67
+ if (width, height) not in [(512, 512), (512, 768), (768, 512)]:
68
+ raise ValueError(f"Unsupported resolution: ({width}x{height}). Supported: 512x512, 512x768, 768x512")
69
+
70
+ output_path = Path(tempfile.gettempdir()) / f"kandinsky_{int(time.time())}.mp4"
71
+
72
+ start = time.perf_counter()
73
+
74
+ # Run pipeline (image optional)
75
+ result = pipe(
76
+ prompt,
77
+ time_length=duration,
78
+ width=width,
79
+ height=height,
80
+ num_steps=steps,
81
+ guidance_weight=guidance,
82
+ scheduler_scale=scheduler_scale,
83
+ expand_prompts=expand_prompt,
84
+ save_path=str(output_path),
85
+ image=image,
 
 
 
 
 
86
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ elapsed = time.perf_counter() - start
89
+ print(f"βœ… Generated video in {elapsed:.2f}s: {output_path}")
90
+ return str(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
+ print(f"❌ Generation failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return None
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # ==========================================================
97
+ # 4️⃣ Gradio UI
98
+ # ==========================================================
99
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
+ gr.Markdown("## 🎬 Kandinsky 5.0 T2V Lite β€” Text & Image to Video Generator")
101
+
102
+ with gr.Row():
103
+ with gr.Column(scale=2):
104
+ prompt = gr.Textbox(label="Prompt", placeholder="A dog in a red hat running in the snow")
105
+ negative_prompt = gr.Textbox(
106
+ label="Negative Prompt",
107
+ value="Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
108
+ )
109
+ image = gr.Image(label="Optional Input Image", type="filepath")
110
+
111
+ with gr.Row():
112
+ width = gr.Radio(choices=[512, 768], value=768, label="Width")
113
+ height = gr.Radio(choices=[512, 768], value=512, label="Height")
114
+
115
+ duration = gr.Slider(1, 10, value=5, step=1, label="Video Duration (seconds)")
116
+ steps = gr.Slider(1, 50, value=None, step=1, label="Sampling Steps (optional)")
117
+ guidance = gr.Slider(1.0, 10.0, value=None, step=0.5, label="Guidance Weight (optional)")
118
+ scheduler_scale = gr.Slider(1.0, 10.0, value=5.0, step=0.5, label="Scheduler Scale")
119
+ expand_prompt = gr.Checkbox(value=True, label="Expand Prompt")
120
+
121
+ generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")
122
+
123
+ with gr.Column(scale=1):
124
+ video_output = gr.Video(label="Generated Video")
125
+
126
+ generate_btn.click(
127
+ fn=generate_video,
128
+ inputs=[prompt, negative_prompt, image, width, height, duration, steps, guidance, scheduler_scale, expand_prompt],
129
+ outputs=[video_output]
130
+ )
131
+
132
+ # ==========================================================
133
+ # 5️⃣ Launch
134
+ # ==========================================================
135
  if __name__ == "__main__":
136
+ demo.launch()