Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # author: adefossez | |
| import math | |
| import time | |
| import torch as th | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .resample import downsample2, upsample2 | |
| from .utils import capture_init | |
| class BLSTM(nn.Module): | |
| def __init__(self, dim, layers=2, bi=True): | |
| super().__init__() | |
| klass = nn.LSTM | |
| self.lstm = klass( | |
| bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim | |
| ) | |
| self.linear = None | |
| if bi: | |
| self.linear = nn.Linear(2 * dim, dim) | |
| def forward(self, x, hidden=None): | |
| x, hidden = self.lstm(x, hidden) | |
| if self.linear: | |
| x = self.linear(x) | |
| return x, hidden | |
| def rescale_conv(conv, reference): | |
| std = conv.weight.std().detach() | |
| scale = (std / reference)**0.5 | |
| conv.weight.data /= scale | |
| if conv.bias is not None: | |
| conv.bias.data /= scale | |
| def rescale_module(module, reference): | |
| for sub in module.modules(): | |
| if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): | |
| rescale_conv(sub, reference) | |
| class Demucs(nn.Module): | |
| """ | |
| Demucs speech enhancement model. | |
| Args: | |
| - chin (int): number of input channels. | |
| - chout (int): number of output channels. | |
| - hidden (int): number of initial hidden channels. | |
| - depth (int): number of layers. | |
| - kernel_size (int): kernel size for each layer. | |
| - stride (int): stride for each layer. | |
| - causal (bool): if false, uses BiLSTM instead of LSTM. | |
| - resample (int): amount of resampling to apply to the input/output. | |
| Can be one of 1, 2 or 4. | |
| - growth (float): number of channels is multiplied by this for every layer. | |
| - max_hidden (int): maximum number of channels. Can be useful to | |
| control the size/speed of the model. | |
| - normalize (bool): if true, normalize the input. | |
| - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. | |
| - rescale (float): controls custom weight initialization. | |
| See https://arxiv.org/abs/1911.13254. | |
| - floor (float): stability flooring when normalizing. | |
| """ | |
| def __init__(self, | |
| chin=1, | |
| chout=1, | |
| hidden=48, | |
| depth=5, | |
| kernel_size=8, | |
| stride=4, | |
| causal=True, | |
| resample=4, | |
| growth=2, | |
| max_hidden=10_000, | |
| normalize=True, | |
| glu=True, | |
| rescale=0.1, | |
| floor=1e-3): | |
| super().__init__() | |
| if resample not in [1, 2, 4]: | |
| raise ValueError("Resample should be 1, 2 or 4.") | |
| self.chin = chin | |
| self.chout = chout | |
| self.hidden = hidden | |
| self.depth = depth | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.causal = causal | |
| self.floor = floor | |
| self.resample = resample | |
| self.normalize = normalize | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| activation = nn.GLU(1) if glu else nn.ReLU() | |
| ch_scale = 2 if glu else 1 | |
| for index in range(depth): | |
| encode = [] | |
| encode += [ | |
| nn.Conv1d(chin, hidden, kernel_size, stride), | |
| nn.ReLU(), | |
| nn.Conv1d(hidden, hidden * ch_scale, 1), activation, | |
| ] | |
| self.encoder.append(nn.Sequential(*encode)) | |
| decode = [] | |
| decode += [ | |
| nn.Conv1d(hidden, ch_scale * hidden, 1), activation, | |
| nn.ConvTranspose1d(hidden, chout, kernel_size, stride), | |
| ] | |
| if index > 0: | |
| decode.append(nn.ReLU()) | |
| self.decoder.insert(0, nn.Sequential(*decode)) | |
| chout = hidden | |
| chin = hidden | |
| hidden = min(int(growth * hidden), max_hidden) | |
| self.lstm = BLSTM(chin, bi=not causal) | |
| if rescale: | |
| rescale_module(self, reference=rescale) | |
| def valid_length(self, length): | |
| """ | |
| Return the nearest valid length to use with the model so that | |
| there is no time steps left over in a convolutions, e.g. for all | |
| layers, size of the input - kernel_size % stride = 0. | |
| If the mixture has a valid length, the estimated sources | |
| will have exactly the same length. | |
| """ | |
| length = math.ceil(length * self.resample) | |
| for _ in range(self.depth): | |
| length = math.ceil((length - self.kernel_size) / self.stride) + 1 | |
| length = max(length, 1) | |
| for _ in range(self.depth): | |
| length = (length - 1) * self.stride + self.kernel_size | |
| length = int(math.ceil(length / self.resample)) | |
| return int(length) | |
| def total_stride(self): | |
| return self.stride ** self.depth // self.resample | |
| def forward(self, mix): | |
| if mix.dim() == 2: | |
| mix = mix.unsqueeze(1) | |
| if self.normalize: | |
| mono = mix.mean(dim=1, keepdim=True) | |
| std = mono.std(dim=-1, keepdim=True) | |
| mix = mix / (self.floor + std) | |
| else: | |
| std = 1 | |
| length = mix.shape[-1] | |
| x = mix | |
| x = F.pad(x, (0, self.valid_length(length) - length)) | |
| if self.resample == 2: | |
| x = upsample2(x) | |
| elif self.resample == 4: | |
| x = upsample2(x) | |
| x = upsample2(x) | |
| skips = [] | |
| for encode in self.encoder: | |
| x = encode(x) | |
| skips.append(x) | |
| x = x.permute(2, 0, 1) | |
| x, _ = self.lstm(x) | |
| x = x.permute(1, 2, 0) | |
| for decode in self.decoder: | |
| skip = skips.pop(-1) | |
| x = x + skip[..., :x.shape[-1]] | |
| x = decode(x) | |
| if self.resample == 2: | |
| x = downsample2(x) | |
| elif self.resample == 4: | |
| x = downsample2(x) | |
| x = downsample2(x) | |
| x = x[..., :length] | |
| return std * x | |
| def fast_conv(conv, x): | |
| """ | |
| Faster convolution evaluation if either kernel size is 1 | |
| or length of sequence is 1. | |
| """ | |
| batch, chin, length = x.shape | |
| chout, chin, kernel = conv.weight.shape | |
| assert batch == 1 | |
| if kernel == 1: | |
| x = x.view(chin, length) | |
| out = th.addmm(conv.bias.view(-1, 1), | |
| conv.weight.view(chout, chin), x) | |
| elif length == kernel: | |
| x = x.view(chin * kernel, 1) | |
| out = th.addmm(conv.bias.view(-1, 1), | |
| conv.weight.view(chout, chin * kernel), x) | |
| else: | |
| out = conv(x) | |
| return out.view(batch, chout, -1) | |
| class DemucsStreamer: | |
| """ | |
| Streaming implementation for Demucs. It supports being fed with any amount | |
| of audio at a time. You will get back as much audio as possible at that | |
| point. | |
| Args: | |
| - demucs (Demucs): Demucs model. | |
| - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum | |
| noise removal, 1 just returns the input signal. Small values > 0 | |
| allows to limit distortions. | |
| - num_frames (int): number of frames to process at once. Higher values | |
| will increase overall latency but improve the real time factor. | |
| - resample_lookahead (int): extra lookahead used for the resampling. | |
| - resample_buffer (int): size of the buffer of previous inputs/outputs | |
| kept for resampling. | |
| """ | |
| def __init__(self, demucs, | |
| dry=0, | |
| num_frames=1, | |
| resample_lookahead=64, | |
| resample_buffer=256): | |
| device = next(iter(demucs.parameters())).device | |
| self.demucs = demucs | |
| self.lstm_state = None | |
| self.conv_state = None | |
| self.dry = dry | |
| self.resample_lookahead = resample_lookahead | |
| resample_buffer = min(demucs.total_stride, resample_buffer) | |
| self.resample_buffer = resample_buffer | |
| self.frame_length = demucs.valid_length(1) + \ | |
| demucs.total_stride * (num_frames - 1) | |
| self.total_length = self.frame_length + self.resample_lookahead | |
| self.stride = demucs.total_stride * num_frames | |
| self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device) | |
| self.resample_out = th.zeros( | |
| demucs.chin, resample_buffer, device=device | |
| ) | |
| self.frames = 0 | |
| self.total_time = 0 | |
| self.variance = 0 | |
| self.pending = th.zeros(demucs.chin, 0, device=device) | |
| bias = demucs.decoder[0][2].bias | |
| weight = demucs.decoder[0][2].weight | |
| chin, chout, kernel = weight.shape | |
| self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) | |
| self._weight = weight.permute(1, 2, 0).contiguous() | |
| def reset_time_per_frame(self): | |
| self.total_time = 0 | |
| self.frames = 0 | |
| def time_per_frame(self): | |
| return self.total_time / self.frames | |
| def flush(self): | |
| """ | |
| Flush remaining audio by padding it with zero. Call this | |
| when you have no more input and want to get back the last chunk of audio. | |
| """ | |
| pending_length = self.pending.shape[1] | |
| padding = th.zeros( | |
| self.demucs.chin, self.total_length, device=self.pending.device | |
| ) | |
| out = self.feed(padding) | |
| return out[:, :pending_length] | |
| def feed(self, wav): | |
| """ | |
| Apply the model to mix using true real time evaluation. | |
| Normalization is done online as is the resampling. | |
| """ | |
| begin = time.time() | |
| demucs = self.demucs | |
| resample_buffer = self.resample_buffer | |
| stride = self.stride | |
| resample = demucs.resample | |
| if wav.dim() != 2: | |
| raise ValueError("input wav should be two dimensional.") | |
| chin, _ = wav.shape | |
| if chin != demucs.chin: | |
| raise ValueError(f"Expected {demucs.chin} channels, got {chin}") | |
| self.pending = th.cat([self.pending, wav], dim=1) | |
| outs = [] | |
| while self.pending.shape[1] >= self.total_length: | |
| self.frames += 1 | |
| frame = self.pending[:, :self.total_length] | |
| dry_signal = frame[:, :stride] | |
| if demucs.normalize: | |
| mono = frame.mean(0) | |
| variance = (mono**2).mean() | |
| self.variance = variance / self.frames + \ | |
| (1 - 1 / self.frames) * self.variance | |
| frame = frame / (demucs.floor + math.sqrt(self.variance)) | |
| frame = th.cat([self.resample_in, frame], dim=-1) | |
| self.resample_in[:] = frame[:, stride - resample_buffer:stride] | |
| if resample == 4: | |
| frame = upsample2(upsample2(frame)) | |
| elif resample == 2: | |
| frame = upsample2(frame) | |
| # remove pre sampling buffer | |
| frame = frame[:, resample * resample_buffer:] | |
| # remove extra samples after window | |
| frame = frame[:, :resample * self.frame_length] | |
| out, extra = self._separate_frame(frame) | |
| padded_out = th.cat([self.resample_out, out, extra], 1) | |
| self.resample_out[:] = out[:, -resample_buffer:] | |
| if resample == 4: | |
| out = downsample2(downsample2(padded_out)) | |
| elif resample == 2: | |
| out = downsample2(padded_out) | |
| else: | |
| out = padded_out | |
| out = out[:, resample_buffer // resample:] | |
| out = out[:, :stride] | |
| if demucs.normalize: | |
| out *= math.sqrt(self.variance) | |
| out = self.dry * dry_signal + (1 - self.dry) * out | |
| outs.append(out) | |
| self.pending = self.pending[:, stride:] | |
| self.total_time += time.time() - begin | |
| if outs: | |
| out = th.cat(outs, 1) | |
| else: | |
| out = th.zeros(chin, 0, device=wav.device) | |
| return out | |
| def _separate_frame(self, frame): | |
| demucs = self.demucs | |
| skips = [] | |
| next_state = [] | |
| first = self.conv_state is None | |
| stride = self.stride * demucs.resample | |
| x = frame[None] | |
| for idx, encode in enumerate(demucs.encoder): | |
| stride //= demucs.stride | |
| length = x.shape[2] | |
| if idx == demucs.depth - 1: | |
| # This is sligthly faster for the last conv | |
| x = fast_conv(encode[0], x) | |
| x = encode[1](x) | |
| x = fast_conv(encode[2], x) | |
| x = encode[3](x) | |
| else: | |
| if not first: | |
| prev = self.conv_state.pop(0) | |
| prev = prev[..., stride:] | |
| tgt = (length - demucs.kernel_size) // demucs.stride + 1 | |
| missing = tgt - prev.shape[-1] | |
| offset = length - demucs.kernel_size - \ | |
| demucs.stride * (missing - 1) | |
| x = x[..., offset:] | |
| x = encode[1](encode[0](x)) | |
| x = fast_conv(encode[2], x) | |
| x = encode[3](x) | |
| if not first: | |
| x = th.cat([prev, x], -1) | |
| next_state.append(x) | |
| skips.append(x) | |
| x = x.permute(2, 0, 1) | |
| x, self.lstm_state = demucs.lstm(x, self.lstm_state) | |
| x = x.permute(1, 2, 0) | |
| # In the following, x contains only correct samples, i.e. the one | |
| # for which each time position is covered by two window of the upper | |
| # layer. extra contains extra samples to the right, and is used only as | |
| # a better padding for the online resampling. | |
| extra = None | |
| for idx, decode in enumerate(demucs.decoder): | |
| skip = skips.pop(-1) | |
| x += skip[..., :x.shape[-1]] | |
| x = fast_conv(decode[0], x) | |
| x = decode[1](x) | |
| if extra is not None: | |
| skip = skip[..., x.shape[-1]:] | |
| extra += skip[..., :extra.shape[-1]] | |
| extra = decode[2](decode[1](decode[0](extra))) | |
| x = decode[2](x) | |
| next_state.append( | |
| x[..., -demucs.stride:] - decode[2].bias.view(-1, 1) | |
| ) | |
| if extra is None: | |
| extra = x[..., -demucs.stride:] | |
| else: | |
| extra[..., :demucs.stride] += next_state[-1] | |
| x = x[..., :-demucs.stride] | |
| if not first: | |
| prev = self.conv_state.pop(0) | |
| x[..., :demucs.stride] += prev | |
| if idx != demucs.depth - 1: | |
| x = decode[3](x) | |
| extra = decode[3](extra) | |
| self.conv_state = next_state | |
| return x[0], extra[0] | |
| def test(): | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| "denoiser.demucs", | |
| description="Benchmark the streaming Demucs implementation, as well as " | |
| "checking the delta with the offline implementation.") | |
| parser.add_argument("--depth", default=5, type=int) | |
| parser.add_argument("--resample", default=4, type=int) | |
| parser.add_argument("--hidden", default=48, type=int) | |
| parser.add_argument("--sample_rate", default=16000, type=float) | |
| parser.add_argument("--device", default="cpu") | |
| parser.add_argument("-t", "--num_threads", type=int) | |
| parser.add_argument("-f", "--num_frames", type=int, default=1) | |
| args = parser.parse_args() | |
| if args.num_threads: | |
| th.set_num_threads(args.num_threads) | |
| sr = args.sample_rate | |
| sr_ms = sr / 1000 | |
| demucs = Demucs( | |
| depth=args.depth, hidden=args.hidden, resample=args.resample | |
| ).to(args.device) | |
| x = th.randn(1, int(sr * 4)).to(args.device) | |
| out = demucs(x[None])[0] | |
| streamer = DemucsStreamer(demucs, num_frames=args.num_frames) | |
| out_rt = [] | |
| frame_size = streamer.total_length | |
| with th.no_grad(): | |
| while x.shape[1] > 0: | |
| out_rt.append(streamer.feed(x[:, :frame_size])) | |
| x = x[:, frame_size:] | |
| frame_size = streamer.demucs.total_stride | |
| out_rt.append(streamer.flush()) | |
| out_rt = th.cat(out_rt, 1) | |
| model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20 | |
| initial_lag = streamer.total_length / sr_ms | |
| tpf = 1000 * streamer.time_per_frame | |
| print(f"model size: {model_size:.1f}MB, ", end='') | |
| print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}") | |
| print(f"initial lag: {initial_lag:.1f}ms, ", end='') | |
| print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms") | |
| print(f"time per frame: {tpf:.1f}ms, ", end='') | |
| rtf = (1000 * streamer.time_per_frame) / (streamer.stride / sr_ms) | |
| print(f"RTF: {rtf:.2f}") | |
| print(f"Total lag with computation: {initial_lag + tpf:.1f}ms") | |
| if __name__ == "__main__": | |
| test() | |