Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.nn as nn | |
| from modules.diffusion import BiDilConv | |
| from modules.encoder.position_encoder import PositionEncoder | |
| class DiffusionWrapper(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.diff_cfg = cfg.model.diffusion | |
| self.diff_encoder = PositionEncoder( | |
| d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding, | |
| d_out=self.diff_cfg.bidilconv.base_channel, | |
| d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer, | |
| activation_function=self.diff_cfg.step_encoder.activation, | |
| n_layer=self.diff_cfg.step_encoder.num_layer, | |
| max_period=self.diff_cfg.step_encoder.max_period, | |
| ) | |
| # FIXME: Only support BiDilConv now for debug | |
| if self.diff_cfg.model_type.lower() == "bidilconv": | |
| self.neural_network = BiDilConv( | |
| input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unsupported diffusion model type: {self.diff_cfg.model_type}" | |
| ) | |
| def forward(self, x, t, c): | |
| """ | |
| Args: | |
| x: [N, T, mel_band] of mel spectrogram | |
| t: Diffusion time step with shape of [N] | |
| c: [N, T, conditioner_size] of conditioner | |
| Returns: | |
| [N, T, mel_band] of mel spectrogram | |
| """ | |
| assert ( | |
| x.size()[:-1] == c.size()[:-1] | |
| ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size()) | |
| assert x.size(0) == t.size( | |
| 0 | |
| ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size()) | |
| assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim()) | |
| N, T, mel_band = x.size() | |
| x = x.transpose(1, 2).contiguous() # [N, mel_band, T] | |
| c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T] | |
| t = self.diff_encoder(t).contiguous() # [N, base_channel] | |
| h = self.neural_network(x, t, c) | |
| h = h.transpose(1, 2).contiguous() # [N, T, mel_band] | |
| assert h.size() == ( | |
| N, | |
| T, | |
| mel_band, | |
| ), "h mismatch with input x, got \n h: {} \n x: {}".format( | |
| h.size(), (N, T, mel_band) | |
| ) | |
| return h | |