Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply | |
| from ovi.modules.attention import flash_attention | |
| from ovi.distributed_comms.communications import all_gather, all_to_all_4D | |
| from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state | |
| class FusionModel(nn.Module): | |
| def __init__(self, video_config=None, audio_config=None): | |
| super().__init__() | |
| has_video = True | |
| has_audio = True | |
| if video_config is not None: | |
| self.video_model = WanModel(**video_config) | |
| else: | |
| has_video = False | |
| self.video_model = None | |
| print("Warning: No video model is provided!") | |
| if audio_config is not None: | |
| self.audio_model = WanModel(**audio_config) | |
| else: | |
| has_audio = False | |
| self.audio_model = None | |
| print("Warning: No audio model is provided!") | |
| if has_video and has_audio: | |
| assert len(self.video_model.blocks) == len(self.audio_model.blocks) | |
| self.num_blocks = len(self.video_model.blocks) | |
| self.use_sp = get_sequence_parallel_state() | |
| if self.use_sp: | |
| self.sp_size = nccl_info.sp_size | |
| self.sp_rank = nccl_info.rank_within_group | |
| self.inject_cross_attention_kv_projections() | |
| self.init_weights() | |
| def inject_cross_attention_kv_projections(self): | |
| for vid_block in self.video_model.blocks: | |
| vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim) | |
| vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim) | |
| vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True) | |
| vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity() | |
| for audio_block in self.audio_model.blocks: | |
| audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim) | |
| audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim) | |
| audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True) | |
| audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity() | |
| def merge_kwargs(self, vid_kwargs, audio_kwargs): | |
| """ | |
| keys in each kwarg: | |
| e | |
| seq_lens | |
| grid_sizes | |
| freqs | |
| context | |
| context_lens | |
| """ | |
| merged_kwargs = {} | |
| for key in vid_kwargs: | |
| merged_kwargs[f"vid_{key}"] = vid_kwargs[key] | |
| for key in audio_kwargs: | |
| merged_kwargs[f"audio_{key}"] = audio_kwargs[key] | |
| return merged_kwargs | |
| def single_fusion_cross_attention_forward(self, | |
| cross_attn_block, | |
| src_seq, | |
| src_grid_sizes, | |
| src_freqs, | |
| target_seq, | |
| target_seq_lens, | |
| target_grid_sizes, | |
| target_freqs, | |
| context, | |
| context_lens | |
| ): | |
| b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim | |
| if hasattr(cross_attn_block, "k_img"): | |
| ## means is i2v block | |
| q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context) | |
| else: | |
| ## means is t2v block | |
| q, k, v = cross_attn_block.qkv_fn(src_seq, context) | |
| k_img = v_img = None | |
| if self.use_sp: | |
| q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) | |
| k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank] | |
| v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank] | |
| if k_img is not None: | |
| k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank] | |
| if v_img is not None: | |
| v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank] | |
| x = flash_attention(q, k, v, k_lens=context_lens) | |
| if k_img is not None: | |
| img_x = flash_attention(q, k_img, v_img, k_lens=None) | |
| x = x + img_x | |
| is_vid = src_grid_sizes.shape[1] > 1 | |
| # compute target attention | |
| target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq) | |
| k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d) | |
| v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d) | |
| if self.use_sp: | |
| k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H] | |
| v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H] | |
| q = rope_apply(q, src_grid_sizes, src_freqs) | |
| k_target = rope_apply(k_target, target_grid_sizes, target_freqs) | |
| target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens) | |
| x = x + target_x | |
| if self.use_sp: | |
| x = all_to_all_4D(x, scatter_dim=1, gather_dim=2) # [B, L/P, H, C/H] | |
| x = x.flatten(2) # [B, L/P, C] | |
| x = cross_attn_block.o(x) | |
| return x | |
| def single_fusion_cross_attention_ffn_forward(self, | |
| attn_block, | |
| src_seq, | |
| src_grid_sizes, | |
| src_freqs, | |
| target_seq, | |
| target_seq_lens, | |
| target_grid_sizes, | |
| target_freqs, | |
| context, | |
| context_lens, | |
| src_e): | |
| src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn, | |
| attn_block.norm3(src_seq), | |
| src_grid_sizes=src_grid_sizes, | |
| src_freqs=src_freqs, | |
| target_seq=target_seq, | |
| target_seq_lens=target_seq_lens, | |
| target_grid_sizes=target_grid_sizes, | |
| target_freqs=target_freqs, | |
| context=context, | |
| context_lens=context_lens | |
| ) | |
| y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2)) | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| src_seq = src_seq + y * src_e[5].squeeze(2) | |
| return src_seq | |
| def single_fusion_block_forward(self, | |
| vid_block, | |
| audio_block, | |
| vid, | |
| audio, | |
| vid_e, | |
| vid_seq_lens, | |
| vid_grid_sizes, | |
| vid_freqs, | |
| vid_context, | |
| vid_context_lens, | |
| audio_e, | |
| audio_seq_lens, | |
| audio_grid_sizes, | |
| audio_freqs, | |
| audio_context, | |
| audio_context_lens | |
| ): | |
| ## audio modulation | |
| assert audio_e.dtype == torch.bfloat16 | |
| assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}" | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| audio_e = audio_block.modulation(audio_e).chunk(6, dim=2) | |
| assert audio_e[0].dtype == torch.bfloat16 | |
| # audio self-attention | |
| audio_y = audio_block.self_attn( | |
| audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes, | |
| audio_freqs) | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| audio = audio + audio_y * audio_e[2].squeeze(2) | |
| ## video modulation | |
| assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}" | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| vid_e = vid_block.modulation(vid_e).chunk(6, dim=2) | |
| # video self-attention | |
| vid_y = vid_block.self_attn( | |
| vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes, | |
| vid_freqs) | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| vid = vid + vid_y * vid_e[2].squeeze(2) | |
| og_audio = audio | |
| # audio cross-attention | |
| audio = self.single_fusion_cross_attention_ffn_forward( | |
| audio_block, | |
| audio, | |
| audio_grid_sizes, | |
| audio_freqs, | |
| vid, | |
| vid_seq_lens, | |
| vid_grid_sizes, | |
| vid_freqs, | |
| audio_context, | |
| audio_context_lens, | |
| audio_e | |
| ) | |
| assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!" | |
| # video cross-attention | |
| vid = self.single_fusion_cross_attention_ffn_forward( | |
| vid_block, | |
| vid, | |
| vid_grid_sizes, | |
| vid_freqs, | |
| og_audio, | |
| audio_seq_lens, | |
| audio_grid_sizes, | |
| audio_freqs, | |
| vid_context, | |
| vid_context_lens, | |
| vid_e | |
| ) | |
| return vid, audio | |
| def forward( | |
| self, | |
| vid, | |
| audio, | |
| t, | |
| vid_context, | |
| audio_context, | |
| vid_seq_len, | |
| audio_seq_len, | |
| clip_fea=None, | |
| clip_fea_audio=None, | |
| y=None, | |
| first_frame_is_clean=False, | |
| slg_layer=False | |
| ): | |
| assert clip_fea is None | |
| assert y is None | |
| if vid is None or all([x is None for x in vid]): | |
| assert vid_context is None | |
| assert vid_seq_len is None | |
| assert self.audio_model is not None | |
| return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None) | |
| if audio is None or all([x is None for x in audio]): | |
| assert clip_fea_audio is None | |
| assert audio_context is None | |
| assert audio_seq_len is None | |
| assert self.video_model is not None | |
| return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None | |
| vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs( | |
| x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean | |
| ) | |
| audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs( | |
| x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False | |
| ) | |
| kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs) | |
| for i in range(self.num_blocks): | |
| """ | |
| 1 fusion block refers to 1 audio block with 1 video block. | |
| """ | |
| if slg_layer > 0 and i == slg_layer: | |
| continue | |
| vid_block = self.video_model.blocks[i] | |
| audio_block = self.audio_model.blocks[i] | |
| vid, audio = gradient_checkpointing( | |
| enabled=(self.training and self.gradient_checkpointing), | |
| module=self.single_fusion_block_forward, | |
| vid_block=vid_block, | |
| audio_block=audio_block, | |
| vid=vid, | |
| audio=audio, | |
| **kwargs | |
| ) | |
| vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e) | |
| audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e) | |
| return vid, audio | |
| def init_weights(self): | |
| if self.audio_model is not None: | |
| self.audio_model.init_weights() | |
| if self.video_model is not None: | |
| self.video_model.init_weights() | |
| for name, mod in self.video_model.named_modules(): | |
| if "fusion" in name and isinstance(mod, nn.Linear): | |
| with torch.no_grad(): | |
| mod.weight.div_(10.0) | |
| def set_rope_params(self): | |
| self.video_model.set_rope_params() | |
| self.audio_model.set_rope_params() |