Spaces:
kgout
/
Running on Zero

asr / main.py
kgout's picture
Update main.py
b515b62 verified
raw
history blame
16.9 kB
import gradio as gr
import torch
import gc # free up memory
import spaces
import gc
import os
import random
import numpy as np
from scipy.signal.windows import hann
import soundfile as sf
import librosa
from audiosr import build_model, super_resolution
from scipy import signal
import pyloudnorm as pyln
import tempfile
import spaces
class AudioUpscaler:
"""
Upscales audio using the AudioSR model.
"""
def __init__(self, model_name="basic", device="cuda"):
"""
Initializes the AudioUpscaler.
Args:
model_name (str, optional): Name of the AudioSR model to use. Defaults to "basic".
device (str, optional): Device to use for inference. Defaults to "cuda".
"""
self.model_name = model_name
self.device = device
self.sr = 44100
self.audiosr = None # Model will be loaded in setup()
@spaces.GPU(duration=120)
def setup(self):
"""
Loads the AudioSR model.
"""
print("Loading Model...")
self.audiosr = build_model(model_name=self.model_name, device=self.device)
print("Model loaded!")
def _match_array_shapes(self, array_1: np.ndarray, array_2: np.ndarray):
"""
Matches the shapes of two arrays by padding the shorter one with zeros.
Args:
array_1 (np.ndarray): First array.
array_2 (np.ndarray): Second array.
Returns:
np.ndarray: The first array with a matching shape to the second array.
"""
if (len(array_1.shape) == 1) & (len(array_2.shape) == 1):
if array_1.shape[0] > array_2.shape[0]:
array_1 = array_1[: array_2.shape[0]]
elif array_1.shape[0] < array_2.shape[0]:
array_1 = np.pad(
array_1,
((array_2.shape[0] - array_1.shape[0], 0)),
"constant",
constant_values=0,
)
else:
if array_1.shape[1] > array_2.shape[1]:
array_1 = array_1[:, : array_2.shape[1]]
elif array_1.shape[1] < array_2.shape[1]:
padding = array_2.shape[1] - array_1.shape[1]
array_1 = np.pad(
array_1, ((0, 0), (0, padding)), "constant", constant_values=0
)
return array_1
def _lr_filter(
self, audio, cutoff, filter_type, order=12, sr=48000
):
"""
Applies a low-pass or high-pass filter to the audio.
Args:
audio (np.ndarray): Audio data.
cutoff (int): Cutoff frequency.
filter_type (str): Filter type ("lowpass" or "highpass").
order (int, optional): Filter order. Defaults to 12.
sr (int, optional): Sample rate. Defaults to 48000.
Returns:
np.ndarray: Filtered audio data.
"""
audio = audio.T
nyquist = 0.5 * sr
normal_cutoff = cutoff / nyquist
b, a = signal.butter(
order // 2, normal_cutoff, btype=filter_type, analog=False
)
sos = signal.tf2sos(b, a)
filtered_audio = signal.sosfiltfilt(sos, audio)
return filtered_audio.T
def _process_audio(
self,
input_file,
chunk_size=5.12,
overlap=0.16,
seed=None,
guidance_scale=3.5,
ddim_steps=50,
multiband_ensemble=True,
input_cutoff=8000,
):
"""
Processes the audio in chunks and performs upsampling.
Args:
input_file (str): Path to the input audio file.
chunk_size (float, optional): Chunk size in seconds. Defaults to 5.12.
overlap (float, optional): Overlap between chunks in seconds. Defaults to 0.1.
seed (int, optional): Random seed. Defaults to None.
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
ddim_steps (int, optional): Number of inference steps. Defaults to 50.
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
Returns:
np.ndarray: Upsampled audio data.
"""
chunk_size = random.randint(a=0, b=10)*0.08
audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False)
audio = audio.T
sr = input_cutoff * 2
is_stereo = len(audio.shape) == 2
if is_stereo:
audio_ch1, audio_ch2 = audio[:, 0], audio[:, 1]
else:
audio_ch1 = audio
chunk_samples = int(chunk_size * sr)
overlap_samples = int(overlap * chunk_samples)
output_chunk_samples = int(chunk_size * self.sr)
output_overlap_samples = int(overlap * output_chunk_samples)
enable_overlap = True if overlap > 0 else False
def process_chunks(audio):
chunks = []
original_lengths = []
start = 0
while start < len(audio):
print(f"{start} / {len(audio)}")
end = min(start + chunk_samples, len(audio))
chunk = audio[start:end]
if len(chunk) < chunk_samples:
original_lengths.append(len(chunk))
pad = np.zeros(chunk_samples - len(chunk))
chunk = np.concatenate([chunk, pad])
else:
original_lengths.append(chunk_samples)
chunks.append(chunk)
start += (
chunk_samples - overlap_samples
if enable_overlap
else chunk_samples
)
return chunks, original_lengths
chunks_ch1, original_lengths_ch1 = process_chunks(audio_ch1)
if is_stereo:
chunks_ch2, original_lengths_ch2 = process_chunks(audio_ch2)
sample_rate_ratio = self.sr / sr
total_length = (
len(chunks_ch1) * output_chunk_samples
- (len(chunks_ch1) - 1)
* (output_overlap_samples if enable_overlap else 0)
)
reconstructed_ch1 = np.zeros((1, total_length))
meter_before = pyln.Meter(sr)
meter_after = pyln.Meter(self.sr)
for i, chunk in enumerate(chunks_ch1):
print(f"{i} / {len(chunks_ch1)}")
loudness_before = meter_before.integrated_loudness(chunk)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
sf.write(temp_wav.name, chunk, sr)
out_chunk = super_resolution(
self.audiosr,
temp_wav.name,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=12.8,
)
out_chunk = out_chunk[0]
num_samples_to_keep = int(
original_lengths_ch1[i] * sample_rate_ratio
)
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
loudness_after = meter_after.integrated_loudness(out_chunk)
out_chunk = pyln.normalize.loudness(
out_chunk, loudness_after, loudness_before
)
if enable_overlap:
actual_overlap_samples = min(
output_overlap_samples, num_samples_to_keep
)
fade_out = np.linspace(1.0, 0.0, actual_overlap_samples)
fade_in = np.linspace(0.0, 1.0, actual_overlap_samples)
if i == 0:
out_chunk[-actual_overlap_samples:] *= fade_out
elif i < len(chunks_ch1) - 1:
out_chunk[:actual_overlap_samples] *= fade_in
out_chunk[-actual_overlap_samples:] *= fade_out
else:
out_chunk[:actual_overlap_samples] *= fade_in
start = i * (
output_chunk_samples - output_overlap_samples
if enable_overlap
else output_chunk_samples
)
end = start + out_chunk.shape[0]
reconstructed_ch1[0, start:end] += out_chunk.flatten()
if is_stereo:
reconstructed_ch2 = np.zeros((1, total_length))
for i, chunk in enumerate(chunks_ch2):
print(f"{i} / {len(chunks_ch2)}")
loudness_before = meter_before.integrated_loudness(chunk)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
sf.write(temp_wav.name, chunk, sr)
out_chunk = super_resolution(
self.audiosr,
temp_wav.name,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=12.8,
)
out_chunk = out_chunk[0]
num_samples_to_keep = int(
original_lengths_ch2[i] * sample_rate_ratio
)
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
loudness_after = meter_after.integrated_loudness(out_chunk)
out_chunk = pyln.normalize.loudness(
out_chunk, loudness_after, loudness_before
)
if enable_overlap:
actual_overlap_samples = min(
output_overlap_samples, num_samples_to_keep
)
fade_out = np.linspace(1.0, 0.0, actual_overlap_samples)
fade_in = np.linspace(0.0, 1.0, actual_overlap_samples)
if i == 0:
out_chunk[-actual_overlap_samples:] *= fade_out
elif i < len(chunks_ch1) - 1:
out_chunk[:actual_overlap_samples] *= fade_in
out_chunk[-actual_overlap_samples:] *= fade_out
else:
out_chunk[:actual_overlap_samples] *= fade_in
start = i * (
output_chunk_samples - output_overlap_samples
if enable_overlap
else output_chunk_samples
)
end = start + out_chunk.shape[0]
reconstructed_ch2[0, start:end] += out_chunk.flatten()
reconstructed_audio = np.stack(
[reconstructed_ch1, reconstructed_ch2], axis=-1
)
else:
reconstructed_audio = reconstructed_ch1
if multiband_ensemble:
low, _ = librosa.load(input_file, sr=48000, mono=False)
output = self._match_array_shapes(
reconstructed_audio[0].T, low
)
crossover_freq = input_cutoff - 1000
low = self._lr_filter(
low.T, crossover_freq, "lowpass", order=10
)
high = self._lr_filter(
output.T, crossover_freq, "highpass", order=10
)
high = self._lr_filter(
high, 23000, "lowpass", order=2
)
output = low + high
else:
output = reconstructed_audio[0]
return output
def predict(
self,
input_file,
output_folder,
ddim_steps=50,
guidance_scale=3.5,
overlap=0.04,
chunk_size=10.24,
seed=None,
multiband_ensemble=True,
input_cutoff=8000,
):
"""
Upscales the audio and saves the result.
Args:
input_file (str): Path to the input audio file.
output_folder (str): Path to the output folder.
ddim_steps (int, optional): Number of inference steps. Defaults to 50.
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
overlap (float, optional): Overlap between chunks. Defaults to 0.04.
chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24.
seed (int, optional): Random seed. Defaults to None.
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
"""
if seed == 0:
seed = random.randint(0, 2**32 - 1)
chunk_size = random.randint(0, 10) * 0.08
os.makedirs(output_folder, exist_ok=True)
waveform = self._process_audio(
input_file,
chunk_size=chunk_size,
overlap=overlap,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
multiband_ensemble=multiband_ensemble,
input_cutoff=input_cutoff,
)
filename = os.path.splitext(os.path.basename(input_file))[0]
output_file = f"{output_folder}/SR_{filename}.wav"
sf.write(output_file, data=waveform, samplerate=48000, subtype="PCM_16")
print(f"File created: {output_file}")
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return waveform
# return output_file
def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
audiosr = build_model(model_name=model_name)
gc.collect()
# set random seed when seed input value is 0
if seed == 0:
import random
seed = random.randint(1, 2**32-1)
waveform = super_resolution(
audiosr,
audio_file,
seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps
)
return (48000, waveform)
def upscale_audio(
input_file,
output_folder,
ddim_steps=20,
guidance_scale=3.5,
overlap=0.04,
chunk_size=10.24,
seed=0,
multiband_ensemble=True,
input_cutoff=14000,
):
"""
Upscales the audio using the AudioSR model.
Args:
input_file (str): Path to the input audio file.
output_folder (str): Path to the output folder.
ddim_steps (int, optional): Number of inference steps. Defaults to 20.
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
overlap (float, optional): Overlap between chunks. Defaults to 0.04.
chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24.
seed (int, optional): Random seed. Defaults to 0.
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
Returns:
tuple: Upscaled audio data and sample rate.
"""
torch.cuda.empty_cache()
chunk_size = random.randint(a=0, b=10)*0.08
gc.collect()
upscaler = AudioUpscaler()
upscaler.setup()
waveform = upscaler.predict(
input_file,
output_folder,
ddim_steps=ddim_steps,
guidance_scale=guidance_scale,
overlap=overlap,
chunk_size=chunk_size,
seed=seed,
multiband_ensemble=multiband_ensemble,
input_cutoff=input_cutoff,
)
torch.cuda.empty_cache()
gc.collect()
return (48000,waveform)
os.getcwd()
gr.Textbox
iface = gr.Interface(
fn=upscale_audio,
inputs=[
gr.Audio(type="filepath", label="Input Audio"),
gr.Textbox(".",label="Out-dir"),
gr.Slider(10, 500, value=20, step=1, label="DDIM Steps", info="Number of inference steps (quality/speed)"),
gr.Slider(1.0, 20.0, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (creativity/fidelity)"),
gr.Slider(0.0, 0.5, value=0.04, step=0.01, label="Overlap (s)", info="Overlap between chunks (smooth transitions)"),
gr.Slider(5.12, 20.48, value=5.12, step=0.64, label="Chunk Size (s)", info="Chunk size (memory/artifact balance)"),
gr.Number(value=0, precision=0, label="Seed", info="Random seed (0 for random)"),
gr.Checkbox(label="Multiband Ensemble", value=False, info="Enhance high frequencies"),
gr.Slider(500, 15000, value=9000, step=500, label="Crossover Frequency (Hz)", info="For multiband processing", visible=True)
],
outputs=gr.Audio(type="numpy", label="Output Audio"),
title="AudioSR",
description="Audio Super Resolution with AudioSR"
)
iface.launch(share=False)