Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import os | |
| import warnings | |
| from gradio.processing_utils import convert_to_16_bit_wav | |
| from typing import Dict, List, Optional, Union | |
| import utils | |
| from infer import get_net_g, infer | |
| from models import SynthesizerTrn | |
| from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra | |
| from .log import logger | |
| from .constants import ( | |
| DEFAULT_ASSIST_TEXT_WEIGHT, | |
| DEFAULT_LENGTH, | |
| DEFAULT_LINE_SPLIT, | |
| DEFAULT_NOISE, | |
| DEFAULT_NOISEW, | |
| DEFAULT_SDP_RATIO, | |
| DEFAULT_SPLIT_INTERVAL, | |
| DEFAULT_STYLE, | |
| DEFAULT_STYLE_WEIGHT, | |
| ) | |
| class Model: | |
| def __init__( | |
| self, model_path: str, config_path: str, style_vec_path: str, device: str | |
| ): | |
| self.model_path: str = model_path | |
| self.config_path: str = config_path | |
| self.device: str = device | |
| self.style_vec_path: str = style_vec_path | |
| self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path) | |
| self.spk2id: Dict[str, int] = self.hps.data.spk2id | |
| self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()} | |
| self.num_styles: int = self.hps.data.num_styles | |
| if hasattr(self.hps.data, "style2id"): | |
| self.style2id: Dict[str, int] = self.hps.data.style2id | |
| else: | |
| self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)} | |
| if len(self.style2id) != self.num_styles: | |
| raise ValueError( | |
| f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})" | |
| ) | |
| self.style_vectors: np.ndarray = np.load(self.style_vec_path) | |
| if self.style_vectors.shape[0] != self.num_styles: | |
| raise ValueError( | |
| f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})" | |
| ) | |
| self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None | |
| def load_net_g(self): | |
| self.net_g = get_net_g( | |
| model_path=self.model_path, | |
| version=self.hps.version, | |
| device=self.device, | |
| hps=self.hps, | |
| ) | |
| def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray: | |
| mean = self.style_vectors[0] | |
| style_vec = self.style_vectors[style_id] | |
| style_vec = mean + (style_vec - mean) * weight | |
| return style_vec | |
| def get_style_vector_from_audio( | |
| self, audio_path: str, weight: float = 1.0 | |
| ) -> np.ndarray: | |
| from style_gen import get_style_vector | |
| xvec = get_style_vector(audio_path) | |
| mean = self.style_vectors[0] | |
| xvec = mean + (xvec - mean) * weight | |
| return xvec | |
| def infer( | |
| self, | |
| text: str, | |
| language: str = "JP", | |
| sid: int = 0, | |
| reference_audio_path: Optional[str] = None, | |
| sdp_ratio: float = DEFAULT_SDP_RATIO, | |
| noise: float = DEFAULT_NOISE, | |
| noisew: float = DEFAULT_NOISEW, | |
| length: float = DEFAULT_LENGTH, | |
| line_split: bool = DEFAULT_LINE_SPLIT, | |
| split_interval: float = DEFAULT_SPLIT_INTERVAL, | |
| assist_text: Optional[str] = None, | |
| assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT, | |
| use_assist_text: bool = False, | |
| style: str = DEFAULT_STYLE, | |
| style_weight: float = DEFAULT_STYLE_WEIGHT, | |
| given_tone: Optional[list[int]] = None, | |
| ) -> tuple[int, np.ndarray]: | |
| logger.info(f"Start generating audio data from text:\n{text}") | |
| if language != "JP" and self.hps.version.endswith("JP-Extra"): | |
| raise ValueError( | |
| "The model is trained with JP-Extra, but the language is not JP" | |
| ) | |
| if reference_audio_path == "": | |
| reference_audio_path = None | |
| if assist_text == "" or not use_assist_text: | |
| assist_text = None | |
| if self.net_g is None: | |
| self.load_net_g() | |
| if reference_audio_path is None: | |
| style_id = self.style2id[style] | |
| style_vector = self.get_style_vector(style_id, style_weight) | |
| else: | |
| style_vector = self.get_style_vector_from_audio( | |
| reference_audio_path, style_weight | |
| ) | |
| if not line_split: | |
| with torch.no_grad(): | |
| audio = infer( | |
| text=text, | |
| sdp_ratio=sdp_ratio, | |
| noise_scale=noise, | |
| noise_scale_w=noisew, | |
| length_scale=length, | |
| sid=sid, | |
| language=language, | |
| hps=self.hps, | |
| net_g=self.net_g, | |
| device=self.device, | |
| assist_text=assist_text, | |
| assist_text_weight=assist_text_weight, | |
| style_vec=style_vector, | |
| given_tone=given_tone, | |
| ) | |
| else: | |
| texts = text.split("\n") | |
| texts = [t for t in texts if t != ""] | |
| audios = [] | |
| with torch.no_grad(): | |
| for i, t in enumerate(texts): | |
| audios.append( | |
| infer( | |
| text=t, | |
| sdp_ratio=sdp_ratio, | |
| noise_scale=noise, | |
| noise_scale_w=noisew, | |
| length_scale=length, | |
| sid=sid, | |
| language=language, | |
| hps=self.hps, | |
| net_g=self.net_g, | |
| device=self.device, | |
| assist_text=assist_text, | |
| assist_text_weight=assist_text_weight, | |
| style_vec=style_vector, | |
| ) | |
| ) | |
| if i != len(texts) - 1: | |
| audios.append(np.zeros(int(44100 * split_interval))) | |
| audio = np.concatenate(audios) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| audio = convert_to_16_bit_wav(audio) | |
| logger.info("Audio data generated successfully") | |
| return (self.hps.data.sampling_rate, audio) | |
| class ModelHolder: | |
| def __init__(self, root_dir: str, device: str): | |
| self.root_dir: str = root_dir | |
| self.device: str = device | |
| self.model_files_dict: Dict[str, List[str]] = {} | |
| self.current_model: Optional[Model] = None | |
| self.model_names: List[str] = [] | |
| self.models: List[Model] = [] | |
| self.refresh() | |
| def refresh(self): | |
| self.model_files_dict = {} | |
| self.model_names = [] | |
| self.current_model = None | |
| model_dirs = [ | |
| d | |
| for d in os.listdir(self.root_dir) | |
| if os.path.isdir(os.path.join(self.root_dir, d)) | |
| ] | |
| for model_name in model_dirs: | |
| model_dir = os.path.join(self.root_dir, model_name) | |
| model_files = [ | |
| os.path.join(model_dir, f) | |
| for f in os.listdir(model_dir) | |
| if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors") | |
| ] | |
| if len(model_files) == 0: | |
| logger.warning( | |
| f"No model files found in {self.root_dir}/{model_name}, so skip it" | |
| ) | |
| continue | |
| self.model_files_dict[model_name] = model_files | |
| self.model_names.append(model_name) | |
| def load_model_gr( | |
| self, model_name: str, model_path: str | |
| ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: | |
| if model_name not in self.model_files_dict: | |
| raise ValueError(f"Model `{model_name}` is not found") | |
| if model_path not in self.model_files_dict[model_name]: | |
| raise ValueError(f"Model file `{model_path}` is not found") | |
| if ( | |
| self.current_model is not None | |
| and self.current_model.model_path == model_path | |
| ): | |
| # Already loaded | |
| speakers = list(self.current_model.spk2id.keys()) | |
| styles = list(self.current_model.style2id.keys()) | |
| return ( | |
| gr.Dropdown(choices=styles, value=styles[0]), | |
| gr.Button(interactive=True, value="音声合成"), | |
| gr.Dropdown(choices=speakers, value=speakers[0]), | |
| ) | |
| self.current_model = Model( | |
| model_path=model_path, | |
| config_path=os.path.join(self.root_dir, model_name, "config.json"), | |
| style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"), | |
| device=self.device, | |
| ) | |
| speakers = list(self.current_model.spk2id.keys()) | |
| styles = list(self.current_model.style2id.keys()) | |
| return ( | |
| gr.Dropdown(choices=styles, value=styles[0]), | |
| gr.Button(interactive=True, value="音声合成"), | |
| gr.Dropdown(choices=speakers, value=speakers[0]), | |
| ) | |
| def update_model_files_gr(self, model_name: str) -> gr.Dropdown: | |
| model_files = self.model_files_dict[model_name] | |
| return gr.Dropdown(choices=model_files, value=model_files[0]) | |
| def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: | |
| self.refresh() | |
| initial_model_name = self.model_names[0] | |
| initial_model_files = self.model_files_dict[initial_model_name] | |
| return ( | |
| gr.Dropdown(choices=self.model_names, value=initial_model_name), | |
| gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]), | |
| gr.Button(interactive=False), # For tts_button | |
| ) | |