rahul7star commited on
Commit
246b21d
Β·
verified Β·
1 Parent(s): d09a395

Create app_torch.py

Browse files
Files changed (1) hide show
  1. app_torch.py +268 -0
app_torch.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import sys
4
+ import subprocess
5
+ import importlib
6
+ import site
7
+ import warnings
8
+ import logging
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import gradio as gr
13
+ import torch
14
+ from huggingface_hub import hf_hub_download
15
+ from PIL import Image
16
+ import spaces
17
+
18
+ # ---------------------------
19
+ # Environment flags (reduce fusion/compilation) β€” set early
20
+ # ---------------------------
21
+ # These help avoid some torchinductor/flash-attn fusion issues that provoke guard errors.
22
+ os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1")
23
+ os.environ.setdefault("TORCHINDUCTOR_FUSION", "0")
24
+ os.environ.setdefault("USE_FLASH_ATTENTION", "0")
25
+ # Some environments check this; safe to set
26
+ os.environ.setdefault("XLA_IGNORE_ENV_VARS", "1")
27
+
28
+ # ---------------------------
29
+ # FlashAttention install (best-effort)
30
+ # ---------------------------
31
+ def try_install_flash_attention():
32
+ try:
33
+ print("Attempting to download and install FlashAttention wheel...")
34
+ wheel = hf_hub_download(
35
+ repo_id="rahul7star/flash-attn-3",
36
+ repo_type="model",
37
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
38
+ )
39
+ subprocess.run([sys.executable, "-m", "pip", "install", wheel], check=True)
40
+ # refresh site-packages
41
+ site.addsitedir(site.getsitepackages()[0])
42
+ importlib.invalidate_caches()
43
+ print("βœ… FlashAttention installed.")
44
+ return True
45
+ except Exception as e:
46
+ print(f"⚠️ FlashAttention install failed: {e}")
47
+ return False
48
+
49
+ # ---------------------------
50
+ # Torch logging / warnings
51
+ # ---------------------------
52
+ warnings.filterwarnings("ignore")
53
+ logging.getLogger("torch").setLevel(logging.ERROR)
54
+ # reduce torch verbose logging
55
+ try:
56
+ torch._logging.set_logs(
57
+ dynamo=logging.ERROR,
58
+ dynamic=logging.ERROR,
59
+ aot=logging.ERROR,
60
+ inductor=logging.ERROR,
61
+ guards=False,
62
+ recompiles=False
63
+ )
64
+ except Exception:
65
+ pass
66
+
67
+ # Make Dynamo tolerant initially (we'll disable if it fails)
68
+ try:
69
+ import torch._dynamo as _dynamo
70
+ _dynamo.config.suppress_errors = True
71
+ _dynamo.config.cache_size_limit = 0 # avoid large guard caches
72
+ except Exception:
73
+ _dynamo = None
74
+
75
+ # ---------------------------
76
+ # Download models if needed
77
+ # ---------------------------
78
+ def ensure_models_downloaded(marker_file=".models_ready"):
79
+ marker = Path(marker_file)
80
+ if marker.exists():
81
+ print("Models already downloaded (marker found).")
82
+ return True
83
+ if not Path("download_models.py").exists():
84
+ print("download_models.py not found in repo.")
85
+ return False
86
+ try:
87
+ print("Running download_models.py ...")
88
+ subprocess.run([sys.executable, "download_models.py"], check=True)
89
+ marker.write_text("ok")
90
+ print("Models download finished.")
91
+ return True
92
+ except Exception as e:
93
+ print("Model download failed:", e)
94
+ return False
95
+
96
+ # ---------------------------
97
+ # Load Kandinsky pipeline with smart Dynamo handling
98
+ # ---------------------------
99
+ def load_pipeline(conf_path="./configs/config_5s_sft.yaml", move_to_cuda_if_available=True):
100
+ """
101
+ Attempt to load the pipeline normally. If Dynamo/guard errors are raised,
102
+ disable torch._dynamo and reload in eager mode.
103
+ Returns pipeline or raises.
104
+ """
105
+ from kandinsky import get_T2V_pipeline # import inside function to respect env changes
106
+
107
+ def _do_load():
108
+ print("Loading pipeline with device_map pointing to cuda if available...")
109
+ device_map = None
110
+ if torch.cuda.is_available():
111
+ # let the pipeline place modules onto CUDA by device_map
112
+ device_map = {"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"}
113
+ else:
114
+ device_map = "cpu"
115
+ pipe = get_T2V_pipeline(device_map=device_map, conf_path=conf_path, offload=False, magcache=False)
116
+ # If pipeline has .to and CUDA is available, move it
117
+ if move_to_cuda_if_available and torch.cuda.is_available() and hasattr(pipe, "to"):
118
+ try:
119
+ pipe.to("cuda")
120
+ except Exception as e:
121
+ # fallback: ignore and continue (some pipelines handle own device_map)
122
+ print("Warning while moving pipeline to CUDA:", e)
123
+ return pipe
124
+
125
+ try:
126
+ # Try normal load first (Dynamo may be enabled but we've suppressed errors)
127
+ pipe = _do_load()
128
+ print("Pipeline loaded successfully (initial try).")
129
+ return pipe
130
+ except Exception as e:
131
+ # Detect Dynamo/guard-related signatures and fallback
132
+ msg = str(e).lower()
133
+ if "dynamo" in msg or "guard" in msg or "attributeerror" in msg or "caught" in msg:
134
+ print("⚠️ Dynamo/guard-related error detected while loading pipeline:", e)
135
+ # Disable torch dynamo and try again
136
+ try:
137
+ if _dynamo is not None:
138
+ print("Disabling torch._dynamo and retrying load in eager mode...")
139
+ _dynamo.disable()
140
+ else:
141
+ print("torch._dynamo not available; proceeding to retry load.")
142
+ except Exception as ex_disable:
143
+ print("Error disabling torch._dynamo:", ex_disable)
144
+ # Retry load
145
+ try:
146
+ pipe = _do_load()
147
+ print("Pipeline loaded successfully after disabling torch._dynamo.")
148
+ return pipe
149
+ except Exception as e2:
150
+ print("Failed to load pipeline even after disabling torch._dynamo:", e2)
151
+ raise
152
+ else:
153
+ # Not obviously a Dynamo issue β€” re-raise
154
+ raise
155
+
156
+ # ---------------------------
157
+ # Startup sequence
158
+ # ---------------------------
159
+ print("=== startup: installing optional FlashAttention (best-effort) ===")
160
+ try_install_flash_attention()
161
+
162
+ print("=== startup: ensuring models ===")
163
+ if not ensure_models_downloaded():
164
+ print("Models not available; app may fail at inference. Proceeding anyway.")
165
+
166
+ print("=== startup: loading pipeline (smart) ===")
167
+ pipe = None
168
+ try:
169
+ pipe = load_pipeline(conf_path="./configs/config_5s_sft.yaml", move_to_cuda_if_available=True)
170
+ except Exception as e:
171
+ print("Pipeline load ultimately failed:", e)
172
+ pipe = None
173
+
174
+ # ---------------------------
175
+ # Helper: ensure pipeline is on CUDA at generation time
176
+ # ---------------------------
177
+ def ensure_pipe_on_cuda(pipeline):
178
+ if pipeline is None:
179
+ raise RuntimeError("Pipeline is None")
180
+ # If CUDA not available, raise early
181
+ if not torch.cuda.is_available():
182
+ raise RuntimeError("CUDA not available on this machine")
183
+ # If pipeline supports .to, move it
184
+ if hasattr(pipeline, "to"):
185
+ try:
186
+ pipeline.to("cuda")
187
+ except Exception as e:
188
+ # Some pipelines use device_map placement β€” ignore move failure
189
+ print("Warning: pipeline.to('cuda') raised:", e)
190
+
191
+ # ---------------------------
192
+ # Generation function (runs on GPU when used)
193
+ # ---------------------------
194
+ @spaces.GPU(duration=60)
195
+ def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler):
196
+ """
197
+ This generation function assumes the pipeline is already loaded (pipe variable).
198
+ It will raise a helpful error if the pipeline wasn't loaded at startup.
199
+ """
200
+ if pipe is None:
201
+ return None, "❌ Pipeline not initialized at startup. Check logs."
202
+
203
+ # Ensure CUDA available and pipeline on CUDA
204
+ if not torch.cuda.is_available():
205
+ return None, "❌ CUDA not available on this host."
206
+
207
+ try:
208
+ # If dynamo is still enabled and we suspect it can cause trouble during forward,
209
+ # run inference inside a context where dynamo is disabled to be safe.
210
+ try:
211
+ if _dynamo is not None:
212
+ _dynamo.disable()
213
+ except Exception:
214
+ pass
215
+
216
+ out_name = f"/tmp/{int(time.time())}_{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}"
217
+
218
+ if mode == "image":
219
+ pipe(prompt, time_length=0, width=width, height=height, save_path=out_name)
220
+ return out_name, f"βœ… Image saved to {out_name}"
221
+
222
+ # video path
223
+ pipe(prompt,
224
+ time_length=duration,
225
+ width=width,
226
+ height=height,
227
+ num_steps=steps if steps else None,
228
+ guidance_weight=guidance if guidance else None,
229
+ scheduler_scale=scheduler if scheduler else None,
230
+ save_path=out_name)
231
+ return out_name, f"βœ… Video saved to {out_name}"
232
+
233
+ except torch.cuda.OutOfMemoryError:
234
+ return None, "⚠️ CUDA OOM β€” try reducing resolution/duration/steps."
235
+ except Exception as e:
236
+ return None, f"❌ Generation error: {e}"
237
+
238
+ # ---------------------------
239
+ # Gradio UI
240
+ # ---------------------------
241
+ with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V (robust load)") as demo:
242
+ gr.Markdown("## Kandinsky 5.0 β€” Robust pipeline loader (smart Dynamo fallback)")
243
+
244
+ with gr.Row():
245
+ with gr.Column(scale=2):
246
+ mode = gr.Radio(["video", "image"], value="video", label="Mode")
247
+ prompt = gr.Textbox(label="Prompt", value="A dog in red boots")
248
+ duration = gr.Slider(1, 10, step=1, value=5, label="Duration (s)")
249
+ width = gr.Radio([512, 768], value=768, label="Width")
250
+ height = gr.Radio([512, 768], value=512, label="Height")
251
+ steps = gr.Slider(4, 50, step=1, value=25, label="Sampling Steps")
252
+ guidance = gr.Slider(0.0, 20.0, step=0.5, value=8.0, label="Guidance Weight")
253
+ scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale")
254
+ btn = gr.Button("Generate", variant="primary")
255
+
256
+ with gr.Column(scale=3):
257
+ out_video = gr.Video(label="Output")
258
+ status = gr.Textbox(label="Status", lines=6)
259
+
260
+ btn.click(fn=generate_output,
261
+ inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler],
262
+ outputs=[out_video, status])
263
+
264
+ # ---------------------------
265
+ # Launch
266
+ # ---------------------------
267
+ if __name__ == "__main__":
268
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))