Spaces:
Running
Running
| import logging | |
| import os | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| import pandas as pd | |
| import tensordict as td | |
| import torch | |
| import torch.distributed as distributed | |
| import torch.nn.functional as F | |
| from open_clip import create_model_from_pretrained | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from mmaudio.data.data_setup import error_avoidance_collate | |
| from mmaudio.data.extraction.wav_dataset import WavTextClipsDataset | |
| from mmaudio.ext.autoencoder import AutoEncoderModule | |
| from mmaudio.ext.mel_converter import get_mel_converter | |
| log = logging.getLogger() | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| # 16k | |
| SAMPLE_RATE = 16_000 | |
| NUM_SAMPLES = 16_000 * 8 | |
| tod_vae_ckpt = './ext_weights/v1-16.pth' | |
| bigvgan_vocoder_ckpt = './ext_weights/best_netG.pt' | |
| mode = '16k' | |
| # 44k | |
| """ | |
| NOTE: 352800 (8*44100) is not divisible by (STFT hop size * VAE downsampling ratio) which is 1024. | |
| 353280 is the next integer divisible by 1024. | |
| """ | |
| # SAMPLE_RATE = 44100 | |
| # NUM_SAMPLES = 353280 | |
| # tod_vae_ckpt = './ext_weights/v1-44.pth' | |
| # bigvgan_vocoder_ckpt = None | |
| # mode = '44k' | |
| def distributed_setup(): | |
| distributed.init_process_group(backend="nccl") | |
| local_rank = distributed.get_rank() | |
| world_size = distributed.get_world_size() | |
| print(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
| return local_rank, world_size | |
| def main(): | |
| distributed_setup() | |
| parser = ArgumentParser() | |
| parser.add_argument('--data_dir', type=Path, default='./training/example_audios/') | |
| parser.add_argument('--captions_tsv', type=Path, default='./training/example_audio.tsv') | |
| parser.add_argument('--clips_tsv', type=Path, default='./training/example_output/clips.tsv') | |
| parser.add_argument('--latent_dir', | |
| type=Path, | |
| default='./training/example_output/audio-latents') | |
| parser.add_argument('--output_dir', | |
| type=Path, | |
| default='./training/example_output/memmap/audio-example') | |
| parser.add_argument('--batch_size', type=int, default=32) | |
| parser.add_argument('--num_workers', type=int, default=8) | |
| args = parser.parse_args() | |
| data_dir = args.data_dir | |
| captions_tsv = args.captions_tsv | |
| clips_tsv = args.clips_tsv | |
| latent_dir = args.latent_dir | |
| output_dir = args.output_dir | |
| batch_size = args.batch_size | |
| num_workers = args.num_workers | |
| clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', | |
| return_transform=False).eval().cuda() | |
| # a hack to make it output last hidden states | |
| def new_encode_text(self, text, normalize: bool = False): | |
| cast_dtype = self.transformer.get_cast_dtype() | |
| x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] | |
| x = x + self.positional_embedding.to(cast_dtype) | |
| x = self.transformer(x, attn_mask=self.attn_mask) | |
| x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] | |
| return F.normalize(x, dim=-1) if normalize else x | |
| clip_model.encode_text = new_encode_text.__get__(clip_model) | |
| tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt, | |
| vocoder_ckpt_path=bigvgan_vocoder_ckpt, | |
| mode=mode).eval().cuda() | |
| mel_converter = get_mel_converter(mode).eval().cuda() | |
| dataset = WavTextClipsDataset(data_dir, | |
| captions_tsv=captions_tsv, | |
| clips_tsv=clips_tsv, | |
| sample_rate=SAMPLE_RATE, | |
| num_samples=NUM_SAMPLES, | |
| normalize_audio=True, | |
| reject_silent=True) | |
| dataloader = DataLoader(dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=error_avoidance_collate) | |
| latent_dir.mkdir(exist_ok=True, parents=True) | |
| # extraction | |
| for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): | |
| ids = batch['id'] | |
| waveforms = batch['waveform'].cuda() | |
| tokens = batch['tokens'].cuda() | |
| text_features = clip_model.encode_text(tokens, normalize=True) | |
| mel = mel_converter(waveforms) | |
| dist = tod.encode(mel) | |
| a_mean = dist.mean.detach().cpu().transpose(1, 2) | |
| a_std = dist.std.detach().cpu().transpose(1, 2) | |
| text_features = text_features.detach().cpu() | |
| ids = [id for id in ids] | |
| captions = [caption for caption in batch['caption']] | |
| data = { | |
| 'id': ids, | |
| 'caption': captions, | |
| 'mean': a_mean, | |
| 'std': a_std, | |
| 'text_features': text_features, | |
| } | |
| torch.save(data, latent_dir / f'r{local_rank}_{i:05d}.pth') | |
| distributed.barrier() | |
| # combine the results | |
| if local_rank == 0: | |
| print('Extraction done. Combining the results.') | |
| list_of_ids_and_labels = [] | |
| output_data = { | |
| 'mean': [], | |
| 'std': [], | |
| 'text_features': [], | |
| } | |
| latents = sorted(os.listdir(latent_dir)) | |
| latents = [l for l in latents if l.endswith('.pth')] | |
| for t in tqdm(latents): | |
| data = torch.load(latent_dir / t, weights_only=True) | |
| bs = len(data['id']) | |
| for bi in range(bs): | |
| this_id = data['id'][bi] | |
| this_caption = data['caption'][bi] | |
| list_of_ids_and_labels.append({'id': this_id, 'caption': this_caption}) | |
| output_data['mean'].append(data['mean'][bi]) | |
| output_data['std'].append(data['std'][bi]) | |
| output_data['text_features'].append(data['text_features'][bi]) | |
| output_df = pd.DataFrame(list_of_ids_and_labels) | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| output_name = output_dir.stem | |
| output_df.to_csv(output_dir.parent / f'{output_name}.tsv', sep='\t', index=False) | |
| print(f'Output: {len(output_df)}') | |
| output_data = {k: torch.stack(v) for k, v in output_data.items()} | |
| td.TensorDict(output_data).memmap_(output_dir) | |
| if __name__ == '__main__': | |
| main() | |
| distributed.destroy_process_group() | |