Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import numpy as
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import librosa
|
|
@@ -32,8 +32,8 @@ def load_model():
|
|
| 32 |
return session, onnx_model, input_names, output_names
|
| 33 |
|
| 34 |
def inference(re_im, session, onnx_model, input_names, output_names):
|
| 35 |
-
inputs = {input_names[i]:
|
| 36 |
-
dtype=
|
| 37 |
for i, _input in enumerate(onnx_model.graph.input)
|
| 38 |
}
|
| 39 |
|
|
@@ -46,25 +46,25 @@ def inference(re_im, session, onnx_model, input_names, output_names):
|
|
| 46 |
inputs[input_names[3]] = mlp_state
|
| 47 |
output_audio.append(out)
|
| 48 |
|
| 49 |
-
output_audio = torch.tensor(
|
| 50 |
output_audio = output_audio.permute(1, 0, 2).contiguous()
|
| 51 |
output_audio = torch.view_as_complex(output_audio)
|
| 52 |
output_audio = torch.istft(output_audio, window, stride, window=hann)
|
| 53 |
-
return output_audio.
|
| 54 |
|
| 55 |
def visualize(hr, lr, recon, sr):
|
| 56 |
sr = sr
|
| 57 |
window_size = 1024
|
| 58 |
-
window =
|
| 59 |
|
| 60 |
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
|
| 61 |
-
stft_hr = 2 *
|
| 62 |
|
| 63 |
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
|
| 64 |
-
stft_lr = 2 *
|
| 65 |
|
| 66 |
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
|
| 67 |
-
stft_recon = 2 *
|
| 68 |
|
| 69 |
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12))
|
| 70 |
ax1.title.set_text('Оригинальный сигнал')
|
|
@@ -109,7 +109,7 @@ lossy_input = lossy_input.reshape(-1)
|
|
| 109 |
hann = torch.sqrt(torch.hann_window(window))
|
| 110 |
lossy_input_tensor = torch.tensor(lossy_input)
|
| 111 |
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
|
| 112 |
-
1).
|
| 113 |
session, onnx_model, input_names, output_names = load_model()
|
| 114 |
|
| 115 |
if st.button('Сгенерировать потери'):
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import librosa
|
|
|
|
| 32 |
return session, onnx_model, input_names, output_names
|
| 33 |
|
| 34 |
def inference(re_im, session, onnx_model, input_names, output_names):
|
| 35 |
+
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
|
| 36 |
+
dtype=np.float32)
|
| 37 |
for i, _input in enumerate(onnx_model.graph.input)
|
| 38 |
}
|
| 39 |
|
|
|
|
| 46 |
inputs[input_names[3]] = mlp_state
|
| 47 |
output_audio.append(out)
|
| 48 |
|
| 49 |
+
output_audio = torch.tensor(np.concatenate(output_audio, 0))
|
| 50 |
output_audio = output_audio.permute(1, 0, 2).contiguous()
|
| 51 |
output_audio = torch.view_as_complex(output_audio)
|
| 52 |
output_audio = torch.istft(output_audio, window, stride, window=hann)
|
| 53 |
+
return output_audio.np()
|
| 54 |
|
| 55 |
def visualize(hr, lr, recon, sr):
|
| 56 |
sr = sr
|
| 57 |
window_size = 1024
|
| 58 |
+
window = np.hanning(window_size)
|
| 59 |
|
| 60 |
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
|
| 61 |
+
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
|
| 62 |
|
| 63 |
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
|
| 64 |
+
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
|
| 65 |
|
| 66 |
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
|
| 67 |
+
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
|
| 68 |
|
| 69 |
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12))
|
| 70 |
ax1.title.set_text('Оригинальный сигнал')
|
|
|
|
| 109 |
hann = torch.sqrt(torch.hann_window(window))
|
| 110 |
lossy_input_tensor = torch.tensor(lossy_input)
|
| 111 |
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
|
| 112 |
+
1).np().astype(np.float32)
|
| 113 |
session, onnx_model, input_names, output_names = load_model()
|
| 114 |
|
| 115 |
if st.button('Сгенерировать потери'):
|