Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import traceback | |
| from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt | |
| import torch | |
| import torchaudio | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_fine_tuned_xtts_v2(config_path, checkpoint_path, reference_audio_path): | |
| """ | |
| Load the fine-tuned XTTS v2 model and compute speaker latents. | |
| Args: | |
| config_path (str): Path to the configuration file. | |
| Example: "path/to/config.json" | |
| checkpoint_path (str): Path to the checkpoint directory. | |
| Example: "path/to/checkpoint/" | |
| reference_audio_path (str): Path to the reference audio file. | |
| Example: "path/to/reference.wav" | |
| Returns: | |
| tuple: A tuple containing the model, gpt_cond_latent, and speaker_embedding. | |
| Example: (model, gpt_cond_latent, speaker_embedding) | |
| """ | |
| print("Loading model...") | |
| config = XttsConfig() | |
| config.load_json(config_path) | |
| model = Xtts.init_from_config(config) | |
| model.load_checkpoint(config, checkpoint_dir=checkpoint_path, use_deepspeed=True) | |
| model.cuda() | |
| print("Computing speaker latents...") | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[reference_audio_path]) | |
| return model, gpt_cond_latent, speaker_embedding | |
| def Inference(model, gpt_cond_latent, speaker_embedding,path_to_save,text, temperature=0.7): | |
| """ | |
| Perform inference using the fine-tuned XTTS v2 model. | |
| Args: | |
| model (Xtts): The XTTS v2 model. | |
| Example: model, gpt_cond_latent, speaker_embedding = load_fine_tuned_xtts_v2(config_path, checkpoint_path, reference_audio_path) | |
| gpt_cond_latent (torch.Tensor): GPT conditioning latent vectors. | |
| speaker_embedding (torch.Tensor): Speaker embedding vectors. | |
| path_to_save (str): Path to save the generated audio. | |
| Example: "path/to/output.wav" | |
| text (str): The input text for synthesis. | |
| Example: "Hello, world!" | |
| temperature (float, optional): Sampling temperature. Default is 0.7. | |
| Example: 0.7 | |
| Returns: | |
| None | |
| """ | |
| print("Inference...") | |
| out = model.inference( | |
| text, | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| temperature, # Add custom parameters here # 3 | |
| ) | |
| torchaudio.save(path_to_save, torch.tensor(out["wav"]).unsqueeze(0), 24000) | |
| #model, gpt_cond_latent, speaker_embedding = load_fine_tuned_xtts_v2("C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/config.json", "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/best_model_72.pth", "old_man_segments/wavs/segment_10.wav") | |
| class xtts_v2_Model(): | |
| """ | |
| A class to handle training of the XTTS v2 model. | |
| Args: | |
| train_csv_path (str): Path to the training CSV file. | |
| Example: "path/to/train.csv" | |
| eval_csv_path (str): Path to the evaluation CSV file. | |
| Example: "path/to/eval.csv" | |
| num_epochs (int): Number of training epochs. | |
| Example: 10 | |
| batch_size (int): Size of each training batch. | |
| Example: 4 | |
| grad_acumm (int): Gradient accumulation steps. | |
| Example: 1 | |
| output_path (str): Path to save the trained model outputs. | |
| Example: "path/to/output/" | |
| max_audio_length (int): Maximum allowed length of audio for training in seconds. | |
| Example: 10 | |
| language (str, optional): Language of the audio files, either 'en' for English or 'es' for Spanish. Default is "en". | |
| Example: "en" | |
| """ | |
| def __init__(self, train_csv_path, eval_csv_path, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language="en"): | |
| self.train_csv_path = train_csv_path | |
| self.eval_csv_path = eval_csv_path | |
| self.num_epochs = num_epochs | |
| self.batch_size = batch_size | |
| self.grad_acumm = grad_acumm | |
| self.output_path = output_path | |
| self.max_audio_length = max_audio_length | |
| self.language = language | |
| self.config_path = None | |
| self.original_xtts_checkpoint = None | |
| self.vocab_file = None | |
| self.exp_path = None | |
| self.speaker_wav = None | |
| def train_model(self): | |
| """ | |
| Train the XTTS v2 model. | |
| Returns: | |
| tuple: A tuple containing a status message, config_path, vocab_file, fine-tuned XTTS checkpoint, and speaker wav file. | |
| Example: ("Model training done!", "path/to/config.json", "path/to/vocab.json", "path/to/best_model.pth", "path/to/speaker.wav") | |
| """ | |
| #clear_gpu_cache() | |
| if not self.train_csv_path or not self.eval_csv_path: | |
| return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" | |
| try: | |
| # convert seconds to waveform frames | |
| max_audio_length = int(max_audio_length * 22050) | |
| self.config_path, self.original_xtts_checkpoint, self.vocab_file, self.exp_path, self.speaker_wav = train_gpt(self.language, self.num_epochs, self.batch_size, self.grad_acumm, self.train_csv_path, self.eval_csv_path, output_path=self.output_path, max_audio_length=max_audio_length) | |
| except: | |
| traceback.print_exc() | |
| error = traceback.format_exc() | |
| return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" | |
| # copy original files to avoid parameters changes issues | |
| os.system(f"cp {self.config_path} {self.exp_path}") | |
| os.system(f"cp {self.vocab_file} {self.exp_path}") | |
| ft_xtts_checkpoint = os.path.join(self.exp_path, "best_model.pth") | |
| print("Model training done!") | |
| #clear_gpu_cache() | |
| return "Model training done!", self.config_path, self.vocab_file, ft_xtts_checkpoint, self.speaker_wav | |
| # example | |
| #train_meta = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/train.csv" | |
| #eval_meta = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/eval.csv" | |
| #num_epochs = 10 | |
| #batch_size = 4 | |
| #grad_acumm = 1 | |
| #out_path = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000" | |
| #max_audio_length = 10 | |
| #lang = "en" | |
| #xtts_v2 = xtts_v2_Model(train_meta, eval_meta, num_epochs, batch_size, grad_acumm, out_path, max_audio_length, lang) | |
| #_, config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = xtts_v2_Model.train_model() |