Spaces:
Running
Running
| import logging | |
| import os | |
| from argparse import ArgumentParser | |
| from datetime import timedelta | |
| from pathlib import Path | |
| import pandas as pd | |
| import tensordict as td | |
| import torch | |
| import torch.distributed as distributed | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from tqdm import tqdm | |
| from mmaudio.data.data_setup import error_avoidance_collate | |
| from mmaudio.data.extraction.vgg_sound import VGGSound | |
| from mmaudio.model.utils.features_utils import FeaturesUtils | |
| from mmaudio.utils.dist_utils import local_rank, world_size | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # for the 16kHz model | |
| # SAMPLING_RATE = 16000 | |
| # DURATION_SEC = 8.0 | |
| # NUM_SAMPLES = 128000 | |
| # vae_path = './ext_weights/v1-16.pth' | |
| # bigvgan_path = './ext_weights/best_netG.pt' | |
| # mode = '16k' | |
| # for the 44.1kHz model | |
| """ | |
| NOTE: 352800 (8*44100) is not divisible by (STFT hop size * VAE downsampling ratio) which is 1024. | |
| 353280 is the next integer divisible by 1024. | |
| """ | |
| SAMPLING_RATE = 44100 | |
| DURATION_SEC = 8.0 | |
| NUM_SAMPLES = 353280 | |
| #DURATION_SEC = 4.02 | |
| #NUM_SAMPLES = 177152 | |
| #DURATION_SEC = 2.02 | |
| #NUM_SAMPLES = 89088 | |
| vae_path = './ext_weights/v1-44.pth' | |
| bigvgan_path = None | |
| mode = '44k' | |
| synchformer_ckpt = './ext_weights/synchformer_state_dict.pth' | |
| # per-GPU | |
| BATCH_SIZE = 24 | |
| NUM_WORKERS = 8 | |
| log = logging.getLogger() | |
| log.setLevel(logging.INFO) | |
| # uncomment the train/test/val sets to extract latents for them | |
| data_cfg = { | |
| #'example': { | |
| # 'root': './training/example_videos', | |
| # 'subset_name': './training/example_video.tsv', | |
| # 'normalize_audio': True, | |
| #}, | |
| 'train': { | |
| 'root': '/ailab-train/speech/zhanghaomin/animation_dataset_v2a/0205audio/', | |
| 'subset_name': '/ailab-train/speech/zhanghaomin/animation_dataset_v2a/train.scp', | |
| 'normalize_audio': True, | |
| }, | |
| 'test': { | |
| 'root': '/ailab-train/speech/zhanghaomin/animation_dataset_v2a/0205audio/', | |
| 'subset_name': '/ailab-train/speech/zhanghaomin/animation_dataset_v2a/test.scp', | |
| 'normalize_audio': False, | |
| }, | |
| # 'val': { | |
| # 'root': '../data/video', | |
| # 'subset_name': './sets/vgg3-val.tsv', | |
| # 'normalize_audio': False, | |
| # }, | |
| } | |
| ####export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 | |
| ####export LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libffi.so.7" | |
| ####torchrun --standalone --nproc_per_node=8 training/extract_video_training_latents.py | |
| def distributed_setup(): | |
| distributed.init_process_group(backend="nccl", timeout=timedelta(hours=1)) | |
| log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
| print(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
| return local_rank, world_size | |
| def setup_dataset(split: str): | |
| dataset = VGGSound( | |
| data_cfg[split]['root'], | |
| tsv_path=data_cfg[split]['subset_name'], | |
| sample_rate=SAMPLING_RATE, | |
| duration_sec=DURATION_SEC, | |
| audio_samples=NUM_SAMPLES, | |
| normalize_audio=data_cfg[split]['normalize_audio'], | |
| ) | |
| sampler = DistributedSampler(dataset, rank=local_rank, shuffle=False) | |
| loader = DataLoader(dataset, | |
| batch_size=BATCH_SIZE, | |
| num_workers=NUM_WORKERS, | |
| sampler=sampler, | |
| drop_last=False, | |
| collate_fn=error_avoidance_collate) | |
| return dataset, loader | |
| def extract(): | |
| # initial setup | |
| distributed_setup() | |
| parser = ArgumentParser() | |
| parser.add_argument('--latent_dir', | |
| type=Path, | |
| default='./training/example_output/video-latents') | |
| parser.add_argument('--output_dir', type=Path, default='./training/example_output/memmap') | |
| args = parser.parse_args() | |
| latent_dir = args.latent_dir | |
| output_dir = args.output_dir | |
| # cuda setup | |
| torch.cuda.set_device(local_rank) | |
| feature_extractor = FeaturesUtils(tod_vae_ckpt=vae_path, | |
| enable_conditions=True, | |
| bigvgan_vocoder_ckpt=bigvgan_path, | |
| synchformer_ckpt=synchformer_ckpt, | |
| mode=mode).eval().cuda() | |
| for split in data_cfg.keys(): | |
| print(f'Extracting latents for the {split} split') | |
| this_latent_dir = latent_dir / split | |
| this_latent_dir.mkdir(parents=True, exist_ok=True) | |
| # setup datasets | |
| dataset, loader = setup_dataset(split) | |
| log.info(f'Number of samples: {len(dataset)}') | |
| log.info(f'Number of batches: {len(loader)}') | |
| for curr_iter, data in enumerate(tqdm(loader)): | |
| output = { | |
| 'id': data['id'], | |
| 'caption': data['caption'], | |
| } | |
| audio = data['audio'].cuda() | |
| dist = feature_extractor.encode_audio(audio) | |
| output['mean'] = dist.mean.detach().cpu().transpose(1, 2) | |
| output['std'] = dist.std.detach().cpu().transpose(1, 2) | |
| clip_video = data['clip_video'].cuda() | |
| clip_features = feature_extractor.encode_video_with_clip(clip_video) | |
| output['clip_features'] = clip_features.detach().cpu() | |
| sync_video = data['sync_video'].cuda() | |
| sync_features = feature_extractor.encode_video_with_sync(sync_video) | |
| output['sync_features'] = sync_features.detach().cpu() | |
| caption = data['caption'] | |
| text_features = feature_extractor.encode_text(caption) | |
| output['text_features'] = text_features.detach().cpu() | |
| torch.save(output, this_latent_dir / f'r{local_rank}_{curr_iter}.pth') | |
| distributed.barrier() | |
| # combine the results | |
| if local_rank == 0: | |
| print('Extraction done. Combining the results.') | |
| used_id = set() | |
| list_of_ids_and_labels = [] | |
| output_data = { | |
| 'mean': [], | |
| 'std': [], | |
| 'clip_features': [], | |
| 'sync_features': [], | |
| 'text_features': [], | |
| } | |
| for t in tqdm(sorted(os.listdir(this_latent_dir))): | |
| data = torch.load(this_latent_dir / t, weights_only=True) | |
| bs = len(data['id']) | |
| for bi in range(bs): | |
| this_id = data['id'][bi] | |
| this_caption = data['caption'][bi] | |
| if this_id in used_id: | |
| print('Duplicate id:', this_id) | |
| continue | |
| list_of_ids_and_labels.append({'id': this_id, 'label': this_caption}) | |
| used_id.add(this_id) | |
| output_data['mean'].append(data['mean'][bi]) | |
| output_data['std'].append(data['std'][bi]) | |
| output_data['clip_features'].append(data['clip_features'][bi]) | |
| output_data['sync_features'].append(data['sync_features'][bi]) | |
| output_data['text_features'].append(data['text_features'][bi]) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| output_df = pd.DataFrame(list_of_ids_and_labels) | |
| output_df.to_csv(output_dir / f'vgg-{split}.tsv', sep='\t', index=False) | |
| print(f'Output: {len(output_df)}') | |
| output_data = {k: torch.stack(v) for k, v in output_data.items()} | |
| td.TensorDict(output_data).memmap_(output_dir / f'vgg-{split}') | |
| if __name__ == '__main__': | |
| extract() | |
| distributed.destroy_process_group() | |