Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import librosa | |
| import json5 | |
| from huggingface_hub import hf_hub_download | |
| from transformers import SeamlessM4TFeatureExtractor, Wav2Vec2BertModel | |
| import safetensors | |
| import numpy as np | |
| from indextts.utils.maskgct.models.codec.kmeans.repcodec_model import RepCodec | |
| from indextts.utils.maskgct.models.tts.maskgct.maskgct_s2a import MaskGCT_S2A | |
| from indextts.utils.maskgct.models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder | |
| import time | |
| def _load_config(config_fn, lowercase=False): | |
| """Load configurations into a dictionary | |
| Args: | |
| config_fn (str): path to configuration file | |
| lowercase (bool, optional): whether changing keys to lower case. Defaults to False. | |
| Returns: | |
| dict: dictionary that stores configurations | |
| """ | |
| with open(config_fn, "r") as f: | |
| data = f.read() | |
| config_ = json5.loads(data) | |
| if "base_config" in config_: | |
| # load configurations from new path | |
| p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) | |
| p_config_ = _load_config(p_config_path) | |
| config_ = override_config(p_config_, config_) | |
| if lowercase: | |
| # change keys in config_ to lower case | |
| config_ = get_lowercase_keys_config(config_) | |
| return config_ | |
| def load_config(config_fn, lowercase=False): | |
| """Load configurations into a dictionary | |
| Args: | |
| config_fn (str): path to configuration file | |
| lowercase (bool, optional): _description_. Defaults to False. | |
| Returns: | |
| JsonHParams: an object that stores configurations | |
| """ | |
| config_ = _load_config(config_fn, lowercase=lowercase) | |
| # create an JsonHParams object with configuration dict | |
| cfg = JsonHParams(**config_) | |
| return cfg | |
| class JsonHParams: | |
| def __init__(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| if type(v) == dict: | |
| v = JsonHParams(**v) | |
| self[k] = v | |
| def keys(self): | |
| return self.__dict__.keys() | |
| def items(self): | |
| return self.__dict__.items() | |
| def values(self): | |
| return self.__dict__.values() | |
| def __len__(self): | |
| return len(self.__dict__) | |
| def __getitem__(self, key): | |
| return getattr(self, key) | |
| def __setitem__(self, key, value): | |
| return setattr(self, key, value) | |
| def __contains__(self, key): | |
| return key in self.__dict__ | |
| def __repr__(self): | |
| return self.__dict__.__repr__() | |
| def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'): | |
| semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") | |
| semantic_model.eval() | |
| stat_mean_var = torch.load(path_) | |
| semantic_mean = stat_mean_var["mean"] | |
| semantic_std = torch.sqrt(stat_mean_var["var"]) | |
| return semantic_model, semantic_mean, semantic_std | |
| def build_semantic_codec(cfg): | |
| semantic_codec = RepCodec(cfg=cfg) | |
| semantic_codec.eval() | |
| return semantic_codec | |
| def build_s2a_model(cfg, device): | |
| soundstorm_model = MaskGCT_S2A(cfg=cfg) | |
| soundstorm_model.eval() | |
| soundstorm_model.to(device) | |
| return soundstorm_model | |
| def build_acoustic_codec(cfg, device): | |
| codec_encoder = CodecEncoder(cfg=cfg.encoder) | |
| codec_decoder = CodecDecoder(cfg=cfg.decoder) | |
| codec_encoder.eval() | |
| codec_decoder.eval() | |
| codec_encoder.to(device) | |
| codec_decoder.to(device) | |
| return codec_encoder, codec_decoder | |
| class Inference_Pipeline(): | |
| def __init__( | |
| self, | |
| semantic_model, | |
| semantic_codec, | |
| semantic_mean, | |
| semantic_std, | |
| codec_encoder, | |
| codec_decoder, | |
| s2a_model_1layer, | |
| s2a_model_full, | |
| ): | |
| self.semantic_model = semantic_model | |
| self.semantic_codec = semantic_codec | |
| self.semantic_mean = semantic_mean | |
| self.semantic_std = semantic_std | |
| self.codec_encoder = codec_encoder | |
| self.codec_decoder = codec_decoder | |
| self.s2a_model_1layer = s2a_model_1layer | |
| self.s2a_model_full = s2a_model_full | |
| def get_emb(self, input_features, attention_mask): | |
| vq_emb = self.semantic_model( | |
| input_features=input_features, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| ) | |
| feat = vq_emb.hidden_states[17] # (B, T, C) | |
| feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat) | |
| return feat | |
| def extract_acoustic_code(self, speech): | |
| vq_emb = self.codec_encoder(speech.unsqueeze(1)) | |
| _, vq, _, _, _ = self.codec_decoder.quantizer(vq_emb) | |
| acoustic_code = vq.permute(1, 2, 0) | |
| return acoustic_code | |
| def get_scode(self, inputs): | |
| semantic_code, feat = self.semantic_codec.quantize(inputs) | |
| # vq = self.semantic_codec.quantizer.vq2emb(semantic_code.unsqueeze(1)) | |
| # vq = vq.transpose(1,2) | |
| return semantic_code | |
| def semantic2acoustic( | |
| self, | |
| combine_semantic_code, | |
| acoustic_code, | |
| n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
| cfg=2.5, | |
| rescale_cfg=0.75, | |
| ): | |
| semantic_code = combine_semantic_code | |
| cond = self.s2a_model_1layer.cond_emb(semantic_code) | |
| prompt = acoustic_code[:, :, :] | |
| predict_1layer = self.s2a_model_1layer.reverse_diffusion( | |
| cond=cond, | |
| prompt=prompt, | |
| temp=1.5, | |
| filter_thres=0.98, | |
| n_timesteps=n_timesteps[:1], | |
| cfg=cfg, | |
| rescale_cfg=rescale_cfg, | |
| ) | |
| cond = self.s2a_model_full.cond_emb(semantic_code) | |
| prompt = acoustic_code[:, :, :] | |
| predict_full = self.s2a_model_full.reverse_diffusion( | |
| cond=cond, | |
| prompt=prompt, | |
| temp=1.5, | |
| filter_thres=0.98, | |
| n_timesteps=n_timesteps, | |
| cfg=cfg, | |
| rescale_cfg=rescale_cfg, | |
| gt_code=predict_1layer, | |
| ) | |
| vq_emb = self.codec_decoder.vq2emb( | |
| predict_full.permute(2, 0, 1), n_quantizers=12 | |
| ) | |
| recovered_audio = self.codec_decoder(vq_emb) | |
| prompt_vq_emb = self.codec_decoder.vq2emb( | |
| prompt.permute(2, 0, 1), n_quantizers=12 | |
| ) | |
| recovered_prompt_audio = self.codec_decoder(prompt_vq_emb) | |
| recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() | |
| recovered_audio = recovered_audio[0][0].cpu().numpy() | |
| combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio]) | |
| return combine_audio, recovered_audio | |
| def s2a_inference( | |
| self, | |
| prompt_speech_path, | |
| combine_semantic_code, | |
| cfg=2.5, | |
| n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
| cfg_s2a=2.5, | |
| rescale_cfg_s2a=0.75, | |
| ): | |
| speech = librosa.load(prompt_speech_path, sr=24000)[0] | |
| acoustic_code = self.extract_acoustic_code( | |
| torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device) | |
| ) | |
| _, recovered_audio = self.semantic2acoustic( | |
| combine_semantic_code, | |
| acoustic_code, | |
| n_timesteps=n_timesteps_s2a, | |
| cfg=cfg_s2a, | |
| rescale_cfg=rescale_cfg_s2a, | |
| ) | |
| return recovered_audio | |
| def gt_inference( | |
| self, | |
| prompt_speech_path, | |
| combine_semantic_code, | |
| ): | |
| speech = librosa.load(prompt_speech_path, sr=24000)[0] | |
| ''' | |
| acoustic_code = self.extract_acoustic_code( | |
| torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device) | |
| ) | |
| prompt = acoustic_code[:, :, :] | |
| prompt_vq_emb = self.codec_decoder.vq2emb( | |
| prompt.permute(2, 0, 1), n_quantizers=12 | |
| ) | |
| ''' | |
| prompt_vq_emb = self.codec_encoder(torch.tensor(speech).unsqueeze(0).unsqueeze(1).to(combine_semantic_code.device)) | |
| recovered_prompt_audio = self.codec_decoder(prompt_vq_emb) | |
| recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() | |
| return recovered_prompt_audio | |