MAPSS-measures / distortions.py
AIvry's picture
Upload 11 files
b759ccc verified
"""
Distortions banks for the PS and the PM computations.
"""
import librosa
import numpy as np
from numpy.fft import irfft, rfft, rfftfreq
from scipy.signal import butter, filtfilt, lfilter
from config import ENERGY_WIN_MS, EPS, SR
def sig_stats(x):
A_pk = max(np.max(np.abs(x)), EPS)
A_rms = max(np.sqrt(np.mean(x**2)), EPS)
A_95 = max(np.percentile(np.abs(x), 95), EPS)
return A_pk, A_rms, A_95
def frame_distortions(
frame,
sr,
distortion_keys,
notch_freqs=None,
low_cutoffs=None,
high_cutoffs=None,
frame_start=0,
):
notch_freqs = [] if notch_freqs is None else notch_freqs
low_cutoffs = [] if low_cutoffs is None else low_cutoffs
high_cutoffs = [] if high_cutoffs is None else high_cutoffs
distortions = {}
A_pk, A_rms, A_95 = sig_stats(frame)
frame_len = len(frame)
X = rfft(frame)
freqs = rfftfreq(frame_len, 1 / sr)
t = np.arange(frame_len) / sr
if ("notch" in distortion_keys) or distortion_keys == "all":
bw = 60.0
for f0 in notch_freqs:
Y = X.copy()
band = (freqs > f0 - bw) & (freqs < f0 + bw)
Y[band] = 0
distortions[f"Notch_{int(round(f0))}Hz"] = irfft(Y, n=len(frame))
if ("comb" in distortion_keys) or distortion_keys == "all":
for d_ms, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]):
D = int(sr * d_ms / 1000)
if D >= frame_len:
continue
out = frame.copy()
out[:-D] += decay * frame[D:]
distortions[f"Comb_{int(d_ms)}ms"] = out
if ("tremolo" in distortion_keys) or distortion_keys == "all":
depth = 1.0
t_centre = (frame_start + 0.5 * len(frame)) / sr
for r_hz in [1, 2, 4, 6]:
mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r_hz * t_centre))
distortions[f"Tremolo_{r_hz}Hz"] = frame * mod
if ("noise" in distortion_keys) or distortion_keys == "all":
nyq = sr / 2
low_norm = 20 / nyq
high_freq = min(20_000, 0.45 * sr)
high_norm = min(high_freq / nyq, 0.99)
b_band, a_band = butter(5, [low_norm, high_norm], btype="band")
def add_noise(sig, snr_db, color="white"):
nl_target = 10 ** (snr_db / 10)
n = np.random.randn(len(sig))
if color == "pink":
n = np.cumsum(n)
n /= max(np.max(np.abs(n)), 1e-12)
elif color == "brown":
n = np.cumsum(np.cumsum(n))
n /= max(np.max(np.abs(n)), 1e-12)
n = lfilter(b_band, a_band, n)
rms_sig = np.sqrt(np.mean(sig**2))
rms_n = np.sqrt(np.mean(n**2)) + 1e-12
noise_rms = rms_sig / np.sqrt(nl_target)
noise_rms = max(noise_rms, rms_sig / np.sqrt(10 ** (15 / 10)))
n *= noise_rms / rms_n
return sig + n
for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]:
for clr in ["white", "pink", "brown"]:
if (snr in [-15, -10, -5]) and (clr == "white"):
continue
distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise(
frame, snr, clr
)
if ("harmonic" in distortion_keys) or distortion_keys == "all":
for f_h, rel_amp in zip([100, 500, 1000, 4000], [0.4, 0.6, 0.8, 1.0]):
tone = (rel_amp * A_rms) * np.sin(2 * np.pi * f_h * t)
distortions[f"Harmonic_{f_h}Hz"] = frame + tone
if ("reverb" in distortion_keys) or distortion_keys == "all":
for tail_ms, decay in zip([50, 100, 200, 400], [0.3, 0.5, 0.7, 0.9]):
L = int(sr * tail_ms / 1000)
if L >= frame_len:
continue
irv = np.exp(-np.linspace(0, 6, L)) * decay
reverbed = np.convolve(frame, irv)[:frame_len]
distortions[f"Reverb_{tail_ms}ms"] = reverbed
if ("noisegate" in distortion_keys) or distortion_keys == "all":
for pct in [0.05, 0.10, 0.20, 0.40]:
thr = pct * A_95
g = frame.copy()
g[np.abs(g) < thr] = 0
distortions[f"NoiseGate_{int(pct * 100)}pct"] = g
if ("pitch_shift" in distortion_keys) or distortion_keys == "all":
n_fft = min(2048, frame_len // 2)
for shift in [-4, -2, 2, 4]:
y = librosa.effects.pitch_shift(frame, sr=sr, n_steps=shift, n_fft=n_fft)
distortions[f"PitchShift_{shift}st"] = y[:frame_len]
if ("lowpass" in distortion_keys) or distortion_keys == "all":
for fc in low_cutoffs:
if fc >= sr / 2 * 0.99:
continue
b, a = butter(6, fc / (sr / 2), btype="low")
distortions[f"Lowpass_{fc}Hz"] = filtfilt(b, a, frame)
if ("highpass" in distortion_keys) or distortion_keys == "all":
for fc in high_cutoffs:
if fc <= 20:
continue
b, a = butter(6, fc / (sr / 2), btype="high")
distortions[f"Highpass_{fc}Hz"] = filtfilt(b, a, frame)
if ("echo" in distortion_keys) or distortion_keys == "all":
for delay_ms, amp in zip([50, 100, 150], [0.4, 0.5, 0.7]):
D = int(sr * delay_ms / 1000)
if D >= frame_len:
continue
echo = np.pad(frame, (D, 0), "constant")[:-D] * amp
distortions[f"Echo_{delay_ms}ms"] = frame + echo
if ("clipping" in distortion_keys) or distortion_keys == "all":
for frac in [0.70, 0.50, 0.30]:
thr = frac * A_95
distortions[f"Clipping_{frac:.2f}p95"] = np.clip(frame, -thr, thr)
if ("vibrato" in distortion_keys) or distortion_keys == "all":
n_fft = min(2048, frame_len // 2)
base_depth = 0.03 * (A_rms / A_pk)
for rate_hz, scale in zip([3, 5, 7], [1.0, 1.3, 1.6]):
depth = np.clip(base_depth * scale, 0.01, 0.05)
y = librosa.effects.time_stretch(frame, rate=1 + depth, n_fft=n_fft)
distortions[f"Vibrato_{rate_hz}Hz"] = librosa.util.fix_length(
y, size=frame_len
)
return distortions
def apply_pm_distortions(ref, distortion_keys, sr=SR):
frame_len = int(ENERGY_WIN_MS * sr / 1000)
n_frames = int(np.ceil(len(ref) / frame_len))
pad_len = n_frames * frame_len - len(ref)
ref_padded = (
np.concatenate([ref, np.zeros(pad_len, dtype=ref.dtype)]) if pad_len else ref
)
X_full = rfft(ref_padded)
freqs_f = rfftfreq(len(ref_padded), 1 / sr)
mag_full = np.abs(X_full)
valid = (freqs_f > 80) & (freqs_f < 0.45 * sr)
cand_indices = np.argsort(mag_full[valid])[-60:]
cand_freqs = freqs_f[valid][cand_indices]
cand_freqs = cand_freqs[np.argsort(mag_full[valid][cand_indices])[::-1]]
selected_notch_freqs = []
for f0 in cand_freqs:
if all(abs(f0 - f_sel) > 300 for f_sel in selected_notch_freqs):
selected_notch_freqs.append(float(f0))
if len(selected_notch_freqs) >= 20:
break
mag2 = np.abs(X_full) ** 2
total_p = mag2.sum()
cum_low = np.cumsum(mag2)
q_low = [0.50, 0.70, 0.85, 0.95]
lowpass_cutoffs = []
for q in q_low:
idx = np.searchsorted(cum_low, q * total_p)
f_c = float(freqs_f[idx])
lowpass_cutoffs.append(round(f_c / 100.0) * 100)
cum_high = np.cumsum(mag2[::-1])
q_high = [0.05, 0.15, 0.30, 0.50]
highpass_cutoffs = []
for q in q_high:
idx = np.searchsorted(cum_high, q * total_p)
f_c = float(freqs_f[-1 - idx])
highpass_cutoffs.append(round(f_c / 100.0) * 100)
lowpass_cutoffs = sorted(set(lowpass_cutoffs))
highpass_cutoffs = sorted(set(highpass_cutoffs))
out = {}
for f in range(n_frames):
start, end = f * frame_len, (f + 1) * frame_len
frame = ref_padded[start:end]
frame_dists = frame_distortions(
frame,
sr,
distortion_keys,
notch_freqs=selected_notch_freqs,
low_cutoffs=lowpass_cutoffs,
high_cutoffs=highpass_cutoffs,
frame_start=start,
)
for lbl, sig in frame_dists.items():
if lbl not in out:
out[lbl] = np.zeros_like(ref_padded)
out[lbl][start:end] = sig
return list(out.values())
def apply_ps_distortions(ref, distortion_keys, sr=SR):
distortions = {}
X = rfft(ref)
freqs = rfftfreq(len(ref), 1 / sr)
t = np.arange(len(ref)) / sr
if ("notch" in distortion_keys) or distortion_keys == "all":
for c in [500, 1000, 2000, 4000, 8000]:
Y = X.copy()
Y[(freqs > c - 50) & (freqs < c + 50)] = 0
distortions[f"Notch_{c}Hz"] = irfft(Y, n=len(ref))
if ("comb" in distortion_keys) or distortion_keys == "all":
for d, decay in zip([2.5, 5, 7.5, 10, 12.5, 15], [0.4, 0.5, 0.6, 0.7, 0.9]):
D = int(sr * d / 1000)
if D >= len(ref):
continue
cpy = ref.copy()
if len(ref) > D:
cpy[:-D] += decay * ref[D:]
distortions[f"Comb_{int(d)}ms"] = cpy
if ("tremolo" in distortion_keys) or distortion_keys == "all":
for r, depth in zip([1, 2, 4, 6], [0.3, 0.5, 0.8, 1.0]):
mod = (1 - depth) + depth * 0.5 * (1 + np.sin(2 * np.pi * r * t))
distortions[f"Tremolo_{r}Hz"] = ref * mod
if ("noise" in distortion_keys) or distortion_keys == "all":
def add_noise(signal, snr_db, color):
rms = np.sqrt(np.mean(signal**2))
nl = 10 ** (snr_db / 10)
noise_rms = rms / np.sqrt(nl)
n = np.random.randn(len(signal))
if color == "pink":
n = np.cumsum(n)
n /= max(np.max(np.abs(n)), 1e-12)
elif color == "brown":
n = np.cumsum(np.cumsum(n))
n /= max(np.max(np.abs(n)), 1e-12)
return signal + noise_rms * n
for snr in [-15, -10, -5, 0, 5, 10, 15, 20, 25]:
for clr in ["white", "pink", "brown"]:
if snr in [-15, -10, -5] and clr in ["white"]:
continue
distortions[f"{clr.capitalize()}Noise_{snr}dB"] = add_noise(
ref, snr, clr
)
if ("harmonic" in distortion_keys) or distortion_keys == "all":
for f_h, amp in zip([100, 500, 1000, 4000], [0.02, 0.03, 0.05, 0.08]):
tone = amp * np.sin(2 * np.pi * f_h * t)
distortions[f"Harmonic_{f_h}Hz"] = ref + tone
if ("reverb" in distortion_keys) or distortion_keys == "all":
for tail_ms, decay in zip([5, 10, 15, 20], [0.3, 0.5, 0.7, 0.9, 1.1]):
L = int(sr * tail_ms / 1000)
if L >= len(ref):
continue
irv = np.exp(-np.linspace(0, 3, L)) * decay
reverbed = np.convolve(ref, irv)[: len(ref)]
distortions[f"Reverb_{tail_ms}ms"] = reverbed
if ("noisegate" in distortion_keys) or distortion_keys == "all":
for thr in [0.005, 0.01, 0.02, 0.04]:
g = ref.copy()
g[np.abs(g) < thr] = 0
distortions[f"NoiseGate_{thr}"] = g
if ("pitch_shift" in distortion_keys) or distortion_keys == "all":
n_fft = min(2048, len(ref) // 2)
for shift in [-4, -2, 2, 4]:
shifted = librosa.effects.pitch_shift(
y=ref, sr=sr, n_steps=shift, n_fft=n_fft
)
distortions[f"PitchShift_{shift}st"] = shifted[: len(ref)]
if ("lowpass" in distortion_keys) or distortion_keys == "all":
for freq in [2000, 3000, 4000, 6000]:
if freq >= (sr / 2):
continue
b, a = butter(4, freq / (sr / 2), "low")
distortions[f"Lowpass_{freq}Hz"] = filtfilt(b, a, ref)
if ("highpass" in distortion_keys) or distortion_keys == "all":
for freq in [100, 300, 500, 800]:
if freq >= (sr / 2):
continue
b, a = butter(4, freq / (sr / 2), "high")
distortions[f"Highpass_{freq}Hz"] = filtfilt(b, a, ref)
if ("echo" in distortion_keys) or distortion_keys == "all":
for delay_ms, amp in zip([5, 10, 15, 20], [0.3, 0.5, 0.7]):
delay = int(sr * delay_ms / 1000)
if delay >= len(ref):
continue
echo = np.pad(ref, (delay, 0), "constant")[:-delay] * amp
distortions[f"Echo_{delay_ms}ms"] = ref + echo
if ("clipping" in distortion_keys) or distortion_keys == "all":
for thr in [0.3, 0.5, 0.7]:
distortions[f"Clipping_{thr}"] = np.clip(ref, -thr, thr)
if ("vibrato" in distortion_keys) or distortion_keys == "all":
for rate, depth in zip([3, 5, 7], [0.001, 0.002, 0.003]):
vibrato = np.sin(2 * np.pi * rate * t) * depth
vibrato_signal = librosa.effects.time_stretch(
ref, rate=1 + float(vibrato.mean()), n_fft=min(2048, len(ref) // 2)
)
distortions[f"Vibrato_{rate}Hz"] = librosa.util.fix_length(
vibrato_signal, size=len(ref)
)
return list(distortions.values())