liumaolin
Add GPT-SoVITS training pipeline with audio slicing, transcription, and model training modules.
1b05d02
| import os | |
| import random | |
| import sys | |
| import soundfile as sf | |
| from TTS_infer_pack.TTS_mid import TTS, TTS_Config | |
| from utils import HParams | |
| if "utils" not in sys.modules: | |
| class GPTSoVITSFixedUtilsModule: | |
| HParams = HParams | |
| sys.modules['utils'] = GPTSoVITSFixedUtilsModule | |
| texts = { | |
| "zh": "今天天气很好,我想出去逛一逛,可是不知道有什么好玩的地方。", | |
| "en": "Today the weather is good, I want to go out and explore, but I don't know what to do." | |
| } | |
| def init_model( | |
| gpt_path, | |
| sovits_path, | |
| ref_audio_path, | |
| ref_prompt_path, | |
| lan, | |
| bert_path, | |
| cnhubert_base_path, | |
| tts_config_file, | |
| example_text_content, | |
| example_audio_path | |
| ): | |
| def inference(text, text_lang, | |
| ref_audio_path, prompt_text, | |
| prompt_lang, top_k, | |
| top_p, temperature, | |
| text_split_method, batch_size, | |
| speed_factor, ref_text_free, | |
| split_bucket, fragment_interval, | |
| seed): | |
| actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32) | |
| inputs = { | |
| "text": text, | |
| "text_lang": text_lang, | |
| "ref_audio_path": ref_audio_path, | |
| "prompt_text": prompt_text if not ref_text_free else "", | |
| "prompt_lang": prompt_lang, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "text_split_method": text_split_method, | |
| "batch_size": int(batch_size), | |
| "speed_factor": float(speed_factor), | |
| "split_bucket": split_bucket, | |
| "return_fragment": False, | |
| "fragment_interval": fragment_interval, | |
| "seed": actual_seed, | |
| } | |
| print(inputs) | |
| for item in tts_pipline.run(inputs): | |
| yield item, actual_seed | |
| device = "cpu" | |
| is_half = False | |
| task_name = os.path.basename(gpt_path).split('_')[0] | |
| with open(ref_prompt_path, "r") as f: | |
| prompt_text = f.read() | |
| tts_config = TTS_Config(tts_config_file) | |
| tts_config.device = device | |
| tts_config.is_half = is_half | |
| if gpt_path is not None: | |
| tts_config.t2s_weights_path = gpt_path | |
| if sovits_path is not None: | |
| tts_config.vits_weights_path = sovits_path | |
| if cnhubert_base_path is not None: | |
| tts_config.cnhuhbert_base_path = cnhubert_base_path | |
| if bert_path is not None: | |
| tts_config.bert_base_path = bert_path | |
| tts_pipline = TTS(tts_config) | |
| text = example_text_content or texts[lan] | |
| if lan == "zh": | |
| text_language = "all_zh" | |
| prompt_language = "zh" | |
| else: | |
| text_language = "en" | |
| prompt_language = "en" | |
| batch_size = 100 # inference batch size | |
| speed_factor = 1.0 # control speed of output audio | |
| top_k = 5 # gpt | |
| top_p = 1 | |
| temperature = 1 | |
| how_to_cut = "cut4" # cut method | |
| ref_text_free = False | |
| split_bucket = True # suggest on | |
| fragment_interval = 0.07 # interval between every sentence | |
| seed = 233333 # seed | |
| [output] = inference(text, text_language, ref_audio_path, | |
| prompt_text, prompt_language, | |
| top_k, top_p, temperature, | |
| how_to_cut, batch_size, | |
| speed_factor, ref_text_free, | |
| split_bucket, fragment_interval, | |
| seed) | |
| output_path = f'{example_audio_path}' if example_text_content else './output.wav' | |
| sf.write(output_path, output[0][1], | |
| samplerate=output[0][0], subtype='PCM_16') | |
| if __name__ == '__main__': | |
| task_name = "yao" | |
| lan = "zh" | |
| # 使用命令行传入的路径 | |
| gpt_path = f"GPT_weights/{task_name}_best_gpt.ckpt" | |
| sovits_path = f"SoVITS_weights/{task_name}_best_sovits.pth" | |
| bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" | |
| cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" | |
| inp_ref = f"ref/{task_name}/ref.wav" | |
| ref_prompt_path = f"ref/{task_name}/ref.txt" | |
| init_model(gpt_path, sovits_path, inp_ref, ref_prompt_path, "zh", bert_path, cnhubert_base_path) | |