Spaces:
Build error
Build error
OFA-Generic_Interface
/
fairseq
/examples
/speech_synthesis
/preprocessing
/get_speaker_embedding.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from collections import defaultdict | |
| from itertools import chain | |
| from pathlib import Path | |
| import numpy as np | |
| import torchaudio | |
| import torchaudio.sox_effects as ta_sox | |
| import yaml | |
| from tqdm import tqdm | |
| from examples.speech_to_text.data_utils import load_tsv_to_dicts | |
| from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder | |
| def extract_embedding(audio_path, embedder): | |
| wav, sr = torchaudio.load(audio_path) # 2D | |
| if sr != embedder.RATE: | |
| wav, sr = ta_sox.apply_effects_tensor( | |
| wav, sr, [["rate", str(embedder.RATE)]] | |
| ) | |
| try: | |
| emb = embedder([wav[0].cuda().float()]).cpu().numpy() | |
| except RuntimeError: | |
| emb = None | |
| return emb | |
| def process(args): | |
| print("Fetching data...") | |
| raw_manifest_root = Path(args.raw_manifest_root).absolute() | |
| samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv")) | |
| for s in args.splits] | |
| samples = list(chain(*samples)) | |
| with open(args.config, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f: | |
| speaker_to_id = {r.strip(): i for i, r in enumerate(f)} | |
| embedder = SpkrEmbedder(args.ckpt).cuda() | |
| speaker_to_cnt = defaultdict(float) | |
| speaker_to_emb = defaultdict(float) | |
| for sample in tqdm(samples, desc="extract emb"): | |
| emb = extract_embedding(sample["audio"], embedder) | |
| if emb is not None: | |
| speaker_to_cnt[sample["speaker"]] += 1 | |
| speaker_to_emb[sample["speaker"]] += emb | |
| if len(speaker_to_emb) != len(speaker_to_id): | |
| missed = set(speaker_to_id) - set(speaker_to_emb.keys()) | |
| print( | |
| f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}" | |
| ) | |
| speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float) | |
| for speaker in speaker_to_emb: | |
| idx = speaker_to_id[speaker] | |
| emb = speaker_to_emb[speaker] | |
| cnt = speaker_to_cnt[speaker] | |
| speaker_emb_mat[idx, :] = emb / cnt | |
| speaker_emb_name = "speaker_emb.npy" | |
| speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}" | |
| np.save(speaker_emb_path, speaker_emb_mat) | |
| config["speaker_emb_filename"] = speaker_emb_name | |
| with open(args.new_config, "w") as f: | |
| yaml.dump(config, f) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--raw-manifest-root", "-m", required=True, type=str) | |
| parser.add_argument("--splits", "-s", type=str, nargs="+", | |
| default=["train"]) | |
| parser.add_argument("--config", "-c", required=True, type=str) | |
| parser.add_argument("--new-config", "-n", required=True, type=str) | |
| parser.add_argument("--ckpt", required=True, type=str, | |
| help="speaker embedder checkpoint") | |
| args = parser.parse_args() | |
| process(args) | |
| if __name__ == "__main__": | |
| main() | |