import gradio as gr import torch import gc import os import random import numpy as np from scipy.signal.windows import hann import soundfile as sf import librosa from audiosr import build_model, super_resolution from scipy import signal import pyloudnorm as pyln import tempfile import spaces class AudioUpscaler: """ Upscales audio using the AudioSR model. """ def __init__(self, model_name="basic", device="cuda"): """ Initializes the AudioUpscaler. Args: model_name (str, optional): Name of the AudioSR model to use. Defaults to "basic". device (str, optional): Device to use for inference. Defaults to "cuda". """ self.model_name = model_name self.device = device self.sr = 44100 self.audiosr = None # Model will be loaded in setup() @spaces.GPU(duration=120) def setup(self): """ Loads the AudioSR model. """ print("Loading Model...") self.audiosr = build_model(model_name=self.model_name, device=self.device) print("Model loaded!") def _match_array_shapes(self, array_1: np.ndarray, array_2: np.ndarray): """ Matches the shapes of two arrays by padding the shorter one with zeros. Args: array_1 (np.ndarray): First array. array_2 (np.ndarray): Second array. Returns: np.ndarray: The first array with a matching shape to the second array. """ if (len(array_1.shape) == 1) & (len(array_2.shape) == 1): if array_1.shape[0] > array_2.shape[0]: array_1 = array_1[: array_2.shape[0]] elif array_1.shape[0] < array_2.shape[0]: array_1 = np.pad( array_1, ((array_2.shape[0] - array_1.shape[0], 0)), "constant", constant_values=0, ) else: if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:, : array_2.shape[1]] elif array_1.shape[1] < array_2.shape[1]: padding = array_2.shape[1] - array_1.shape[1] array_1 = np.pad( array_1, ((0, 0), (0, padding)), "constant", constant_values=0 ) return array_1 def _lr_filter( self, audio, cutoff, filter_type, order=12, sr=48000 ): """ Applies a low-pass or high-pass filter to the audio. Args: audio (np.ndarray): Audio data. cutoff (int): Cutoff frequency. filter_type (str): Filter type ("lowpass" or "highpass"). order (int, optional): Filter order. Defaults to 12. sr (int, optional): Sample rate. Defaults to 48000. Returns: np.ndarray: Filtered audio data. """ audio = audio.T nyquist = 0.5 * sr normal_cutoff = cutoff / nyquist b, a = signal.butter( order // 2, normal_cutoff, btype=filter_type, analog=False ) sos = signal.tf2sos(b, a) filtered_audio = signal.sosfiltfilt(sos, audio) return filtered_audio.T def _process_audio( self, input_file, chunk_size=5.12, overlap=0.16, seed=None, guidance_scale=3.5, ddim_steps=50, multiband_ensemble=True, input_cutoff=8000, ): """ Processes the audio in chunks and performs upsampling. Args: input_file (str): Path to the input audio file. chunk_size (float, optional): Chunk size in seconds. Defaults to 5.12. overlap (float, optional): Overlap between chunks in seconds. Defaults to 0.1. seed (int, optional): Random seed. Defaults to None. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. ddim_steps (int, optional): Number of inference steps. Defaults to 50. multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. Returns: np.ndarray: Upsampled audio data. """ chunk_size = random.randint(a=0, b=10)*0.08 audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False) audio = audio.T sr = input_cutoff * 2 is_stereo = len(audio.shape) == 2 if is_stereo: audio_ch1, audio_ch2 = audio[:, 0], audio[:, 1] else: audio_ch1 = audio chunk_samples = int(chunk_size * sr) overlap_samples = int(overlap * chunk_samples) output_chunk_samples = int(chunk_size * self.sr) output_overlap_samples = int(overlap * output_chunk_samples) enable_overlap = True if overlap > 0 else False def process_chunks(audio): chunks = [] original_lengths = [] start = 0 while start < len(audio): print(f"{start} / {len(audio)}") end = min(start + chunk_samples, len(audio)) chunk = audio[start:end] if len(chunk) < chunk_samples: original_lengths.append(len(chunk)) pad = np.zeros(chunk_samples - len(chunk)) chunk = np.concatenate([chunk, pad]) else: original_lengths.append(chunk_samples) chunks.append(chunk) start += ( chunk_samples - overlap_samples if enable_overlap else chunk_samples ) return chunks, original_lengths chunks_ch1, original_lengths_ch1 = process_chunks(audio_ch1) if is_stereo: chunks_ch2, original_lengths_ch2 = process_chunks(audio_ch2) sample_rate_ratio = self.sr / sr total_length = ( len(chunks_ch1) * output_chunk_samples - (len(chunks_ch1) - 1) * (output_overlap_samples if enable_overlap else 0) ) reconstructed_ch1 = np.zeros((1, total_length)) meter_before = pyln.Meter(sr) meter_after = pyln.Meter(self.sr) for i, chunk in enumerate(chunks_ch1): print(f"{i} / {len(chunks_ch1)}") loudness_before = meter_before.integrated_loudness(chunk) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: sf.write(temp_wav.name, chunk, sr) out_chunk = super_resolution( self.audiosr, temp_wav.name, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=12.8, ) out_chunk = out_chunk[0] num_samples_to_keep = int( original_lengths_ch1[i] * sample_rate_ratio ) out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() loudness_after = meter_after.integrated_loudness(out_chunk) out_chunk = pyln.normalize.loudness( out_chunk, loudness_after, loudness_before ) if enable_overlap: actual_overlap_samples = min( output_overlap_samples, num_samples_to_keep ) fade_out = np.linspace(1.0, 0.0, actual_overlap_samples) fade_in = np.linspace(0.0, 1.0, actual_overlap_samples) if i == 0: out_chunk[-actual_overlap_samples:] *= fade_out elif i < len(chunks_ch1) - 1: out_chunk[:actual_overlap_samples] *= fade_in out_chunk[-actual_overlap_samples:] *= fade_out else: out_chunk[:actual_overlap_samples] *= fade_in start = i * ( output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples ) end = start + out_chunk.shape[0] reconstructed_ch1[0, start:end] += out_chunk.flatten() if is_stereo: reconstructed_ch2 = np.zeros((1, total_length)) for i, chunk in enumerate(chunks_ch2): print(f"{i} / {len(chunks_ch2)}") loudness_before = meter_before.integrated_loudness(chunk) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: sf.write(temp_wav.name, chunk, sr) out_chunk = super_resolution( self.audiosr, temp_wav.name, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=12.8, ) out_chunk = out_chunk[0] num_samples_to_keep = int( original_lengths_ch2[i] * sample_rate_ratio ) out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() loudness_after = meter_after.integrated_loudness(out_chunk) out_chunk = pyln.normalize.loudness( out_chunk, loudness_after, loudness_before ) if enable_overlap: actual_overlap_samples = min( output_overlap_samples, num_samples_to_keep ) fade_out = np.linspace(1.0, 0.0, actual_overlap_samples) fade_in = np.linspace(0.0, 1.0, actual_overlap_samples) if i == 0: out_chunk[-actual_overlap_samples:] *= fade_out elif i < len(chunks_ch1) - 1: out_chunk[:actual_overlap_samples] *= fade_in out_chunk[-actual_overlap_samples:] *= fade_out else: out_chunk[:actual_overlap_samples] *= fade_in start = i * ( output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples ) end = start + out_chunk.shape[0] reconstructed_ch2[0, start:end] += out_chunk.flatten() reconstructed_audio = np.stack( [reconstructed_ch1, reconstructed_ch2], axis=-1 ) else: reconstructed_audio = reconstructed_ch1 if multiband_ensemble: low, _ = librosa.load(input_file, sr=48000, mono=False) output = self._match_array_shapes( reconstructed_audio[0].T, low ) crossover_freq = input_cutoff - 1000 low = self._lr_filter( low.T, crossover_freq, "lowpass", order=10 ) high = self._lr_filter( output.T, crossover_freq, "highpass", order=10 ) high = self._lr_filter( high, 23000, "lowpass", order=2 ) output = low + high else: output = reconstructed_audio[0] return output def predict( self, input_file, output_folder, ddim_steps=50, guidance_scale=3.5, overlap=0.04, chunk_size=10.24, seed=None, multiband_ensemble=True, input_cutoff=8000, ): """ Upscales the audio and saves the result. Args: input_file (str): Path to the input audio file. output_folder (str): Path to the output folder. ddim_steps (int, optional): Number of inference steps. Defaults to 50. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. overlap (float, optional): Overlap between chunks. Defaults to 0.04. chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24. seed (int, optional): Random seed. Defaults to None. multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. """ if seed == 0: seed = random.randint(0, 2**32 - 1) chunk_size = random.randint(0, 10) * 0.08 os.makedirs(output_folder, exist_ok=True) waveform = self._process_audio( input_file, chunk_size=chunk_size, overlap=overlap, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, multiband_ensemble=multiband_ensemble, input_cutoff=input_cutoff, ) filename = os.path.splitext(os.path.basename(input_file))[0] output_file = f"{output_folder}/SR_{filename}.wav" sf.write(output_file, data=waveform, samplerate=48000, subtype="PCM_16") print(f"File created: {output_file}") # Cleanup gc.collect() torch.cuda.empty_cache() return waveform # return output_file @spaces.GPU def inference(audio_file, model_name, guidance_scale, ddim_steps, seed): audiosr = build_model(model_name=model_name) gc.collect() # set random seed when seed input value is 0 if seed == 0: import random seed = random.randint(1, 2**32-1) waveform = super_resolution( audiosr, audio_file, seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps ) return (48000, waveform) def upscale_audio( input_file, output_folder, ddim_steps=20, guidance_scale=3.5, overlap=0.04, chunk_size=10.24, seed=0, multiband_ensemble=True, input_cutoff=14000, ): """ Upscales the audio using the AudioSR model. Args: input_file (str): Path to the input audio file. output_folder (str): Path to the output folder. ddim_steps (int, optional): Number of inference steps. Defaults to 20. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. overlap (float, optional): Overlap between chunks. Defaults to 0.04. chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24. seed (int, optional): Random seed. Defaults to 0. multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. Returns: tuple: Upscaled audio data and sample rate. """ torch.cuda.empty_cache() chunk_size = random.randint(a=0, b=10)*0.08 gc.collect() upscaler = AudioUpscaler() upscaler.setup() waveform = upscaler.predict( input_file, output_folder, ddim_steps=ddim_steps, guidance_scale=guidance_scale, overlap=overlap, chunk_size=chunk_size, seed=seed, multiband_ensemble=multiband_ensemble, input_cutoff=input_cutoff, ) torch.cuda.empty_cache() gc.collect() return (48000,waveform) os.getcwd() gr.Textbox iface = gr.Interface( fn=upscale_audio, inputs=[ gr.Audio(type="filepath", label="Input Audio"), gr.Textbox(".",label="Out-dir"), gr.Slider(10, 500, value=20, step=1, label="DDIM Steps", info="Number of inference steps (quality/speed)"), gr.Slider(1.0, 20.0, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (creativity/fidelity)"), gr.Slider(0.0, 0.5, value=0.04, step=0.01, label="Overlap (s)", info="Overlap between chunks (smooth transitions)"), gr.Slider(5.12, 20.48, value=5.12, step=0.64, label="Chunk Size (s)", info="Chunk size (memory/artifact balance)"), gr.Number(value=0, precision=0, label="Seed", info="Random seed (0 for random)"), gr.Checkbox(label="Multiband Ensemble", value=False, info="Enhance high frequencies"), gr.Slider(500, 15000, value=9000, step=500, label="Crossover Frequency (Hz)", info="For multiband processing", visible=True) ], outputs=gr.Audio(type="numpy", label="Output Audio"), title="AudioSR", description="Audio Super Resolution with AudioSR" ) iface.launch(share=False)