Spaces:
Build error
Build error
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import gc | |
| import os | |
| import os.path as osp | |
| import random | |
| import numpy as np | |
| import tqdm | |
| import torch | |
| from collections import namedtuple | |
| import faiss | |
| import fairseq | |
| import soundfile as sf | |
| def get_parser(): | |
| parser = argparse.ArgumentParser( | |
| description="compute kmeans codebook from kaldi-computed feats" | |
| ) | |
| # fmt: off | |
| parser.add_argument('data', help='location of tsv files') | |
| parser.add_argument('--save-dir', help='where to save the output', required=True) | |
| parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) | |
| parser.add_argument('--sample-pct', '-r', type=float, help='percentage of timesteps to sample', default=0) | |
| parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) | |
| parser.add_argument('--faiss-specs', '-f', type=str, | |
| help='faiss index specs; separated by space ' | |
| 'format is: PCAx_NORM_CLUSx_SPHERICAL -> ' | |
| 'PCAx if exists first apply PCA ' | |
| 'NORM if exists, normalize the vector by L2 norm ' | |
| 'CLUSx must exist, cluster to x clusters ' | |
| 'SPEHRICAL if exists, apply spherical kmeans', | |
| default='l2') | |
| # fmt: on | |
| return parser | |
| faiss_spec = namedtuple("faiss_spec", ["pca", "norm", "n_clus", "sphere", "spec_str"]) | |
| def parse_faiss_specs(specs_str): | |
| specs = [] | |
| for ss in specs_str.split(): | |
| comps = ss.split("_") | |
| pca = 0 | |
| norm = False | |
| n_clus = 0 | |
| sphere = False | |
| for c in comps: | |
| if c.startswith("PCA"): | |
| pca = int(c[3:]) | |
| elif c == "NORM": | |
| norm = True | |
| elif c.startswith("CLUS"): | |
| n_clus = int(c[4:]) | |
| elif c == "SPHERICAL": | |
| sphere = True | |
| assert n_clus > 0 | |
| specs.append( | |
| faiss_spec(pca=pca, norm=norm, n_clus=n_clus, sphere=sphere, spec_str=ss) | |
| ) | |
| return specs | |
| class Wav2VecFeatureReader(object): | |
| def __init__(self, cp_file, layer): | |
| state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file) | |
| self.layer = layer | |
| if "cfg" in state: | |
| w2v_args = state["cfg"] | |
| task = fairseq.tasks.setup_task(w2v_args.task) | |
| model = task.build_model(w2v_args.model) | |
| else: | |
| w2v_args = state["args"] | |
| task = fairseq.tasks.setup_task(w2v_args) | |
| model = task.build_model(w2v_args) | |
| model.load_state_dict(state["model"], strict=True) | |
| model.eval() | |
| model.cuda() | |
| self.model = model | |
| def read_audio(self, fname): | |
| """Load an audio file and return PCM along with the sample rate""" | |
| wav, sr = sf.read(fname) | |
| assert sr == 16e3 | |
| return wav | |
| def get_feats(self, loc): | |
| x = self.read_audio(loc) | |
| with torch.no_grad(): | |
| source = torch.from_numpy(x).view(1, -1).float().cuda() | |
| res = self.model( | |
| source=source, mask=False, features_only=True, layer=self.layer | |
| ) | |
| return res["layer_results"][self.layer][0].squeeze(1) | |
| def get_iterator(args): | |
| with open(args.data, "r") as fp: | |
| lines = fp.read().split("\n") | |
| root = lines.pop(0).strip() | |
| files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] | |
| if getattr(args, "sample_pct", 0) > 0: | |
| files = random.sample(files, int(args.sample_pct * len(files))) | |
| num = len(files) | |
| reader = Wav2VecFeatureReader(args.checkpoint, args.layer) | |
| def iterate(): | |
| for fname in files: | |
| feats = reader.get_feats(fname) | |
| yield feats.cpu().numpy() | |
| return iterate, num | |
| def main(): | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| faiss_specs = parse_faiss_specs(args.faiss_specs) | |
| print("Faiss Specs:", faiss_specs) | |
| feat_path = osp.join(args.save_dir, "features") | |
| if osp.exists(feat_path + ".npy"): | |
| feats = np.load(feat_path + ".npy") | |
| else: | |
| generator, num = get_iterator(args) | |
| iterator = generator() | |
| feats = [] | |
| for f in tqdm.tqdm(iterator, total=num): | |
| feats.append(f) | |
| del iterator | |
| del generator | |
| feats = np.concatenate(feats) | |
| print(feats.shape) | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| # np.save(feat_path, feats) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| reload = False | |
| for spec in faiss_specs: | |
| print("Processing spec", spec) | |
| if reload: | |
| print("Reloading...") | |
| del feats | |
| gc.collect() | |
| feats = np.load(feat_path + ".npy") | |
| save_path = osp.join(args.save_dir, spec.spec_str) | |
| os.makedirs(save_path, exist_ok=True) | |
| d = feats.shape[-1] | |
| x = feats | |
| if spec.pca > 0: | |
| print("Computing PCA") | |
| pca = faiss.PCAMatrix(d, spec.pca) | |
| pca.train(x) | |
| d = spec.pca | |
| b = faiss.vector_to_array(pca.b) | |
| A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) | |
| np.save(osp.join(save_path, "pca_A"), A.T) | |
| np.save(osp.join(save_path, "pca_b"), b) | |
| print("Applying PCA") | |
| x = pca.apply_py(x) | |
| if spec.norm: | |
| reload = spec.pca <= 0 | |
| print("Normalizing") | |
| faiss.normalize_L2(x) | |
| print("Computing kmeans") | |
| kmeans = faiss.Kmeans( | |
| d, | |
| spec.n_clus, | |
| niter=50, | |
| verbose=True, | |
| spherical=spec.sphere, | |
| max_points_per_centroid=feats.shape[0], | |
| gpu=True, | |
| nredo=3, | |
| ) | |
| kmeans.train(x) | |
| np.save(osp.join(save_path, "centroids"), kmeans.centroids) | |
| del kmeans | |
| del x | |
| gc.collect() | |
| if __name__ == "__main__": | |
| main() | |