RollingForcing / pipeline /rolling_forcing_inference.py
kunhaokhliu
Add application file
5d2a97a
from typing import List, Optional
import torch
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
class CausalInferencePipeline(torch.nn.Module):
def __init__(
self,
args,
device,
generator=None,
text_encoder=None,
vae=None
):
super().__init__()
# Step 1: Initialize all models
self.generator = WanDiffusionWrapper(
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
self.vae = WanVAEWrapper() if vae is None else vae
# Step 2: Initialize all causal hyperparmeters
self.scheduler = self.generator.get_scheduler()
self.denoising_step_list = torch.tensor(
args.denoising_step_list, dtype=torch.long)
if args.warp_denoising_step:
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
self.num_transformer_blocks = 30
self.frame_seq_length = 1560
self.kv_cache_clean = None
self.args = args
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.independent_first_frame = args.independent_first_frame
self.local_attn_size = self.generator.model.local_attn_size
print(f"KV inference with {self.num_frame_per_block} frames per block")
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
def inference_rolling_forcing(
self,
noise: torch.Tensor,
text_prompts: List[str],
initial_latent: Optional[torch.Tensor] = None,
return_latents: bool = False,
profile: bool = False
) -> torch.Tensor:
"""
Perform inference on the given noise and text prompts.
Inputs:
noise (torch.Tensor): The input noise tensor of shape
(batch_size, num_output_frames, num_channels, height, width).
text_prompts (List[str]): The list of text prompts.
initial_latent (torch.Tensor): The initial latent tensor of shape
(batch_size, num_input_frames, num_channels, height, width).
If num_input_frames is 1, perform image to video.
If num_input_frames is greater than 1, perform video extension.
return_latents (bool): Whether to return the latents.
Outputs:
video (torch.Tensor): The generated video tensor of shape
(batch_size, num_output_frames, num_channels, height, width).
It is normalized to be in the range [0, 1].
"""
batch_size, num_frames, num_channels, height, width = noise.shape
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
# If the first frame is independent and the first frame is provided, then the number of frames in the
# noise should still be a multiple of num_frame_per_block
assert num_frames % self.num_frame_per_block == 0
num_blocks = num_frames // self.num_frame_per_block
else:
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
assert (num_frames - 1) % self.num_frame_per_block == 0
num_blocks = (num_frames - 1) // self.num_frame_per_block
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
num_output_frames = num_frames + num_input_frames # add the initial latent frames
conditional_dict = self.text_encoder(
text_prompts=text_prompts
)
output = torch.zeros(
[batch_size, num_output_frames, num_channels, height, width],
device=noise.device,
dtype=noise.dtype
)
# Set up profiling if requested
if profile:
init_start = torch.cuda.Event(enable_timing=True)
init_end = torch.cuda.Event(enable_timing=True)
diffusion_start = torch.cuda.Event(enable_timing=True)
diffusion_end = torch.cuda.Event(enable_timing=True)
vae_start = torch.cuda.Event(enable_timing=True)
vae_end = torch.cuda.Event(enable_timing=True)
block_times = []
block_start = torch.cuda.Event(enable_timing=True)
block_end = torch.cuda.Event(enable_timing=True)
init_start.record()
# Step 1: Initialize KV cache to all zeros
if self.kv_cache_clean is None:
self._initialize_kv_cache(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
self._initialize_crossattn_cache(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
else:
# reset cross attn cache
for block_index in range(self.num_transformer_blocks):
self.crossattn_cache[block_index]["is_init"] = False
# reset kv cache
for block_index in range(len(self.kv_cache_clean)):
self.kv_cache_clean[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_clean[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
# Step 2: Cache context feature
if initial_latent is not None:
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
if self.independent_first_frame:
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
assert (num_input_frames - 1) % self.num_frame_per_block == 0
num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
output[:, :1] = initial_latent[:, :1]
self.generator(
noisy_image_or_video=initial_latent[:, :1],
conditional_dict=conditional_dict,
timestep=timestep * 0,
kv_cache=self.kv_cache_clean,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
current_start_frame += 1
else:
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
assert num_input_frames % self.num_frame_per_block == 0
num_input_blocks = num_input_frames // self.num_frame_per_block
for _ in range(num_input_blocks):
current_ref_latents = \
initial_latent[:, current_start_frame:current_start_frame + self.num_frame_per_block]
output[:, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
self.generator(
noisy_image_or_video=current_ref_latents,
conditional_dict=conditional_dict,
timestep=timestep * 0,
kv_cache=self.kv_cache_clean,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
current_start_frame += self.num_frame_per_block
if profile:
init_end.record()
torch.cuda.synchronize()
diffusion_start.record()
# implementing rolling forcing
# construct the rolling forcing windows
num_denoising_steps = len(self.denoising_step_list)
rolling_window_length_blocks = num_denoising_steps
window_start_blocks = []
window_end_blocks = []
window_num = num_blocks + rolling_window_length_blocks - 1
for window_index in range(window_num):
start_block = max(0, window_index - rolling_window_length_blocks + 1)
end_block = min(num_blocks - 1, window_index)
window_start_blocks.append(start_block)
window_end_blocks.append(end_block)
# init noisy cache
noisy_cache = torch.zeros(
[batch_size, num_output_frames, num_channels, height, width],
device=noise.device,
dtype=noise.dtype
)
# init denosing timestep, same accross windows
shared_timestep = torch.ones(
[batch_size, rolling_window_length_blocks * self.num_frame_per_block],
device=noise.device,
dtype=torch.float32)
for index, current_timestep in enumerate(reversed(self.denoising_step_list)): # from clean to noisy
shared_timestep[:, index * self.num_frame_per_block:(index + 1) * self.num_frame_per_block] *= current_timestep
# Denoising loop with rolling forcing
for window_index in range(window_num):
if profile:
block_start.record()
print('window_index:', window_index)
start_block = window_start_blocks[window_index]
end_block = window_end_blocks[window_index] # include
print(f"start_block: {start_block}, end_block: {end_block}")
current_start_frame = start_block * self.num_frame_per_block
current_end_frame = (end_block + 1) * self.num_frame_per_block # not include
current_num_frames = current_end_frame - current_start_frame
# noisy_input: new noise and previous denoised noisy frames, only last block is pure noise
if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block or current_start_frame == 0:
noisy_input = torch.cat([
noisy_cache[:, current_start_frame : current_end_frame - self.num_frame_per_block],
noise[:, current_end_frame - self.num_frame_per_block : current_end_frame ]
], dim=1)
else: # at the end of the video
noisy_input = noisy_cache[:, current_start_frame:current_end_frame]
# init denosing timestep
if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block:
current_timestep = shared_timestep
elif current_start_frame == 0:
current_timestep = shared_timestep[:,-current_num_frames:]
elif current_end_frame == num_frames:
current_timestep = shared_timestep[:,:current_num_frames]
else:
raise ValueError("current_num_frames should be equal to rolling_window_length_blocks * self.num_frame_per_block, or the first or last window.")
# calling DiT
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=current_timestep,
kv_cache=self.kv_cache_clean,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
output[:, current_start_frame:current_end_frame] = denoised_pred
# update noisy_cache, which is detached from the computation graph
with torch.no_grad():
for block_idx in range(start_block, end_block + 1):
block_time_step = current_timestep[:,
(block_idx - start_block)*self.num_frame_per_block :
(block_idx - start_block+1)*self.num_frame_per_block].mean().item()
matches = torch.abs(self.denoising_step_list - block_time_step) < 1e-4
block_timestep_index = torch.nonzero(matches, as_tuple=True)[0]
if block_timestep_index == len(self.denoising_step_list) - 1:
continue
next_timestep = self.denoising_step_list[block_timestep_index + 1].to(noise.device)
noisy_cache[:, block_idx * self.num_frame_per_block:
(block_idx+1) * self.num_frame_per_block] = \
self.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
next_timestep * torch.ones(
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
).unflatten(0, denoised_pred.shape[:2])[:, (block_idx - start_block)*self.num_frame_per_block:
(block_idx - start_block+1)*self.num_frame_per_block]
# rerun with timestep zero to update the clean cache, which is also detached from the computation graph
with torch.no_grad():
context_timestep = torch.ones_like(current_timestep) * self.args.context_noise
# # add context noise
# denoised_pred = self.scheduler.add_noise(
# denoised_pred.flatten(0, 1),
# torch.randn_like(denoised_pred.flatten(0, 1)),
# context_timestep * torch.ones(
# [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
# ).unflatten(0, denoised_pred.shape[:2])
# only cache the first block
denoised_pred = denoised_pred[:,:self.num_frame_per_block]
context_timestep = context_timestep[:,:self.num_frame_per_block]
self.generator(
noisy_image_or_video=denoised_pred,
conditional_dict=conditional_dict,
timestep=context_timestep,
kv_cache=self.kv_cache_clean,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
updating_cache=True,
)
if profile:
block_end.record()
torch.cuda.synchronize()
block_time = block_start.elapsed_time(block_end)
block_times.append(block_time)
if profile:
# End diffusion timing and synchronize CUDA
diffusion_end.record()
torch.cuda.synchronize()
diffusion_time = diffusion_start.elapsed_time(diffusion_end)
init_time = init_start.elapsed_time(init_end)
vae_start.record()
# Step 4: Decode the output
video = self.vae.decode_to_pixel(output, use_cache=False)
video = (video * 0.5 + 0.5).clamp(0, 1)
if profile:
# End VAE timing and synchronize CUDA
vae_end.record()
torch.cuda.synchronize()
vae_time = vae_start.elapsed_time(vae_end)
total_time = init_time + diffusion_time + vae_time
print("Profiling results:")
print(f" - Initialization/caching time: {init_time:.2f} ms ({100 * init_time / total_time:.2f}%)")
print(f" - Diffusion generation time: {diffusion_time:.2f} ms ({100 * diffusion_time / total_time:.2f}%)")
for i, block_time in enumerate(block_times):
print(f" - Block {i} generation time: {block_time:.2f} ms ({100 * block_time / diffusion_time:.2f}% of diffusion)")
print(f" - VAE decoding time: {vae_time:.2f} ms ({100 * vae_time / total_time:.2f}%)")
print(f" - Total time: {total_time:.2f} ms")
if return_latents:
return video, output
else:
return video
def _initialize_kv_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache_clean = []
# if self.local_attn_size != -1:
# # Use the local attention size to compute the KV cache size
# kv_cache_size = self.local_attn_size * self.frame_seq_length
# else:
# # Use the default KV cache size
kv_cache_size = 1560 * 24
for _ in range(self.num_transformer_blocks):
kv_cache_clean.append({
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
self.kv_cache_clean = kv_cache_clean # always store the clean cache
def _initialize_crossattn_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
"""
crossattn_cache = []
for _ in range(self.num_transformer_blocks):
crossattn_cache.append({
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
"is_init": False
})
self.crossattn_cache = crossattn_cache