Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from models.svc.base import SVCTrainer | |
| from modules.encoder.condition_encoder import ConditionEncoder | |
| from models.svc.transformer.transformer import Transformer | |
| from models.svc.transformer.conformer import Conformer | |
| from utils.ssim import SSIM | |
| class TransformerTrainer(SVCTrainer): | |
| def __init__(self, args, cfg): | |
| SVCTrainer.__init__(self, args, cfg) | |
| self.ssim_loss = SSIM() | |
| def _build_model(self): | |
| self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min | |
| self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max | |
| self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) | |
| if self.cfg.model.transformer.type == "transformer": | |
| self.acoustic_mapper = Transformer(self.cfg.model.transformer) | |
| elif self.cfg.model.transformer.type == "conformer": | |
| self.acoustic_mapper = Conformer(self.cfg.model.transformer) | |
| else: | |
| raise NotImplementedError | |
| model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) | |
| return model | |
| def _forward_step(self, batch): | |
| total_loss = 0 | |
| device = self.accelerator.device | |
| mel = batch["mel"] | |
| mask = batch["mask"] | |
| condition = self.condition_encoder(batch) | |
| mel_pred = self.acoustic_mapper(condition, mask) | |
| l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum( | |
| batch["mask"] | |
| ) | |
| self._check_nan(l1_loss, mel_pred, mel) | |
| total_loss += l1_loss | |
| ssim_loss = self.ssim_loss(mel_pred, mel) | |
| ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"]) | |
| self._check_nan(ssim_loss, mel_pred, mel) | |
| total_loss += ssim_loss | |
| return total_loss | |