# Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import torch from torch import nn from tqdm import tqdm from torchdiffeq import odeint from .backbones.dit import DiT from .cache_utils import BlockFlowMatchingCache from torch.nn.attention.flex_attention import create_block_mask def all_mask(b, h, q_idx, kv_idx): return q_idx == q_idx class CFM(nn.Module): def __init__( self, transformer: DiT, sigma=0.0, odeint_kwargs: dict = dict( # atol = 1e-5, # rtol = 1e-5, method="euler" # 'midpoint' # method="adaptive_heun" ), odeint_options: dict = dict( min_step=0.05 ), num_channels=None, block_size=None, num_history_block=None ): super().__init__() self.num_channels = num_channels # transformer self.transformer = transformer dim = transformer.dim self.dim = dim # conditional flow related self.sigma = sigma # sampling related self.odeint_kwargs = odeint_kwargs print(f"ODE SOLVER: {self.odeint_kwargs['method']}") self.odeint_options = odeint_options self.block_size = block_size self.num_history_block = num_history_block if self.num_history_block is not None and self.num_history_block <= 0: self.num_history_block = None print(f"block_size: {self.block_size}; num_history_block: {self.num_history_block}") @property def device(self): return next(self.parameters()).device @torch.no_grad() def sample_block_cache( self, text, duration, # noqa: F821 style_prompt, steps=32, cfg_strength=1.0, odeint_method='euler' ): self.eval() batch = text.shape[0] device = self.device num_blocks = duration // self.block_size + (duration % self.block_size > 0) text_emb = self.transformer.text_embed(text) cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text)) text_lens = torch.LongTensor([text_emb.shape[1]]).to(device) clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype) noisy_lens = torch.LongTensor([self.block_size]).to(device) block_iterator = range(num_blocks) # create cache kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype) # generate text cache text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype) text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1) text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool() # text_attn_mask = create_block_mask( # all_mask, # B = batch, # H = None, # Q_LEN=text_emb.shape[1], # KV_LEN=text_emb.shape[1] # ) if text_emb.shape[1] != 0: with kv_cache.cache_text(): _, _, kv_cache = self.transformer( x = text_emb, time=text_time, attn_mask=text_attn_mask, position_ids=text_position_ids, style_prompt=style_prompt, use_cache=True, past_key_value = kv_cache ) with cfg_kv_cache.cache_text(): _, _, cfg_kv_cache = self.transformer( x = cfg_text_emb, time=text_time, attn_mask=text_attn_mask, position_ids=text_position_ids, style_prompt=torch.zeros_like(style_prompt), use_cache=True, past_key_value = cfg_kv_cache ) end_pos = 0 for bid in block_iterator: clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device) #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True) # all one mask attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV] # attn_mask = create_block_mask( # all_mask, # B = batch, # H = None, # Q_LEN=noisy_lens.max(), # KV_LEN=(text_lens + clean_lens + noisy_lens).max() # ) # generate position id position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1) position_ids = position_ids[:, -noisy_lens.max():] # core sample fn def fn(t, x): noisy_embed = self.transformer.latent_embed(x) if t.ndim == 0: t = t.repeat(batch) time = t[:, None].repeat(1, noisy_lens.max()) pred, *_ = self.transformer( x=noisy_embed, time=time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=style_prompt, use_cache=True, past_key_value = kv_cache ) if cfg_strength < 1e-5: return pred null_pred, *_ = self.transformer( x=noisy_embed, time=time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=torch.zeros_like(style_prompt), use_cache=True, past_key_value = cfg_kv_cache ) return pred + (pred - null_pred) * cfg_strength # generate time noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype) t_start = 0 t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype) # sampling outputs = odeint(fn, noisy_emb, t_set, method=odeint_method) sampled = outputs[-1] # generate next kv cache cache_embed = self.transformer.latent_embed(sampled) with kv_cache.cache_context(): _, _, kv_cache = self.transformer( x = cache_embed, time=cache_time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=style_prompt, use_cache=True, past_key_value = kv_cache ) with cfg_kv_cache.cache_context(): _, _, cfg_kv_cache = self.transformer( x = cache_embed, time=cache_time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=torch.zeros_like(style_prompt), use_cache=True, past_key_value = cfg_kv_cache ) # push new block clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1) pos = -1 curr_frame = clean_emb_stream[:, pos, :] eos = torch.ones_like(curr_frame) last_kl = torch.nn.functional.mse_loss( curr_frame, eos ) if last_kl.abs() <= 0.05: while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]: pos -= 1 curr_frame = clean_emb_stream[:, pos, :] last_kl = torch.nn.functional.mse_loss( curr_frame, eos ) end_pos = clean_emb_stream.shape[1] + pos break else: end_pos = clean_emb_stream.shape[1] clean_emb_stream = clean_emb_stream[:, :end_pos, :] return clean_emb_stream def sample_cache_stream( self, decoder, text, duration, # noqa: F821 style_prompt, steps=32, cfg_strength=1.0, seed: int | None = None, chunk_size=10, overlap=2, odeint_method='euler' ): self.eval() batch = text.shape[0] device = self.device num_blocks = duration // self.block_size + (duration % self.block_size > 0) text_emb = self.transformer.text_embed(text) cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text)) text_lens = torch.LongTensor([text_emb.shape[1]]).to(device) clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype) noisy_lens = torch.LongTensor([self.block_size]).to(device) block_iterator = range(num_blocks) # create cache kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype) # generate text cache text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype) text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1) text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool() if text_emb.shape[1] != 0: with kv_cache.cache_text(): _, _, kv_cache = self.transformer( x = text_emb, time=text_time, attn_mask=text_attn_mask, position_ids=text_position_ids, style_prompt=style_prompt, use_cache=True, past_key_value = kv_cache ) with cfg_kv_cache.cache_text(): _, _, cfg_kv_cache = self.transformer( x = cfg_text_emb, time=text_time, attn_mask=text_attn_mask, position_ids=text_position_ids, style_prompt=torch.zeros_like(style_prompt), use_cache=True, past_key_value = cfg_kv_cache ) end_pos = 0 last_decoder_pos = 0 decode_audio = [] for bid in block_iterator: clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device) #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True) # all one mask attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV] # generate position id position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1) position_ids = position_ids[:, -noisy_lens.max():] # core sample fn def fn(t, x): noisy_embed = self.transformer.latent_embed(x) if t.ndim == 0: t = t.repeat(batch) time = t[:, None].repeat(1, noisy_lens.max()) pred, *_ = self.transformer( x=noisy_embed, time=time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=style_prompt, use_cache=True, past_key_value = kv_cache ) if cfg_strength < 1e-5: return pred null_pred, *_ = self.transformer( x=noisy_embed, time=time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=torch.zeros_like(style_prompt), use_cache=True, past_key_value = cfg_kv_cache ) return pred + (pred - null_pred) * cfg_strength # generate time noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype) t_start = 0 t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype) # sampling outputs = odeint(fn, noisy_emb, t_set, method=odeint_method) sampled = outputs[-1] # generate next kv cache cache_embed = self.transformer.latent_embed(sampled) with kv_cache.cache_context(): _, _, kv_cache = self.transformer( x = cache_embed, time=cache_time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=style_prompt, use_cache=True, past_key_value = kv_cache ) with cfg_kv_cache.cache_context(): _, _, cfg_kv_cache = self.transformer( x = cache_embed, time=cache_time, attn_mask=attn_mask, position_ids=position_ids, style_prompt=torch.zeros_like(style_prompt), use_cache=True, past_key_value = cfg_kv_cache ) # push new block clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1) pos = -1 curr_frame = clean_emb_stream[:, pos, :] eos = torch.ones_like(curr_frame) last_kl = torch.nn.functional.mse_loss( curr_frame, eos ) if last_kl.abs() <= 0.05: while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]: pos -= 1 curr_frame = clean_emb_stream[:, pos, :] last_kl = torch.nn.functional.mse_loss( curr_frame, eos ) end_pos = clean_emb_stream.shape[1] + pos break else: end_pos = clean_emb_stream.shape[1] if end_pos - last_decoder_pos >= chunk_size: start = max(0, last_decoder_pos - overlap) overlap_frame = max(0, last_decoder_pos - start) latent = clean_emb_stream[:, start:end_pos, :] audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T] # print(last_decoder_pos, start, end_pos, latent.shape, audio.shape, clean_emb_stream.shape, chunk_size, overlap_frame, last_decoder_pos-overlap, last_decoder_pos-start) audio = audio[:, :, overlap_frame * 9600:] print(audio.shape) yield audio last_decoder_pos = end_pos clean_emb_stream = clean_emb_stream[:, :end_pos, :] start = max(0, last_decoder_pos - overlap) overlap = max(0, last_decoder_pos - start) latent = clean_emb_stream[:, start:end_pos, :] audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T] audio = audio[:, :, overlap * 9600:] print("last", audio.shape) audio = torch.cat([audio, torch.zeros(audio.shape[0], audio.shape[1], 5, device=audio.device, dtype=audio.dtype)], dim=-1) print(audio.shape) yield audio