Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from audiotools import AudioSignal | |
| from audiotools.ml import BaseModel | |
| from encodec import EncodecModel | |
| class Encodec(BaseModel): | |
| def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): | |
| super().__init__() | |
| if sample_rate == 24000: | |
| self.model = EncodecModel.encodec_model_24khz() | |
| else: | |
| self.model = EncodecModel.encodec_model_48khz() | |
| self.model.set_target_bandwidth(bandwidth) | |
| self.sample_rate = 44100 | |
| def forward( | |
| self, | |
| audio_data: torch.Tensor, | |
| sample_rate: int = 44100, | |
| n_quantizers: int = None, | |
| ): | |
| signal = AudioSignal(audio_data, sample_rate) | |
| signal.resample(self.model.sample_rate) | |
| recons = self.model(signal.audio_data) | |
| recons = AudioSignal(recons, self.model.sample_rate) | |
| recons.resample(sample_rate) | |
| return {"audio": recons.audio_data} | |
| if __name__ == "__main__": | |
| import numpy as np | |
| from functools import partial | |
| model = Encodec() | |
| for n, m in model.named_modules(): | |
| o = m.extra_repr() | |
| p = sum([np.prod(p.size()) for p in m.parameters()]) | |
| fn = lambda o, p: o + f" {p/1e6:<.3f}M params." | |
| setattr(m, "extra_repr", partial(fn, o=o, p=p)) | |
| print(model) | |
| print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) | |
| length = 88200 * 2 | |
| x = torch.randn(1, 1, length).to(model.device) | |
| x.requires_grad_(True) | |
| x.retain_grad() | |
| # Make a forward pass | |
| out = model(x)["audio"] | |
| print(x.shape, out.shape) | |