Files changed (1) hide show
  1. app.py +168 -42
app.py CHANGED
@@ -10,6 +10,8 @@ from PIL import Image
10
  import random
11
  import numpy as np
12
  import spaces
 
 
13
 
14
  import wan
15
  from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
@@ -61,7 +63,52 @@ pipeline = wan.WanTI2V(
61
  )
62
  print("Pipeline initialized and ready.")
63
 
64
- # --- Helper Functions (from Wan 2.1 Fast demo) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
66
  min_slider_h, max_slider_h,
67
  min_slider_w, max_slider_w,
@@ -83,38 +130,65 @@ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
83
 
84
  return new_h, new_w
85
 
86
- def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
87
  """
88
- Handle image upload and calculate appropriate dimensions for video generation.
89
 
90
  Args:
91
- uploaded_pil_image: The uploaded image (PIL Image or numpy array)
92
  current_h_val: Current height slider value
93
  current_w_val: Current width slider value
94
 
95
  Returns:
96
- Tuple of gr.update objects for height and width sliders
97
  """
98
- if uploaded_pil_image is None:
99
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
 
 
 
100
  try:
101
- # Convert numpy array to PIL Image if needed
102
- if hasattr(uploaded_pil_image, 'shape'): # numpy array
103
- pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
104
- else: # already PIL Image
105
- pil_image = uploaded_pil_image
106
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  new_h, new_w = _calculate_new_dimensions_wan(
108
  pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
109
  SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
110
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
111
  )
112
- return gr.update(value=new_h), gr.update(value=new_w)
 
 
 
 
 
113
  except Exception as e:
114
- gr.Warning("Error attempting to calculate new dimensions")
115
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
 
 
 
116
 
117
- def get_duration(image,
 
118
  prompt,
119
  height,
120
  width,
@@ -130,7 +204,8 @@ def get_duration(image,
130
  # --- 2. Gradio Inference Function ---
131
  @spaces.GPU(duration=get_duration)
132
  def generate_video(
133
- image,
 
134
  prompt,
135
  height,
136
  width,
@@ -142,10 +217,11 @@ def generate_video(
142
  progress=gr.Progress(track_tqdm=True)
143
  ):
144
  """
145
- Generate a video from text prompt and optional image using the Wan 2.2 TI2V model.
146
 
147
  Args:
148
- image: Optional input image (numpy array) for image-to-video generation
 
149
  prompt: Text prompt describing the desired video
150
  height: Target video height in pixels
151
  width: Target video width in pixels
@@ -167,9 +243,21 @@ def generate_video(
167
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
168
 
169
  input_image = None
170
- if image is not None:
171
- input_image = Image.fromarray(image).convert("RGB")
172
- # Resize image to match target dimensions
 
 
 
 
 
 
 
 
 
 
 
 
173
  input_image = input_image.resize((target_w, target_h))
174
 
175
  # Calculate number of frames based on duration
@@ -183,7 +271,7 @@ def generate_video(
183
  img=input_image, # Pass None for T2V, Image for I2V
184
  size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
185
  max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
186
- frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
187
  shift=shift,
188
  sample_solver='unipc',
189
  sampling_steps=int(sampling_steps),
@@ -206,16 +294,29 @@ def generate_video(
206
 
207
 
208
  # --- 3. Gradio Interface ---
209
- css = ".gradio-container {max-width: 1100px !important; margin: 0 auto} #output_video {height: 500px;} #input_image {height: 500px;}"
210
 
211
  with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
212
- gr.Markdown("# Wan 2.2 TI2V 5B")
213
- gr.Markdown("generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**,[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B),[[paper]](https://arxiv.org/abs/2503.20314)")
214
 
215
  with gr.Row():
216
  with gr.Column(scale=2):
217
- image_input = gr.Image(type="numpy", label="Optional (blank = text-to-image)", elem_id="input_image")
218
- prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  duration_input = gr.Slider(
220
  minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
221
  maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
@@ -227,8 +328,20 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
227
 
228
  with gr.Accordion("Advanced Settings", open=False):
229
  with gr.Row():
230
- height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
231
- width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
 
 
 
 
 
 
 
 
 
 
 
 
232
  steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
233
  scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
234
  shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
@@ -238,17 +351,19 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
238
  video_output = gr.Video(label="Generated Video", elem_id="output_video")
239
  run_button = gr.Button("Generate Video", variant="primary")
240
 
241
- # Add image upload handler
242
- image_input.upload(
243
- fn=handle_image_upload_for_dims_wan,
244
- inputs=[image_input, height_input, width_input],
245
- outputs=[height_input, width_input]
246
  )
247
 
248
- image_input.clear(
249
- fn=handle_image_upload_for_dims_wan,
250
- inputs=[image_input, height_input, width_input],
251
- outputs=[height_input, width_input]
 
 
252
  )
253
 
254
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
@@ -258,7 +373,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
258
  [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
259
  [None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
260
  ],
261
- inputs=[image_input, prompt_input, height_input, width_input, duration_input],
262
  outputs=video_output,
263
  fn=generate_video,
264
  cache_examples="lazy",
@@ -266,7 +381,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
266
 
267
  run_button.click(
268
  fn=generate_video,
269
- inputs=[image_input, prompt_input, height_input, width_input, duration_input, steps_input, scale_input, shift_input, seed_input],
 
 
 
 
 
 
 
 
 
 
 
270
  outputs=video_output
271
  )
272
 
 
10
  import random
11
  import numpy as np
12
  import spaces
13
+ import cv2
14
+ import tempfile
15
 
16
  import wan
17
  from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
 
63
  )
64
  print("Pipeline initialized and ready.")
65
 
66
+ # --- Helper Functions ---
67
+
68
+ def extract_first_frame_from_video(video_path):
69
+ """
70
+ Extract the first frame from a video file.
71
+
72
+ Args:
73
+ video_path: Path to the video file
74
+
75
+ Returns:
76
+ PIL Image of the first frame, or None if extraction fails
77
+ """
78
+ try:
79
+ cap = cv2.VideoCapture(video_path)
80
+ ret, frame = cap.read()
81
+ cap.release()
82
+
83
+ if ret:
84
+ # Convert BGR to RGB
85
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+ return Image.fromarray(frame_rgb)
87
+ return None
88
+ except Exception as e:
89
+ print(f"Error extracting frame from video: {e}")
90
+ return None
91
+
92
+ def get_video_dimensions(video_path):
93
+ """
94
+ Get the dimensions of a video file.
95
+
96
+ Args:
97
+ video_path: Path to the video file
98
+
99
+ Returns:
100
+ Tuple of (width, height) or None if extraction fails
101
+ """
102
+ try:
103
+ cap = cv2.VideoCapture(video_path)
104
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
105
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
106
+ cap.release()
107
+ return width, height
108
+ except Exception as e:
109
+ print(f"Error getting video dimensions: {e}")
110
+ return None
111
+
112
  def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
113
  min_slider_h, max_slider_h,
114
  min_slider_w, max_slider_w,
 
130
 
131
  return new_h, new_w
132
 
133
+ def handle_media_upload_for_dims_wan(uploaded_media, current_h_val, current_w_val):
134
  """
135
+ Handle image or video upload and calculate appropriate dimensions.
136
 
137
  Args:
138
+ uploaded_media: The uploaded file (can be image or video path)
139
  current_h_val: Current height slider value
140
  current_w_val: Current width slider value
141
 
142
  Returns:
143
+ Tuple of (gr.update for height, gr.update for width, first frame as numpy array or None)
144
  """
145
+ if uploaded_media is None:
146
+ return (gr.update(value=DEFAULT_H_SLIDER_VALUE),
147
+ gr.update(value=DEFAULT_W_SLIDER_VALUE),
148
+ None)
149
+
150
  try:
151
+ pil_image = None
152
+
153
+ # Check if it's a video file
154
+ if isinstance(uploaded_media, str) and uploaded_media.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.webm')):
155
+ # Extract first frame from video
156
+ pil_image = extract_first_frame_from_video(uploaded_media)
157
+ if pil_image is None:
158
+ gr.Warning("Could not extract frame from video")
159
+ return (gr.update(value=DEFAULT_H_SLIDER_VALUE),
160
+ gr.update(value=DEFAULT_W_SLIDER_VALUE),
161
+ None)
162
+ else:
163
+ # Handle as image
164
+ if hasattr(uploaded_media, 'shape'): # numpy array
165
+ pil_image = Image.fromarray(uploaded_media).convert("RGB")
166
+ elif isinstance(uploaded_media, str): # file path
167
+ pil_image = Image.open(uploaded_media).convert("RGB")
168
+ else: # PIL Image
169
+ pil_image = uploaded_media
170
+
171
+ # Calculate dimensions
172
  new_h, new_w = _calculate_new_dimensions_wan(
173
  pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
174
  SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
175
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
176
  )
177
+
178
+ # Convert PIL image to numpy array for display
179
+ display_image = np.array(pil_image)
180
+
181
+ return gr.update(value=new_h), gr.update(value=new_w), display_image
182
+
183
  except Exception as e:
184
+ print(f"Error in handle_media_upload_for_dims_wan: {e}")
185
+ gr.Warning("Error processing uploaded file")
186
+ return (gr.update(value=DEFAULT_H_SLIDER_VALUE),
187
+ gr.update(value=DEFAULT_W_SLIDER_VALUE),
188
+ None)
189
 
190
+ def get_duration(video_input,
191
+ image_preview,
192
  prompt,
193
  height,
194
  width,
 
204
  # --- 2. Gradio Inference Function ---
205
  @spaces.GPU(duration=get_duration)
206
  def generate_video(
207
+ video_input,
208
+ image_preview,
209
  prompt,
210
  height,
211
  width,
 
217
  progress=gr.Progress(track_tqdm=True)
218
  ):
219
  """
220
+ Generate a video from text prompt and optional image/video using the Wan 2.2 TI2V model.
221
 
222
  Args:
223
+ video_input: Optional input video file path
224
+ image_preview: Preview image (numpy array) extracted from video or uploaded image
225
  prompt: Text prompt describing the desired video
226
  height: Target video height in pixels
227
  width: Target video width in pixels
 
243
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
244
 
245
  input_image = None
246
+
247
+ # Process video input if provided
248
+ if video_input is not None:
249
+ if isinstance(video_input, str) and video_input.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.webm')):
250
+ input_image = extract_first_frame_from_video(video_input)
251
+ else:
252
+ # Fallback to image preview
253
+ if image_preview is not None:
254
+ input_image = Image.fromarray(image_preview).convert("RGB")
255
+ elif image_preview is not None:
256
+ # Use image preview if no video input
257
+ input_image = Image.fromarray(image_preview).convert("RGB")
258
+
259
+ # Resize image to match target dimensions if we have an input image
260
+ if input_image is not None:
261
  input_image = input_image.resize((target_w, target_h))
262
 
263
  # Calculate number of frames based on duration
 
271
  img=input_image, # Pass None for T2V, Image for I2V
272
  size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
273
  max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
274
+ frame_num=num_frames,
275
  shift=shift,
276
  sample_solver='unipc',
277
  sampling_steps=int(sampling_steps),
 
294
 
295
 
296
  # --- 3. Gradio Interface ---
297
+ css = ".gradio-container {max-width: 1200px !important; margin: 0 auto} #output_video {height: 500px;} #image_preview {height: 400px;}"
298
 
299
  with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
300
+ gr.Markdown("# Wan 2.2 TI2V 5B - Video/Image to Video")
301
+ gr.Markdown("Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model** with support for video input. [[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)")
302
 
303
  with gr.Row():
304
  with gr.Column(scale=2):
305
+ video_input = gr.Video(
306
+ label="Upload Video or Image (optional - blank for text-to-video)",
307
+ sources=["upload"],
308
+ )
309
+ image_preview = gr.Image(
310
+ type="numpy",
311
+ label="Preview (first frame will be extracted from video)",
312
+ elem_id="image_preview",
313
+ interactive=False
314
+ )
315
+ prompt_input = gr.Textbox(
316
+ label="Prompt",
317
+ value="A beautiful waterfall in a lush jungle, cinematic.",
318
+ lines=3
319
+ )
320
  duration_input = gr.Slider(
321
  minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
322
  maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
 
328
 
329
  with gr.Accordion("Advanced Settings", open=False):
330
  with gr.Row():
331
+ height_input = gr.Slider(
332
+ minimum=SLIDER_MIN_H,
333
+ maximum=SLIDER_MAX_H,
334
+ step=MOD_VALUE,
335
+ value=DEFAULT_H_SLIDER_VALUE,
336
+ label=f"Output Height (multiple of {MOD_VALUE})"
337
+ )
338
+ width_input = gr.Slider(
339
+ minimum=SLIDER_MIN_W,
340
+ maximum=SLIDER_MAX_W,
341
+ step=MOD_VALUE,
342
+ value=DEFAULT_W_SLIDER_VALUE,
343
+ label=f"Output Width (multiple of {MOD_VALUE})"
344
+ )
345
  steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
346
  scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
347
  shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
 
351
  video_output = gr.Video(label="Generated Video", elem_id="output_video")
352
  run_button = gr.Button("Generate Video", variant="primary")
353
 
354
+ # Add video/image upload handler
355
+ video_input.upload(
356
+ fn=handle_media_upload_for_dims_wan,
357
+ inputs=[video_input, height_input, width_input],
358
+ outputs=[height_input, width_input, image_preview]
359
  )
360
 
361
+ video_input.clear(
362
+ fn=lambda: (gr.update(value=DEFAULT_H_SLIDER_VALUE),
363
+ gr.update(value=DEFAULT_W_SLIDER_VALUE),
364
+ None),
365
+ inputs=[],
366
+ outputs=[height_input, width_input, image_preview]
367
  )
368
 
369
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
 
373
  [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
374
  [None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
375
  ],
376
+ inputs=[video_input, prompt_input, height_input, width_input, duration_input],
377
  outputs=video_output,
378
  fn=generate_video,
379
  cache_examples="lazy",
 
381
 
382
  run_button.click(
383
  fn=generate_video,
384
+ inputs=[
385
+ video_input,
386
+ image_preview,
387
+ prompt_input,
388
+ height_input,
389
+ width_input,
390
+ duration_input,
391
+ steps_input,
392
+ scale_input,
393
+ shift_input,
394
+ seed_input
395
+ ],
396
  outputs=video_output
397
  )
398