rahul7star commited on
Commit
f68fbe5
Β·
verified Β·
1 Parent(s): 6ebaddb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -182
app.py CHANGED
@@ -10,10 +10,12 @@ from pathlib import Path
10
  from huggingface_hub import hf_hub_download
11
  import gradio as gr
12
 
13
- # ---------- Helper utilities ----------
 
 
14
 
15
  def sh(cmd, check=True, env=None):
16
- """Shell helper that streams output to stdout/stderr and returns (returncode, stdout)."""
17
  print(f"RUN: {cmd}")
18
  try:
19
  completed = subprocess.run(cmd, shell=True, check=check, capture_output=True, text=True, env=env)
@@ -27,235 +29,160 @@ def sh(cmd, check=True, env=None):
27
  print(e.stderr, file=sys.stderr)
28
  return e.returncode, e.stdout if hasattr(e, "stdout") else ""
29
 
30
- # ---------- FlashAttention install (best-effort) ----------
 
 
 
 
31
  def try_install_flash_attention():
32
- """
33
- Attempt to download and install the FlashAttention wheel from HF repo rahul7star/flash-attn-3
34
- Path in repo: 128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl (as provided).
35
- This is a best-effort install; failures are non-fatal.
36
- """
37
- flash_attention_installed = False
38
  try:
39
- print("Attempting to download and install FlashAttention wheel...")
40
  wheel = hf_hub_download(
41
  repo_id="rahul7star/flash-attn-3",
42
  repo_type="model",
43
  filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
44
  )
45
- print("Downloaded wheel:", wheel)
46
- rc, out = sh(f"pip install {wheel}")
47
- # refresh site-packages so Python can see newly-installed extension
48
- try:
49
- import importlib, site
50
- # add site-packages dir (first one) and invalidate caches
51
- site.addsitedir(site.getsitepackages()[0])
52
- importlib.invalidate_caches()
53
- except Exception as e:
54
- print("Could not update site-packages cache:", e)
55
- flash_attention_installed = True
56
- print("FlashAttention installed successfully.")
57
  except Exception as e:
58
  print(f"⚠️ Could not install FlashAttention: {e}")
59
- print("Continuing without FlashAttention...")
60
- return flash_attention_installed
61
-
62
- # ---------- Model downloader ----------
63
- def ensure_models_downloaded(marker_file=".models_ready"):
64
- """
65
- Run download_models.py if models haven't been downloaded yet.
66
- This creates a small marker file after success to avoid repeated downloads.
67
- """
68
- marker = Path(marker_file)
69
  if marker.exists():
70
- print("Models already downloaded (marker found).")
71
  return True
72
 
73
  if not Path("download_models.py").exists():
74
- print("Warning: download_models.py not found in repo. Please add it or run model download manually.")
75
  return False
76
 
 
77
  try:
78
- print("Running download_models.py to fetch model artifacts...")
79
- # Try to call the script directly. Use same python executable.
80
- rc, out = sh(f"{sys.executable} download_models.py", check=True)
81
- # If it completes without exception, create marker
82
  marker.write_text("ok")
83
- print("download_models.py finished. Marker created.")
84
  return True
85
  except Exception as e:
86
- print("Failed to run download_models.py:", e)
87
  return False
88
 
89
- # ---------- Inference runner ----------
90
- def run_inference(prompt: str, image_path: str | None, seed: int | None = None, duration: float | None = None, workdir: str | None = None):
91
- """
92
- Run test.py with prompt and optional image. Expect test.py to produce a video file (e.g. output.mp4)
93
- Returns path to produced video or None on failure.
94
- """
95
- workdir = workdir or os.getcwd()
96
- out_video = Path(workdir) / "output.mp4"
97
 
98
- # remove old output if present
99
- if out_video.exists():
100
- try:
101
- out_video.unlink()
102
- except Exception:
103
- pass
104
 
105
- if not Path("test.py").exists():
106
- raise FileNotFoundError("test.py not found in repo. Place the repo's test.py in the same folder as app.py.")
 
 
 
 
107
 
108
  cmd = [sys.executable, "test.py", "--prompt", f"\"{prompt}\""]
109
  if image_path:
110
  cmd += ["--image_path", f"\"{image_path}\""]
111
- if seed is not None:
112
- cmd += ["--seed", str(seed)]
113
- if duration is not None:
114
- # If the test.py uses a --duration flag; adapt if your script uses different arg name.
115
- cmd += ["--duration", str(duration)]
116
 
117
- # Join to single command string to ensure shell wildcard expansion if needed
118
  cmd_str = " ".join(cmd)
119
- print("Inference command:", cmd_str)
120
 
121
  try:
122
- # We stream output and check for completion
123
- proc = subprocess.run(cmd_str, shell=True, check=True, capture_output=True, text=True, env=os.environ)
124
- print("Inference stdout:", proc.stdout)
125
  if proc.stderr:
126
- print("Inference stderr:", proc.stderr, file=sys.stderr)
127
  except subprocess.CalledProcessError as e:
128
- print("Inference failed:", e, file=sys.stderr)
129
- print(e.stdout if hasattr(e, "stdout") else "")
130
- print(e.stderr if hasattr(e, "stderr") else "", file=sys.stderr)
131
  return None
132
 
133
- # locate output video
134
  if out_video.exists():
135
  return str(out_video)
136
- # fallback: find any recent mp4 in workdir
137
- candidates = sorted(Path(workdir).glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
138
- if candidates:
139
- return str(candidates[0])
140
- return None
141
-
142
- # ---------- Gradio app callbacks ----------
143
- @spaces.GPU(duration = 50)
144
- def generate(prompt, image, seed, duration, install_flash, force_download_models):
145
- """
146
- Main callback for Gradio "Generate" button.
147
- - install_flash: boolean, whether to attempt flash-attn install this run
148
- - force_download_models: boolean to re-run download_models.py even if marker exists
149
- Returns (video_file, status_text)
150
- """
151
- status_msgs = []
152
- # Convert image (gradio gives a PIL Image or None) to a temp file if provided
153
- temp_image_path = None
154
  if image is not None:
155
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
156
- try:
157
- image.save(tmp, format="PNG")
158
- tmp.flush()
159
- temp_image_path = tmp.name
160
- tmp.close()
161
- status_msgs.append(f"Saved input image to {temp_image_path}")
162
- except Exception as e:
163
- status_msgs.append(f"Failed to save uploaded image: {e}")
164
- temp_image_path = None
165
-
166
- # Optionally install flash attention
167
- if install_flash:
168
- ok = try_install_flash_attention()
169
- status_msgs.append(f"Attempted FlashAttention install: {'OK' if ok else 'FAILED'}")
170
- else:
171
- status_msgs.append("Skipped FlashAttention install (checkbox unchecked).")
172
-
173
- # Ensure models downloaded
174
- if force_download_models:
175
- # remove marker if present so we re-download
176
- marker = Path(".models_ready")
177
- if marker.exists():
178
- try:
179
- marker.unlink()
180
- status_msgs.append("Removed existing model marker to force re-download.")
181
- except Exception as e:
182
- status_msgs.append(f"Could not remove marker file: {e}")
183
-
184
- ok_models = ensure_models_downloaded()
185
- status_msgs.append(f"Models ready: {'yes' if ok_models else 'no'}")
186
- if not ok_models:
187
- status_msgs.append("Warning: models not ready. Inference will probably fail.")
188
-
189
- # Run inference
190
- status_msgs.append("Starting inference (this may take time on GPU).")
191
  try:
192
- video_path = run_inference(prompt=prompt, image_path=temp_image_path, seed=seed, duration=duration)
 
 
 
193
  except Exception as e:
194
- status_msgs.append(f"Inference runner raised an exception: {e}")
195
- return None, "\n".join(status_msgs)
196
-
197
- if video_path:
198
- status_msgs.append(f"Video created: {video_path}")
199
- # Move to /tmp or keep in repo for Gradio to serve
200
- # We'll copy to a stable path that Gradio can serve e.g. ./outputs/output_{timestamp}.mp4
201
- dest_dir = Path("outputs")
202
- dest_dir.mkdir(exist_ok=True)
203
- ts = int(time.time())
204
- dest = dest_dir / f"t2v_output_{ts}.mp4"
205
- try:
206
- shutil.copy(video_path, dest)
207
- status_msgs.append(f"Video copied to {dest}")
208
- return str(dest), "\n".join(status_msgs)
209
- except Exception as e:
210
- status_msgs.append(f"Could not copy video to outputs/: {e}")
211
- # still try to return original path
212
- return str(video_path), "\n".join(status_msgs)
213
- else:
214
- status_msgs.append("No video produced by test.py (output not found). Check logs.")
215
- return None, "\n".join(status_msgs)
216
-
217
- # ---------- Build Gradio interface ----------
218
- def build_ui():
219
- with gr.Blocks(title="Text+Image β†’ Video (Spaces GPU)", css="""
220
- .output-video { max-width: 800px; }
221
- """) as demo:
222
- gr.Markdown("# Text + (Optional) Image β†’ Video\nSimple UI to run Kandinsky/Wan T2V `test.py` in this Space (GPU required).")
223
 
 
 
 
 
 
 
 
 
224
  with gr.Row():
225
  with gr.Column(scale=3):
226
- prompt = gr.Textbox(label="Prompt", placeholder="A dog in a red hat, cinematic, 5s", value="A dog in a red hat")
227
- image_in = gr.Image(label="Optional reference image (still)", type="pil")
228
- with gr.Row():
229
- seed = gr.Number(value=42, label="Seed (optional)", precision=0)
230
- duration = gr.Number(value=5.0, label="Duration (seconds, optional)", precision=2)
231
- install_flash = gr.Checkbox(label="Attempt FlashAttention install before running (best-effort)", value=False)
232
- force_download = gr.Checkbox(label="Force run download_models.py (re-download models)", value=False)
233
- generate_btn = gr.Button("Generate Video", variant="primary")
234
- status = gr.Textbox(label="Status / Logs", interactive=False, lines=10)
235
  with gr.Column(scale=2):
236
- out_video = gr.Video(label="Output video", elem_classes="output-video")
237
- gr.Markdown("**Notes**:\n- Ensure `download_models.py` and `test.py` are present and compatible.\n- `test.py` should produce an mp4 named `output.mp4` in the repo root or an mp4 somewhere in the working dir.\n- Long-running jobs may hit Space runtime limits if very long.")
238
-
239
- # wire up
240
- generate_btn.click(fn=generate,
241
- inputs=[prompt, image_in, seed, duration, install_flash, force_download],
242
- outputs=[out_video, status])
243
 
 
244
  return demo
245
 
246
- # ---------- Main entrypoint ----------
 
 
 
 
247
  if __name__ == "__main__":
248
- # Quick environment checks
249
- print("Starting T2V Gradio app. Python:", sys.executable)
250
- print("CUDA available?", os.environ.get("CUDA_VISIBLE_DEVICES", "(not set)"))
251
- # Attempt to install flash-attn automatically? We default to not attempting until user requests in UI.
252
- # Pre-check models: create marker if download_models.py has already run previously
253
- if not Path(".models_ready").exists() and Path("download_models.py").exists():
254
- # we do NOT force downloading on startup automatically to avoid long startup delays on Spaces.
255
- print("download_models.py exists. Models not yet marked as downloaded. Use the UI to run download (or set force flag).")
256
-
257
- # Create outputs dir
258
- Path("outputs").mkdir(exist_ok=True)
259
 
 
260
  demo = build_ui()
261
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
10
  from huggingface_hub import hf_hub_download
11
  import gradio as gr
12
 
13
+ # ====================================
14
+ # Helper utilities
15
+ # ====================================
16
 
17
  def sh(cmd, check=True, env=None):
18
+ """Shell helper that prints output live."""
19
  print(f"RUN: {cmd}")
20
  try:
21
  completed = subprocess.run(cmd, shell=True, check=check, capture_output=True, text=True, env=env)
 
29
  print(e.stderr, file=sys.stderr)
30
  return e.returncode, e.stdout if hasattr(e, "stdout") else ""
31
 
32
+
33
+ # ====================================
34
+ # FlashAttention install (startup)
35
+ # ====================================
36
+
37
  def try_install_flash_attention():
38
+ """Download and install FlashAttention wheel from rahul7star/flash-attn-3 repo."""
 
 
 
 
 
39
  try:
40
+ print("πŸ”Ή Attempting to install FlashAttention...")
41
  wheel = hf_hub_download(
42
  repo_id="rahul7star/flash-attn-3",
43
  repo_type="model",
44
  filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
45
  )
46
+ print(f"βœ… Wheel downloaded: {wheel}")
47
+ sh(f"pip install {wheel}")
48
+ import importlib, site
49
+ site.addsitedir(site.getsitepackages()[0])
50
+ importlib.invalidate_caches()
51
+ print("βœ… FlashAttention installed successfully.")
 
 
 
 
 
 
52
  except Exception as e:
53
  print(f"⚠️ Could not install FlashAttention: {e}")
54
+ print("Continuing without it...")
55
+
56
+
57
+ # ====================================
58
+ # Model download (startup)
59
+ # ====================================
60
+
61
+ def ensure_models_downloaded():
62
+ """Run download_models.py once at startup to fetch model weights."""
63
+ marker = Path(".models_ready")
64
  if marker.exists():
65
+ print("βœ… Models already downloaded (marker found).")
66
  return True
67
 
68
  if not Path("download_models.py").exists():
69
+ print("❌ Missing download_models.py in repo. Please include it.")
70
  return False
71
 
72
+ print("⬇️ Downloading model weights via download_models.py ...")
73
  try:
74
+ rc, _ = sh(f"{sys.executable} download_models.py", check=True)
 
 
 
75
  marker.write_text("ok")
76
+ print("βœ… Model download complete.")
77
  return True
78
  except Exception as e:
79
+ print(f"❌ Model download failed: {e}")
80
  return False
81
 
 
 
 
 
 
 
 
 
82
 
83
+ # ====================================
84
+ # Inference runner (text/image β†’ video)
85
+ # ====================================
 
 
 
86
 
87
+ def run_inference(prompt: str, image_path: str | None = None):
88
+ """Run test.py with prompt + optional image. Returns path to video."""
89
+ workdir = os.getcwd()
90
+ out_video = Path(workdir) / "output.mp4"
91
+ if out_video.exists():
92
+ out_video.unlink(missing_ok=True)
93
 
94
  cmd = [sys.executable, "test.py", "--prompt", f"\"{prompt}\""]
95
  if image_path:
96
  cmd += ["--image_path", f"\"{image_path}\""]
 
 
 
 
 
97
 
 
98
  cmd_str = " ".join(cmd)
99
+ print(f"πŸš€ Running inference: {cmd_str}")
100
 
101
  try:
102
+ proc = subprocess.run(cmd_str, shell=True, capture_output=True, text=True, check=True)
103
+ print(proc.stdout)
 
104
  if proc.stderr:
105
+ print(proc.stderr, file=sys.stderr)
106
  except subprocess.CalledProcessError as e:
107
+ print("❌ Inference failed:", e)
108
+ print(e.stdout)
109
+ print(e.stderr)
110
  return None
111
 
112
+ # Find the resulting .mp4
113
  if out_video.exists():
114
  return str(out_video)
115
+ vids = sorted(Path(workdir).glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
116
+ return str(vids[0]) if vids else None
117
+
118
+
119
+ # ====================================
120
+ # Gradio callback
121
+ # ====================================
122
+
123
+ @spaces.GPU(duration=50)
124
+ def generate(prompt, image):
125
+ """Main Gradio callback for generating video."""
126
+ status = []
127
+ temp_img_path = None
128
+
 
 
 
 
129
  if image is not None:
130
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
131
+ image.save(tmp, format="PNG")
132
+ tmp.close()
133
+ temp_img_path = tmp.name
134
+ status.append(f"πŸ“Έ Saved image: {temp_img_path}")
135
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  try:
137
+ video_path = run_inference(prompt, image_path=temp_img_path)
138
+ if not video_path:
139
+ status.append("❌ No video produced. Check test.py output.")
140
+ return None, "\n".join(status)
141
  except Exception as e:
142
+ status.append(f"❌ Inference failed: {e}")
143
+ return None, "\n".join(status)
144
+
145
+ dest_dir = Path("outputs"); dest_dir.mkdir(exist_ok=True)
146
+ ts = int(time.time())
147
+ dest = dest_dir / f"t2v_output_{ts}.mp4"
148
+ shutil.copy(video_path, dest)
149
+ status.append(f"βœ… Video generated: {dest}")
150
+ return str(dest), "\n".join(status)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+
153
+ # ====================================
154
+ # UI builder
155
+ # ====================================
156
+
157
+ def build_ui():
158
+ with gr.Blocks(title="Text+Image β†’ Video (Spaces GPU)") as demo:
159
+ gr.Markdown("## 🎬 Kandinsky / T2V Video Generator\nProvide a text prompt and optional image to generate short video clips using GPU inference.")
160
  with gr.Row():
161
  with gr.Column(scale=3):
162
+ prompt = gr.Textbox(label="Prompt", placeholder="A dog in a red hat, cinematic lighting", value="A dog in a red hat")
163
+ image_in = gr.Image(label="Optional input image", type="pil")
164
+ generate_btn = gr.Button("πŸŽ₯ Generate Video", variant="primary")
165
+ status = gr.Textbox(label="Logs", lines=8)
 
 
 
 
 
166
  with gr.Column(scale=2):
167
+ out_video = gr.Video(label="Output video")
 
 
 
 
 
 
168
 
169
+ generate_btn.click(fn=generate, inputs=[prompt, image_in], outputs=[out_video, status])
170
  return demo
171
 
172
+
173
+ # ====================================
174
+ # App startup
175
+ # ====================================
176
+
177
  if __name__ == "__main__":
178
+ print("πŸš€ Starting Text+Image β†’ Video Gradio App")
179
+ print("Python:", sys.executable)
180
+ print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES", "(not set)"))
181
+
182
+ # Install FlashAttention + download models ONCE at startup
183
+ try_install_flash_attention()
184
+ ensure_models_downloaded()
 
 
 
 
185
 
186
+ Path("outputs").mkdir(exist_ok=True)
187
  demo = build_ui()
188
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))