Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import torch | |
| from torch import nn | |
| from tqdm import tqdm | |
| from torchdiffeq import odeint | |
| from .backbones.dit import DiT | |
| from .cache_utils import BlockFlowMatchingCache | |
| from torch.nn.attention.flex_attention import create_block_mask | |
| def all_mask(b, h, q_idx, kv_idx): | |
| return q_idx == q_idx | |
| class CFM(nn.Module): | |
| def __init__( | |
| self, | |
| transformer: DiT, | |
| sigma=0.0, | |
| odeint_kwargs: dict = dict( | |
| # atol = 1e-5, | |
| # rtol = 1e-5, | |
| method="euler" # 'midpoint' | |
| # method="adaptive_heun" | |
| ), | |
| odeint_options: dict = dict( | |
| min_step=0.05 | |
| ), | |
| num_channels=None, | |
| block_size=None, | |
| num_history_block=None | |
| ): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| # transformer | |
| self.transformer = transformer | |
| dim = transformer.dim | |
| self.dim = dim | |
| # conditional flow related | |
| self.sigma = sigma | |
| # sampling related | |
| self.odeint_kwargs = odeint_kwargs | |
| print(f"ODE SOLVER: {self.odeint_kwargs['method']}") | |
| self.odeint_options = odeint_options | |
| self.block_size = block_size | |
| self.num_history_block = num_history_block | |
| if self.num_history_block is not None and self.num_history_block <= 0: | |
| self.num_history_block = None | |
| print(f"block_size: {self.block_size}; num_history_block: {self.num_history_block}") | |
| def device(self): | |
| return next(self.parameters()).device | |
| def sample_block_cache( | |
| self, | |
| text, | |
| duration, # noqa: F821 | |
| style_prompt, | |
| steps=32, | |
| cfg_strength=1.0, | |
| odeint_method='euler' | |
| ): | |
| self.eval() | |
| batch = text.shape[0] | |
| device = self.device | |
| num_blocks = duration // self.block_size + (duration % self.block_size > 0) | |
| text_emb = self.transformer.text_embed(text) | |
| cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text)) | |
| text_lens = torch.LongTensor([text_emb.shape[1]]).to(device) | |
| clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype) | |
| noisy_lens = torch.LongTensor([self.block_size]).to(device) | |
| block_iterator = range(num_blocks) | |
| # create cache | |
| kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) | |
| cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) | |
| cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype) | |
| # generate text cache | |
| text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype) | |
| text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1) | |
| text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool() | |
| # text_attn_mask = create_block_mask( | |
| # all_mask, | |
| # B = batch, | |
| # H = None, | |
| # Q_LEN=text_emb.shape[1], | |
| # KV_LEN=text_emb.shape[1] | |
| # ) | |
| if text_emb.shape[1] != 0: | |
| with kv_cache.cache_text(): | |
| _, _, kv_cache = self.transformer( | |
| x = text_emb, | |
| time=text_time, | |
| attn_mask=text_attn_mask, | |
| position_ids=text_position_ids, | |
| style_prompt=style_prompt, | |
| use_cache=True, | |
| past_key_value = kv_cache | |
| ) | |
| with cfg_kv_cache.cache_text(): | |
| _, _, cfg_kv_cache = self.transformer( | |
| x = cfg_text_emb, | |
| time=text_time, | |
| attn_mask=text_attn_mask, | |
| position_ids=text_position_ids, | |
| style_prompt=torch.zeros_like(style_prompt), | |
| use_cache=True, | |
| past_key_value = cfg_kv_cache | |
| ) | |
| end_pos = 0 | |
| for bid in block_iterator: | |
| clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device) | |
| #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True) | |
| # all one mask | |
| attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV] | |
| # attn_mask = create_block_mask( | |
| # all_mask, | |
| # B = batch, | |
| # H = None, | |
| # Q_LEN=noisy_lens.max(), | |
| # KV_LEN=(text_lens + clean_lens + noisy_lens).max() | |
| # ) | |
| # generate position id | |
| position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1) | |
| position_ids = position_ids[:, -noisy_lens.max():] | |
| # core sample fn | |
| def fn(t, x): | |
| noisy_embed = self.transformer.latent_embed(x) | |
| if t.ndim == 0: | |
| t = t.repeat(batch) | |
| time = t[:, None].repeat(1, noisy_lens.max()) | |
| pred, *_ = self.transformer( | |
| x=noisy_embed, | |
| time=time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=style_prompt, | |
| use_cache=True, | |
| past_key_value = kv_cache | |
| ) | |
| if cfg_strength < 1e-5: | |
| return pred | |
| null_pred, *_ = self.transformer( | |
| x=noisy_embed, | |
| time=time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=torch.zeros_like(style_prompt), | |
| use_cache=True, | |
| past_key_value = cfg_kv_cache | |
| ) | |
| return pred + (pred - null_pred) * cfg_strength | |
| # generate time | |
| noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype) | |
| t_start = 0 | |
| t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype) | |
| # sampling | |
| outputs = odeint(fn, noisy_emb, t_set, method=odeint_method) | |
| sampled = outputs[-1] | |
| # generate next kv cache | |
| cache_embed = self.transformer.latent_embed(sampled) | |
| with kv_cache.cache_context(): | |
| _, _, kv_cache = self.transformer( | |
| x = cache_embed, | |
| time=cache_time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=style_prompt, | |
| use_cache=True, | |
| past_key_value = kv_cache | |
| ) | |
| with cfg_kv_cache.cache_context(): | |
| _, _, cfg_kv_cache = self.transformer( | |
| x = cache_embed, | |
| time=cache_time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=torch.zeros_like(style_prompt), | |
| use_cache=True, | |
| past_key_value = cfg_kv_cache | |
| ) | |
| # push new block | |
| clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1) | |
| pos = -1 | |
| curr_frame = clean_emb_stream[:, pos, :] | |
| eos = torch.ones_like(curr_frame) | |
| last_kl = torch.nn.functional.mse_loss( | |
| curr_frame, | |
| eos | |
| ) | |
| if last_kl.abs() <= 0.05: | |
| while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]: | |
| pos -= 1 | |
| curr_frame = clean_emb_stream[:, pos, :] | |
| last_kl = torch.nn.functional.mse_loss( | |
| curr_frame, | |
| eos | |
| ) | |
| end_pos = clean_emb_stream.shape[1] + pos | |
| break | |
| else: | |
| end_pos = clean_emb_stream.shape[1] | |
| clean_emb_stream = clean_emb_stream[:, :end_pos, :] | |
| return clean_emb_stream | |
| def sample_cache_stream( | |
| self, | |
| decoder, | |
| text, | |
| duration, # noqa: F821 | |
| style_prompt, | |
| steps=32, | |
| cfg_strength=1.0, | |
| seed: int | None = None, | |
| chunk_size=10, | |
| overlap=2, | |
| odeint_method='euler' | |
| ): | |
| self.eval() | |
| batch = text.shape[0] | |
| device = self.device | |
| num_blocks = duration // self.block_size + (duration % self.block_size > 0) | |
| text_emb = self.transformer.text_embed(text) | |
| cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text)) | |
| text_lens = torch.LongTensor([text_emb.shape[1]]).to(device) | |
| clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype) | |
| noisy_lens = torch.LongTensor([self.block_size]).to(device) | |
| block_iterator = range(num_blocks) | |
| # create cache | |
| kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) | |
| cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block) | |
| cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype) | |
| # generate text cache | |
| text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype) | |
| text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1) | |
| text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool() | |
| if text_emb.shape[1] != 0: | |
| with kv_cache.cache_text(): | |
| _, _, kv_cache = self.transformer( | |
| x = text_emb, | |
| time=text_time, | |
| attn_mask=text_attn_mask, | |
| position_ids=text_position_ids, | |
| style_prompt=style_prompt, | |
| use_cache=True, | |
| past_key_value = kv_cache | |
| ) | |
| with cfg_kv_cache.cache_text(): | |
| _, _, cfg_kv_cache = self.transformer( | |
| x = cfg_text_emb, | |
| time=text_time, | |
| attn_mask=text_attn_mask, | |
| position_ids=text_position_ids, | |
| style_prompt=torch.zeros_like(style_prompt), | |
| use_cache=True, | |
| past_key_value = cfg_kv_cache | |
| ) | |
| end_pos = 0 | |
| last_decoder_pos = 0 | |
| decode_audio = [] | |
| for bid in block_iterator: | |
| clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device) | |
| #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True) | |
| # all one mask | |
| attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV] | |
| # generate position id | |
| position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1) | |
| position_ids = position_ids[:, -noisy_lens.max():] | |
| # core sample fn | |
| def fn(t, x): | |
| noisy_embed = self.transformer.latent_embed(x) | |
| if t.ndim == 0: | |
| t = t.repeat(batch) | |
| time = t[:, None].repeat(1, noisy_lens.max()) | |
| pred, *_ = self.transformer( | |
| x=noisy_embed, | |
| time=time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=style_prompt, | |
| use_cache=True, | |
| past_key_value = kv_cache | |
| ) | |
| if cfg_strength < 1e-5: | |
| return pred | |
| null_pred, *_ = self.transformer( | |
| x=noisy_embed, | |
| time=time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=torch.zeros_like(style_prompt), | |
| use_cache=True, | |
| past_key_value = cfg_kv_cache | |
| ) | |
| return pred + (pred - null_pred) * cfg_strength | |
| # generate time | |
| noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype) | |
| t_start = 0 | |
| t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype) | |
| # sampling | |
| outputs = odeint(fn, noisy_emb, t_set, method=odeint_method) | |
| sampled = outputs[-1] | |
| # generate next kv cache | |
| cache_embed = self.transformer.latent_embed(sampled) | |
| with kv_cache.cache_context(): | |
| _, _, kv_cache = self.transformer( | |
| x = cache_embed, | |
| time=cache_time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=style_prompt, | |
| use_cache=True, | |
| past_key_value = kv_cache | |
| ) | |
| with cfg_kv_cache.cache_context(): | |
| _, _, cfg_kv_cache = self.transformer( | |
| x = cache_embed, | |
| time=cache_time, | |
| attn_mask=attn_mask, | |
| position_ids=position_ids, | |
| style_prompt=torch.zeros_like(style_prompt), | |
| use_cache=True, | |
| past_key_value = cfg_kv_cache | |
| ) | |
| # push new block | |
| clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1) | |
| pos = -1 | |
| curr_frame = clean_emb_stream[:, pos, :] | |
| eos = torch.ones_like(curr_frame) | |
| last_kl = torch.nn.functional.mse_loss( | |
| curr_frame, | |
| eos | |
| ) | |
| if last_kl.abs() <= 0.05: | |
| while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]: | |
| pos -= 1 | |
| curr_frame = clean_emb_stream[:, pos, :] | |
| last_kl = torch.nn.functional.mse_loss( | |
| curr_frame, | |
| eos | |
| ) | |
| end_pos = clean_emb_stream.shape[1] + pos | |
| break | |
| else: | |
| end_pos = clean_emb_stream.shape[1] | |
| if end_pos - last_decoder_pos >= chunk_size: | |
| start = max(0, last_decoder_pos - overlap) | |
| overlap_frame = max(0, last_decoder_pos - start) | |
| latent = clean_emb_stream[:, start:end_pos, :] | |
| audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T] | |
| # print(last_decoder_pos, start, end_pos, latent.shape, audio.shape, clean_emb_stream.shape, chunk_size, overlap_frame, last_decoder_pos-overlap, last_decoder_pos-start) | |
| audio = audio[:, :, overlap_frame * 9600:] | |
| print(audio.shape) | |
| yield audio | |
| last_decoder_pos = end_pos | |
| clean_emb_stream = clean_emb_stream[:, :end_pos, :] | |
| start = max(0, last_decoder_pos - overlap) | |
| overlap = max(0, last_decoder_pos - start) | |
| latent = clean_emb_stream[:, start:end_pos, :] | |
| audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T] | |
| audio = audio[:, :, overlap * 9600:] | |
| print("last", audio.shape) | |
| audio = torch.cat([audio, torch.zeros(audio.shape[0], audio.shape[1], 5, device=audio.device, dtype=audio.dtype)], dim=-1) | |
| print(audio.shape) | |
| yield audio | |