Spaces:
Running
on
Zero
Running
on
Zero
| from utils.wan_wrapper import WanDiffusionWrapper | |
| from utils.scheduler import SchedulerInterface | |
| from typing import List, Optional | |
| import torch | |
| import torch.distributed as dist | |
| class RollingForcingTrainingPipeline: | |
| def __init__(self, | |
| denoising_step_list: List[int], | |
| scheduler: SchedulerInterface, | |
| generator: WanDiffusionWrapper, | |
| num_frame_per_block=3, | |
| independent_first_frame: bool = False, | |
| same_step_across_blocks: bool = False, | |
| last_step_only: bool = False, | |
| num_max_frames: int = 21, | |
| context_noise: int = 0, | |
| **kwargs): | |
| super().__init__() | |
| self.scheduler = scheduler | |
| self.generator = generator | |
| self.denoising_step_list = denoising_step_list | |
| if self.denoising_step_list[-1] == 0: | |
| self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference | |
| # Wan specific hyperparameters | |
| self.num_transformer_blocks = 30 | |
| self.frame_seq_length = 1560 | |
| self.num_frame_per_block = num_frame_per_block | |
| self.context_noise = context_noise | |
| self.i2v = False | |
| self.kv_cache_clean = None | |
| self.kv_cache2 = None | |
| self.independent_first_frame = independent_first_frame | |
| self.same_step_across_blocks = same_step_across_blocks | |
| self.last_step_only = last_step_only | |
| self.kv_cache_size = num_max_frames * self.frame_seq_length | |
| def generate_and_sync_list(self, num_blocks, num_denoising_steps, device): | |
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| if rank == 0: | |
| # Generate random indices | |
| indices = torch.randint( | |
| low=0, | |
| high=num_denoising_steps, | |
| size=(num_blocks,), | |
| device=device | |
| ) | |
| if self.last_step_only: | |
| indices = torch.ones_like(indices) * (num_denoising_steps - 1) | |
| else: | |
| indices = torch.empty(num_blocks, dtype=torch.long, device=device) | |
| dist.broadcast(indices, src=0) # Broadcast the random indices to all ranks | |
| return indices.tolist() | |
| def generate_list(self, num_blocks, num_denoising_steps, device): | |
| # Generate random indices | |
| indices = torch.randint( | |
| low=0, | |
| high=num_denoising_steps, | |
| size=(num_blocks,), | |
| device=device | |
| ) | |
| if self.last_step_only: | |
| indices = torch.ones_like(indices) * (num_denoising_steps - 1) | |
| return indices.tolist() | |
| def inference_with_rolling_forcing( | |
| self, | |
| noise: torch.Tensor, | |
| initial_latent: Optional[torch.Tensor] = None, | |
| return_sim_step: bool = False, | |
| **conditional_dict | |
| ) -> torch.Tensor: | |
| batch_size, num_frames, num_channels, height, width = noise.shape | |
| if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None): | |
| # If the first frame is independent and the first frame is provided, then the number of frames in the | |
| # noise should still be a multiple of num_frame_per_block | |
| assert num_frames % self.num_frame_per_block == 0 | |
| num_blocks = num_frames // self.num_frame_per_block | |
| else: | |
| # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning | |
| assert (num_frames - 1) % self.num_frame_per_block == 0 | |
| num_blocks = (num_frames - 1) // self.num_frame_per_block | |
| num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0 | |
| num_output_frames = num_frames + num_input_frames # add the initial latent frames | |
| output = torch.zeros( | |
| [batch_size, num_output_frames, num_channels, height, width], | |
| device=noise.device, | |
| dtype=noise.dtype | |
| ) | |
| # Step 1: Initialize KV cache to all zeros | |
| self._initialize_kv_cache( | |
| batch_size=batch_size, dtype=noise.dtype, device=noise.device | |
| ) | |
| self._initialize_crossattn_cache( | |
| batch_size=batch_size, dtype=noise.dtype, device=noise.device | |
| ) | |
| # implementing rolling forcing | |
| # construct the rolling forcing windows | |
| num_denoising_steps = len(self.denoising_step_list) | |
| rolling_window_length_blocks = num_denoising_steps | |
| window_start_blocks = [] | |
| window_end_blocks = [] | |
| window_num = num_blocks + rolling_window_length_blocks - 1 | |
| for window_index in range(window_num): | |
| start_block = max(0, window_index - rolling_window_length_blocks + 1) | |
| end_block = min(num_blocks - 1, window_index) | |
| window_start_blocks.append(start_block) | |
| window_end_blocks.append(end_block) | |
| # exit_flag indicates the window at which the model will backpropagate gradients. | |
| exit_flag = torch.randint(high=rolling_window_length_blocks, device=noise.device, size=()) | |
| start_gradient_frame_index = num_output_frames - 21 | |
| # init noisy cache | |
| noisy_cache = torch.zeros( | |
| [batch_size, num_output_frames, num_channels, height, width], | |
| device=noise.device, | |
| dtype=noise.dtype | |
| ) | |
| # init denosing timestep, same accross windows | |
| shared_timestep = torch.ones( | |
| [batch_size, rolling_window_length_blocks * self.num_frame_per_block], | |
| device=noise.device, | |
| dtype=torch.float32) | |
| for index, current_timestep in enumerate(reversed(self.denoising_step_list)): # from clean to noisy | |
| shared_timestep[:, index * self.num_frame_per_block:(index + 1) * self.num_frame_per_block] *= current_timestep | |
| # Denoising loop with rolling forcing | |
| for window_index in range(window_num): | |
| start_block = window_start_blocks[window_index] | |
| end_block = window_end_blocks[window_index] # include | |
| current_start_frame = start_block * self.num_frame_per_block | |
| current_end_frame = (end_block + 1) * self.num_frame_per_block # not include | |
| current_num_frames = current_end_frame - current_start_frame | |
| # noisy_input: new noise and previous denoised noisy frames, only last block is pure noise | |
| if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block or current_start_frame == 0: | |
| noisy_input = torch.cat([ | |
| noisy_cache[:, current_start_frame : current_end_frame - self.num_frame_per_block], | |
| noise[:, current_end_frame - self.num_frame_per_block : current_end_frame ] | |
| ], dim=1) | |
| else: # at the end of the video | |
| noisy_input = noisy_cache[:, current_start_frame:current_end_frame].clone() | |
| # init denosing timestep | |
| if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block: | |
| current_timestep = shared_timestep | |
| elif current_start_frame == 0: | |
| current_timestep = shared_timestep[:,-current_num_frames:] | |
| elif current_end_frame == num_frames: | |
| current_timestep = shared_timestep[:,:current_num_frames] | |
| else: | |
| raise ValueError("current_num_frames should be equal to rolling_window_length_blocks * self.num_frame_per_block, or the first or last window.") | |
| require_grad = window_index % rolling_window_length_blocks == exit_flag | |
| if current_end_frame <= start_gradient_frame_index: | |
| require_grad = False | |
| # calling DiT | |
| if not require_grad: | |
| with torch.no_grad(): | |
| _, denoised_pred = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=current_timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length | |
| ) | |
| else: | |
| _, denoised_pred = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=current_timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length | |
| ) | |
| output[:, current_start_frame:current_end_frame] = denoised_pred | |
| # update noisy_cache, which is detached from the computation graph | |
| with torch.no_grad(): | |
| for block_idx in range(start_block, end_block + 1): | |
| block_time_step = current_timestep[:, | |
| (block_idx - start_block)*self.num_frame_per_block : | |
| (block_idx - start_block+1)*self.num_frame_per_block].mean().item() | |
| matches = torch.abs(self.denoising_step_list - block_time_step) < 1e-4 | |
| block_timestep_index = torch.nonzero(matches, as_tuple=True)[0] | |
| if block_timestep_index == len(self.denoising_step_list) - 1: | |
| continue | |
| next_timestep = self.denoising_step_list[block_timestep_index + 1].to(noise.device) | |
| noisy_cache[:, block_idx * self.num_frame_per_block: | |
| (block_idx+1) * self.num_frame_per_block] = \ | |
| self.scheduler.add_noise( | |
| denoised_pred.flatten(0, 1), | |
| torch.randn_like(denoised_pred.flatten(0, 1)), | |
| next_timestep * torch.ones( | |
| [batch_size * current_num_frames], device=noise.device, dtype=torch.long) | |
| ).unflatten(0, denoised_pred.shape[:2])[:, (block_idx - start_block)*self.num_frame_per_block: | |
| (block_idx - start_block+1)*self.num_frame_per_block] | |
| # rerun with timestep zero to update the clean cache, which is also detached from the computation graph | |
| with torch.no_grad(): | |
| context_timestep = torch.ones_like(current_timestep) * self.context_noise | |
| # # add context noise | |
| # denoised_pred = self.scheduler.add_noise( | |
| # denoised_pred.flatten(0, 1), | |
| # torch.randn_like(denoised_pred.flatten(0, 1)), | |
| # context_timestep * torch.ones( | |
| # [batch_size * current_num_frames], device=noise.device, dtype=torch.long) | |
| # ).unflatten(0, denoised_pred.shape[:2]) | |
| # only cache the first block | |
| denoised_pred = denoised_pred[:,:self.num_frame_per_block] | |
| context_timestep = context_timestep[:,:self.num_frame_per_block] | |
| self.generator( | |
| noisy_image_or_video=denoised_pred, | |
| conditional_dict=conditional_dict, | |
| timestep=context_timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length, | |
| updating_cache=True, | |
| ) | |
| # Step 3.5: Return the denoised timestep | |
| # can ignore since not used | |
| denoised_timestep_from, denoised_timestep_to = None, None | |
| return output, denoised_timestep_from, denoised_timestep_to | |
| def inference_with_self_forcing( | |
| self, | |
| noise: torch.Tensor, | |
| initial_latent: Optional[torch.Tensor] = None, | |
| return_sim_step: bool = False, | |
| **conditional_dict | |
| ) -> torch.Tensor: | |
| batch_size, num_frames, num_channels, height, width = noise.shape | |
| if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None): | |
| # If the first frame is independent and the first frame is provided, then the number of frames in the | |
| # noise should still be a multiple of num_frame_per_block | |
| assert num_frames % self.num_frame_per_block == 0 | |
| num_blocks = num_frames // self.num_frame_per_block | |
| else: | |
| # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning | |
| assert (num_frames - 1) % self.num_frame_per_block == 0 | |
| num_blocks = (num_frames - 1) // self.num_frame_per_block | |
| num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0 | |
| num_output_frames = num_frames + num_input_frames # add the initial latent frames | |
| output = torch.zeros( | |
| [batch_size, num_output_frames, num_channels, height, width], | |
| device=noise.device, | |
| dtype=noise.dtype | |
| ) | |
| # Step 1: Initialize KV cache to all zeros | |
| self._initialize_kv_cache( | |
| batch_size=batch_size, dtype=noise.dtype, device=noise.device | |
| ) | |
| self._initialize_crossattn_cache( | |
| batch_size=batch_size, dtype=noise.dtype, device=noise.device | |
| ) | |
| # if self.kv_cache_clean is None: | |
| # self._initialize_kv_cache( | |
| # batch_size=batch_size, | |
| # dtype=noise.dtype, | |
| # device=noise.device, | |
| # ) | |
| # self._initialize_crossattn_cache( | |
| # batch_size=batch_size, | |
| # dtype=noise.dtype, | |
| # device=noise.device | |
| # ) | |
| # else: | |
| # # reset cross attn cache | |
| # for block_index in range(self.num_transformer_blocks): | |
| # self.crossattn_cache[block_index]["is_init"] = False | |
| # # reset kv cache | |
| # for block_index in range(len(self.kv_cache_clean)): | |
| # self.kv_cache_clean[block_index]["global_end_index"] = torch.tensor( | |
| # [0], dtype=torch.long, device=noise.device) | |
| # self.kv_cache_clean[block_index]["local_end_index"] = torch.tensor( | |
| # [0], dtype=torch.long, device=noise.device) | |
| # Step 2: Cache context feature | |
| current_start_frame = 0 | |
| if initial_latent is not None: | |
| timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0 | |
| # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks | |
| output[:, :1] = initial_latent | |
| with torch.no_grad(): | |
| self.generator( | |
| noisy_image_or_video=initial_latent, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep * 0, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length | |
| ) | |
| current_start_frame += 1 | |
| # Step 3: Temporal denoising loop | |
| all_num_frames = [self.num_frame_per_block] * num_blocks | |
| if self.independent_first_frame and initial_latent is None: | |
| all_num_frames = [1] + all_num_frames | |
| num_denoising_steps = len(self.denoising_step_list) | |
| exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device) | |
| start_gradient_frame_index = num_output_frames - 21 | |
| # for block_index in range(num_blocks): | |
| for block_index, current_num_frames in enumerate(all_num_frames): | |
| noisy_input = noise[ | |
| :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames] | |
| # Step 3.1: Spatial denoising loop | |
| for index, current_timestep in enumerate(self.denoising_step_list): | |
| if self.same_step_across_blocks: | |
| exit_flag = (index == exit_flags[0]) | |
| else: | |
| exit_flag = (index == exit_flags[block_index]) # Only backprop at the randomly selected timestep (consistent across all ranks) | |
| timestep = torch.ones( | |
| [batch_size, current_num_frames], | |
| device=noise.device, | |
| dtype=torch.int64) * current_timestep | |
| if not exit_flag: | |
| with torch.no_grad(): | |
| _, denoised_pred = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length | |
| ) | |
| next_timestep = self.denoising_step_list[index + 1] | |
| noisy_input = self.scheduler.add_noise( | |
| denoised_pred.flatten(0, 1), | |
| torch.randn_like(denoised_pred.flatten(0, 1)), | |
| next_timestep * torch.ones( | |
| [batch_size * current_num_frames], device=noise.device, dtype=torch.long) | |
| ).unflatten(0, denoised_pred.shape[:2]) | |
| else: | |
| # for getting real output | |
| # with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index): | |
| if current_start_frame < start_gradient_frame_index: | |
| with torch.no_grad(): | |
| _, denoised_pred = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length | |
| ) | |
| else: | |
| _, denoised_pred = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length | |
| ) | |
| break | |
| # Step 3.2: record the model's output | |
| output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred | |
| # Step 3.3: rerun with timestep zero to update the cache | |
| context_timestep = torch.ones_like(timestep) * self.context_noise | |
| # add context noise | |
| denoised_pred = self.scheduler.add_noise( | |
| denoised_pred.flatten(0, 1), | |
| torch.randn_like(denoised_pred.flatten(0, 1)), | |
| context_timestep * torch.ones( | |
| [batch_size * current_num_frames], device=noise.device, dtype=torch.long) | |
| ).unflatten(0, denoised_pred.shape[:2]) | |
| with torch.no_grad(): | |
| self.generator( | |
| noisy_image_or_video=denoised_pred, | |
| conditional_dict=conditional_dict, | |
| timestep=context_timestep, | |
| kv_cache=self.kv_cache_clean, | |
| crossattn_cache=self.crossattn_cache, | |
| current_start=current_start_frame * self.frame_seq_length, | |
| updating_cache=True, | |
| ) | |
| # Step 3.4: update the start and end frame indices | |
| current_start_frame += current_num_frames | |
| # Step 3.5: Return the denoised timestep | |
| if not self.same_step_across_blocks: | |
| denoised_timestep_from, denoised_timestep_to = None, None | |
| elif exit_flags[0] == len(self.denoising_step_list) - 1: | |
| denoised_timestep_to = 0 | |
| denoised_timestep_from = 1000 - torch.argmin( | |
| (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item() | |
| else: | |
| denoised_timestep_to = 1000 - torch.argmin( | |
| (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item() | |
| denoised_timestep_from = 1000 - torch.argmin( | |
| (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item() | |
| if return_sim_step: | |
| return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1 | |
| return output, denoised_timestep_from, denoised_timestep_to | |
| def _initialize_kv_cache(self, batch_size, dtype, device): | |
| """ | |
| Initialize a Per-GPU KV cache for the Wan model. | |
| """ | |
| kv_cache_clean = [] | |
| for _ in range(self.num_transformer_blocks): | |
| kv_cache_clean.append({ | |
| "k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device), | |
| "v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device), | |
| "global_end_index": torch.tensor([0], dtype=torch.long, device=device), | |
| "local_end_index": torch.tensor([0], dtype=torch.long, device=device) | |
| }) | |
| self.kv_cache_clean = kv_cache_clean # always store the clean cache | |
| def _initialize_crossattn_cache(self, batch_size, dtype, device): | |
| """ | |
| Initialize a Per-GPU cross-attention cache for the Wan model. | |
| """ | |
| crossattn_cache = [] | |
| for _ in range(self.num_transformer_blocks): | |
| crossattn_cache.append({ | |
| "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device), | |
| "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device), | |
| "is_init": False | |
| }) | |
| self.crossattn_cache = crossattn_cache |