Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from .wavenet import WaveNet | |
| class ReferenceEncoder(WaveNet): | |
| def __init__( | |
| self, | |
| input_channels: Optional[int] = None, | |
| output_channels: Optional[int] = None, | |
| residual_channels: int = 512, | |
| residual_layers: int = 20, | |
| dilation_cycle: Optional[int] = 4, | |
| num_heads: int = 8, | |
| latent_len: int = 4, | |
| ): | |
| super().__init__( | |
| input_channels=input_channels, | |
| residual_channels=residual_channels, | |
| residual_layers=residual_layers, | |
| dilation_cycle=dilation_cycle, | |
| ) | |
| self.head_dim = residual_channels // num_heads | |
| self.num_heads = num_heads | |
| self.latent_len = latent_len | |
| self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels)) | |
| self.q = nn.Linear(residual_channels, residual_channels, bias=True) | |
| self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True) | |
| self.q_norm = nn.LayerNorm(self.head_dim) | |
| self.k_norm = nn.LayerNorm(self.head_dim) | |
| self.proj = nn.Linear(residual_channels, residual_channels) | |
| self.proj_drop = nn.Dropout(0.1) | |
| self.norm = nn.LayerNorm(residual_channels) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(residual_channels, residual_channels * 4), | |
| nn.SiLU(), | |
| nn.Linear(residual_channels * 4, residual_channels), | |
| ) | |
| self.output_projection_attn = nn.Linear(residual_channels, output_channels) | |
| torch.nn.init.trunc_normal_(self.latent, std=0.02) | |
| self.apply(self.init_weights) | |
| def init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| torch.nn.init.constant_(m.bias, 0) | |
| def forward(self, x, attn_mask=None): | |
| x = super().forward(x).mT | |
| B, N, C = x.shape | |
| # Calculate mask | |
| if attn_mask is not None: | |
| assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool | |
| attn_mask = attn_mask[:, None, None, :].expand( | |
| B, self.num_heads, self.latent_len, N | |
| ) | |
| q_latent = self.latent.expand(B, -1, -1) | |
| q = ( | |
| self.q(q_latent) | |
| .reshape(B, self.latent_len, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| ) | |
| kv = ( | |
| self.kv(x) | |
| .reshape(B, N, 2, self.num_heads, self.head_dim) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| k, v = kv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) | |
| x = x.transpose(1, 2).reshape(B, self.latent_len, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| x = x + self.mlp(self.norm(x)) | |
| x = self.output_projection_attn(x) | |
| x = x.mean(1) | |
| return x | |
| if __name__ == "__main__": | |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): | |
| model = ReferenceEncoder( | |
| input_channels=128, | |
| output_channels=64, | |
| residual_channels=384, | |
| residual_layers=20, | |
| dilation_cycle=4, | |
| num_heads=8, | |
| ) | |
| x = torch.randn(4, 128, 64) | |
| mask = torch.ones(4, 64, dtype=torch.bool) | |
| y = model(x, mask) | |
| print(y.shape) | |
| loss = F.mse_loss(y, torch.randn(4, 64)) | |
| loss.backward() | |