Spaces:
Running
Running
| import logging | |
| import os | |
| from pathlib import Path | |
| import hydra | |
| import torch | |
| import torch.distributed as distributed | |
| import torchaudio | |
| from hydra.core.hydra_config import HydraConfig | |
| from omegaconf import DictConfig | |
| from tqdm import tqdm | |
| from mmaudio.data.data_setup import setup_eval_dataset | |
| from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate | |
| from mmaudio.model.flow_matching import FlowMatching | |
| from mmaudio.model.networks import MMAudio, get_my_mmaudio | |
| from mmaudio.model.utils.features_utils import FeaturesUtils | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| log = logging.getLogger() | |
| def main(cfg: DictConfig): | |
| device = 'cuda' | |
| torch.cuda.set_device(local_rank) | |
| if cfg.model not in all_model_cfg: | |
| raise ValueError(f'Unknown model variant: {cfg.model}') | |
| model: ModelConfig = all_model_cfg[cfg.model] | |
| model.download_if_needed() | |
| seq_cfg = model.seq_cfg | |
| run_dir = Path(HydraConfig.get().run.dir) | |
| if cfg.output_name is None: | |
| output_dir = run_dir / cfg.dataset | |
| else: | |
| output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}' | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # load a pretrained model | |
| seq_cfg.duration = cfg.duration_s | |
| net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval() | |
| net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) | |
| log.info(f'Loaded weights from {model.model_path}') | |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
| log.info(f'Latent seq len: {seq_cfg.latent_seq_len}') | |
| log.info(f'Clip seq len: {seq_cfg.clip_seq_len}') | |
| log.info(f'Sync seq len: {seq_cfg.sync_seq_len}') | |
| # misc setup | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(cfg.seed) | |
| fm = FlowMatching(cfg.sampling.min_sigma, | |
| inference_mode=cfg.sampling.method, | |
| num_steps=cfg.sampling.num_steps) | |
| feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, | |
| synchformer_ckpt=model.synchformer_ckpt, | |
| enable_conditions=True, | |
| mode=model.mode, | |
| bigvgan_vocoder_ckpt=model.bigvgan_16k_path, | |
| need_vae_encoder=False) | |
| feature_utils = feature_utils.to(device).eval() | |
| if cfg.compile: | |
| net.preprocess_conditions = torch.compile(net.preprocess_conditions) | |
| net.predict_flow = torch.compile(net.predict_flow) | |
| feature_utils.compile() | |
| dataset, loader = setup_eval_dataset(cfg.dataset, cfg) | |
| with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device): | |
| for batch in tqdm(loader): | |
| audios = generate(batch.get('clip_video', None), | |
| batch.get('sync_video', None), | |
| batch.get('caption', None), | |
| feature_utils=feature_utils, | |
| net=net, | |
| fm=fm, | |
| rng=rng, | |
| cfg_strength=cfg.cfg_strength, | |
| clip_batch_size_multiplier=64, | |
| sync_batch_size_multiplier=64) | |
| audios = audios.float().cpu() | |
| names = batch['name'] | |
| for audio, name in zip(audios, names): | |
| torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate) | |
| def distributed_setup(): | |
| distributed.init_process_group(backend="nccl") | |
| local_rank = distributed.get_rank() | |
| world_size = distributed.get_world_size() | |
| log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
| return local_rank, world_size | |
| if __name__ == '__main__': | |
| distributed_setup() | |
| main() | |
| # clean-up | |
| distributed.destroy_process_group() | |