# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) # 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil import logging from contextlib import nullcontext import os import torchaudio import torch import torch.distributed as dist import torchaudio from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join import datetime import sys from datetime import timedelta sys.path.append('/inspire/hdd/project/embodied-multimodality/public/lzjjin/CosyVoice/cosyvoice/utils') from file_utils import get_dataset_name_from_path class Executor: def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None): self.gan = gan self.ref_model = ref_model self.dpo_loss = dpo_loss self.step = 0 self.epoch = 0 self.validate_interval=None self.rank = int(os.environ.get('RANK', 0)) self.device = torch.device('cuda:{}'.format(self.rank)) def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None): ''' Train one epoch ''' lr = optimizer.param_groups[0]['lr'] logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) logging.info('using accumulate grad, new batch size is {} times' ' larger than before'.format(info_dict['accum_grad'])) # A context manager to be used in conjunction with an instance of # torch.nn.parallel.DistributedDataParallel to be able to train # with uneven inputs across participating processes. model.train() if self.ref_model is not None: self.ref_model.eval() model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext with model_context(): for batch_idx, batch_dict in enumerate(train_data_loader): info_dict["tag"] = "TRAIN" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx if cosyvoice_join(group_join, info_dict): break # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: context = model.no_sync # Used for single gpu training and DDP gradient synchronization # processes. else: context = nullcontext with context(): info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss) info_dict = batch_backward(model, scaler, info_dict) info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) log_per_step(writer, info_dict) # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ (batch_idx + 1) % info_dict["accum_grad"] == 0: dist.barrier() self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) model.train() if (batch_idx + 1) % info_dict["accum_grad"] == 0: self.step += 1 dist.barrier() self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join): ''' Train one epoch ''' lr = optimizer.param_groups[0]['lr'] logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) logging.info('using accumulate grad, new batch size is {} times' ' larger than before'.format(info_dict['accum_grad'])) # A context manager to be used in conjunction with an instance of # torch.nn.parallel.DistributedDataParallel to be able to train # with uneven inputs across participating processes. model.train() model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext with model_context(): for batch_idx, batch_dict in enumerate(train_data_loader): info_dict["tag"] = "TRAIN" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx if cosyvoice_join(group_join, info_dict): break # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: context = model.no_sync # Used for single gpu training and DDP gradient synchronization # processes. else: context = nullcontext with context(): batch_dict['turn'] = 'discriminator' info_dict = batch_forward(model, batch_dict, scaler, info_dict) info_dict = batch_backward(model, scaler, info_dict) info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict) optimizer.zero_grad() log_per_step(writer, info_dict) with context(): batch_dict['turn'] = 'generator' info_dict = batch_forward(model, batch_dict, scaler, info_dict) info_dict = batch_backward(model, scaler, info_dict) info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) optimizer_d.zero_grad() log_per_step(writer, info_dict) # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ (batch_idx + 1) % info_dict["accum_grad"] == 0: dist.barrier() self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) model.train() if (batch_idx + 1) % info_dict["accum_grad"] == 0: self.step += 1 dist.barrier() self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) # def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None): # ''' Train one epoch # ''' # lr = optimizer.param_groups[0]['lr'] # logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) # logging.info('using accumulate grad, new batch size is {} times' # ' larger than before'.format(info_dict['accum_grad'])) # # A context manager to be used in conjunction with an instance of # # torch.nn.parallel.DistributedDataParallel to be able to train # # with uneven inputs across participating processes. # model.train() # if self.ref_model is not None: # self.ref_model.eval() # model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext # train_loader_iter = iter(train_data_loader) # info_dict["tag"] = "TRAIN" # info_dict["epoch"] = self.epoch # with model_context(): # batch_idx = -1 # while True: # batch_idx += 1 # data_exhausted_local = False # try: # current_batch_dict = next(train_loader_iter) # except StopIteration: # data_exhausted_local = True # data_exhausted_global_signal = torch.tensor([int(data_exhausted_local)], dtype=torch.int, device=self.device) # dist.all_reduce(data_exhausted_global_signal, op=dist.ReduceOp.MAX, group=group_join) # if data_exhausted_global_signal.item() == 1: # break # batch_dict = current_batch_dict # torch.cuda.empty_cache() # info_dict["step"] = self.step # info_dict["batch_idx"] = batch_idx # # if cosyvoice_join(group_join, info_dict): # # break # if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: # context = model.no_sync # else: # context = nullcontext # with context(): # info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss) # info_dict = batch_backward(model, scaler, info_dict) # info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) # log_per_step(writer, info_dict) # if info_dict.get('save_per_step', 0) > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ # (batch_idx + 1) % info_dict["accum_grad"] == 0: # dist.barrier() # self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) # model.train() # if (batch_idx + 1) % info_dict["accum_grad"] == 0: # self.step += 1 # dist.barrier() # self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join): ''' Train one epoch ''' lr = optimizer.param_groups[0]['lr'] logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) logging.info('using accumulate grad, new batch size is {} times' ' larger than before'.format(info_dict['accum_grad'])) # A context manager to be used in conjunction with an instance of # torch.nn.parallel.DistributedDataParallel to be able to train # with uneven inputs across participating processes. model.train() model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext with model_context(): for batch_idx, batch_dict in enumerate(train_data_loader): import pdb pdb.set_trace() info_dict["tag"] = "TRAIN" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx if cosyvoice_join(group_join, info_dict): break # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: context = model.no_sync # Used for single gpu training and DDP gradient synchronization # processes. else: context = nullcontext with context(): batch_dict['turn'] = 'discriminator' info_dict = batch_forward(model, batch_dict, scaler, info_dict) info_dict = batch_backward(model, scaler, info_dict) info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict) optimizer.zero_grad() log_per_step(writer, info_dict) with context(): batch_dict['turn'] = 'generator' info_dict = batch_forward(model, batch_dict, scaler, info_dict) info_dict = batch_backward(model, scaler, info_dict) info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) optimizer_d.zero_grad() log_per_step(writer, info_dict) # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ (batch_idx + 1) % info_dict["accum_grad"] == 0: dist.barrier() self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) model.train() if (batch_idx + 1) % info_dict["accum_grad"] == 0: self.step += 1 dist.barrier() self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) @torch.inference_mode() def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True): ''' Cross validation on ''' logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) model.eval() total_num_utts, total_loss_dict = 0, {} # avoid division by 0 for batch_idx, batch_dict in enumerate(cv_data_loader): info_dict["tag"] = "CV" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx num_utts = len(batch_dict["utts"]) total_num_utts += num_utts if self.gan is True: batch_dict['turn'] = 'generator' info_dict = batch_forward(model, batch_dict, None, info_dict) for k, v in info_dict['loss_dict'].items(): if k not in total_loss_dict: total_loss_dict[k] = [] total_loss_dict[k].append(v.mean().item() * num_utts) log_per_step(None, info_dict) for k, v in total_loss_dict.items(): total_loss_dict[k] = sum(v) / total_num_utts info_dict['loss_dict'] = total_loss_dict log_per_save(writer, info_dict) model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) save_model(model, model_name, info_dict) @torch.inference_mode() def generate(self, model, generate_data_loader, writer, info_dict, on_batch_end=True,hift=None, output_folder=None): ''' Cross validation on ''' logging.info('Epoch {} Step {} on_batch_end {} Start Generating'.format(self.epoch, self.step + 1, on_batch_end)) model.eval() total_num_utts, total_loss_dict = 0, {} # avoid division by 0 if output_folder==None: output_folder=info_dict['model_dir'] for batch_idx, batch_dict in enumerate(generate_data_loader): print(batch_idx) info_dict["tag"] = "GENERATE" info_dict["step"] = self.step info_dict["epoch"] = self.epoch info_dict["batch_idx"] = batch_idx num_utts = len(batch_dict["utts"]) total_num_utts += num_utts ref_wavs=batch_dict['wavs'] speech_token=batch_dict['speech_token'] speech_token_len=batch_dict['speech_token_len'] speech_feat=batch_dict['speech_feat'] speech_feat_len=batch_dict['speech_feat_len'] speech_embedding=batch_dict['embedding'] path=batch_dict['wavs'][0] name=os.path.splitext(os.path.basename(path))[0] random_ratios = torch.rand(1) * 0.5 prompt_lengths = (speech_token_len.min().float() * random_ratios).int().to(speech_feat.device) prompt_token=speech_token[:,:prompt_lengths] input_token=speech_token[:,prompt_lengths:] input_token_len=speech_token_len-prompt_lengths prompt_token_len=speech_token_len-input_token_len prompt_feat_lengths=prompt_lengths*model.module.token_mel_ratio input_feat_len=speech_feat_len-prompt_feat_lengths prompt_feat_len=speech_feat_len-input_feat_len input_feat=speech_feat[:,prompt_feat_lengths:] prompt_feat=speech_feat[:,:prompt_feat_lengths] device=model.module.encoder_proj.weight.device mel=model.module.inference(input_token.to(device),input_token_len.to(device),prompt_token.to(device),prompt_token_len.to(device),prompt_feat.to(device),prompt_feat_len.to(device),speech_embedding.to(device),streaming=True,finalize=True)[0] mel=torch.cat([prompt_feat.to(mel.device).transpose(-1,-2),mel],dim=-1) gen_speech=hift.inference(mel)[0] ref_audio_source_path = ref_wavs[0] dataset_name = get_dataset_name_from_path(ref_audio_source_path) ref_audio_output_dir = os.path.join(output_folder, 'ref', dataset_name) os.makedirs(ref_audio_output_dir, exist_ok=True) ref_audio_dest_path = os.path.join(ref_audio_output_dir, f'{name}.wav') if not os.path.exists(ref_audio_dest_path): if ref_audio_source_path.lower().endswith('.wav'): shutil.copy(ref_audio_source_path, ref_audio_dest_path) else: waveform, sample_rate = torchaudio.load(ref_audio_source_path) torchaudio.save(ref_audio_dest_path, waveform, sample_rate) generate_audio_output_dir = os.path.join(output_folder, 'generate', dataset_name) os.makedirs(generate_audio_output_dir, exist_ok=True) generated_audio_path = os.path.join(generate_audio_output_dir, f'{name}.wav') torchaudio.save(generated_audio_path, gen_speech.cpu(), 24000) # if self.gan is True: # batch_dict['turn'] = 'generator' # info_dict = batch_forward(model, batch_dict, None, info_dict) # for k, v in info_dict['loss_dict'].items(): # if k not in total_loss_dict: # total_loss_dict[k] = [] # total_loss_dict[k].append(v.mean().item() * num_utts) # log_per_step(None, info_dict) # for k, v in total_loss_dict.items(): # total_loss_dict[k] = sum(v) / total_num_utts # info_dict['loss_dict'] = total_loss_dict # log_per_save(writer, info_dict) # model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) # save_model(model, model_name, info_dict)