Spaces:
Running
on
Zero
Running
on
Zero
| # @ hwang258@jhu.edu | |
| import os | |
| import json | |
| import torch | |
| import random | |
| import logging | |
| import shutil | |
| import typing as tp | |
| import numpy as np | |
| import torchaudio | |
| import sys | |
| from torch.utils.data import Dataset | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| def read_json(path): | |
| with open(path, 'r') as f: | |
| return json.load(f) | |
| class CapSpeech(Dataset): | |
| def __init__( | |
| self, | |
| dataset_dir: str = None, | |
| clap_emb_dir: str = None, | |
| t5_folder_name: str = "t5", | |
| phn_folder_name: str = "g2p", | |
| manifest_name: str = "manifest", | |
| json_name: str = "jsons", | |
| dynamic_batching: bool = True, | |
| text_pad_token: int = -1, | |
| audio_pad_token: float = 0.0, | |
| split: str = "val", | |
| sr: int = 24000, | |
| norm_audio: bool = False, | |
| vocab_file: str = None, | |
| ): | |
| super().__init__() | |
| self.dataset_dir = dataset_dir | |
| self.clap_emb_dir = clap_emb_dir | |
| self.t5_folder_name = t5_folder_name | |
| self.phn_folder_name = phn_folder_name | |
| self.manifest_name = manifest_name | |
| self.json_name = json_name | |
| self.dynamic_batching = dynamic_batching | |
| self.text_pad_token = text_pad_token | |
| self.audio_pad_token = torch.tensor(audio_pad_token) | |
| self.split = split | |
| self.sr = sr | |
| self.norm_audio = norm_audio | |
| assert self.split in ['train', 'train_small', 'val', 'test'] | |
| manifest_fn = os.path.join(self.dataset_dir, self.manifest_name, self.split+".txt") | |
| meta = read_json(os.path.join(self.dataset_dir, self.json_name, self.split + ".json")) | |
| self.meta = {item["segment_id"]: item["audio_path"] for item in meta} | |
| with open(manifest_fn, "r") as rf: | |
| data = [l.strip().split("\t") for l in rf.readlines()] | |
| # data = [item for item in data if item[2] == 'none'] # remove sound effects | |
| self.data = [item[0] for item in data] | |
| self.tag_list = [item[1] for item in data] | |
| logging.info(f"number of data points for {self.split} split: {len(self.data)}") | |
| # phoneme vocabulary | |
| if vocab_file is None: | |
| vocab_fn = os.path.join(self.dataset_dir, "vocab.txt") | |
| else: | |
| vocab_fn = vocab_file | |
| with open(vocab_fn, "r") as f: | |
| temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0] | |
| self.phn2num = {item[1]:int(item[0]) for item in temp} | |
| def __len__(self): | |
| return len(self.data) | |
| def _load_audio(self, audio_path): | |
| try: | |
| y, sr = torchaudio.load(audio_path) | |
| if y.shape[0] > 1: | |
| y = y.mean(dim=0, keepdim=True) | |
| if sr != self.sr: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr) | |
| y = resampler(y) | |
| if self.norm_audio: | |
| eps = 1e-9 | |
| max_val = torch.max(torch.abs(y)) | |
| y = y / (max_val + eps) | |
| if torch.isnan(y.mean()): | |
| return None | |
| return y | |
| except: | |
| return None | |
| def _load_phn_enc(self, index): | |
| try: | |
| seg_id = self.data[index] | |
| pf = os.path.join(self.dataset_dir, self.phn_folder_name, seg_id+".txt") | |
| audio_path = self.meta[seg_id] | |
| cf = os.path.join(self.dataset_dir, self.t5_folder_name, seg_id+".npz") | |
| tagf = os.path.join(self.clap_emb_dir, self.tag_list[index]+'.npz') | |
| with open(pf, "r") as p: | |
| phns = [l.strip() for l in p.readlines()] | |
| assert len(phns) == 1, phns | |
| x = [self.phn2num[item] for item in phns[0].split(" ")] | |
| c = np.load(cf)['arr_0'] | |
| c = torch.tensor(c).squeeze() | |
| tag = np.load(tagf)['arr_0'] | |
| tag = torch.tensor(tag).squeeze() | |
| y = self._load_audio(audio_path) | |
| if y is not None: | |
| return x, y, c, tag | |
| return None, None, None, None | |
| except: | |
| return None, None, None, None | |
| def __getitem__(self, index): | |
| x, y, c, tag = self._load_phn_enc(index) | |
| if x is None: | |
| return { | |
| "x": None, | |
| "x_len": None, | |
| "y": None, | |
| "y_len": None, | |
| "c": None, | |
| "c_len": None, | |
| "tag": None | |
| } | |
| x_len, y_len, c_len = len(x), len(y[0]), len(c) | |
| y_len = y_len / self.sr | |
| if y_len * self.sr / 256 <= x_len: | |
| return { | |
| "x": None, | |
| "x_len": None, | |
| "y": None, | |
| "y_len": None, | |
| "c": None, | |
| "c_len": None, | |
| "tag": None | |
| } | |
| x = torch.LongTensor(x) | |
| return { | |
| "x": x, | |
| "x_len": x_len, | |
| "y": y, | |
| "y_len": y_len, | |
| "c": c, | |
| "c_len": c_len, | |
| "tag": tag | |
| } | |
| def collate(self, batch): | |
| out = {key:[] for key in batch[0]} | |
| for item in batch: | |
| if item['x'] == None: # deal with load failure | |
| continue | |
| if item['c'].ndim != 2: | |
| continue | |
| for key, val in item.items(): | |
| out[key].append(val) | |
| res = {} | |
| res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.text_pad_token) | |
| res["x_lens"] = torch.LongTensor(out["x_len"]) | |
| if self.dynamic_batching: | |
| res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.audio_pad_token) | |
| res['y'] = res['y'].permute(1,2,0) # T B K -> B K T | |
| else: | |
| res['y'] = torch.stack(out['y'], dim=0) | |
| res["y_lens"] = torch.Tensor(out["y_len"]) | |
| res['c'] = torch.nn.utils.rnn.pad_sequence(out['c'], batch_first=True) | |
| res["c_lens"] = torch.LongTensor(out["c_len"]) | |
| res["tag"] = torch.stack(out['tag'], dim=0) | |
| return res | |
| if __name__ == "__main__": | |
| # debug | |
| import argparse | |
| from torch.utils.data import DataLoader | |
| from accelerate import Accelerator | |
| dataset = CapSpeech( | |
| dataset_dir="./data/capspeech", | |
| clap_emb_dir="./data/clap_embs/", | |
| split="val" | |
| ) | |