Spaces:
Running
Running
| import os | |
| from concurrent.futures import ProcessPoolExecutor | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from pymcd.mcd import Calculate_MCD | |
| from tqdm import tqdm | |
| def calculate_mcd_for_wav(wav, target): | |
| if not os.path.exists(target): | |
| print("not exist", target) | |
| return 0 | |
| try: | |
| _mcd = mcd_toolbox.calculate_mcd(target, wav) | |
| except Exception as e: | |
| print(f"Error in {target, wav}, {e}") | |
| return 0 | |
| # if _mcd > 12: | |
| # print(wav, target) | |
| return _mcd | |
| import sys | |
| test_lst = sys.argv[1] | |
| output_path = sys.argv[2] | |
| mode = sys.argv[3] | |
| #mode = "dtw" # dtw_sl | |
| mcd_toolbox = Calculate_MCD(MCD_mode=mode) | |
| with open(test_lst, "r") as fr: | |
| lines = fr.readlines() | |
| path = output_path | |
| gen_wavs = [path + "gen/" + str(idx).zfill(8) + ".wav" for idx, line in enumerate(lines)] | |
| targets = [path + "tgt/" + str(idx).zfill(8) + ".wav" for idx, line in enumerate(lines)] | |
| mcd = 0 | |
| nums = 0 | |
| mcd_values = [] | |
| with ProcessPoolExecutor(max_workers=64) as executor: | |
| results = list(tqdm(executor.map(calculate_mcd_for_wav, gen_wavs, targets), total=len(gen_wavs))) | |
| mcd_values = [it for it in results if it != 0] | |
| mcd_avg = np.mean(mcd_values) | |
| if mode == "dtw": | |
| print(f"Average MCD: {mcd_avg:.3f}") | |
| if mode == "dtw_sl": | |
| print(f"Average MCD_SL: {mcd_avg:.3f}") | |