Spaces:
Running
on
L40S
Running
on
L40S
| import os, sys | |
| from transformers import AutoModel | |
| import torch | |
| from torch import nn | |
| import torchaudio.transforms as T | |
| import einops | |
| import numpy as np | |
| import joblib | |
| from torch.nn.utils.rnn import pad_sequence | |
| def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| lengths: | |
| A 1-D tensor containing sentence lengths. | |
| Returns: | |
| Return a 2-D bool tensor, where masked positions | |
| are filled with `True` and non-masked positions are | |
| filled with `False`. | |
| >>> lengths = torch.tensor([1, 3, 2, 5]) | |
| >>> make_pad_mask(lengths) | |
| tensor([[False, True, True, True, True], | |
| [False, False, False, True, True], | |
| [False, False, True, True, True], | |
| [False, False, False, False, False]]) | |
| """ | |
| assert lengths.ndim == 1, lengths.ndim | |
| max_len = lengths.max() | |
| n = lengths.size(0) | |
| expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) | |
| return expaned_lengths >= lengths.unsqueeze(1) | |
| class KmeansQuantizer(nn.Module): | |
| def __init__(self, centroids) -> None: | |
| super().__init__() | |
| if type(centroids) == np.ndarray: | |
| centroids = torch.from_numpy(centroids) | |
| # self.clusters = nn.Embedding(n_cluster, feature_dim) | |
| self.clusters = nn.Parameter(centroids) | |
| def from_pretrained(cls, km_path): | |
| km_model = joblib.load(km_path) | |
| centroids = km_model.cluster_centers_ | |
| return cls(centroids) | |
| def n_cluster(self) -> int: | |
| return self.clusters.shape[0] | |
| def feature_dim(self) -> int: | |
| return self.clusters.shape[1] | |
| def forward(self, inp: torch.Tensor): | |
| if inp.ndim == 3 and inp.shape[-1] == self.feature_dim: | |
| return self.feat2indice(inp) | |
| elif inp.ndim < 3: | |
| return self.indice2feat(inp) | |
| else: | |
| raise NotImplementedError | |
| def feat2indice(self, feat): | |
| ''' | |
| feat: B,T,D | |
| ''' | |
| batched_cluster_centers = einops.repeat(self.clusters, 'c d -> b c d', b = feat.shape[0]) | |
| dists = torch.cdist(feat, batched_cluster_centers, p = 2) | |
| indices = dists.argmin(dim = -1) | |
| return indices | |
| def indice2feat(self, indice): | |
| ''' | |
| indice: B, T | |
| ''' | |
| return nn.functional.embedding(input=indice, weight=self.clusters) | |
| class MERTwithKmeans(nn.Module): | |
| def __init__(self, pretrained_model_name_or_path, kmeans_path=None, sampling_rate=44100, output_layer=-1, mean_pool=1) -> None: | |
| super().__init__() | |
| # assert pretrained_model_name_or_path in ["MERT-v1-95M", "MERT-v1-330M"] | |
| assert pretrained_model_name_or_path == "MERT-v1-330M" | |
| # loading our model weights | |
| # self.model = AutoModel.from_pretrained(f"vocal2accmpl/model/.cache/models--m-a-p--MERT-v1-95M/snapshots/8881df140a93e2ea270235b5d7be802245e3d2c6", trust_remote_code=True) | |
| self.model = AutoModel.from_pretrained('pretrained/models--m-a-p--MERT-v1-330M/snapshots/af10da70c94a0c849de9cc94b83e12769c4db499', trust_remote_code=True) | |
| # print(self.model) | |
| if kmeans_path is not None: | |
| centroids = joblib.load(kmeans_path).cluster_centers_ | |
| self.kmeans = KmeansQuantizer(centroids) | |
| else: | |
| self.kmeans = None | |
| # loading the corresponding preprocessor config | |
| # self.processor = Wav2Vec2FeatureExtractor.from_pretrained(f"m-a-p/{pretrained_model_name_or_path}",trust_remote_code=True) | |
| # make sure the sample_rate aligned | |
| self.sampling_rate = sampling_rate | |
| self.resampler = T.Resample(sampling_rate, 24000) if sampling_rate != 24000 else lambda x: x | |
| self.do_normalization = (pretrained_model_name_or_path == "MERT-v1-95M") | |
| self.output_layer = output_layer | |
| self.mean_pool = mean_pool | |
| assert self.mean_pool % 2 == 1 | |
| def forward(self, input_audio, seq_len=None, apply_kmeans=True): | |
| ''' | |
| input_audio: B,T | |
| seq_len: B, | |
| ''' | |
| device = input_audio.device | |
| return_seq_len = True | |
| if seq_len is None: | |
| return_seq_len = False | |
| seq_len = [input_audio.shape[1] for _ in input_audio] | |
| input_audio = [self.resampler(x[:l]) for x, l in zip(input_audio, seq_len)] | |
| new_seq_len = torch.tensor([len(i) for i in input_audio], device=device) | |
| # std_inp = self.processor([x.numpy() for x in input_audio], sampling_rate=24000, return_tensors="pt", padding=True) | |
| if self.do_normalization: | |
| input_audio = self.zero_mean_unit_var_norm(input_audio, new_seq_len) | |
| padded_input = pad_sequence(input_audio, batch_first=True) | |
| attention_mask = ~ make_pad_mask(new_seq_len) | |
| # assert (~(attention_mask == std_inp['attention_mask'])).sum() == 0, f"{attention_mask}, {std_inp['attention_mask']}" | |
| # assert (~(padded_input.to(dtype=std_inp['input_values'].dtype) == std_inp['input_values'])).sum() == 0, f"{torch.sum((padded_input - std_inp['input_values']))}" | |
| outputs = self.model(input_values=padded_input, attention_mask=attention_mask, output_hidden_states=True) | |
| output = outputs['hidden_states'][self.output_layer] | |
| output_len = torch.round(new_seq_len.float() / 24000 * 75).long() | |
| # print(output_len) | |
| # output_len = output_len.masked_fill(output_len > output.shape[1], output.shape[1]).long() | |
| output = nn.functional.interpolate(output.transpose(-1,-2), output_len.max().item()).transpose(-1,-2) | |
| if self.mean_pool > 1: | |
| output_len = output_len // 3 | |
| output = nn.functional.avg_pool1d(output.transpose(-1, -2), kernel_size=self.mean_pool, stride=self.mean_pool) | |
| output = output.transpose(-1,-2) | |
| # print(output.shape, output_len) | |
| # print(output.shape, output_len) | |
| if apply_kmeans: | |
| output = self.kmeans.feat2indice(output) | |
| if return_seq_len: | |
| return output, output_len | |
| return output | |
| # from transformers.models.wav2vec2.feature_extraction_wav2vec2 | |
| # rewrite it by pytorch | |
| def zero_mean_unit_var_norm( | |
| input_values: torch.Tensor, seq_len: torch.Tensor = None, padding_value: float = 0.0 | |
| ) -> torch.Tensor: | |
| """ | |
| Every array in the list is normalized to have zero mean and unit variance | |
| """ | |
| if seq_len is not None: | |
| normed_input_values = [] | |
| for vector, length in zip(input_values, seq_len): | |
| normed_slice = (vector - vector[:length].mean()) / torch.sqrt(vector[:length].var() + 1e-7) | |
| if length < normed_slice.shape[0]: | |
| normed_slice[length:] = padding_value | |
| normed_input_values.append(normed_slice) | |
| # normed_input_values = torch.stack(normed_input_values, dim=0) | |
| else: | |
| normed_input_values = (input_values - input_values.mean(dim=-1, keepdim=True)) / torch.sqrt(input_values.var(dim=-1, keepdim=True) + 1e-7) | |
| return normed_input_values | |