import json import random from concurrent.futures import ThreadPoolExecutor from datetime import datetime import librosa import pandas as pd from audio import ( loudness_normalize, compute_speaker_activity_masks, ) from config import * from distortions import apply_pm_distortions, apply_ps_distortions from metrics import ( compute_pm, compute_ps, diffusion_map_torch, pm_ci_components_full, ps_ci_components_full, ) from models import embed_batch, load_model from utils import * def compute_mapss_measures( models, mixtures, *, systems=None, algos=None, experiment_id=None, layer=DEFAULT_LAYER, add_ci=DEFAULT_ADD_CI, alpha=DEFAULT_ALPHA, seed=42, on_missing="skip", verbose=False, max_gpus=None, ): """ Compute MAPSS measures (PM, PS, and their errors). Data is saved to csv files. :param models: backbone self-supervised models. :param mixtures: data to process from _read_manifest :param systems: specific systems (algos and data) :param algos: specific algorithms to use :param experiment_id: user-specified name for experiment :param layer: transformer layer of model to consider :param add_ci: True will compute error radius and tail bounds. False will not. :param alpha: normalization factor of the diffusion maps. Lives in [0, 1]. :param seed: random seed number. :param on_missing: "skip" when missing values or throw an "error". :param verbose: True will print process info to console during runtime. False will minimize it. :param max_gpus: maximal amount of GPUs the program tries to utilize in parallel. """ gpu_distributor = GPUWorkDistributor(max_gpus) ngpu = get_gpu_count(max_gpus) if on_missing not in {"skip", "error"}: raise ValueError("on_missing must be 'skip' or 'error'.") torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) canon_mix = canonicalize_mixtures(mixtures, systems=systems) mixture_entries = [] for m in canon_mix: entries = [] for i, refp in enumerate(m.refs): sid = m.speaker_ids[i] entries.append( {"id": sid, "ref": Path(refp), "mixture": m.mixture_id, "outs": {}} ) mixture_entries.append(entries) for m, mix_entries in zip(canon_mix, mixture_entries): for algo, out_list in (m.systems or {}).items(): if len(out_list) != len(mix_entries): msg = f"[{algo}] Number of outputs ({len(out_list)}) does not match number of references ({len(mix_entries)}) for mixture {m.mixture_id}" if on_missing == "error": raise ValueError(msg) else: if verbose: warnings.warn(msg + " Skipping this algorithm.") continue for idx, e in enumerate(mix_entries): e["outs"][algo] = out_list[idx] if algos is None: algos_to_run = sorted( {algo for algo in canon_mix[0].systems.keys()} if canon_mix and canon_mix[0].systems else [] ) else: algos_to_run = list(algos) exp_id = experiment_id or datetime.now().strftime("%Y%m%d_%H%M%S") exp_root = os.path.join(RESULTS_ROOT, f"experiment_{exp_id}") os.makedirs(exp_root, exist_ok=True) params = { "models": models, "layer": layer, "add_ci": add_ci, "alpha": alpha, "seed": seed, "batch_size": BATCH_SIZE, "ngpu": ngpu, "max_gpus": max_gpus, } with open(os.path.join(exp_root, "params.json"), "w") as f: json.dump(params, f, indent=2) canon_struct = [ { "mixture_id": m.mixture_id, "references": [str(p) for p in m.refs], "systems": { a: [str(p) for p in outs] for a, outs in (m.systems or {}).items() }, "speaker_ids": m.speaker_ids, } for m in canon_mix ] with open(os.path.join(exp_root, "manifest_canonical.json"), "w") as f: json.dump(canon_struct, f, indent=2) print(f"Starting experiment {exp_id} with {ngpu} GPUs") print(f"Results will be saved to: {exp_root}") print("NOTE: Output files must be provided in the same order as reference files.") clear_gpu_memory() get_gpu_memory_info(verbose) flat_entries = [e for mix in mixture_entries for e in mix] all_refs = {} if verbose: print("Loading reference signals...") for e in flat_entries: wav, _ = librosa.load(str(e["ref"]), sr=SR) all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav)) if verbose: print("Computing speaker activity masks...") win = int(ENERGY_WIN_MS * SR / 1000) hop = int(ENERGY_HOP_MS * SR / 1000) multi_speaker_masks_mix = [] individual_speaker_masks_mix = [] total_frames_per_mix = [] for i, mix in enumerate(mixture_entries): if verbose: print(f" Computing masks for mixture {i + 1}/{len(mixture_entries)}") if ngpu > 0: with torch.cuda.device(0): refs_for_mix = [all_refs[e["id"]].cuda() for e in mix] multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop) multi_speaker_masks_mix.append(multi_mask.cpu()) individual_speaker_masks_mix.append([m.cpu() for m in individual_masks]) total_frames_per_mix.append(multi_mask.shape[0]) for ref in refs_for_mix: del ref torch.cuda.empty_cache() else: refs_for_mix = [all_refs[e["id"]].cpu() for e in mix] multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop) multi_speaker_masks_mix.append(multi_mask.cpu()) individual_speaker_masks_mix.append([m.cpu() for m in individual_masks]) total_frames_per_mix.append(multi_mask.shape[0]) ordered_speakers = [e["id"] for e in flat_entries] all_mixture_results = {} for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)): mixture_id = mix_canon.mixture_id all_mixture_results[mixture_id] = {} total_frames = total_frames_per_mix[mix_idx] mixture_speakers = [e["id"] for e in mix_entries] for algo_idx, algo in enumerate(algos_to_run): if verbose: print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}") all_outs = {} missing = [] for e in mix_entries: assigned_path = e.get("outs", {}).get(algo) if assigned_path is None: missing.append((e["mixture"], e["id"])) continue wav, _ = librosa.load(str(assigned_path), sr=SR) all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav)) if missing: msg = f"[{algo}] missing outputs for {len(missing)} speaker(s) in mixture {mixture_id}" if on_missing == "error": raise FileNotFoundError(msg) else: if verbose: warnings.warn(msg + " Skipping those speakers.") if not all_outs: if verbose: warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.") continue if algo not in all_mixture_results[mixture_id]: all_mixture_results[mixture_id][algo] = {} ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} ps_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} pm_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} pm_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} for model_idx, mname in enumerate(models): if verbose: print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}") for metric_type in ["PS", "PM"]: clear_gpu_memory() gc.collect() model_wrapper, layer_eff = load_model(mname, layer, max_gpus) get_gpu_memory_info(verbose) speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs] if not speakers_this_mix: continue if verbose: print(f" Processing {metric_type} for mixture {mixture_id}") multi_speaker_mask = multi_speaker_masks_mix[mix_idx] individual_masks = individual_speaker_masks_mix[mix_idx] valid_frame_indices = torch.where(multi_speaker_mask)[0].tolist() speaker_signals = {} speaker_labels = {} for speaker_idx, e in enumerate(speakers_this_mix): s = e["id"] if metric_type == "PS": dists = [ loudness_normalize(d) for d in apply_ps_distortions(all_refs[s].numpy(), "all") ] else: dists = [ loudness_normalize(d) for d in apply_pm_distortions( all_refs[s].numpy(), "all" ) ] sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))] speaker_signals[s] = sigs speaker_labels[s] = [f"{s}-{l}" for l in lbls] all_embeddings = {} for s in speaker_signals: sigs = speaker_signals[s] masks = [multi_speaker_mask] * len(sigs) batch_size = min(2, BATCH_SIZE) embeddings_list = [] for i in range(0, len(sigs), batch_size): batch_sigs = sigs[i:i + batch_size] batch_masks = masks[i:i + batch_size] batch_embs = embed_batch( batch_sigs, batch_masks, model_wrapper, layer_eff, use_mlm=False, ) if batch_embs.numel() > 0: embeddings_list.append(batch_embs.cpu()) torch.cuda.empty_cache() if embeddings_list: all_embeddings[s] = torch.cat(embeddings_list, dim=0) else: all_embeddings[s] = torch.empty(0, 0, 0) if not all_embeddings or all(e.numel() == 0 for e in all_embeddings.values()): if verbose: print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.") continue L = next(iter(all_embeddings.values())).shape[1] if all_embeddings else 0 if L == 0: if verbose: print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.") continue if verbose: print(f"Computing {metric_type} scores for {mname}...") with ThreadPoolExecutor( max_workers=min(2, ngpu if ngpu > 0 else 1) ) as executor: def process_frame(f, frame_idx, all_embeddings_dict, speaker_labels_dict, individual_masks_list, speaker_indices): try: active_speakers = [] for spk_idx, spk_id in enumerate(speaker_indices): if individual_masks_list[spk_idx][frame_idx]: active_speakers.append(spk_id) if len(active_speakers) < 2: return frame_idx, metric_type, {}, None, None frame_embeddings = [] frame_labels = [] for spk_id in active_speakers: spk_embs = all_embeddings_dict[spk_id][:, f, :] frame_embeddings.append(spk_embs) frame_labels.extend(speaker_labels_dict[spk_id]) frame_emb = torch.cat(frame_embeddings, dim=0).detach().cpu().numpy() if add_ci: coords_d, coords_c, eigvals, k_sub_gauss = ( gpu_distributor.execute_on_gpu( diffusion_map_torch, frame_emb, frame_labels, alpha=alpha, eig_solver="full", return_eigs=True, return_complement=True, return_cval=add_ci, ) ) else: coords_d = gpu_distributor.execute_on_gpu( diffusion_map_torch, frame_emb, frame_labels, alpha=alpha, eig_solver="full", return_eigs=False, return_complement=False, return_cval=False, ) coords_c = None eigvals = None k_sub_gauss = 1 if metric_type == "PS": score = compute_ps( coords_d, frame_labels, max_gpus ) bias = prob = None if add_ci: bias, prob = ps_ci_components_full( coords_d, coords_c, eigvals, frame_labels, delta=DEFAULT_DELTA_CI, ) return frame_idx, "PS", score, bias, prob else: score = compute_pm( coords_d, frame_labels, "gamma", max_gpus ) bias = prob = None if add_ci: bias, prob = pm_ci_components_full( coords_d, coords_c, eigvals, frame_labels, delta=DEFAULT_DELTA_CI, K=k_sub_gauss, ) return frame_idx, "PM", score, bias, prob except Exception as ex: if verbose: print(f"ERROR frame {frame_idx}: {ex}") return None speaker_ids = [e["id"] for e in speakers_this_mix] futures = [ executor.submit( process_frame, f, valid_frame_indices[f], all_embeddings, speaker_labels, individual_masks, speaker_ids ) for f in range(L) ] for fut in futures: result = fut.result() if result is None: continue frame_idx, metric, score, bias, prob = result if metric == "PS": for sp in mixture_speakers: if sp in score: ps_frames[mname][sp][frame_idx] = score[sp] if add_ci and bias is not None and sp in bias: ps_bias_frames[mname][sp][frame_idx] = bias[sp] ps_prob_frames[mname][sp][frame_idx] = prob[sp] else: for sp in mixture_speakers: if sp in score: pm_frames[mname][sp][frame_idx] = score[sp] if add_ci and bias is not None and sp in bias: pm_bias_frames[mname][sp][frame_idx] = bias[sp] pm_prob_frames[mname][sp][frame_idx] = prob[sp] clear_gpu_memory() gc.collect() del model_wrapper clear_gpu_memory() gc.collect() all_mixture_results[mixture_id][algo][mname] = { 'ps_frames': ps_frames[mname], 'pm_frames': pm_frames[mname], 'ps_bias_frames': ps_bias_frames[mname] if add_ci else None, 'ps_prob_frames': ps_prob_frames[mname] if add_ci else None, 'pm_bias_frames': pm_bias_frames[mname] if add_ci else None, 'pm_prob_frames': pm_prob_frames[mname] if add_ci else None, 'total_frames': total_frames } if verbose: print(f"Saving results for mixture {mixture_id}...") timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)] for model in models: ps_data = {'timestamp_ms': timestamps_ms} pm_data = {'timestamp_ms': timestamps_ms} ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None for algo in algos_to_run: if algo not in all_mixture_results[mixture_id]: continue if model not in all_mixture_results[mixture_id][algo]: continue model_data = all_mixture_results[mixture_id][algo][model] for speaker in mixture_speakers: col_name = f"{algo}_{speaker}" ps_data[col_name] = model_data['ps_frames'][speaker] pm_data[col_name] = model_data['pm_frames'][speaker] if add_ci and ci_data is not None: ci_data[f"{algo}_{speaker}_ps_bias"] = model_data['ps_bias_frames'][speaker] ci_data[f"{algo}_{speaker}_ps_prob"] = model_data['ps_prob_frames'][speaker] ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker] ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker] mixture_dir = os.path.join(exp_root, mixture_id) os.makedirs(mixture_dir, exist_ok=True) pd.DataFrame(ps_data).to_csv( os.path.join(mixture_dir, f"ps_scores_{model}.csv"), index=False ) pd.DataFrame(pm_data).to_csv( os.path.join(mixture_dir, f"pm_scores_{model}.csv"), index=False ) if add_ci and ci_data is not None: pd.DataFrame(ci_data).to_csv( os.path.join(mixture_dir, f"ci_{model}.csv"), index=False ) del all_outs clear_gpu_memory() gc.collect() print(f"\nEXPERIMENT COMPLETED") print(f"Results saved to: {exp_root}") del all_refs, multi_speaker_masks_mix, individual_speaker_masks_mix from models import cleanup_all_models cleanup_all_models() clear_gpu_memory() get_gpu_memory_info(verbose) gc.collect() return exp_root