Spaces:
Running
Running
| import logging | |
| import math | |
| import random | |
| from datetime import timedelta | |
| from pathlib import Path | |
| import hydra | |
| import numpy as np | |
| import torch | |
| import torch.distributed as distributed | |
| from hydra import compose | |
| from hydra.core.hydra_config import HydraConfig | |
| from omegaconf import DictConfig, open_dict | |
| from torch.distributed.elastic.multiprocessing.errors import record | |
| from mmaudio.data.data_setup import setup_training_datasets, setup_val_datasets | |
| from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K | |
| from mmaudio.runner import Runner | |
| from mmaudio.sample import sample | |
| from mmaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size | |
| from mmaudio.utils.logger import TensorboardLogger | |
| from mmaudio.utils.synthesize_ema import synthesize_ema | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| log = logging.getLogger() | |
| ####OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=8 train.py exp_id=exp_1 model=large_44k | |
| def distributed_setup(): | |
| distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2)) | |
| log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
| return local_rank, world_size | |
| def train(cfg: DictConfig): | |
| # initial setup | |
| torch.cuda.set_device(local_rank) | |
| torch.backends.cudnn.benchmark = cfg.cudnn_benchmark | |
| distributed_setup() | |
| num_gpus = world_size | |
| run_dir = HydraConfig.get().run.dir | |
| # compose early such that it does not rely on future hard disk reading | |
| eval_cfg = compose('eval_config', overrides=[f'exp_id={cfg.exp_id}']) | |
| # patch data dim | |
| if cfg.model.endswith('16k'): | |
| seq_cfg = CONFIG_16K | |
| elif cfg.model.endswith('44k'): | |
| seq_cfg = CONFIG_44K | |
| else: | |
| raise ValueError(f'Unknown model: {cfg.model}') | |
| with open_dict(cfg): | |
| cfg.data_dim.latent_seq_len = seq_cfg.latent_seq_len | |
| cfg.data_dim.clip_seq_len = seq_cfg.clip_seq_len | |
| cfg.data_dim.sync_seq_len = seq_cfg.sync_seq_len | |
| # wrap python logger with a tensorboard logger | |
| log = TensorboardLogger(cfg.exp_id, | |
| run_dir, | |
| logging.getLogger(), | |
| is_rank0=(local_rank == 0), | |
| enable_email=cfg.enable_email and not cfg.debug) | |
| info_if_rank_zero(log, f'All configuration: {cfg}') | |
| info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') | |
| # number of dataloader workers | |
| info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') | |
| # Set seeds to ensure the same initialization | |
| torch.manual_seed(cfg.seed) | |
| np.random.seed(cfg.seed) | |
| random.seed(cfg.seed) | |
| # setting up configurations | |
| info_if_rank_zero(log, f'Training configuration: {cfg}') | |
| cfg.batch_size //= num_gpus | |
| info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') | |
| # determine time to change max skip | |
| total_iterations = cfg['num_iterations'] | |
| # setup datasets | |
| dataset, sampler, loader = setup_training_datasets(cfg) | |
| info_if_rank_zero(log, f'Number of training samples: {len(dataset)}') | |
| info_if_rank_zero(log, f'Number of training batches: {len(loader)}') | |
| val_dataset, val_loader, eval_loader = setup_val_datasets(cfg) | |
| info_if_rank_zero(log, f'Number of val samples: {len(val_dataset)}') | |
| val_cfg = cfg.data.ExtractedVGG_val | |
| # compute and set mean and std | |
| latent_mean, latent_std = dataset.compute_latent_stats() | |
| # construct the trainer | |
| trainer = Runner(cfg, | |
| log=log, | |
| run_path=run_dir, | |
| for_training=True, | |
| latent_mean=latent_mean, | |
| latent_std=latent_std).enter_train() | |
| eval_rng_clone = trainer.rng.graphsafe_get_state() | |
| # load previous checkpoint if needed | |
| if cfg['checkpoint'] is not None: | |
| curr_iter = trainer.load_checkpoint(cfg['checkpoint']) | |
| cfg['checkpoint'] = None | |
| info_if_rank_zero(log, 'Model checkpoint loaded!') | |
| else: | |
| # if run_dir exists, load the latest checkpoint | |
| checkpoint = trainer.get_latest_checkpoint_path() | |
| if checkpoint is not None: | |
| curr_iter = trainer.load_checkpoint(checkpoint) | |
| info_if_rank_zero(log, 'Latest checkpoint loaded!') | |
| else: | |
| # load previous network weights if needed | |
| curr_iter = 0 | |
| if cfg['weights'] is not None: | |
| info_if_rank_zero(log, 'Loading weights from the disk') | |
| trainer.load_weights(cfg['weights']) | |
| cfg['weights'] = None | |
| # determine max epoch | |
| total_epoch = math.ceil(total_iterations / len(loader)) | |
| current_epoch = curr_iter // len(loader) | |
| info_if_rank_zero(log, f'We will approximately use {total_epoch} epochs.') | |
| # training loop | |
| try: | |
| # Need this to select random bases in different workers | |
| np.random.seed(np.random.randint(2**30 - 1) + local_rank * 1000) | |
| while curr_iter < total_iterations: | |
| # Crucial for randomness! | |
| sampler.set_epoch(current_epoch) | |
| current_epoch += 1 | |
| log.debug(f'Current epoch: {current_epoch}') | |
| trainer.enter_train() | |
| trainer.log.data_timer.start() | |
| for data in loader: | |
| trainer.train_pass(data, curr_iter) | |
| if (curr_iter + 1) % cfg.val_interval == 0: | |
| # swap into a eval rng state, i.e., use the same seed for every validation pass | |
| train_rng_snapshot = trainer.rng.graphsafe_get_state() | |
| trainer.rng.graphsafe_set_state(eval_rng_clone) | |
| info_if_rank_zero(log, f'Iteration {curr_iter}: validating') | |
| for data in val_loader: | |
| trainer.validation_pass(data, curr_iter) | |
| distributed.barrier() | |
| trainer.val_integrator.finalize('val', curr_iter, ignore_timer=True) | |
| trainer.rng.graphsafe_set_state(train_rng_snapshot) | |
| if (curr_iter + 1) % cfg.eval_interval == 0 and False: | |
| save_eval = (curr_iter + 1) % cfg.save_eval_interval == 0 | |
| train_rng_snapshot = trainer.rng.graphsafe_get_state() | |
| trainer.rng.graphsafe_set_state(eval_rng_clone) | |
| info_if_rank_zero(log, f'Iteration {curr_iter}: validating') | |
| for data in eval_loader: | |
| audio_path = trainer.inference_pass(data, | |
| curr_iter, | |
| val_cfg, | |
| save_eval=save_eval) | |
| distributed.barrier() | |
| trainer.rng.graphsafe_set_state(train_rng_snapshot) | |
| trainer.eval(audio_path, curr_iter, val_cfg) | |
| curr_iter += 1 | |
| if curr_iter >= total_iterations: | |
| break | |
| except Exception as e: | |
| log.error(f'Error occurred at iteration {curr_iter}!') | |
| log.critical(e.message if hasattr(e, 'message') else str(e)) | |
| raise | |
| finally: | |
| if not cfg.debug: | |
| trainer.save_checkpoint(curr_iter) | |
| trainer.save_weights(curr_iter) | |
| # Inference pass | |
| del trainer | |
| torch.cuda.empty_cache() | |
| # Synthesize EMA | |
| if local_rank == 0: | |
| log.info(f'Synthesizing EMA with sigma={cfg.ema.default_output_sigma}') | |
| ema_sigma = cfg.ema.default_output_sigma | |
| state_dict = synthesize_ema(cfg, ema_sigma, step=None) | |
| save_dir = Path(run_dir) / f'{cfg.exp_id}_ema_final.pth' | |
| torch.save(state_dict, save_dir) | |
| log.info(f'Synthesized EMA saved to {save_dir}!') | |
| distributed.barrier() | |
| log.info(f'Evaluation: {eval_cfg}') | |
| sample(eval_cfg) | |
| # clean-up | |
| log.complete() | |
| distributed.barrier() | |
| distributed.destroy_process_group() | |
| if __name__ == '__main__': | |
| train() | |