Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,158 Bytes
a3a2e41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
import torch
import torch.nn as nn
from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply
from ovi.modules.attention import flash_attention
from ovi.distributed_comms.communications import all_gather, all_to_all_4D
from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state
class FusionModel(nn.Module):
def __init__(self, video_config=None, audio_config=None):
super().__init__()
has_video = True
has_audio = True
if video_config is not None:
self.video_model = WanModel(**video_config)
else:
has_video = False
self.video_model = None
print("Warning: No video model is provided!")
if audio_config is not None:
self.audio_model = WanModel(**audio_config)
else:
has_audio = False
self.audio_model = None
print("Warning: No audio model is provided!")
if has_video and has_audio:
assert len(self.video_model.blocks) == len(self.audio_model.blocks)
self.num_blocks = len(self.video_model.blocks)
self.use_sp = get_sequence_parallel_state()
if self.use_sp:
self.sp_size = nccl_info.sp_size
self.sp_rank = nccl_info.rank_within_group
self.inject_cross_attention_kv_projections()
self.init_weights()
def inject_cross_attention_kv_projections(self):
for vid_block in self.video_model.blocks:
vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
for audio_block in self.audio_model.blocks:
audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
def merge_kwargs(self, vid_kwargs, audio_kwargs):
"""
keys in each kwarg:
e
seq_lens
grid_sizes
freqs
context
context_lens
"""
merged_kwargs = {}
for key in vid_kwargs:
merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
for key in audio_kwargs:
merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
return merged_kwargs
def single_fusion_cross_attention_forward(self,
cross_attn_block,
src_seq,
src_grid_sizes,
src_freqs,
target_seq,
target_seq_lens,
target_grid_sizes,
target_freqs,
context,
context_lens
):
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
if hasattr(cross_attn_block, "k_img"):
## means is i2v block
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
else:
## means is t2v block
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
k_img = v_img = None
if self.use_sp:
q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)
k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank]
v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank]
if k_img is not None:
k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank]
if v_img is not None:
v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank]
x = flash_attention(q, k, v, k_lens=context_lens)
if k_img is not None:
img_x = flash_attention(q, k_img, v_img, k_lens=None)
x = x + img_x
is_vid = src_grid_sizes.shape[1] > 1
# compute target attention
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
if self.use_sp:
k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
q = rope_apply(q, src_grid_sizes, src_freqs)
k_target = rope_apply(k_target, target_grid_sizes, target_freqs)
target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens)
x = x + target_x
if self.use_sp:
x = all_to_all_4D(x, scatter_dim=1, gather_dim=2) # [B, L/P, H, C/H]
x = x.flatten(2) # [B, L/P, C]
x = cross_attn_block.o(x)
return x
def single_fusion_cross_attention_ffn_forward(self,
attn_block,
src_seq,
src_grid_sizes,
src_freqs,
target_seq,
target_seq_lens,
target_grid_sizes,
target_freqs,
context,
context_lens,
src_e):
src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn,
attn_block.norm3(src_seq),
src_grid_sizes=src_grid_sizes,
src_freqs=src_freqs,
target_seq=target_seq,
target_seq_lens=target_seq_lens,
target_grid_sizes=target_grid_sizes,
target_freqs=target_freqs,
context=context,
context_lens=context_lens
)
y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2))
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
src_seq = src_seq + y * src_e[5].squeeze(2)
return src_seq
def single_fusion_block_forward(self,
vid_block,
audio_block,
vid,
audio,
vid_e,
vid_seq_lens,
vid_grid_sizes,
vid_freqs,
vid_context,
vid_context_lens,
audio_e,
audio_seq_lens,
audio_grid_sizes,
audio_freqs,
audio_context,
audio_context_lens
):
## audio modulation
assert audio_e.dtype == torch.bfloat16
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}"
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
audio_e = audio_block.modulation(audio_e).chunk(6, dim=2)
assert audio_e[0].dtype == torch.bfloat16
# audio self-attention
audio_y = audio_block.self_attn(
audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes,
audio_freqs)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
audio = audio + audio_y * audio_e[2].squeeze(2)
## video modulation
assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}"
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
vid_e = vid_block.modulation(vid_e).chunk(6, dim=2)
# video self-attention
vid_y = vid_block.self_attn(
vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes,
vid_freqs)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
vid = vid + vid_y * vid_e[2].squeeze(2)
og_audio = audio
# audio cross-attention
audio = self.single_fusion_cross_attention_ffn_forward(
audio_block,
audio,
audio_grid_sizes,
audio_freqs,
vid,
vid_seq_lens,
vid_grid_sizes,
vid_freqs,
audio_context,
audio_context_lens,
audio_e
)
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
# video cross-attention
vid = self.single_fusion_cross_attention_ffn_forward(
vid_block,
vid,
vid_grid_sizes,
vid_freqs,
og_audio,
audio_seq_lens,
audio_grid_sizes,
audio_freqs,
vid_context,
vid_context_lens,
vid_e
)
return vid, audio
def forward(
self,
vid,
audio,
t,
vid_context,
audio_context,
vid_seq_len,
audio_seq_len,
clip_fea=None,
clip_fea_audio=None,
y=None,
first_frame_is_clean=False,
slg_layer=False
):
assert clip_fea is None
assert y is None
if vid is None or all([x is None for x in vid]):
assert vid_context is None
assert vid_seq_len is None
assert self.audio_model is not None
return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None)
if audio is None or all([x is None for x in audio]):
assert clip_fea_audio is None
assert audio_context is None
assert audio_seq_len is None
assert self.video_model is not None
return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None
vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs(
x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean
)
audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs(
x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False
)
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
for i in range(self.num_blocks):
"""
1 fusion block refers to 1 audio block with 1 video block.
"""
if slg_layer > 0 and i == slg_layer:
continue
vid_block = self.video_model.blocks[i]
audio_block = self.audio_model.blocks[i]
vid, audio = gradient_checkpointing(
enabled=(self.training and self.gradient_checkpointing),
module=self.single_fusion_block_forward,
vid_block=vid_block,
audio_block=audio_block,
vid=vid,
audio=audio,
**kwargs
)
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e)
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e)
return vid, audio
def init_weights(self):
if self.audio_model is not None:
self.audio_model.init_weights()
if self.video_model is not None:
self.video_model.init_weights()
for name, mod in self.video_model.named_modules():
if "fusion" in name and isinstance(mod, nn.Linear):
with torch.no_grad():
mod.weight.div_(10.0)
def set_rope_params(self):
self.video_model.set_rope_params()
self.audio_model.set_rope_params() |