alexnasa's picture
Upload 121 files
a3a2e41 verified
import torch
import torch.nn as nn
from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply
from ovi.modules.attention import flash_attention
from ovi.distributed_comms.communications import all_gather, all_to_all_4D
from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state
class FusionModel(nn.Module):
def __init__(self, video_config=None, audio_config=None):
super().__init__()
has_video = True
has_audio = True
if video_config is not None:
self.video_model = WanModel(**video_config)
else:
has_video = False
self.video_model = None
print("Warning: No video model is provided!")
if audio_config is not None:
self.audio_model = WanModel(**audio_config)
else:
has_audio = False
self.audio_model = None
print("Warning: No audio model is provided!")
if has_video and has_audio:
assert len(self.video_model.blocks) == len(self.audio_model.blocks)
self.num_blocks = len(self.video_model.blocks)
self.use_sp = get_sequence_parallel_state()
if self.use_sp:
self.sp_size = nccl_info.sp_size
self.sp_rank = nccl_info.rank_within_group
self.inject_cross_attention_kv_projections()
self.init_weights()
def inject_cross_attention_kv_projections(self):
for vid_block in self.video_model.blocks:
vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
for audio_block in self.audio_model.blocks:
audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
def merge_kwargs(self, vid_kwargs, audio_kwargs):
"""
keys in each kwarg:
e
seq_lens
grid_sizes
freqs
context
context_lens
"""
merged_kwargs = {}
for key in vid_kwargs:
merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
for key in audio_kwargs:
merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
return merged_kwargs
def single_fusion_cross_attention_forward(self,
cross_attn_block,
src_seq,
src_grid_sizes,
src_freqs,
target_seq,
target_seq_lens,
target_grid_sizes,
target_freqs,
context,
context_lens
):
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
if hasattr(cross_attn_block, "k_img"):
## means is i2v block
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
else:
## means is t2v block
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
k_img = v_img = None
if self.use_sp:
q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)
k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank]
v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank]
if k_img is not None:
k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank]
if v_img is not None:
v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank]
x = flash_attention(q, k, v, k_lens=context_lens)
if k_img is not None:
img_x = flash_attention(q, k_img, v_img, k_lens=None)
x = x + img_x
is_vid = src_grid_sizes.shape[1] > 1
# compute target attention
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
if self.use_sp:
k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
q = rope_apply(q, src_grid_sizes, src_freqs)
k_target = rope_apply(k_target, target_grid_sizes, target_freqs)
target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens)
x = x + target_x
if self.use_sp:
x = all_to_all_4D(x, scatter_dim=1, gather_dim=2) # [B, L/P, H, C/H]
x = x.flatten(2) # [B, L/P, C]
x = cross_attn_block.o(x)
return x
def single_fusion_cross_attention_ffn_forward(self,
attn_block,
src_seq,
src_grid_sizes,
src_freqs,
target_seq,
target_seq_lens,
target_grid_sizes,
target_freqs,
context,
context_lens,
src_e):
src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn,
attn_block.norm3(src_seq),
src_grid_sizes=src_grid_sizes,
src_freqs=src_freqs,
target_seq=target_seq,
target_seq_lens=target_seq_lens,
target_grid_sizes=target_grid_sizes,
target_freqs=target_freqs,
context=context,
context_lens=context_lens
)
y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2))
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
src_seq = src_seq + y * src_e[5].squeeze(2)
return src_seq
def single_fusion_block_forward(self,
vid_block,
audio_block,
vid,
audio,
vid_e,
vid_seq_lens,
vid_grid_sizes,
vid_freqs,
vid_context,
vid_context_lens,
audio_e,
audio_seq_lens,
audio_grid_sizes,
audio_freqs,
audio_context,
audio_context_lens
):
## audio modulation
assert audio_e.dtype == torch.bfloat16
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}"
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
audio_e = audio_block.modulation(audio_e).chunk(6, dim=2)
assert audio_e[0].dtype == torch.bfloat16
# audio self-attention
audio_y = audio_block.self_attn(
audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes,
audio_freqs)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
audio = audio + audio_y * audio_e[2].squeeze(2)
## video modulation
assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}"
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
vid_e = vid_block.modulation(vid_e).chunk(6, dim=2)
# video self-attention
vid_y = vid_block.self_attn(
vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes,
vid_freqs)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
vid = vid + vid_y * vid_e[2].squeeze(2)
og_audio = audio
# audio cross-attention
audio = self.single_fusion_cross_attention_ffn_forward(
audio_block,
audio,
audio_grid_sizes,
audio_freqs,
vid,
vid_seq_lens,
vid_grid_sizes,
vid_freqs,
audio_context,
audio_context_lens,
audio_e
)
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
# video cross-attention
vid = self.single_fusion_cross_attention_ffn_forward(
vid_block,
vid,
vid_grid_sizes,
vid_freqs,
og_audio,
audio_seq_lens,
audio_grid_sizes,
audio_freqs,
vid_context,
vid_context_lens,
vid_e
)
return vid, audio
def forward(
self,
vid,
audio,
t,
vid_context,
audio_context,
vid_seq_len,
audio_seq_len,
clip_fea=None,
clip_fea_audio=None,
y=None,
first_frame_is_clean=False,
slg_layer=False
):
assert clip_fea is None
assert y is None
if vid is None or all([x is None for x in vid]):
assert vid_context is None
assert vid_seq_len is None
assert self.audio_model is not None
return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None)
if audio is None or all([x is None for x in audio]):
assert clip_fea_audio is None
assert audio_context is None
assert audio_seq_len is None
assert self.video_model is not None
return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None
vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs(
x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean
)
audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs(
x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False
)
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
for i in range(self.num_blocks):
"""
1 fusion block refers to 1 audio block with 1 video block.
"""
if slg_layer > 0 and i == slg_layer:
continue
vid_block = self.video_model.blocks[i]
audio_block = self.audio_model.blocks[i]
vid, audio = gradient_checkpointing(
enabled=(self.training and self.gradient_checkpointing),
module=self.single_fusion_block_forward,
vid_block=vid_block,
audio_block=audio_block,
vid=vid,
audio=audio,
**kwargs
)
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e)
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e)
return vid, audio
def init_weights(self):
if self.audio_model is not None:
self.audio_model.init_weights()
if self.video_model is not None:
self.video_model.init_weights()
for name, mod in self.video_model.named_modules():
if "fusion" in name and isinstance(mod, nn.Linear):
with torch.no_grad():
mod.weight.div_(10.0)
def set_rope_params(self):
self.video_model.set_rope_params()
self.audio_model.set_rope_params()