Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Distortions banks for the PS and the PM computations. | |
| """ | |
| import librosa | |
| import numpy as np | |
| from numpy.fft import irfft, rfft, rfftfreq | |
| from scipy.signal import butter, filtfilt, lfilter | |
| from config import ENERGY_WIN_MS, EPS, SR | |
| def sig_stats(x): | |
| A_pk = max(np.max(np.abs(x)), EPS) | |
| A_rms = max(np.sqrt(np.mean(x**2)), EPS) | |
| A_95 = max(np.percentile(np.abs(x), 95), EPS) | |
| return A_pk, A_rms, A_95 | |
| def frame_distortions( | |
| frame, | |
| sr, | |
| distortion_keys, | |
| notch_freqs=None, | |
| low_cutoffs=None, | |
| high_cutoffs=None, | |
| frame_start=0, | |
| ): | |
| notch_freqs = [] if notch_freqs is None else notch_freqs | |
| low_cutoffs = [] if low_cutoffs is None else low_cutoffs | |
| high_cutoffs = [] if high_cutoffs is None else high_cutoffs | |
| distortions = {} | |
| A_pk, A_rms, A_95 = sig_stats(frame) | |
| frame_len = len(frame) | |
| X = rfft(frame) | |
| freqs = rfftfreq(frame_len, 1 / sr) | |
| t = np.arange(frame_len) / sr | |
| if ("notch" in distortion_keys) or distortion_keys == "all": | |
| bw = 60.0 | |
| for f0 in notch_freqs: | |
| Y = X.copy() | |
| band = (freqs > f0 - bw) & (freqs < f0 + bw) | |
| Y[band] = 0 | |
| distortions[f"Notch_{int(round(f0))}Hz"] = irfft(Y, n=len(frame)) | |
| if ("comb" in distortion_keys) or distortion_keys == "all": | |
| for d_ms, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]): | |
| D = int(sr * d_ms / 1000) | |
| if D >= frame_len: | |
| continue | |
| out = frame.copy() | |
| out[:-D] += decay * frame[D:] | |
| distortions[f"Comb_{int(d_ms)}ms"] = out | |
| if ("tremolo" in distortion_keys) or distortion_keys == "all": | |
| depth = 1.0 | |
| t_centre = (frame_start + 0.5 * len(frame)) / sr | |
| for r_hz in [1, 2, 4, 6]: | |
| mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r_hz * t_centre)) | |
| distortions[f"Tremolo_{r_hz}Hz"] = frame * mod | |
| if ("noise" in distortion_keys) or distortion_keys == "all": | |
| nyq = sr / 2 | |
| low_norm = 20 / nyq | |
| high_freq = min(20_000, 0.45 * sr) | |
| high_norm = min(high_freq / nyq, 0.99) | |
| b_band, a_band = butter(5, [low_norm, high_norm], btype="band") | |
| def add_noise(sig, snr_db, color="white"): | |
| nl_target = 10 ** (snr_db / 10) | |
| n = np.random.randn(len(sig)) | |
| if color == "pink": | |
| n = np.cumsum(n) | |
| n /= max(np.max(np.abs(n)), 1e-12) | |
| elif color == "brown": | |
| n = np.cumsum(np.cumsum(n)) | |
| n /= max(np.max(np.abs(n)), 1e-12) | |
| n = lfilter(b_band, a_band, n) | |
| rms_sig = np.sqrt(np.mean(sig**2)) | |
| rms_n = np.sqrt(np.mean(n**2)) + 1e-12 | |
| noise_rms = rms_sig / np.sqrt(nl_target) | |
| noise_rms = max(noise_rms, rms_sig / np.sqrt(10 ** (15 / 10))) | |
| n *= noise_rms / rms_n | |
| return sig + n | |
| for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]: | |
| for clr in ["white", "pink", "brown"]: | |
| if (snr in [-15, -10, -5]) and (clr == "white"): | |
| continue | |
| distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise( | |
| frame, snr, clr | |
| ) | |
| if ("harmonic" in distortion_keys) or distortion_keys == "all": | |
| for f_h, rel_amp in zip([100, 500, 1000, 4000], [0.4, 0.6, 0.8, 1.0]): | |
| tone = (rel_amp * A_rms) * np.sin(2 * np.pi * f_h * t) | |
| distortions[f"Harmonic_{f_h}Hz"] = frame + tone | |
| if ("reverb" in distortion_keys) or distortion_keys == "all": | |
| for tail_ms, decay in zip([50, 100, 200, 400], [0.3, 0.5, 0.7, 0.9]): | |
| L = int(sr * tail_ms / 1000) | |
| if L >= frame_len: | |
| continue | |
| irv = np.exp(-np.linspace(0, 6, L)) * decay | |
| reverbed = np.convolve(frame, irv)[:frame_len] | |
| distortions[f"Reverb_{tail_ms}ms"] = reverbed | |
| if ("noisegate" in distortion_keys) or distortion_keys == "all": | |
| for pct in [0.05, 0.10, 0.20, 0.40]: | |
| thr = pct * A_95 | |
| g = frame.copy() | |
| g[np.abs(g) < thr] = 0 | |
| distortions[f"NoiseGate_{int(pct * 100)}pct"] = g | |
| if ("pitch_shift" in distortion_keys) or distortion_keys == "all": | |
| n_fft = min(2048, frame_len // 2) | |
| for shift in [-4, -2, 2, 4]: | |
| y = librosa.effects.pitch_shift(frame, sr=sr, n_steps=shift, n_fft=n_fft) | |
| distortions[f"PitchShift_{shift}st"] = y[:frame_len] | |
| if ("lowpass" in distortion_keys) or distortion_keys == "all": | |
| for fc in low_cutoffs: | |
| if fc >= sr / 2 * 0.99: | |
| continue | |
| b, a = butter(6, fc / (sr / 2), btype="low") | |
| distortions[f"Lowpass_{fc}Hz"] = filtfilt(b, a, frame) | |
| if ("highpass" in distortion_keys) or distortion_keys == "all": | |
| for fc in high_cutoffs: | |
| if fc <= 20: | |
| continue | |
| b, a = butter(6, fc / (sr / 2), btype="high") | |
| distortions[f"Highpass_{fc}Hz"] = filtfilt(b, a, frame) | |
| if ("echo" in distortion_keys) or distortion_keys == "all": | |
| for delay_ms, amp in zip([50, 100, 150], [0.4, 0.5, 0.7]): | |
| D = int(sr * delay_ms / 1000) | |
| if D >= frame_len: | |
| continue | |
| echo = np.pad(frame, (D, 0), "constant")[:-D] * amp | |
| distortions[f"Echo_{delay_ms}ms"] = frame + echo | |
| if ("clipping" in distortion_keys) or distortion_keys == "all": | |
| for frac in [0.70, 0.50, 0.30]: | |
| thr = frac * A_95 | |
| distortions[f"Clipping_{frac:.2f}p95"] = np.clip(frame, -thr, thr) | |
| if ("vibrato" in distortion_keys) or distortion_keys == "all": | |
| n_fft = min(2048, frame_len // 2) | |
| base_depth = 0.03 * (A_rms / A_pk) | |
| for rate_hz, scale in zip([3, 5, 7], [1.0, 1.3, 1.6]): | |
| depth = np.clip(base_depth * scale, 0.01, 0.05) | |
| y = librosa.effects.time_stretch(frame, rate=1 + depth, n_fft=n_fft) | |
| distortions[f"Vibrato_{rate_hz}Hz"] = librosa.util.fix_length( | |
| y, size=frame_len | |
| ) | |
| return distortions | |
| def apply_pm_distortions(ref, distortion_keys, sr=SR): | |
| frame_len = int(ENERGY_WIN_MS * sr / 1000) | |
| n_frames = int(np.ceil(len(ref) / frame_len)) | |
| pad_len = n_frames * frame_len - len(ref) | |
| ref_padded = ( | |
| np.concatenate([ref, np.zeros(pad_len, dtype=ref.dtype)]) if pad_len else ref | |
| ) | |
| X_full = rfft(ref_padded) | |
| freqs_f = rfftfreq(len(ref_padded), 1 / sr) | |
| mag_full = np.abs(X_full) | |
| valid = (freqs_f > 80) & (freqs_f < 0.45 * sr) | |
| cand_indices = np.argsort(mag_full[valid])[-60:] | |
| cand_freqs = freqs_f[valid][cand_indices] | |
| cand_freqs = cand_freqs[np.argsort(mag_full[valid][cand_indices])[::-1]] | |
| selected_notch_freqs = [] | |
| for f0 in cand_freqs: | |
| if all(abs(f0 - f_sel) > 300 for f_sel in selected_notch_freqs): | |
| selected_notch_freqs.append(float(f0)) | |
| if len(selected_notch_freqs) >= 20: | |
| break | |
| mag2 = np.abs(X_full) ** 2 | |
| total_p = mag2.sum() | |
| cum_low = np.cumsum(mag2) | |
| q_low = [0.50, 0.70, 0.85, 0.95] | |
| lowpass_cutoffs = [] | |
| for q in q_low: | |
| idx = np.searchsorted(cum_low, q * total_p) | |
| f_c = float(freqs_f[idx]) | |
| lowpass_cutoffs.append(round(f_c / 100.0) * 100) | |
| cum_high = np.cumsum(mag2[::-1]) | |
| q_high = [0.05, 0.15, 0.30, 0.50] | |
| highpass_cutoffs = [] | |
| for q in q_high: | |
| idx = np.searchsorted(cum_high, q * total_p) | |
| f_c = float(freqs_f[-1 - idx]) | |
| highpass_cutoffs.append(round(f_c / 100.0) * 100) | |
| lowpass_cutoffs = sorted(set(lowpass_cutoffs)) | |
| highpass_cutoffs = sorted(set(highpass_cutoffs)) | |
| out = {} | |
| for f in range(n_frames): | |
| start, end = f * frame_len, (f + 1) * frame_len | |
| frame = ref_padded[start:end] | |
| frame_dists = frame_distortions( | |
| frame, | |
| sr, | |
| distortion_keys, | |
| notch_freqs=selected_notch_freqs, | |
| low_cutoffs=lowpass_cutoffs, | |
| high_cutoffs=highpass_cutoffs, | |
| frame_start=start, | |
| ) | |
| for lbl, sig in frame_dists.items(): | |
| if lbl not in out: | |
| out[lbl] = np.zeros_like(ref_padded) | |
| out[lbl][start:end] = sig | |
| return list(out.values()) | |
| def apply_ps_distortions(ref, distortion_keys, sr=SR): | |
| distortions = {} | |
| X = rfft(ref) | |
| freqs = rfftfreq(len(ref), 1 / sr) | |
| t = np.arange(len(ref)) / sr | |
| if ("notch" in distortion_keys) or distortion_keys == "all": | |
| for c in [500, 1000, 2000, 4000, 8000]: | |
| Y = X.copy() | |
| Y[(freqs > c - 50) & (freqs < c + 50)] = 0 | |
| distortions[f"Notch_{c}Hz"] = irfft(Y, n=len(ref)) | |
| if ("comb" in distortion_keys) or distortion_keys == "all": | |
| for d, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]): | |
| D = int(sr * d / 1000) | |
| if D >= len(ref): | |
| continue | |
| cpy = ref.copy() | |
| if len(ref) > D: | |
| cpy[:-D] += decay * ref[D:] | |
| distortions[f"Comb_{int(d)}ms"] = cpy | |
| if ("tremolo" in distortion_keys) or distortion_keys == "all": | |
| for r, depth in zip([1, 2, 4, 6], [0.3, 0.5, 0.8, 1.0]): | |
| mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r * t)) | |
| distortions[f"Tremolo_{r}Hz"] = ref * mod | |
| if ("noise" in distortion_keys) or distortion_keys == "all": | |
| def add_noise(signal, snr_db, color): | |
| rms = np.sqrt(np.mean(signal**2)) | |
| nl = 10 ** (snr_db / 10) | |
| noise_rms = rms / np.sqrt(nl) | |
| n = np.random.randn(len(signal)) | |
| if color == "pink": | |
| n = np.cumsum(n) | |
| n /= max(np.max(np.abs(n)), 1e-12) | |
| elif color == "brown": | |
| n = np.cumsum(np.cumsum(n)) | |
| n /= max(np.max(np.abs(n)), 1e-12) | |
| return signal + noise_rms * n | |
| for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]: | |
| for clr in ["white", "pink", "brown"]: | |
| if snr in [-15, -10, -5] and clr in ["white"]: | |
| continue | |
| distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise( | |
| ref, snr, clr | |
| ) | |
| if ("harmonic" in distortion_keys) or distortion_keys == "all": | |
| for f_h, amp in zip([100, 500, 1000, 4000], [0.02, 0.03, 0.05, 0.08]): | |
| tone = amp * np.sin(2 * np.pi * f_h * t) | |
| distortions[f"Harmonic_{f_h}Hz"] = ref + tone | |
| if ("reverb" in distortion_keys) or distortion_keys == "all": | |
| for tail_ms, decay in zip([5, 10, 15, 20], [0.3, 0.5, 0.7, 0.9, 1.1]): | |
| L = int(sr * tail_ms / 1000) | |
| if L >= len(ref): | |
| continue | |
| irv = np.exp(-np.linspace(0, 3, L)) * decay | |
| reverbed = np.convolve(ref, irv)[: len(ref)] | |
| distortions[f"Reverb_{tail_ms}ms"] = reverbed | |
| if ("noisegate" in distortion_keys) or distortion_keys == "all": | |
| for thr in [0.005, 0.01, 0.02, 0.04]: | |
| g = ref.copy() | |
| g[np.abs(g) < thr] = 0 | |
| distortions[f"NoiseGate_{thr}"] = g | |
| if ("pitch_shift" in distortion_keys) or distortion_keys == "all": | |
| n_fft = min(2048, len(ref) // 2) | |
| for shift in [-4, -2, 2, 4]: | |
| shifted = librosa.effects.pitch_shift( | |
| y=ref, sr=sr, n_steps=shift, n_fft=n_fft | |
| ) | |
| distortions[f"PitchShift_{shift}st"] = shifted[: len(ref)] | |
| if ("lowpass" in distortion_keys) or distortion_keys == "all": | |
| for freq in [2000, 3000, 4000, 6000]: | |
| if freq >= (sr / 2): | |
| continue | |
| b, a = butter(4, freq / (sr / 2), "low") | |
| distortions[f"Lowpass_{freq}Hz"] = filtfilt(b, a, ref) | |
| if ("highpass" in distortion_keys) or distortion_keys == "all": | |
| for freq in [100, 300, 500, 800]: | |
| if freq >= (sr / 2): | |
| continue | |
| b, a = butter(4, freq / (sr / 2), "high") | |
| distortions[f"Highpass_{freq}Hz"] = filtfilt(b, a, ref) | |
| if ("echo" in distortion_keys) or distortion_keys == "all": | |
| for delay_ms, amp in zip([5, 10, 15, 20], [0.3, 0.5, 0.7]): | |
| delay = int(sr * delay_ms / 1000) | |
| if delay >= len(ref): | |
| continue | |
| echo = np.pad(ref, (delay, 0), "constant")[:-delay] * amp | |
| distortions[f"Echo_{delay_ms}ms"] = ref + echo | |
| if ("clipping" in distortion_keys) or distortion_keys == "all": | |
| for thr in [0.3, 0.5, 0.7]: | |
| distortions[f"Clipping_{thr}"] = np.clip(ref, -thr, thr) | |
| if ("vibrato" in distortion_keys) or distortion_keys == "all": | |
| for rate, depth in zip([3, 5, 7], [0.001, 0.002, 0.003]): | |
| vibrato = np.sin(2 * np.pi * rate * t) * depth | |
| vibrato_signal = librosa.effects.time_stretch( | |
| ref, rate=1 + float(vibrato.mean()), n_fft=min(2048, len(ref) // 2) | |
| ) | |
| distortions[f"Vibrato_{rate}Hz"] = librosa.util.fix_length( | |
| vibrato_signal, size=len(ref) | |
| ) | |
| return list(distortions.values()) | |