| from typing import List | |
| from poetry_diacritizer.models.seq2seq import Seq2Seq, Decoder as Seq2SeqDecoder | |
| from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet | |
| from torch import nn | |
| class Tacotron(Seq2Seq): | |
| pass | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| inp_vocab_size: int, | |
| embedding_dim: int = 512, | |
| use_prenet: bool = True, | |
| prenet_sizes: List[int] = [256, 128], | |
| cbhg_gru_units: int = 128, | |
| cbhg_filters: int = 16, | |
| cbhg_projections: List[int] = [128, 128], | |
| padding_idx: int = 0, | |
| ): | |
| super().__init__() | |
| self.use_prenet = use_prenet | |
| self.embedding = nn.Embedding( | |
| inp_vocab_size, embedding_dim, padding_idx=padding_idx | |
| ) | |
| if use_prenet: | |
| self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes) | |
| self.cbhg = CBHG( | |
| prenet_sizes[-1] if use_prenet else embedding_dim, | |
| cbhg_gru_units, | |
| K=cbhg_filters, | |
| projections=cbhg_projections, | |
| ) | |
| def forward(self, inputs, input_lengths=None): | |
| outputs = self.embedding(inputs) | |
| if self.use_prenet: | |
| outputs = self.prenet(outputs) | |
| return self.cbhg(outputs, input_lengths) | |
| class Decoder(Seq2SeqDecoder): | |
| pass | |