Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		zhzluke96
		
	commited on
		
		
					Commit 
							
							·
						
						da8d589
	
1
								Parent(s):
							
							c4c6bff
								
update
Browse files- modules/Denoiser/AudioDenoiser.py +140 -0
- modules/Denoiser/AudioNosiseModel.py +66 -0
- modules/Denoiser/__init__.py +0 -0
- modules/Enhancer/ResembleEnhance.py +116 -0
- modules/Enhancer/__init__.py +0 -0
- modules/SynthesizeSegments.py +147 -185
- modules/api/impl/google_api.py +0 -1
- modules/api/impl/speaker_api.py +7 -3
- modules/api/impl/ssml_api.py +11 -24
- modules/api/utils.py +0 -2
- modules/denoise.py +46 -2
- modules/generate_audio.py +1 -1
- modules/models.py +1 -9
- modules/speaker.py +30 -17
- modules/ssml_parser/SSMLParser.py +178 -0
- modules/ssml_parser/__init__.py +0 -0
- modules/ssml_parser/test_ssml_parser.py +104 -0
- modules/utils/JsonObject.py +19 -0
- modules/utils/constants.py +1 -1
- modules/webui/app.py +11 -9
- modules/webui/speaker_tab.py +250 -4
- modules/webui/spliter_tab.py +2 -1
- modules/webui/system_tab.py +15 -0
- modules/webui/tts_tab.py +98 -82
- modules/webui/webui_config.py +4 -0
- modules/webui/webui_utils.py +72 -31
- webui.py +3 -1
    	
        modules/Denoiser/AudioDenoiser.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            from typing import Union
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torchaudio
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
            from audio_denoiser.helpers.torch_helper import batched_apply
         | 
| 8 | 
            +
            from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
         | 
| 9 | 
            +
            from audio_denoiser.helpers.audio_helper import (
         | 
| 10 | 
            +
                create_spectrogram,
         | 
| 11 | 
            +
                reconstruct_from_spectrogram,
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            _expected_t_std = 0.23
         | 
| 15 | 
            +
            _recommended_backend = "soundfile"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            # ref: https://github.com/jose-solorzano/audio-denoiser
         | 
| 19 | 
            +
            class AudioDenoiser:
         | 
| 20 | 
            +
                def __init__(
         | 
| 21 | 
            +
                    self,
         | 
| 22 | 
            +
                    local_dir: str,
         | 
| 23 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 24 | 
            +
                    num_iterations: int = 100,
         | 
| 25 | 
            +
                ):
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    if device is None:
         | 
| 28 | 
            +
                        is_cuda = torch.cuda.is_available()
         | 
| 29 | 
            +
                        if not is_cuda:
         | 
| 30 | 
            +
                            logging.warning("CUDA not available. Will use CPU.")
         | 
| 31 | 
            +
                        device = torch.device("cuda:0") if is_cuda else torch.device("cpu")
         | 
| 32 | 
            +
                    self.device = device
         | 
| 33 | 
            +
                    self.model = load_audio_denosier_model(dir_path=local_dir, device=device)
         | 
| 34 | 
            +
                    self.model.eval()
         | 
| 35 | 
            +
                    self.model_sample_rate = self.model.sample_rate
         | 
| 36 | 
            +
                    self.scaler = self.model.scaler
         | 
| 37 | 
            +
                    self.n_fft = self.model.n_fft
         | 
| 38 | 
            +
                    self.segment_num_frames = self.model.num_frames
         | 
| 39 | 
            +
                    self.num_iterations = num_iterations
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @staticmethod
         | 
| 42 | 
            +
                def _sp_log(spectrogram: torch.Tensor, eps=0.01):
         | 
| 43 | 
            +
                    return torch.log(spectrogram + eps)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                @staticmethod
         | 
| 46 | 
            +
                def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01):
         | 
| 47 | 
            +
                    return torch.clamp(torch.exp(log_spectrogram) - eps, min=0)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                @staticmethod
         | 
| 50 | 
            +
                def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float:
         | 
| 51 | 
            +
                    # Expected for training data is ~0.23
         | 
| 52 | 
            +
                    abs_waveform = torch.abs(waveform)
         | 
| 53 | 
            +
                    quantile_value = torch.quantile(abs_waveform, q).item()
         | 
| 54 | 
            +
                    trimmed_values = waveform[abs_waveform >= quantile_value]
         | 
| 55 | 
            +
                    return torch.std(trimmed_values).item()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def process_waveform(
         | 
| 58 | 
            +
                    self,
         | 
| 59 | 
            +
                    waveform: torch.Tensor,
         | 
| 60 | 
            +
                    sample_rate: int,
         | 
| 61 | 
            +
                    return_cpu_tensor: bool = False,
         | 
| 62 | 
            +
                    auto_scale: bool = False,
         | 
| 63 | 
            +
                ) -> torch.Tensor:
         | 
| 64 | 
            +
                    """
         | 
| 65 | 
            +
                    Denoises a waveform.
         | 
| 66 | 
            +
                    @param waveform: A waveform tensor. Use torchaudio structure.
         | 
| 67 | 
            +
                    @param sample_rate: The sample rate of the waveform in Hz.
         | 
| 68 | 
            +
                    @param return_cpu_tensor: Whether the returned tensor must be a CPU tensor.
         | 
| 69 | 
            +
                    @param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio.
         | 
| 70 | 
            +
                    @return: A denoised waveform.
         | 
| 71 | 
            +
                    """
         | 
| 72 | 
            +
                    waveform = waveform.cpu()
         | 
| 73 | 
            +
                    if auto_scale:
         | 
| 74 | 
            +
                        w_t_std = self._trimmed_dev(waveform)
         | 
| 75 | 
            +
                        waveform = waveform * _expected_t_std / w_t_std
         | 
| 76 | 
            +
                    if sample_rate != self.model_sample_rate:
         | 
| 77 | 
            +
                        transform = torchaudio.transforms.Resample(
         | 
| 78 | 
            +
                            orig_freq=sample_rate, new_freq=self.model_sample_rate
         | 
| 79 | 
            +
                        )
         | 
| 80 | 
            +
                        waveform = transform(waveform)
         | 
| 81 | 
            +
                    hop_len = self.n_fft // 2
         | 
| 82 | 
            +
                    spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len)
         | 
| 83 | 
            +
                    spectrogram = spectrogram.to(self.device)
         | 
| 84 | 
            +
                    num_a_channels = spectrogram.size(0)
         | 
| 85 | 
            +
                    with torch.no_grad():
         | 
| 86 | 
            +
                        results = []
         | 
| 87 | 
            +
                        for c in range(num_a_channels):
         | 
| 88 | 
            +
                            c_spectrogram = spectrogram[c]
         | 
| 89 | 
            +
                            # c_spectrogram: (257, num_frames)
         | 
| 90 | 
            +
                            fft_size, num_frames = c_spectrogram.shape
         | 
| 91 | 
            +
                            num_segments = math.ceil(num_frames / self.segment_num_frames)
         | 
| 92 | 
            +
                            adj_num_frames = num_segments * self.segment_num_frames
         | 
| 93 | 
            +
                            if adj_num_frames > num_frames:
         | 
| 94 | 
            +
                                c_spectrogram = nn.functional.pad(
         | 
| 95 | 
            +
                                    c_spectrogram, (0, adj_num_frames - num_frames)
         | 
| 96 | 
            +
                                )
         | 
| 97 | 
            +
                            c_spectrogram = c_spectrogram.view(
         | 
| 98 | 
            +
                                fft_size, num_segments, self.segment_num_frames
         | 
| 99 | 
            +
                            )
         | 
| 100 | 
            +
                            # c_spectrogram: (257, num_segments, 32)
         | 
| 101 | 
            +
                            c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2))
         | 
| 102 | 
            +
                            # c_spectrogram: (num_segments, 257, 32)
         | 
| 103 | 
            +
                            log_c_spectrogram = self._sp_log(c_spectrogram)
         | 
| 104 | 
            +
                            scaled_log_c_sp = self.scaler(log_c_spectrogram)
         | 
| 105 | 
            +
                            pred_noise_log_sp = batched_apply(
         | 
| 106 | 
            +
                                self.model, scaled_log_c_sp, detached=True
         | 
| 107 | 
            +
                            )
         | 
| 108 | 
            +
                            log_denoised_sp = log_c_spectrogram - pred_noise_log_sp
         | 
| 109 | 
            +
                            denoised_sp = self._sp_exp(log_denoised_sp)
         | 
| 110 | 
            +
                            # denoised_sp: (num_segments, 257, 32)
         | 
| 111 | 
            +
                            denoised_sp = torch.permute(denoised_sp, (1, 0, 2))
         | 
| 112 | 
            +
                            # denoised_sp: (257, num_segments, 32)
         | 
| 113 | 
            +
                            denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames)
         | 
| 114 | 
            +
                            # denoised_sp: (1, 257, adj_num_frames)
         | 
| 115 | 
            +
                            denoised_sp = denoised_sp[:, :, :num_frames]
         | 
| 116 | 
            +
                            denoised_sp = denoised_sp.cpu()
         | 
| 117 | 
            +
                            denoised_waveform = reconstruct_from_spectrogram(
         | 
| 118 | 
            +
                                denoised_sp, num_iterations=self.num_iterations
         | 
| 119 | 
            +
                            )
         | 
| 120 | 
            +
                            # denoised_waveform: (1, num_samples)
         | 
| 121 | 
            +
                            results.append(denoised_waveform)
         | 
| 122 | 
            +
                        cpu_results = torch.cat(results)
         | 
| 123 | 
            +
                        return cpu_results if return_cpu_tensor else cpu_results.to(self.device)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def process_audio_file(
         | 
| 126 | 
            +
                    self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False
         | 
| 127 | 
            +
                ):
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    Denoises an audio file.
         | 
| 130 | 
            +
                    @param in_audio_file: An input audio file with a format supported by torchaudio.
         | 
| 131 | 
            +
                    @param out_audio_file: Am output audio file with a format supported by torchaudio.
         | 
| 132 | 
            +
                    @param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio.
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    waveform, sample_rate = torchaudio.load(in_audio_file)
         | 
| 135 | 
            +
                    denoised_waveform = self.process_waveform(
         | 
| 136 | 
            +
                        waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
                    torchaudio.save(
         | 
| 139 | 
            +
                        out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate
         | 
| 140 | 
            +
                    )
         | 
    	
        modules/Denoiser/AudioNosiseModel.py
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from audio_denoiser.modules.Permute import Permute
         | 
| 5 | 
            +
            from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
         | 
| 6 | 
            +
            from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class AudioNoiseModel(nn.Module):
         | 
| 12 | 
            +
                def __init__(self, config: dict):
         | 
| 13 | 
            +
                    super(AudioNoiseModel, self).__init__()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    # Encoder layers
         | 
| 16 | 
            +
                    self.config = config
         | 
| 17 | 
            +
                    scaler_dict = config["scaler"]
         | 
| 18 | 
            +
                    self.scaler = SpectrogramScaler.from_dict(scaler_dict)
         | 
| 19 | 
            +
                    self.in_channels = config.get("in_channels", 257)
         | 
| 20 | 
            +
                    self.roberta_hidden_size = config.get("roberta_hidden_size", 768)
         | 
| 21 | 
            +
                    self.model1 = nn.Sequential(
         | 
| 22 | 
            +
                        nn.Conv1d(self.in_channels, 1024, kernel_size=1),
         | 
| 23 | 
            +
                        nn.ELU(),
         | 
| 24 | 
            +
                        nn.Conv1d(1024, 1024, kernel_size=1),
         | 
| 25 | 
            +
                        nn.ELU(),
         | 
| 26 | 
            +
                        nn.Conv1d(1024, self.in_channels, kernel_size=1),
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
                    self.model2 = nn.Sequential(
         | 
| 29 | 
            +
                        Permute(0, 2, 1),
         | 
| 30 | 
            +
                        nn.Linear(self.in_channels, self.roberta_hidden_size),
         | 
| 31 | 
            +
                        SimpleRoberta(num_hidden_layers=5, hidden_size=self.roberta_hidden_size),
         | 
| 32 | 
            +
                        nn.Linear(self.roberta_hidden_size, self.in_channels),
         | 
| 33 | 
            +
                        Permute(0, 2, 1),
         | 
| 34 | 
            +
                    )
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                @property
         | 
| 37 | 
            +
                def sample_rate(self) -> int:
         | 
| 38 | 
            +
                    return self.config.get("sample_rate", 16000)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @property
         | 
| 41 | 
            +
                def n_fft(self) -> int:
         | 
| 42 | 
            +
                    return self.config.get("n_fft", 512)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                @property
         | 
| 45 | 
            +
                def num_frames(self) -> int:
         | 
| 46 | 
            +
                    return self.config.get("num_frames", 32)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, x, use_scaler: bool = False, out_scale: float = 1.0):
         | 
| 49 | 
            +
                    if use_scaler:
         | 
| 50 | 
            +
                        x = self.scaler(x)
         | 
| 51 | 
            +
                    x1 = self.model1(x)
         | 
| 52 | 
            +
                    x2 = self.model2(x)
         | 
| 53 | 
            +
                    x = x1 + x2
         | 
| 54 | 
            +
                    return x * out_scale
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def load_audio_denosier_model(dir_path: str, device) -> AudioNoiseModel:
         | 
| 58 | 
            +
                config = json.load(open(f"{dir_path}/config.json", "r"))
         | 
| 59 | 
            +
                model = AudioNoiseModel(config)
         | 
| 60 | 
            +
                model.load_state_dict(torch.load(f"{dir_path}/pytorch_model.bin"))
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                model.to(device)
         | 
| 63 | 
            +
                model.model1.to(device)
         | 
| 64 | 
            +
                model.model2.to(device)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return model
         | 
    	
        modules/Denoiser/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        modules/Enhancer/ResembleEnhance.py
    ADDED
    
    | @@ -0,0 +1,116 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from typing import List
         | 
| 3 | 
            +
            from resemble_enhance.enhancer.enhancer import Enhancer
         | 
| 4 | 
            +
            from resemble_enhance.enhancer.hparams import HParams
         | 
| 5 | 
            +
            from resemble_enhance.inference import inference
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from modules.utils.constants import MODELS_DIR
         | 
| 10 | 
            +
            from pathlib import Path
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from threading import Lock
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            resemble_enhance = None
         | 
| 15 | 
            +
            lock = Lock()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def load_enhancer(device: torch.device):
         | 
| 19 | 
            +
                global resemble_enhance
         | 
| 20 | 
            +
                with lock:
         | 
| 21 | 
            +
                    if resemble_enhance is None:
         | 
| 22 | 
            +
                        resemble_enhance = ResembleEnhance(device)
         | 
| 23 | 
            +
                        resemble_enhance.load_model()
         | 
| 24 | 
            +
                return resemble_enhance
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class ResembleEnhance:
         | 
| 28 | 
            +
                hparams: HParams
         | 
| 29 | 
            +
                enhancer: Enhancer
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(self, device: torch.device):
         | 
| 32 | 
            +
                    self.device = device
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.enhancer = None
         | 
| 35 | 
            +
                    self.hparams = None
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def load_model(self):
         | 
| 38 | 
            +
                    hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
         | 
| 39 | 
            +
                    enhancer = Enhancer(hparams)
         | 
| 40 | 
            +
                    state_dict = torch.load(
         | 
| 41 | 
            +
                        Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
         | 
| 42 | 
            +
                        map_location="cpu",
         | 
| 43 | 
            +
                    )["module"]
         | 
| 44 | 
            +
                    enhancer.load_state_dict(state_dict)
         | 
| 45 | 
            +
                    enhancer.eval()
         | 
| 46 | 
            +
                    enhancer.to(self.device)
         | 
| 47 | 
            +
                    enhancer.denoiser.to(self.device)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.hparams = hparams
         | 
| 50 | 
            +
                    self.enhancer = enhancer
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                @torch.inference_mode()
         | 
| 53 | 
            +
                def denoise(self, dwav, sr, device) -> tuple[torch.Tensor, int]:
         | 
| 54 | 
            +
                    assert self.enhancer is not None, "Model not loaded"
         | 
| 55 | 
            +
                    assert self.enhancer.denoiser is not None, "Denoiser not loaded"
         | 
| 56 | 
            +
                    enhancer = self.enhancer
         | 
| 57 | 
            +
                    return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                @torch.inference_mode()
         | 
| 60 | 
            +
                def enhance(
         | 
| 61 | 
            +
                    self,
         | 
| 62 | 
            +
                    dwav,
         | 
| 63 | 
            +
                    sr,
         | 
| 64 | 
            +
                    device,
         | 
| 65 | 
            +
                    nfe=32,
         | 
| 66 | 
            +
                    solver="midpoint",
         | 
| 67 | 
            +
                    lambd=0.5,
         | 
| 68 | 
            +
                    tau=0.5,
         | 
| 69 | 
            +
                ) -> tuple[torch.Tensor, int]:
         | 
| 70 | 
            +
                    assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
         | 
| 71 | 
            +
                    assert solver in (
         | 
| 72 | 
            +
                        "midpoint",
         | 
| 73 | 
            +
                        "rk4",
         | 
| 74 | 
            +
                        "euler",
         | 
| 75 | 
            +
                    ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
         | 
| 76 | 
            +
                    assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
         | 
| 77 | 
            +
                    assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
         | 
| 78 | 
            +
                    assert self.enhancer is not None, "Model not loaded"
         | 
| 79 | 
            +
                    enhancer = self.enhancer
         | 
| 80 | 
            +
                    enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
         | 
| 81 | 
            +
                    return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            if __name__ == "__main__":
         | 
| 85 | 
            +
                import torchaudio
         | 
| 86 | 
            +
                from modules.models import load_chat_tts
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                load_chat_tts()
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                device = torch.device("cuda")
         | 
| 91 | 
            +
                ench = ResembleEnhance(device)
         | 
| 92 | 
            +
                ench.load_model()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                wav, sr = torchaudio.load("test.wav")
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                print(wav.shape, type(wav), sr, type(sr))
         | 
| 97 | 
            +
                exit()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                wav = wav.squeeze(0).cuda()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                print(wav.device)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                denoised, d_sr = ench.denoise(wav.cpu(), sr, device)
         | 
| 104 | 
            +
                denoised = denoised.unsqueeze(0)
         | 
| 105 | 
            +
                print(denoised.shape)
         | 
| 106 | 
            +
                torchaudio.save("denoised.wav", denoised, d_sr)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                for solver in ("midpoint", "rk4", "euler"):
         | 
| 109 | 
            +
                    for lambd in (0.1, 0.5, 0.9):
         | 
| 110 | 
            +
                        for tau in (0.1, 0.5, 0.9):
         | 
| 111 | 
            +
                            enhanced, e_sr = ench.enhance(
         | 
| 112 | 
            +
                                wav.cpu(), sr, device, solver=solver, lambd=lambd, tau=tau, nfe=128
         | 
| 113 | 
            +
                            )
         | 
| 114 | 
            +
                            enhanced = enhanced.unsqueeze(0)
         | 
| 115 | 
            +
                            print(enhanced.shape)
         | 
| 116 | 
            +
                            torchaudio.save(f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced, e_sr)
         | 
    	
        modules/Enhancer/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        modules/SynthesizeSegments.py
    CHANGED
    
    | @@ -1,17 +1,18 @@ | |
|  | |
| 1 | 
             
            from pydub import AudioSegment
         | 
| 2 | 
            -
            from typing import  | 
| 3 | 
             
            from scipy.io.wavfile import write
         | 
| 4 | 
             
            import io
         | 
|  | |
|  | |
| 5 | 
             
            from modules.utils import rng
         | 
| 6 | 
             
            from modules.utils.audio import time_stretch, pitch_shift
         | 
| 7 | 
             
            from modules import generate_audio
         | 
| 8 | 
             
            from modules.normalization import text_normalize
         | 
| 9 | 
             
            import logging
         | 
| 10 | 
             
            import json
         | 
| 11 | 
            -
            import copy
         | 
| 12 | 
            -
            import numpy as np
         | 
| 13 |  | 
| 14 | 
            -
            from modules.speaker import Speaker
         | 
| 15 |  | 
| 16 | 
             
            logger = logging.getLogger(__name__)
         | 
| 17 |  | 
| @@ -24,7 +25,7 @@ def audio_data_to_segment(audio_data, sr): | |
| 24 | 
             
                return AudioSegment.from_file(byte_io, format="wav")
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
            -
            def combine_audio_segments(audio_segments: list) -> AudioSegment:
         | 
| 28 | 
             
                combined_audio = AudioSegment.empty()
         | 
| 29 | 
             
                for segment in audio_segments:
         | 
| 30 | 
             
                    combined_audio += segment
         | 
| @@ -54,230 +55,191 @@ def to_number(value, t, default=0): | |
| 54 | 
             
                    return default
         | 
| 55 |  | 
| 56 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 57 | 
             
            class SynthesizeSegments:
         | 
| 58 | 
             
                def __init__(self, batch_size: int = 8):
         | 
| 59 | 
             
                    self.batch_size = batch_size
         | 
| 60 | 
             
                    self.batch_default_spk_seed = rng.np_rng()
         | 
| 61 | 
             
                    self.batch_default_infer_seed = rng.np_rng()
         | 
| 62 |  | 
| 63 | 
            -
                def segment_to_generate_params( | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
                    if segment.get("params", None) is not None:
         | 
| 65 | 
            -
                        return segment | 
| 66 |  | 
| 67 | 
             
                    text = segment.get("text", "")
         | 
| 68 | 
             
                    is_end = segment.get("is_end", False)
         | 
| 69 |  | 
| 70 | 
             
                    text = str(text).strip()
         | 
| 71 |  | 
| 72 | 
            -
                    attrs = segment. | 
| 73 | 
            -
                    spk = attrs. | 
| 74 | 
            -
                     | 
| 75 | 
            -
             | 
| 76 | 
            -
                     | 
| 77 | 
            -
             | 
| 78 | 
            -
                     | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
                     | 
| 82 | 
            -
                     | 
| 83 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
                    disable_normalize = attrs.get("normalize", "") == "False"
         | 
| 85 |  | 
| 86 | 
            -
                     | 
| 87 | 
            -
                        " | 
| 88 | 
            -
                         | 
| 89 | 
            -
                         | 
| 90 | 
            -
                         | 
| 91 | 
            -
                         | 
| 92 | 
            -
                         | 
| 93 | 
            -
                         | 
| 94 | 
            -
                         | 
| 95 | 
            -
                         | 
| 96 | 
            -
             | 
|  | |
| 97 |  | 
| 98 | 
             
                    if not disable_normalize:
         | 
| 99 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 100 |  | 
| 101 | 
            -
             | 
| 102 | 
            -
                     | 
| 103 | 
            -
             | 
| 104 | 
            -
                     | 
| 105 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 106 |  | 
| 107 | 
            -
             | 
|  | |
|  | |
|  | |
| 108 |  | 
| 109 | 
             
                def bucket_segments(
         | 
| 110 | 
            -
                    self, segments: List[ | 
| 111 | 
            -
                ) -> List[List[ | 
| 112 | 
            -
                     | 
| 113 | 
            -
                    buckets = {}
         | 
| 114 | 
             
                    for segment in segments:
         | 
|  | |
|  | |
|  | |
|  | |
| 115 | 
             
                        params = self.segment_to_generate_params(segment)
         | 
| 116 |  | 
| 117 | 
            -
                         | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
             
                        key = json.dumps(
         | 
| 121 | 
            -
                            {k: v for k, v in  | 
| 122 | 
             
                        )
         | 
| 123 | 
             
                        if key not in buckets:
         | 
| 124 | 
             
                            buckets[key] = []
         | 
| 125 | 
             
                        buckets[key].append(segment)
         | 
| 126 |  | 
| 127 | 
            -
                     | 
| 128 | 
            -
                    bucket_list = list(buckets.values())
         | 
| 129 | 
            -
                    return bucket_list
         | 
| 130 |  | 
| 131 | 
            -
                def synthesize_segments( | 
| 132 | 
            -
                     | 
| 133 | 
            -
             | 
| 134 | 
            -
                     | 
| 135 | 
             
                    buckets = self.bucket_segments(segments)
         | 
| 136 | 
            -
                    logger.debug(f"segments len: {len(segments)}")
         | 
| 137 | 
            -
                    logger.debug(f"bucket pool size: {len(buckets)}")
         | 
| 138 | 
            -
                    for bucket in buckets:
         | 
| 139 | 
            -
                        for i in range(0, len(bucket), self.batch_size):
         | 
| 140 | 
            -
                            batch = bucket[i : i + self.batch_size]
         | 
| 141 | 
            -
                            param_arr = [
         | 
| 142 | 
            -
                                self.segment_to_generate_params(segment) for segment in batch
         | 
| 143 | 
            -
                            ]
         | 
| 144 | 
            -
                            texts = [params["text"] for params in param_arr]
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                            params = param_arr[0]  # Use the first segment to get the parameters
         | 
| 147 | 
            -
                            audio_datas = generate_audio.generate_audio_batch(
         | 
| 148 | 
            -
                                texts=texts,
         | 
| 149 | 
            -
                                temperature=params["temperature"],
         | 
| 150 | 
            -
                                top_P=params["top_P"],
         | 
| 151 | 
            -
                                top_K=params["top_K"],
         | 
| 152 | 
            -
                                spk=params["spk"],
         | 
| 153 | 
            -
                                infer_seed=params["infer_seed"],
         | 
| 154 | 
            -
                                prompt1=params["prompt1"],
         | 
| 155 | 
            -
                                prompt2=params["prompt2"],
         | 
| 156 | 
            -
                                prefix=params["prefix"],
         | 
| 157 | 
            -
                            )
         | 
| 158 | 
            -
                            for idx, segment in enumerate(batch):
         | 
| 159 | 
            -
                                (sr, audio_data) = audio_datas[idx]
         | 
| 160 | 
            -
                                rate = float(segment.get("rate", "1.0"))
         | 
| 161 | 
            -
                                volume = float(segment.get("volume", "0"))
         | 
| 162 | 
            -
                                pitch = float(segment.get("pitch", "0"))
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                                audio_segment = audio_data_to_segment(audio_data, sr)
         | 
| 165 | 
            -
                                audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
         | 
| 166 | 
            -
                                original_index = segments.index(
         | 
| 167 | 
            -
                                    segment
         | 
| 168 | 
            -
                                )  # Get the original index of the segment
         | 
| 169 | 
            -
                                audio_segments[original_index] = (
         | 
| 170 | 
            -
                                    audio_segment  # Place the audio_segment in the correct position
         | 
| 171 | 
            -
                                )
         | 
| 172 | 
            -
             | 
| 173 | 
            -
                    return audio_segments
         | 
| 174 |  | 
|  | |
|  | |
| 175 |  | 
| 176 | 
            -
             | 
| 177 | 
            -
                text: str,
         | 
| 178 | 
            -
                spk: int = -1,
         | 
| 179 | 
            -
                seed: int = -1,
         | 
| 180 | 
            -
                top_p: float = 0.5,
         | 
| 181 | 
            -
                top_k: int = 20,
         | 
| 182 | 
            -
                temp: float = 0.3,
         | 
| 183 | 
            -
                prompt1: str = "",
         | 
| 184 | 
            -
                prompt2: str = "",
         | 
| 185 | 
            -
                prefix: str = "",
         | 
| 186 | 
            -
                enable_normalize=True,
         | 
| 187 | 
            -
                is_end: bool = False,
         | 
| 188 | 
            -
            ) -> AudioSegment:
         | 
| 189 | 
            -
                if enable_normalize:
         | 
| 190 | 
            -
                    text = text_normalize(text, is_end=is_end)
         | 
| 191 | 
            -
             | 
| 192 | 
            -
                logger.debug(f"generate segment: {text}")
         | 
| 193 | 
            -
             | 
| 194 | 
            -
                sample_rate, audio_data = generate_audio.generate_audio(
         | 
| 195 | 
            -
                    text=text,
         | 
| 196 | 
            -
                    temperature=temp if temp is not None else 0.3,
         | 
| 197 | 
            -
                    top_P=top_p if top_p is not None else 0.5,
         | 
| 198 | 
            -
                    top_K=top_k if top_k is not None else 20,
         | 
| 199 | 
            -
                    spk=spk if spk else -1,
         | 
| 200 | 
            -
                    infer_seed=seed if seed else -1,
         | 
| 201 | 
            -
                    prompt1=prompt1 if prompt1 else "",
         | 
| 202 | 
            -
                    prompt2=prompt2 if prompt2 else "",
         | 
| 203 | 
            -
                    prefix=prefix if prefix else "",
         | 
| 204 | 
            -
                )
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                byte_io = io.BytesIO()
         | 
| 207 | 
            -
                write(byte_io, sample_rate, audio_data)
         | 
| 208 | 
            -
                byte_io.seek(0)
         | 
| 209 |  | 
| 210 | 
            -
             | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
| 213 | 
            -
            def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
         | 
| 214 | 
            -
                if "break" in segment:
         | 
| 215 | 
            -
                    pause_segment = AudioSegment.silent(duration=segment["break"])
         | 
| 216 | 
            -
                    return pause_segment
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                attrs = segment.get("attrs", {})
         | 
| 219 | 
            -
                text = segment.get("text", "")
         | 
| 220 | 
            -
                is_end = segment.get("is_end", False)
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                text = str(text).strip()
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                if text == "":
         | 
| 225 | 
            -
                    return None
         | 
| 226 | 
            -
             | 
| 227 | 
            -
                spk = attrs.get("spk", "")
         | 
| 228 | 
            -
                if isinstance(spk, str):
         | 
| 229 | 
            -
                    spk = int(spk)
         | 
| 230 | 
            -
                seed = to_number(attrs.get("seed", ""), int, -1)
         | 
| 231 | 
            -
                top_k = to_number(attrs.get("top_k", ""), int, None)
         | 
| 232 | 
            -
                top_p = to_number(attrs.get("top_p", ""), float, None)
         | 
| 233 | 
            -
                temp = to_number(attrs.get("temp", ""), float, None)
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                prompt1 = attrs.get("prompt1", "")
         | 
| 236 | 
            -
                prompt2 = attrs.get("prompt2", "")
         | 
| 237 | 
            -
                prefix = attrs.get("prefix", "")
         | 
| 238 | 
            -
                disable_normalize = attrs.get("normalize", "") == "False"
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                audio_segment = generate_audio_segment(
         | 
| 241 | 
            -
                    text,
         | 
| 242 | 
            -
                    enable_normalize=not disable_normalize,
         | 
| 243 | 
            -
                    spk=spk,
         | 
| 244 | 
            -
                    seed=seed,
         | 
| 245 | 
            -
                    top_k=top_k,
         | 
| 246 | 
            -
                    top_p=top_p,
         | 
| 247 | 
            -
                    temp=temp,
         | 
| 248 | 
            -
                    prompt1=prompt1,
         | 
| 249 | 
            -
                    prompt2=prompt2,
         | 
| 250 | 
            -
                    prefix=prefix,
         | 
| 251 | 
            -
                    is_end=is_end,
         | 
| 252 | 
            -
                )
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                rate = float(attrs.get("rate", "1.0"))
         | 
| 255 | 
            -
                volume = float(attrs.get("volume", "0"))
         | 
| 256 | 
            -
                pitch = float(attrs.get("pitch", "0"))
         | 
| 257 | 
            -
             | 
| 258 | 
            -
                audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
         | 
| 259 |  | 
| 260 | 
            -
             | 
| 261 |  | 
| 262 |  | 
| 263 | 
             
            # 示例使用
         | 
| 264 | 
             
            if __name__ == "__main__":
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 265 | 
             
                ssml_segments = [
         | 
| 266 | 
            -
                     | 
| 267 | 
            -
             | 
| 268 | 
            -
             | 
| 269 | 
            -
                     | 
| 270 | 
            -
                    {
         | 
| 271 | 
            -
                        "text": "大🍉,一个大🍉,嘿,你的感觉真的很奇妙  [lbreak]",
         | 
| 272 | 
            -
                        "attrs": {"spk": 2, "temp": 0.1, "seed": 42},
         | 
| 273 | 
            -
                    },
         | 
| 274 | 
            -
                    {
         | 
| 275 | 
            -
                        "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙  [lbreak]",
         | 
| 276 | 
            -
                        "attrs": {"spk": 2, "temp": 0.3, "seed": 42},
         | 
| 277 | 
            -
                    },
         | 
| 278 | 
             
                ]
         | 
| 279 |  | 
| 280 | 
             
                synthesizer = SynthesizeSegments(batch_size=2)
         | 
| 281 | 
             
                audio_segments = synthesizer.synthesize_segments(ssml_segments)
         | 
|  | |
| 282 | 
             
                combined_audio = combine_audio_segments(audio_segments)
         | 
| 283 | 
             
                combined_audio.export("output.wav", format="wav")
         | 
|  | |
| 1 | 
            +
            from box import Box
         | 
| 2 | 
             
            from pydub import AudioSegment
         | 
| 3 | 
            +
            from typing import List, Union
         | 
| 4 | 
             
            from scipy.io.wavfile import write
         | 
| 5 | 
             
            import io
         | 
| 6 | 
            +
            from modules.api.utils import calc_spk_style
         | 
| 7 | 
            +
            from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
         | 
| 8 | 
             
            from modules.utils import rng
         | 
| 9 | 
             
            from modules.utils.audio import time_stretch, pitch_shift
         | 
| 10 | 
             
            from modules import generate_audio
         | 
| 11 | 
             
            from modules.normalization import text_normalize
         | 
| 12 | 
             
            import logging
         | 
| 13 | 
             
            import json
         | 
|  | |
|  | |
| 14 |  | 
| 15 | 
            +
            from modules.speaker import Speaker, speaker_mgr
         | 
| 16 |  | 
| 17 | 
             
            logger = logging.getLogger(__name__)
         | 
| 18 |  | 
|  | |
| 25 | 
             
                return AudioSegment.from_file(byte_io, format="wav")
         | 
| 26 |  | 
| 27 |  | 
| 28 | 
            +
            def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
         | 
| 29 | 
             
                combined_audio = AudioSegment.empty()
         | 
| 30 | 
             
                for segment in audio_segments:
         | 
| 31 | 
             
                    combined_audio += segment
         | 
|  | |
| 55 | 
             
                    return default
         | 
| 56 |  | 
| 57 |  | 
| 58 | 
            +
            class TTSAudioSegment(Box):
         | 
| 59 | 
            +
                text: str
         | 
| 60 | 
            +
                temperature: float
         | 
| 61 | 
            +
                top_P: float
         | 
| 62 | 
            +
                top_K: int
         | 
| 63 | 
            +
                spk: int
         | 
| 64 | 
            +
                infer_seed: int
         | 
| 65 | 
            +
                prompt1: str
         | 
| 66 | 
            +
                prompt2: str
         | 
| 67 | 
            +
                prefix: str
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                _type: str
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 72 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
             
            class SynthesizeSegments:
         | 
| 76 | 
             
                def __init__(self, batch_size: int = 8):
         | 
| 77 | 
             
                    self.batch_size = batch_size
         | 
| 78 | 
             
                    self.batch_default_spk_seed = rng.np_rng()
         | 
| 79 | 
             
                    self.batch_default_infer_seed = rng.np_rng()
         | 
| 80 |  | 
| 81 | 
            +
                def segment_to_generate_params(
         | 
| 82 | 
            +
                    self, segment: Union[SSMLSegment, SSMLBreak]
         | 
| 83 | 
            +
                ) -> TTSAudioSegment:
         | 
| 84 | 
            +
                    if isinstance(segment, SSMLBreak):
         | 
| 85 | 
            +
                        return TTSAudioSegment(_type="break")
         | 
| 86 | 
            +
             | 
| 87 | 
             
                    if segment.get("params", None) is not None:
         | 
| 88 | 
            +
                        return TTSAudioSegment(**segment.get("params"))
         | 
| 89 |  | 
| 90 | 
             
                    text = segment.get("text", "")
         | 
| 91 | 
             
                    is_end = segment.get("is_end", False)
         | 
| 92 |  | 
| 93 | 
             
                    text = str(text).strip()
         | 
| 94 |  | 
| 95 | 
            +
                    attrs = segment.attrs
         | 
| 96 | 
            +
                    spk = attrs.spk
         | 
| 97 | 
            +
                    style = attrs.style
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    ss_params = calc_spk_style(spk, style)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    if "spk" in ss_params:
         | 
| 102 | 
            +
                        spk = ss_params["spk"]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    seed = to_number(attrs.seed, int, ss_params.get("seed") or -1)
         | 
| 105 | 
            +
                    top_k = to_number(attrs.top_k, int, None)
         | 
| 106 | 
            +
                    top_p = to_number(attrs.top_p, float, None)
         | 
| 107 | 
            +
                    temp = to_number(attrs.temp, float, None)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    prompt1 = attrs.prompt1 or ss_params.get("prompt1")
         | 
| 110 | 
            +
                    prompt2 = attrs.prompt2 or ss_params.get("prompt2")
         | 
| 111 | 
            +
                    prefix = attrs.prefix or ss_params.get("prefix")
         | 
| 112 | 
             
                    disable_normalize = attrs.get("normalize", "") == "False"
         | 
| 113 |  | 
| 114 | 
            +
                    seg = TTSAudioSegment(
         | 
| 115 | 
            +
                        _type="voice",
         | 
| 116 | 
            +
                        text=text,
         | 
| 117 | 
            +
                        temperature=temp if temp is not None else 0.3,
         | 
| 118 | 
            +
                        top_P=top_p if top_p is not None else 0.5,
         | 
| 119 | 
            +
                        top_K=top_k if top_k is not None else 20,
         | 
| 120 | 
            +
                        spk=spk if spk else -1,
         | 
| 121 | 
            +
                        infer_seed=seed if seed else -1,
         | 
| 122 | 
            +
                        prompt1=prompt1 if prompt1 else "",
         | 
| 123 | 
            +
                        prompt2=prompt2 if prompt2 else "",
         | 
| 124 | 
            +
                        prefix=prefix if prefix else "",
         | 
| 125 | 
            +
                    )
         | 
| 126 |  | 
| 127 | 
             
                    if not disable_normalize:
         | 
| 128 | 
            +
                        seg.text = text_normalize(text, is_end=is_end)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况
         | 
| 131 | 
            +
                    if seg.spk == -1:
         | 
| 132 | 
            +
                        seg.spk = self.batch_default_spk_seed
         | 
| 133 | 
            +
                    if seg.infer_seed == -1:
         | 
| 134 | 
            +
                        seg.infer_seed = self.batch_default_infer_seed
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    return seg
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def process_break_segments(
         | 
| 139 | 
            +
                    self,
         | 
| 140 | 
            +
                    src_segments: List[SSMLBreak],
         | 
| 141 | 
            +
                    bucket_segments: List[SSMLBreak],
         | 
| 142 | 
            +
                    audio_segments: List[AudioSegment],
         | 
| 143 | 
            +
                ):
         | 
| 144 | 
            +
                    for segment in bucket_segments:
         | 
| 145 | 
            +
                        index = src_segments.index(segment)
         | 
| 146 | 
            +
                        audio_segments[index] = AudioSegment.silent(
         | 
| 147 | 
            +
                            duration=int(segment.attrs.duration)
         | 
| 148 | 
            +
                        )
         | 
| 149 |  | 
| 150 | 
            +
                def process_voice_segments(
         | 
| 151 | 
            +
                    self,
         | 
| 152 | 
            +
                    src_segments: List[SSMLSegment],
         | 
| 153 | 
            +
                    bucket: List[SSMLSegment],
         | 
| 154 | 
            +
                    audio_segments: List[AudioSegment],
         | 
| 155 | 
            +
                ):
         | 
| 156 | 
            +
                    for i in range(0, len(bucket), self.batch_size):
         | 
| 157 | 
            +
                        batch = bucket[i : i + self.batch_size]
         | 
| 158 | 
            +
                        param_arr = [self.segment_to_generate_params(segment) for segment in batch]
         | 
| 159 | 
            +
                        texts = [params.text for params in param_arr]
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                        params = param_arr[0]
         | 
| 162 | 
            +
                        audio_datas = generate_audio.generate_audio_batch(
         | 
| 163 | 
            +
                            texts=texts,
         | 
| 164 | 
            +
                            temperature=params.temperature,
         | 
| 165 | 
            +
                            top_P=params.top_P,
         | 
| 166 | 
            +
                            top_K=params.top_K,
         | 
| 167 | 
            +
                            spk=params.spk,
         | 
| 168 | 
            +
                            infer_seed=params.infer_seed,
         | 
| 169 | 
            +
                            prompt1=params.prompt1,
         | 
| 170 | 
            +
                            prompt2=params.prompt2,
         | 
| 171 | 
            +
                            prefix=params.prefix,
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
                        for idx, segment in enumerate(batch):
         | 
| 174 | 
            +
                            sr, audio_data = audio_datas[idx]
         | 
| 175 | 
            +
                            rate = float(segment.get("rate", "1.0"))
         | 
| 176 | 
            +
                            volume = float(segment.get("volume", "0"))
         | 
| 177 | 
            +
                            pitch = float(segment.get("pitch", "0"))
         | 
| 178 |  | 
| 179 | 
            +
                            audio_segment = audio_data_to_segment(audio_data, sr)
         | 
| 180 | 
            +
                            audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
         | 
| 181 | 
            +
                            original_index = src_segments.index(segment)
         | 
| 182 | 
            +
                            audio_segments[original_index] = audio_segment
         | 
| 183 |  | 
| 184 | 
             
                def bucket_segments(
         | 
| 185 | 
            +
                    self, segments: List[Union[SSMLSegment, SSMLBreak]]
         | 
| 186 | 
            +
                ) -> List[List[Union[SSMLSegment, SSMLBreak]]]:
         | 
| 187 | 
            +
                    buckets = {"<break>": []}
         | 
|  | |
| 188 | 
             
                    for segment in segments:
         | 
| 189 | 
            +
                        if isinstance(segment, SSMLBreak):
         | 
| 190 | 
            +
                            buckets["<break>"].append(segment)
         | 
| 191 | 
            +
                            continue
         | 
| 192 | 
            +
             | 
| 193 | 
             
                        params = self.segment_to_generate_params(segment)
         | 
| 194 |  | 
| 195 | 
            +
                        if isinstance(params.spk, Speaker):
         | 
| 196 | 
            +
                            params.spk = str(params.spk.id)
         | 
| 197 | 
            +
             | 
| 198 | 
             
                        key = json.dumps(
         | 
| 199 | 
            +
                            {k: v for k, v in params.items() if k != "text"}, sort_keys=True
         | 
| 200 | 
             
                        )
         | 
| 201 | 
             
                        if key not in buckets:
         | 
| 202 | 
             
                            buckets[key] = []
         | 
| 203 | 
             
                        buckets[key].append(segment)
         | 
| 204 |  | 
| 205 | 
            +
                    return buckets
         | 
|  | |
|  | |
| 206 |  | 
| 207 | 
            +
                def synthesize_segments(
         | 
| 208 | 
            +
                    self, segments: List[Union[SSMLSegment, SSMLBreak]]
         | 
| 209 | 
            +
                ) -> List[AudioSegment]:
         | 
| 210 | 
            +
                    audio_segments = [None] * len(segments)
         | 
| 211 | 
             
                    buckets = self.bucket_segments(segments)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 212 |  | 
| 213 | 
            +
                    break_segments = buckets.pop("<break>")
         | 
| 214 | 
            +
                    self.process_break_segments(segments, break_segments, audio_segments)
         | 
| 215 |  | 
| 216 | 
            +
                    buckets = list(buckets.values())
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 217 |  | 
| 218 | 
            +
                    for bucket in buckets:
         | 
| 219 | 
            +
                        self.process_voice_segments(segments, bucket, audio_segments)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 220 |  | 
| 221 | 
            +
                    return audio_segments
         | 
| 222 |  | 
| 223 |  | 
| 224 | 
             
            # 示例使用
         | 
| 225 | 
             
            if __name__ == "__main__":
         | 
| 226 | 
            +
                ctx1 = SSMLContext()
         | 
| 227 | 
            +
                ctx1.spk = 1
         | 
| 228 | 
            +
                ctx1.seed = 42
         | 
| 229 | 
            +
                ctx1.temp = 0.1
         | 
| 230 | 
            +
                ctx2 = SSMLContext()
         | 
| 231 | 
            +
                ctx2.spk = 2
         | 
| 232 | 
            +
                ctx2.seed = 42
         | 
| 233 | 
            +
                ctx2.temp = 0.1
         | 
| 234 | 
             
                ssml_segments = [
         | 
| 235 | 
            +
                    SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
         | 
| 236 | 
            +
                    SSMLBreak(duration_ms=1000),
         | 
| 237 | 
            +
                    SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
         | 
| 238 | 
            +
                    SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()),
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 239 | 
             
                ]
         | 
| 240 |  | 
| 241 | 
             
                synthesizer = SynthesizeSegments(batch_size=2)
         | 
| 242 | 
             
                audio_segments = synthesizer.synthesize_segments(ssml_segments)
         | 
| 243 | 
            +
                print(audio_segments)
         | 
| 244 | 
             
                combined_audio = combine_audio_segments(audio_segments)
         | 
| 245 | 
             
                combined_audio.export("output.wav", format="wav")
         | 
    	
        modules/api/impl/google_api.py
    CHANGED
    
    | @@ -18,7 +18,6 @@ from modules.ssml import parse_ssml | |
| 18 | 
             
            from modules.SynthesizeSegments import (
         | 
| 19 | 
             
                SynthesizeSegments,
         | 
| 20 | 
             
                combine_audio_segments,
         | 
| 21 | 
            -
                synthesize_segment,
         | 
| 22 | 
             
            )
         | 
| 23 |  | 
| 24 | 
             
            from modules.api import utils as api_utils
         | 
|  | |
| 18 | 
             
            from modules.SynthesizeSegments import (
         | 
| 19 | 
             
                SynthesizeSegments,
         | 
| 20 | 
             
                combine_audio_segments,
         | 
|  | |
| 21 | 
             
            )
         | 
| 22 |  | 
| 23 | 
             
            from modules.api import utils as api_utils
         | 
    	
        modules/api/impl/speaker_api.py
    CHANGED
    
    | @@ -7,11 +7,11 @@ from modules.api.Api import APIManager | |
| 7 |  | 
| 8 |  | 
| 9 | 
             
            class CreateSpeaker(BaseModel):
         | 
| 10 | 
            -
                seed: int
         | 
| 11 | 
             
                name: str
         | 
| 12 | 
             
                gender: str
         | 
| 13 | 
             
                describe: str
         | 
| 14 | 
            -
                tensor: list
         | 
|  | |
| 15 |  | 
| 16 |  | 
| 17 | 
             
            class UpdateSpeaker(BaseModel):
         | 
| @@ -76,7 +76,7 @@ def setup(app: APIManager): | |
| 76 | 
             
                            gender=request.gender,
         | 
| 77 | 
             
                            describe=request.describe,
         | 
| 78 | 
             
                        )
         | 
| 79 | 
            -
                     | 
| 80 | 
             
                        # from seed
         | 
| 81 | 
             
                        speaker = speaker_mgr.create_speaker_from_seed(
         | 
| 82 | 
             
                            seed=request.seed,
         | 
| @@ -84,6 +84,10 @@ def setup(app: APIManager): | |
| 84 | 
             
                            gender=request.gender,
         | 
| 85 | 
             
                            describe=request.describe,
         | 
| 86 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
| 87 | 
             
                    return {"message": "ok", "data": speaker.to_json()}
         | 
| 88 |  | 
| 89 | 
             
                @app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
         | 
|  | |
| 7 |  | 
| 8 |  | 
| 9 | 
             
            class CreateSpeaker(BaseModel):
         | 
|  | |
| 10 | 
             
                name: str
         | 
| 11 | 
             
                gender: str
         | 
| 12 | 
             
                describe: str
         | 
| 13 | 
            +
                tensor: list = None
         | 
| 14 | 
            +
                seed: int = None
         | 
| 15 |  | 
| 16 |  | 
| 17 | 
             
            class UpdateSpeaker(BaseModel):
         | 
|  | |
| 76 | 
             
                            gender=request.gender,
         | 
| 77 | 
             
                            describe=request.describe,
         | 
| 78 | 
             
                        )
         | 
| 79 | 
            +
                    elif request.seed:
         | 
| 80 | 
             
                        # from seed
         | 
| 81 | 
             
                        speaker = speaker_mgr.create_speaker_from_seed(
         | 
| 82 | 
             
                            seed=request.seed,
         | 
|  | |
| 84 | 
             
                            gender=request.gender,
         | 
| 85 | 
             
                            describe=request.describe,
         | 
| 86 | 
             
                        )
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        raise HTTPException(
         | 
| 89 | 
            +
                            status_code=400, detail="Missing tensor or seed in request"
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
             
                    return {"message": "ok", "data": speaker.to_json()}
         | 
| 92 |  | 
| 93 | 
             
                @app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
         | 
    	
        modules/api/impl/ssml_api.py
    CHANGED
    
    | @@ -10,7 +10,6 @@ from modules.normalization import text_normalize | |
| 10 | 
             
            from modules.ssml import parse_ssml
         | 
| 11 | 
             
            from modules.SynthesizeSegments import (
         | 
| 12 | 
             
                SynthesizeSegments,
         | 
| 13 | 
            -
                synthesize_segment,
         | 
| 14 | 
             
                combine_audio_segments,
         | 
| 15 | 
             
            )
         | 
| 16 |  | 
| @@ -23,6 +22,8 @@ from modules.api.Api import APIManager | |
| 23 | 
             
            class SSMLRequest(BaseModel):
         | 
| 24 | 
             
                ssml: str
         | 
| 25 | 
             
                format: str = "mp3"
         | 
|  | |
|  | |
| 26 | 
             
                batch_size: int = 4
         | 
| 27 |  | 
| 28 |  | 
| @@ -48,29 +49,15 @@ async def synthesize_ssml( | |
| 48 | 
             
                    for seg in segments:
         | 
| 49 | 
             
                        seg["text"] = text_normalize(seg["text"], is_end=True)
         | 
| 50 |  | 
| 51 | 
            -
                     | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
                         | 
| 59 | 
            -
             | 
| 60 | 
            -
                        return StreamingResponse(buffer, media_type=f"audio/{format}")
         | 
| 61 | 
            -
                    else:
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                        def audio_streamer():
         | 
| 64 | 
            -
                            for segment in segments:
         | 
| 65 | 
            -
                                audio_segment = synthesize_segment(segment=segment)
         | 
| 66 | 
            -
                                buffer = io.BytesIO()
         | 
| 67 | 
            -
                                audio_segment.export(buffer, format="wav")
         | 
| 68 | 
            -
                                buffer.seek(0)
         | 
| 69 | 
            -
                                if format == "mp3":
         | 
| 70 | 
            -
                                    buffer = api_utils.wav_to_mp3(buffer)
         | 
| 71 | 
            -
                                yield buffer.read()
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                        return StreamingResponse(audio_streamer(), media_type=f"audio/{format}")
         | 
| 74 |  | 
| 75 | 
             
                except Exception as e:
         | 
| 76 | 
             
                    import logging
         | 
|  | |
| 10 | 
             
            from modules.ssml import parse_ssml
         | 
| 11 | 
             
            from modules.SynthesizeSegments import (
         | 
| 12 | 
             
                SynthesizeSegments,
         | 
|  | |
| 13 | 
             
                combine_audio_segments,
         | 
| 14 | 
             
            )
         | 
| 15 |  | 
|  | |
| 22 | 
             
            class SSMLRequest(BaseModel):
         | 
| 23 | 
             
                ssml: str
         | 
| 24 | 
             
                format: str = "mp3"
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
         | 
| 27 | 
             
                batch_size: int = 4
         | 
| 28 |  | 
| 29 |  | 
|  | |
| 49 | 
             
                    for seg in segments:
         | 
| 50 | 
             
                        seg["text"] = text_normalize(seg["text"], is_end=True)
         | 
| 51 |  | 
| 52 | 
            +
                    synthesize = SynthesizeSegments(batch_size)
         | 
| 53 | 
            +
                    audio_segments = synthesize.synthesize_segments(segments)
         | 
| 54 | 
            +
                    combined_audio = combine_audio_segments(audio_segments)
         | 
| 55 | 
            +
                    buffer = io.BytesIO()
         | 
| 56 | 
            +
                    combined_audio.export(buffer, format="wav")
         | 
| 57 | 
            +
                    buffer.seek(0)
         | 
| 58 | 
            +
                    if format == "mp3":
         | 
| 59 | 
            +
                        buffer = api_utils.wav_to_mp3(buffer)
         | 
| 60 | 
            +
                    return StreamingResponse(buffer, media_type=f"audio/{format}")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 61 |  | 
| 62 | 
             
                except Exception as e:
         | 
| 63 | 
             
                    import logging
         | 
    	
        modules/api/utils.py
    CHANGED
    
    | @@ -52,7 +52,6 @@ def to_number(value, t, default=0): | |
| 52 | 
             
            def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
         | 
| 53 | 
             
                voice_attrs = {
         | 
| 54 | 
             
                    "spk": None,
         | 
| 55 | 
            -
                    "seed": None,
         | 
| 56 | 
             
                    "prompt1": None,
         | 
| 57 | 
             
                    "prompt2": None,
         | 
| 58 | 
             
                    "prefix": None,
         | 
| @@ -85,7 +84,6 @@ def calc_spk_style(spk: Union[str, int], style: Union[str, int]): | |
| 85 | 
             
                merge_prompt(voice_attrs, params)
         | 
| 86 |  | 
| 87 | 
             
                voice_attrs["spk"] = params.get("spk", voice_attrs.get("spk", None))
         | 
| 88 | 
            -
                voice_attrs["seed"] = params.get("seed", voice_attrs.get("seed", None))
         | 
| 89 | 
             
                voice_attrs["temperature"] = params.get(
         | 
| 90 | 
             
                    "temp", voice_attrs.get("temperature", None)
         | 
| 91 | 
             
                )
         | 
|  | |
| 52 | 
             
            def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
         | 
| 53 | 
             
                voice_attrs = {
         | 
| 54 | 
             
                    "spk": None,
         | 
|  | |
| 55 | 
             
                    "prompt1": None,
         | 
| 56 | 
             
                    "prompt2": None,
         | 
| 57 | 
             
                    "prefix": None,
         | 
|  | |
| 84 | 
             
                merge_prompt(voice_attrs, params)
         | 
| 85 |  | 
| 86 | 
             
                voice_attrs["spk"] = params.get("spk", voice_attrs.get("spk", None))
         | 
|  | |
| 87 | 
             
                voice_attrs["temperature"] = params.get(
         | 
| 88 | 
             
                    "temp", voice_attrs.get("temperature", None)
         | 
| 89 | 
             
                )
         | 
    	
        modules/denoise.py
    CHANGED
    
    | @@ -1,7 +1,51 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
             
            import torchaudio
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 |  | 
| 5 |  | 
| 6 | 
             
            class TTSAudioDenoiser:
         | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from typing import Union
         | 
| 3 | 
            +
             | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import torchaudio
         | 
| 6 | 
            +
            from modules.Denoiser.AudioDenoiser import AudioDenoiser
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from modules.utils.constants import MODELS_DIR
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from modules.devices import devices
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import soundfile as sf
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ad: Union[AudioDenoiser, None] = None
         | 
| 15 |  | 
| 16 |  | 
| 17 | 
             
            class TTSAudioDenoiser:
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def load_ad(self):
         | 
| 20 | 
            +
                    global ad
         | 
| 21 | 
            +
                    if ad is None:
         | 
| 22 | 
            +
                        ad = AudioDenoiser(
         | 
| 23 | 
            +
                            os.path.join(
         | 
| 24 | 
            +
                                MODELS_DIR,
         | 
| 25 | 
            +
                                "Denoise",
         | 
| 26 | 
            +
                                "audio-denoiser-512-32-v1",
         | 
| 27 | 
            +
                            ),
         | 
| 28 | 
            +
                            device=devices.device,
         | 
| 29 | 
            +
                        )
         | 
| 30 | 
            +
                        ad.model.to(devices.device)
         | 
| 31 | 
            +
                    return ad
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def denoise(self, audio_data, sample_rate, auto_scale=False):
         | 
| 34 | 
            +
                    ad = self.load_ad()
         | 
| 35 | 
            +
                    sr = ad.model_sample_rate
         | 
| 36 | 
            +
                    return sr, ad.process_waveform(audio_data, sample_rate, auto_scale)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            if __name__ == "__main__":
         | 
| 40 | 
            +
                tts_deno = TTSAudioDenoiser()
         | 
| 41 | 
            +
                data, sr = sf.read("test.wav")
         | 
| 42 | 
            +
                audio_tensor = torch.from_numpy(data).unsqueeze(0).float()
         | 
| 43 | 
            +
                print(audio_tensor)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # data, sr = torchaudio.load("test.wav")
         | 
| 46 | 
            +
                # print(data)
         | 
| 47 | 
            +
                # data = data.to(devices.device)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr)
         | 
| 50 | 
            +
                denoised = denoised.cpu()
         | 
| 51 | 
            +
                torchaudio.save("denoised.wav", denoised, sample_rate=sr)
         | 
    	
        modules/generate_audio.py
    CHANGED
    
    | @@ -79,7 +79,7 @@ def generate_audio_batch( | |
| 79 | 
             
                    params_infer_code["spk_emb"] = spk.emb
         | 
| 80 | 
             
                    logger.info(("spk", spk.name))
         | 
| 81 | 
             
                else:
         | 
| 82 | 
            -
                    raise ValueError("spk must be int or Speaker")
         | 
| 83 |  | 
| 84 | 
             
                logger.info(
         | 
| 85 | 
             
                    {
         | 
|  | |
| 79 | 
             
                    params_infer_code["spk_emb"] = spk.emb
         | 
| 80 | 
             
                    logger.info(("spk", spk.name))
         | 
| 81 | 
             
                else:
         | 
| 82 | 
            +
                    raise ValueError(f"spk must be int or Speaker, but: <{type(spk)}> {spk}")
         | 
| 83 |  | 
| 84 | 
             
                logger.info(
         | 
| 85 | 
             
                    {
         | 
    	
        modules/models.py
    CHANGED
    
    | @@ -37,17 +37,9 @@ def load_chat_tts_in_thread(): | |
| 37 | 
             
                logger.info("ChatTTS models loaded")
         | 
| 38 |  | 
| 39 |  | 
| 40 | 
            -
            def  | 
| 41 | 
             
                with lock:
         | 
| 42 | 
             
                    if chat_tts is None:
         | 
| 43 | 
            -
                        model_thread = threading.Thread(target=load_chat_tts_in_thread)
         | 
| 44 | 
            -
                        model_thread.start()
         | 
| 45 | 
            -
                        model_thread.join()
         | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
            def load_chat_tts():
         | 
| 49 | 
            -
                if chat_tts is None:
         | 
| 50 | 
            -
                    with lock:
         | 
| 51 | 
             
                        load_chat_tts_in_thread()
         | 
| 52 | 
             
                if chat_tts is None:
         | 
| 53 | 
             
                    raise Exception("Failed to load ChatTTS models")
         | 
|  | |
| 37 | 
             
                logger.info("ChatTTS models loaded")
         | 
| 38 |  | 
| 39 |  | 
| 40 | 
            +
            def load_chat_tts():
         | 
| 41 | 
             
                with lock:
         | 
| 42 | 
             
                    if chat_tts is None:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 43 | 
             
                        load_chat_tts_in_thread()
         | 
| 44 | 
             
                if chat_tts is None:
         | 
| 45 | 
             
                    raise Exception("Failed to load ChatTTS models")
         | 
    	
        modules/speaker.py
    CHANGED
    
    | @@ -1,5 +1,6 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            from typing import Union
         | 
|  | |
| 3 | 
             
            import torch
         | 
| 4 |  | 
| 5 | 
             
            from modules import models
         | 
| @@ -16,6 +17,18 @@ def create_speaker_from_seed(seed): | |
| 16 |  | 
| 17 |  | 
| 18 | 
             
            class Speaker:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 19 | 
             
                def __init__(self, seed, name="", gender="", describe=""):
         | 
| 20 | 
             
                    self.id = uuid.uuid4()
         | 
| 21 | 
             
                    self.seed = seed
         | 
| @@ -24,15 +37,20 @@ class Speaker: | |
| 24 | 
             
                    self.describe = describe
         | 
| 25 | 
             
                    self.emb = None
         | 
| 26 |  | 
|  | |
|  | |
|  | |
| 27 | 
             
                def to_json(self, with_emb=False):
         | 
| 28 | 
            -
                    return  | 
| 29 | 
            -
                         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
|  | |
|  | |
| 36 |  | 
| 37 | 
             
                def fix(self):
         | 
| 38 | 
             
                    is_update = False
         | 
| @@ -78,14 +96,9 @@ class SpeakerManager: | |
| 78 | 
             
                    self.speakers = {}
         | 
| 79 | 
             
                    for speaker_file in os.listdir(self.speaker_dir):
         | 
| 80 | 
             
                        if speaker_file.endswith(".pt"):
         | 
| 81 | 
            -
                             | 
| 82 | 
            -
                                self.speaker_dir + speaker_file | 
| 83 | 
             
                            )
         | 
| 84 | 
            -
                            self.speakers[speaker_file] = speaker
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                            is_update = speaker.fix()
         | 
| 87 | 
            -
                            if is_update:
         | 
| 88 | 
            -
                                torch.save(speaker, self.speaker_dir + speaker_file)
         | 
| 89 |  | 
| 90 | 
             
                def list_speakers(self):
         | 
| 91 | 
             
                    return list(self.speakers.values())
         | 
| @@ -103,8 +116,8 @@ class SpeakerManager: | |
| 103 | 
             
                def create_speaker_from_tensor(
         | 
| 104 | 
             
                    self, tensor, filename="", name="", gender="", describe=""
         | 
| 105 | 
             
                ):
         | 
| 106 | 
            -
                    if  | 
| 107 | 
            -
                         | 
| 108 | 
             
                    speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
         | 
| 109 | 
             
                    if isinstance(tensor, torch.Tensor):
         | 
| 110 | 
             
                        speaker.emb = tensor
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            from typing import Union
         | 
| 3 | 
            +
            from box import Box
         | 
| 4 | 
             
            import torch
         | 
| 5 |  | 
| 6 | 
             
            from modules import models
         | 
|  | |
| 17 |  | 
| 18 |  | 
| 19 | 
             
            class Speaker:
         | 
| 20 | 
            +
                @staticmethod
         | 
| 21 | 
            +
                def from_file(file_like):
         | 
| 22 | 
            +
                    speaker = torch.load(file_like, map_location=torch.device("cpu"))
         | 
| 23 | 
            +
                    speaker.fix()
         | 
| 24 | 
            +
                    return speaker
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                @staticmethod
         | 
| 27 | 
            +
                def from_tensor(tensor):
         | 
| 28 | 
            +
                    speaker = Speaker(seed=-2)
         | 
| 29 | 
            +
                    speaker.emb = tensor
         | 
| 30 | 
            +
                    return speaker
         | 
| 31 | 
            +
             | 
| 32 | 
             
                def __init__(self, seed, name="", gender="", describe=""):
         | 
| 33 | 
             
                    self.id = uuid.uuid4()
         | 
| 34 | 
             
                    self.seed = seed
         | 
|  | |
| 37 | 
             
                    self.describe = describe
         | 
| 38 | 
             
                    self.emb = None
         | 
| 39 |  | 
| 40 | 
            +
                    # TODO replace emb => tokens
         | 
| 41 | 
            +
                    self.tokens = []
         | 
| 42 | 
            +
             | 
| 43 | 
             
                def to_json(self, with_emb=False):
         | 
| 44 | 
            +
                    return Box(
         | 
| 45 | 
            +
                        **{
         | 
| 46 | 
            +
                            "id": str(self.id),
         | 
| 47 | 
            +
                            "seed": self.seed,
         | 
| 48 | 
            +
                            "name": self.name,
         | 
| 49 | 
            +
                            "gender": self.gender,
         | 
| 50 | 
            +
                            "describe": self.describe,
         | 
| 51 | 
            +
                            "emb": self.emb.tolist() if with_emb else None,
         | 
| 52 | 
            +
                        }
         | 
| 53 | 
            +
                    )
         | 
| 54 |  | 
| 55 | 
             
                def fix(self):
         | 
| 56 | 
             
                    is_update = False
         | 
|  | |
| 96 | 
             
                    self.speakers = {}
         | 
| 97 | 
             
                    for speaker_file in os.listdir(self.speaker_dir):
         | 
| 98 | 
             
                        if speaker_file.endswith(".pt"):
         | 
| 99 | 
            +
                            self.speakers[speaker_file] = Speaker.from_file(
         | 
| 100 | 
            +
                                self.speaker_dir + speaker_file
         | 
| 101 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 102 |  | 
| 103 | 
             
                def list_speakers(self):
         | 
| 104 | 
             
                    return list(self.speakers.values())
         | 
|  | |
| 116 | 
             
                def create_speaker_from_tensor(
         | 
| 117 | 
             
                    self, tensor, filename="", name="", gender="", describe=""
         | 
| 118 | 
             
                ):
         | 
| 119 | 
            +
                    if filename == "":
         | 
| 120 | 
            +
                        filename = name
         | 
| 121 | 
             
                    speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
         | 
| 122 | 
             
                    if isinstance(tensor, torch.Tensor):
         | 
| 123 | 
             
                        speaker.emb = tensor
         | 
    	
        modules/ssml_parser/SSMLParser.py
    ADDED
    
    | @@ -0,0 +1,178 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from lxml import etree
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any, List, Dict, Union
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from modules.data import styles_mgr
         | 
| 8 | 
            +
            from modules.speaker import speaker_mgr
         | 
| 9 | 
            +
            from box import Box
         | 
| 10 | 
            +
            import copy
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class SSMLContext(Box):
         | 
| 14 | 
            +
                def __init__(self, parent=None):
         | 
| 15 | 
            +
                    self.parent: Union[SSMLContext, None] = parent
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    self.style = None
         | 
| 18 | 
            +
                    self.spk = None
         | 
| 19 | 
            +
                    self.volume = None
         | 
| 20 | 
            +
                    self.rate = None
         | 
| 21 | 
            +
                    self.pitch = None
         | 
| 22 | 
            +
                    # tempurature
         | 
| 23 | 
            +
                    self.temp = None
         | 
| 24 | 
            +
                    self.top_p = None
         | 
| 25 | 
            +
                    self.top_k = None
         | 
| 26 | 
            +
                    self.seed = None
         | 
| 27 | 
            +
                    self.noramalize = None
         | 
| 28 | 
            +
                    self.prompt1 = None
         | 
| 29 | 
            +
                    self.prompt2 = None
         | 
| 30 | 
            +
                    self.prefix = None
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class SSMLSegment(Box):
         | 
| 34 | 
            +
                def __init__(self, text: str, attrs=SSMLContext()):
         | 
| 35 | 
            +
                    self.attrs = attrs
         | 
| 36 | 
            +
                    self.text = text
         | 
| 37 | 
            +
                    self.params = None
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class SSMLBreak:
         | 
| 41 | 
            +
                def __init__(self, duration_ms: Union[str, int, float]):
         | 
| 42 | 
            +
                    # TODO 支持其他单位
         | 
| 43 | 
            +
                    duration_ms = int(str(duration_ms).replace("ms", ""))
         | 
| 44 | 
            +
                    self.attrs = Box(**{"duration": duration_ms})
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class SSMLParser:
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __init__(self):
         | 
| 50 | 
            +
                    self.logger = logging.getLogger(__name__)
         | 
| 51 | 
            +
                    self.logger.debug("SSMLParser.__init__()")
         | 
| 52 | 
            +
                    self.resolvers = []
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def resolver(self, tag: str):
         | 
| 55 | 
            +
                    def decorator(func):
         | 
| 56 | 
            +
                        self.resolvers.append((tag, func))
         | 
| 57 | 
            +
                        return func
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    return decorator
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def parse(self, ssml: str) -> List[Union[SSMLSegment, SSMLBreak]]:
         | 
| 62 | 
            +
                    root = etree.fromstring(ssml)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    root_ctx = SSMLContext()
         | 
| 65 | 
            +
                    segments = []
         | 
| 66 | 
            +
                    self.resolve(root, root_ctx, segments)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    return segments
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def resolve(
         | 
| 71 | 
            +
                    self, element: etree.Element, context: SSMLContext, segments: List[SSMLSegment]
         | 
| 72 | 
            +
                ):
         | 
| 73 | 
            +
                    resolver = [resolver for tag, resolver in self.resolvers if tag == element.tag]
         | 
| 74 | 
            +
                    if len(resolver) == 0:
         | 
| 75 | 
            +
                        raise NotImplementedError(f"Tag {element.tag} not supported.")
         | 
| 76 | 
            +
                    else:
         | 
| 77 | 
            +
                        resolver = resolver[0]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    resolver(element, context, segments, self)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def create_ssml_parser():
         | 
| 83 | 
            +
                parser = SSMLParser()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                @parser.resolver("speak")
         | 
| 86 | 
            +
                def tag_speak(element, context, segments, parser):
         | 
| 87 | 
            +
                    ctx = copy.deepcopy(context)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    version = element.get("version")
         | 
| 90 | 
            +
                    if version != "0.1":
         | 
| 91 | 
            +
                        raise ValueError(f"Unsupported SSML version {version}")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    for child in element:
         | 
| 94 | 
            +
                        parser.resolve(child, ctx, segments)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                @parser.resolver("voice")
         | 
| 97 | 
            +
                def tag_voice(element, context, segments, parser):
         | 
| 98 | 
            +
                    ctx = copy.deepcopy(context)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    ctx.spk = element.get("spk", ctx.spk)
         | 
| 101 | 
            +
                    ctx.style = element.get("style", ctx.style)
         | 
| 102 | 
            +
                    ctx.spk = element.get("spk", ctx.spk)
         | 
| 103 | 
            +
                    ctx.volume = element.get("volume", ctx.volume)
         | 
| 104 | 
            +
                    ctx.rate = element.get("rate", ctx.rate)
         | 
| 105 | 
            +
                    ctx.pitch = element.get("pitch", ctx.pitch)
         | 
| 106 | 
            +
                    # tempurature
         | 
| 107 | 
            +
                    ctx.temp = element.get("temp", ctx.temp)
         | 
| 108 | 
            +
                    ctx.top_p = element.get("top_p", ctx.top_p)
         | 
| 109 | 
            +
                    ctx.top_k = element.get("top_k", ctx.top_k)
         | 
| 110 | 
            +
                    ctx.seed = element.get("seed", ctx.seed)
         | 
| 111 | 
            +
                    ctx.noramalize = element.get("noramalize", ctx.noramalize)
         | 
| 112 | 
            +
                    ctx.prompt1 = element.get("prompt1", ctx.prompt1)
         | 
| 113 | 
            +
                    ctx.prompt2 = element.get("prompt2", ctx.prompt2)
         | 
| 114 | 
            +
                    ctx.prefix = element.get("prefix", ctx.prefix)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # 处理 voice 开头的文本
         | 
| 117 | 
            +
                    if element.text and element.text.strip():
         | 
| 118 | 
            +
                        segments.append(SSMLSegment(element.text.strip(), ctx))
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    for child in element:
         | 
| 121 | 
            +
                        parser.resolve(child, ctx, segments)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                        # 处理 voice 结尾的文本
         | 
| 124 | 
            +
                        if child.tail and child.tail.strip():
         | 
| 125 | 
            +
                            segments.append(SSMLSegment(child.tail.strip(), ctx))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                @parser.resolver("break")
         | 
| 128 | 
            +
                def tag_break(element, context, segments, parser):
         | 
| 129 | 
            +
                    time_ms = int(element.get("time", "0").replace("ms", ""))
         | 
| 130 | 
            +
                    segments.append(SSMLBreak(time_ms))
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                @parser.resolver("prosody")
         | 
| 133 | 
            +
                def tag_prosody(element, context, segments, parser):
         | 
| 134 | 
            +
                    ctx = copy.deepcopy(context)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    ctx.spk = element.get("spk", ctx.spk)
         | 
| 137 | 
            +
                    ctx.style = element.get("style", ctx.style)
         | 
| 138 | 
            +
                    ctx.spk = element.get("spk", ctx.spk)
         | 
| 139 | 
            +
                    ctx.volume = element.get("volume", ctx.volume)
         | 
| 140 | 
            +
                    ctx.rate = element.get("rate", ctx.rate)
         | 
| 141 | 
            +
                    ctx.pitch = element.get("pitch", ctx.pitch)
         | 
| 142 | 
            +
                    # tempurature
         | 
| 143 | 
            +
                    ctx.temp = element.get("temp", ctx.temp)
         | 
| 144 | 
            +
                    ctx.top_p = element.get("top_p", ctx.top_p)
         | 
| 145 | 
            +
                    ctx.top_k = element.get("top_k", ctx.top_k)
         | 
| 146 | 
            +
                    ctx.seed = element.get("seed", ctx.seed)
         | 
| 147 | 
            +
                    ctx.noramalize = element.get("noramalize", ctx.noramalize)
         | 
| 148 | 
            +
                    ctx.prompt1 = element.get("prompt1", ctx.prompt1)
         | 
| 149 | 
            +
                    ctx.prompt2 = element.get("prompt2", ctx.prompt2)
         | 
| 150 | 
            +
                    ctx.prefix = element.get("prefix", ctx.prefix)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if element.text and element.text.strip():
         | 
| 153 | 
            +
                        segments.append(SSMLSegment(element.text.strip(), ctx))
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                return parser
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            if __name__ == "__main__":
         | 
| 159 | 
            +
                parser = create_ssml_parser()
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                ssml = """
         | 
| 162 | 
            +
                <speak version="0.1">
         | 
| 163 | 
            +
                    <voice spk="xiaoyan" style="news">
         | 
| 164 | 
            +
                        <prosody rate="fast">你好</prosody>
         | 
| 165 | 
            +
                        <break time="500ms"/>
         | 
| 166 | 
            +
                        <prosody rate="slow">你好</prosody>
         | 
| 167 | 
            +
                    </voice>
         | 
| 168 | 
            +
                </speak>
         | 
| 169 | 
            +
                """
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                segments = parser.parse(ssml)
         | 
| 172 | 
            +
                for segment in segments:
         | 
| 173 | 
            +
                    if isinstance(segment, SSMLBreak):
         | 
| 174 | 
            +
                        print("<break>", segment.attrs)
         | 
| 175 | 
            +
                    elif isinstance(segment, SSMLSegment):
         | 
| 176 | 
            +
                        print(segment.text, segment.attrs)
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        raise ValueError("Unknown segment type")
         | 
    	
        modules/ssml_parser/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        modules/ssml_parser/test_ssml_parser.py
    ADDED
    
    | @@ -0,0 +1,104 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pytest
         | 
| 2 | 
            +
            from lxml import etree
         | 
| 3 | 
            +
            from modules.ssml_parser.SSMLParser import (
         | 
| 4 | 
            +
                create_ssml_parser,
         | 
| 5 | 
            +
                SSMLSegment,
         | 
| 6 | 
            +
                SSMLBreak,
         | 
| 7 | 
            +
                SSMLContext,
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @pytest.fixture
         | 
| 12 | 
            +
            def parser():
         | 
| 13 | 
            +
                return create_ssml_parser()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            @pytest.mark.ssml_parser
         | 
| 17 | 
            +
            def test_speak_tag(parser):
         | 
| 18 | 
            +
                ssml = """
         | 
| 19 | 
            +
                <speak version="0.1">
         | 
| 20 | 
            +
                    <voice spk="xiaoyan" style="news">
         | 
| 21 | 
            +
                        <prosody rate="fast">你好</prosody>
         | 
| 22 | 
            +
                        <break time="500ms"/>
         | 
| 23 | 
            +
                        <prosody rate="slow">你好</prosody>
         | 
| 24 | 
            +
                    </voice>
         | 
| 25 | 
            +
                </speak>
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                segments = parser.parse(ssml)
         | 
| 28 | 
            +
                assert len(segments) == 3
         | 
| 29 | 
            +
                assert isinstance(segments[0], SSMLSegment)
         | 
| 30 | 
            +
                assert segments[0].text == "你好"
         | 
| 31 | 
            +
                assert segments[0].params.rate == "fast"
         | 
| 32 | 
            +
                assert isinstance(segments[1], SSMLBreak)
         | 
| 33 | 
            +
                assert segments[1].duration == 500
         | 
| 34 | 
            +
                assert isinstance(segments[2], SSMLSegment)
         | 
| 35 | 
            +
                assert segments[2].text == "你好"
         | 
| 36 | 
            +
                assert segments[2].params.rate == "slow"
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            @pytest.mark.ssml_parser
         | 
| 40 | 
            +
            def test_voice_tag(parser):
         | 
| 41 | 
            +
                ssml = """
         | 
| 42 | 
            +
                <speak version="0.1">
         | 
| 43 | 
            +
                    <voice spk="xiaoyan" style="news">你好</voice>
         | 
| 44 | 
            +
                </speak>
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                segments = parser.parse(ssml)
         | 
| 47 | 
            +
                assert len(segments) == 1
         | 
| 48 | 
            +
                assert isinstance(segments[0], SSMLSegment)
         | 
| 49 | 
            +
                assert segments[0].text == "你好"
         | 
| 50 | 
            +
                assert segments[0].params.spk == "xiaoyan"
         | 
| 51 | 
            +
                assert segments[0].params.style == "news"
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            @pytest.mark.ssml_parser
         | 
| 55 | 
            +
            def test_break_tag(parser):
         | 
| 56 | 
            +
                ssml = """
         | 
| 57 | 
            +
                <speak version="0.1">
         | 
| 58 | 
            +
                    <break time="500ms"/>
         | 
| 59 | 
            +
                </speak>
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                segments = parser.parse(ssml)
         | 
| 62 | 
            +
                assert len(segments) == 1
         | 
| 63 | 
            +
                assert isinstance(segments[0], SSMLBreak)
         | 
| 64 | 
            +
                assert segments[0].duration == 500
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            @pytest.mark.ssml_parser
         | 
| 68 | 
            +
            def test_prosody_tag(parser):
         | 
| 69 | 
            +
                ssml = """
         | 
| 70 | 
            +
                <speak version="0.1">
         | 
| 71 | 
            +
                    <prosody rate="fast">你好</prosody>
         | 
| 72 | 
            +
                </speak>
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                segments = parser.parse(ssml)
         | 
| 75 | 
            +
                assert len(segments) == 1
         | 
| 76 | 
            +
                assert isinstance(segments[0], SSMLSegment)
         | 
| 77 | 
            +
                assert segments[0].text == "你好"
         | 
| 78 | 
            +
                assert segments[0].params.rate == "fast"
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            @pytest.mark.ssml_parser
         | 
| 82 | 
            +
            def test_unsupported_version(parser):
         | 
| 83 | 
            +
                ssml = """
         | 
| 84 | 
            +
                <speak version="0.2">
         | 
| 85 | 
            +
                    <voice spk="xiaoyan" style="news">你好</voice>
         | 
| 86 | 
            +
                </speak>
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                with pytest.raises(ValueError, match=r"Unsupported SSML version 0.2"):
         | 
| 89 | 
            +
                    parser.parse(ssml)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            @pytest.mark.ssml_parser
         | 
| 93 | 
            +
            def test_unsupported_tag(parser):
         | 
| 94 | 
            +
                ssml = """
         | 
| 95 | 
            +
                <speak version="0.1">
         | 
| 96 | 
            +
                    <unsupported>你好</unsupported>
         | 
| 97 | 
            +
                </speak>
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                with pytest.raises(NotImplementedError, match=r"Tag unsupported not supported."):
         | 
| 100 | 
            +
                    parser.parse(ssml)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            if __name__ == "__main__":
         | 
| 104 | 
            +
                pytest.main()
         | 
    	
        modules/utils/JsonObject.py
    CHANGED
    
    | @@ -8,6 +8,9 @@ class JsonObject: | |
| 8 | 
             
                    # If no initial dictionary is provided, use an empty dictionary
         | 
| 9 | 
             
                    self._dict_obj = initial_dict if initial_dict is not None else {}
         | 
| 10 |  | 
|  | |
|  | |
|  | |
| 11 | 
             
                def __getattr__(self, name):
         | 
| 12 | 
             
                    """
         | 
| 13 | 
             
                    Get an attribute value. If the attribute does not exist,
         | 
| @@ -111,3 +114,19 @@ class JsonObject: | |
| 111 | 
             
                    :return: A list of values.
         | 
| 112 | 
             
                    """
         | 
| 113 | 
             
                    return self._dict_obj.values()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 8 | 
             
                    # If no initial dictionary is provided, use an empty dictionary
         | 
| 9 | 
             
                    self._dict_obj = initial_dict if initial_dict is not None else {}
         | 
| 10 |  | 
| 11 | 
            +
                    if self._dict_obj is self:
         | 
| 12 | 
            +
                        raise ValueError("JsonObject cannot be initialized with itself")
         | 
| 13 | 
            +
             | 
| 14 | 
             
                def __getattr__(self, name):
         | 
| 15 | 
             
                    """
         | 
| 16 | 
             
                    Get an attribute value. If the attribute does not exist,
         | 
|  | |
| 114 | 
             
                    :return: A list of values.
         | 
| 115 | 
             
                    """
         | 
| 116 | 
             
                    return self._dict_obj.values()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def clone(self):
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    Clone the JsonObject.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    :return: A new JsonObject with the same internal dictionary.
         | 
| 123 | 
            +
                    """
         | 
| 124 | 
            +
                    return JsonObject(self._dict_obj.copy())
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def merge(self, other):
         | 
| 127 | 
            +
                    """
         | 
| 128 | 
            +
                    Merge the internal dictionary with another dictionary.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    :param other: The other dictionary to merge.
         | 
| 131 | 
            +
                    """
         | 
| 132 | 
            +
                    self._dict_obj.update(other)
         | 
    	
        modules/utils/constants.py
    CHANGED
    
    | @@ -10,4 +10,4 @@ DATA_DIR = os.path.join(ROOT_DIR, "data") | |
| 10 |  | 
| 11 | 
             
            MODELS_DIR = os.path.join(ROOT_DIR, "models")
         | 
| 12 |  | 
| 13 | 
            -
             | 
|  | |
| 10 |  | 
| 11 | 
             
            MODELS_DIR = os.path.join(ROOT_DIR, "models")
         | 
| 12 |  | 
| 13 | 
            +
            SPEAKERS_DIR = os.path.join(DATA_DIR, "speakers")
         | 
    	
        modules/webui/app.py
    CHANGED
    
    | @@ -5,7 +5,9 @@ import torch | |
| 5 | 
             
            import gradio as gr
         | 
| 6 |  | 
| 7 | 
             
            from modules import config
         | 
|  | |
| 8 |  | 
|  | |
| 9 | 
             
            from modules.webui.tts_tab import create_tts_interface
         | 
| 10 | 
             
            from modules.webui.ssml_tab import create_ssml_interface
         | 
| 11 | 
             
            from modules.webui.spliter_tab import create_spliter_tab
         | 
| @@ -93,15 +95,15 @@ def create_interface(): | |
| 93 | 
             
                        with gr.TabItem("Spilter"):
         | 
| 94 | 
             
                            create_spliter_tab(ssml_input, tabs=tabs)
         | 
| 95 |  | 
| 96 | 
            -
                         | 
| 97 | 
            -
                             | 
| 98 | 
            -
             | 
| 99 | 
            -
                             | 
| 100 | 
            -
             | 
| 101 | 
            -
                             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 |  | 
| 106 | 
             
                        with gr.TabItem("README"):
         | 
| 107 | 
             
                            create_readme_tab()
         | 
|  | |
| 5 | 
             
            import gradio as gr
         | 
| 6 |  | 
| 7 | 
             
            from modules import config
         | 
| 8 | 
            +
            from modules.webui import webui_config
         | 
| 9 |  | 
| 10 | 
            +
            from modules.webui.system_tab import create_system_tab
         | 
| 11 | 
             
            from modules.webui.tts_tab import create_tts_interface
         | 
| 12 | 
             
            from modules.webui.ssml_tab import create_ssml_interface
         | 
| 13 | 
             
            from modules.webui.spliter_tab import create_spliter_tab
         | 
|  | |
| 95 | 
             
                        with gr.TabItem("Spilter"):
         | 
| 96 | 
             
                            create_spliter_tab(ssml_input, tabs=tabs)
         | 
| 97 |  | 
| 98 | 
            +
                        with gr.TabItem("Speaker"):
         | 
| 99 | 
            +
                            create_speaker_panel()
         | 
| 100 | 
            +
                        with gr.TabItem("Inpainting", visible=webui_config.experimental):
         | 
| 101 | 
            +
                            gr.Markdown("🚧 Under construction")
         | 
| 102 | 
            +
                        with gr.TabItem("ASR", visible=webui_config.experimental):
         | 
| 103 | 
            +
                            gr.Markdown("🚧 Under construction")
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        with gr.TabItem("System"):
         | 
| 106 | 
            +
                            create_system_tab()
         | 
| 107 |  | 
| 108 | 
             
                        with gr.TabItem("README"):
         | 
| 109 | 
             
                            create_readme_tab()
         | 
    	
        modules/webui/speaker_tab.py
    CHANGED
    
    | @@ -1,13 +1,259 @@ | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
| 2 |  | 
| 3 | 
            -
            from modules. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 |  | 
| 5 |  | 
| 6 | 
             
            # 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
         | 
| 7 | 
             
            def create_speaker_panel():
         | 
| 8 | 
             
                speakers = get_speakers()
         | 
| 9 |  | 
| 10 | 
            -
                 | 
| 11 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 |  | 
| 13 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import io
         | 
| 2 | 
             
            import gradio as gr
         | 
| 3 | 
            +
            import torch
         | 
| 4 |  | 
| 5 | 
            +
            from modules.hf import spaces
         | 
| 6 | 
            +
            from modules.webui.webui_utils import get_speakers, tts_generate
         | 
| 7 | 
            +
            from modules.speaker import speaker_mgr, Speaker
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import tempfile
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def spk_to_tensor(spk):
         | 
| 13 | 
            +
                spk = spk.split(" : ")[1].strip() if " : " in spk else spk
         | 
| 14 | 
            +
                if spk == "None" or spk == "":
         | 
| 15 | 
            +
                    return None
         | 
| 16 | 
            +
                return speaker_mgr.get_speaker(spk).emb
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def get_speaker_show_name(spk):
         | 
| 20 | 
            +
                if spk.gender == "*" or spk.gender == "":
         | 
| 21 | 
            +
                    return spk.name
         | 
| 22 | 
            +
                return f"{spk.gender} : {spk.name}"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def merge_spk(
         | 
| 26 | 
            +
                spk_a,
         | 
| 27 | 
            +
                spk_a_w,
         | 
| 28 | 
            +
                spk_b,
         | 
| 29 | 
            +
                spk_b_w,
         | 
| 30 | 
            +
                spk_c,
         | 
| 31 | 
            +
                spk_c_w,
         | 
| 32 | 
            +
                spk_d,
         | 
| 33 | 
            +
                spk_d_w,
         | 
| 34 | 
            +
            ):
         | 
| 35 | 
            +
                tensor_a = spk_to_tensor(spk_a)
         | 
| 36 | 
            +
                tensor_b = spk_to_tensor(spk_b)
         | 
| 37 | 
            +
                tensor_c = spk_to_tensor(spk_c)
         | 
| 38 | 
            +
                tensor_d = spk_to_tensor(spk_d)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                assert (
         | 
| 41 | 
            +
                    tensor_a is not None
         | 
| 42 | 
            +
                    or tensor_b is not None
         | 
| 43 | 
            +
                    or tensor_c is not None
         | 
| 44 | 
            +
                    or tensor_d is not None
         | 
| 45 | 
            +
                ), "At least one speaker should be selected"
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                merge_tensor = torch.zeros_like(
         | 
| 48 | 
            +
                    tensor_a
         | 
| 49 | 
            +
                    if tensor_a is not None
         | 
| 50 | 
            +
                    else (
         | 
| 51 | 
            +
                        tensor_b
         | 
| 52 | 
            +
                        if tensor_b is not None
         | 
| 53 | 
            +
                        else tensor_c if tensor_c is not None else tensor_d
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                )
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                total_weight = 0
         | 
| 58 | 
            +
                if tensor_a is not None:
         | 
| 59 | 
            +
                    merge_tensor += spk_a_w * tensor_a
         | 
| 60 | 
            +
                    total_weight += spk_a_w
         | 
| 61 | 
            +
                if tensor_b is not None:
         | 
| 62 | 
            +
                    merge_tensor += spk_b_w * tensor_b
         | 
| 63 | 
            +
                    total_weight += spk_b_w
         | 
| 64 | 
            +
                if tensor_c is not None:
         | 
| 65 | 
            +
                    merge_tensor += spk_c_w * tensor_c
         | 
| 66 | 
            +
                    total_weight += spk_c_w
         | 
| 67 | 
            +
                if tensor_d is not None:
         | 
| 68 | 
            +
                    merge_tensor += spk_d_w * tensor_d
         | 
| 69 | 
            +
                    total_weight += spk_d_w
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                if total_weight > 0:
         | 
| 72 | 
            +
                    merge_tensor /= total_weight
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                merged_spk = Speaker.from_tensor(merge_tensor)
         | 
| 75 | 
            +
                merged_spk.name = "<MIX>"
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return merged_spk
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            @torch.inference_mode()
         | 
| 81 | 
            +
            @spaces.GPU
         | 
| 82 | 
            +
            def merge_and_test_spk_voice(
         | 
| 83 | 
            +
                spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
         | 
| 84 | 
            +
            ):
         | 
| 85 | 
            +
                merged_spk = merge_spk(
         | 
| 86 | 
            +
                    spk_a,
         | 
| 87 | 
            +
                    spk_a_w,
         | 
| 88 | 
            +
                    spk_b,
         | 
| 89 | 
            +
                    spk_b_w,
         | 
| 90 | 
            +
                    spk_c,
         | 
| 91 | 
            +
                    spk_c_w,
         | 
| 92 | 
            +
                    spk_d,
         | 
| 93 | 
            +
                    spk_d_w,
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
                return tts_generate(
         | 
| 96 | 
            +
                    spk=merged_spk,
         | 
| 97 | 
            +
                    text=test_text,
         | 
| 98 | 
            +
                )
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            @torch.inference_mode()
         | 
| 102 | 
            +
            @spaces.GPU
         | 
| 103 | 
            +
            def merge_spk_to_file(
         | 
| 104 | 
            +
                spk_a,
         | 
| 105 | 
            +
                spk_a_w,
         | 
| 106 | 
            +
                spk_b,
         | 
| 107 | 
            +
                spk_b_w,
         | 
| 108 | 
            +
                spk_c,
         | 
| 109 | 
            +
                spk_c_w,
         | 
| 110 | 
            +
                spk_d,
         | 
| 111 | 
            +
                spk_d_w,
         | 
| 112 | 
            +
                speaker_name,
         | 
| 113 | 
            +
                speaker_gender,
         | 
| 114 | 
            +
                speaker_desc,
         | 
| 115 | 
            +
            ):
         | 
| 116 | 
            +
                merged_spk = merge_spk(
         | 
| 117 | 
            +
                    spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
                merged_spk.name = speaker_name
         | 
| 120 | 
            +
                merged_spk.gender = speaker_gender
         | 
| 121 | 
            +
                merged_spk.desc = speaker_desc
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
         | 
| 124 | 
            +
                    torch.save(merged_spk, tmp_file)
         | 
| 125 | 
            +
                    tmp_file_path = tmp_file.name
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                return tmp_file_path
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            merge_desc = """
         | 
| 131 | 
            +
            ## Speaker Merger
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            ### 1. 选择说话人
         | 
| 136 | 
            +
            您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            ### 2. 合成语音
         | 
| 139 | 
            +
            在选择好说话人和设置好权重后,您可以在“测试文本”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            ### 3. 保存说话人
         | 
| 142 | 
            +
            您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“保存说话人”按钮来保存合成的说话人。保存后的说话人文件将显示在“合成说话人”栏中,供下载使用。
         | 
| 143 | 
            +
            """
         | 
| 144 |  | 
| 145 |  | 
| 146 | 
             
            # 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
         | 
| 147 | 
             
            def create_speaker_panel():
         | 
| 148 | 
             
                speakers = get_speakers()
         | 
| 149 |  | 
| 150 | 
            +
                speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                with gr.Tabs():
         | 
| 153 | 
            +
                    with gr.TabItem("Merger"):
         | 
| 154 | 
            +
                        gr.Markdown(merge_desc)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                        with gr.Row():
         | 
| 157 | 
            +
                            with gr.Column(scale=5):
         | 
| 158 | 
            +
                                with gr.Row():
         | 
| 159 | 
            +
                                    with gr.Group():
         | 
| 160 | 
            +
                                        spk_a = gr.Dropdown(
         | 
| 161 | 
            +
                                            choices=speaker_names, value="None", label="Speaker A"
         | 
| 162 | 
            +
                                        )
         | 
| 163 | 
            +
                                        spk_a_w = gr.Slider(
         | 
| 164 | 
            +
                                            value=1, minimum=0, maximum=10, step=1, label="Weight A"
         | 
| 165 | 
            +
                                        )
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                                    with gr.Group():
         | 
| 168 | 
            +
                                        spk_b = gr.Dropdown(
         | 
| 169 | 
            +
                                            choices=speaker_names, value="None", label="Speaker B"
         | 
| 170 | 
            +
                                        )
         | 
| 171 | 
            +
                                        spk_b_w = gr.Slider(
         | 
| 172 | 
            +
                                            value=1, minimum=0, maximum=10, step=1, label="Weight B"
         | 
| 173 | 
            +
                                        )
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                                    with gr.Group():
         | 
| 176 | 
            +
                                        spk_c = gr.Dropdown(
         | 
| 177 | 
            +
                                            choices=speaker_names, value="None", label="Speaker C"
         | 
| 178 | 
            +
                                        )
         | 
| 179 | 
            +
                                        spk_c_w = gr.Slider(
         | 
| 180 | 
            +
                                            value=1, minimum=0, maximum=10, step=1, label="Weight C"
         | 
| 181 | 
            +
                                        )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                                    with gr.Group():
         | 
| 184 | 
            +
                                        spk_d = gr.Dropdown(
         | 
| 185 | 
            +
                                            choices=speaker_names, value="None", label="Speaker D"
         | 
| 186 | 
            +
                                        )
         | 
| 187 | 
            +
                                        spk_d_w = gr.Slider(
         | 
| 188 | 
            +
                                            value=1, minimum=0, maximum=10, step=1, label="Weight D"
         | 
| 189 | 
            +
                                        )
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                                with gr.Row():
         | 
| 192 | 
            +
                                    with gr.Column(scale=3):
         | 
| 193 | 
            +
                                        with gr.Group():
         | 
| 194 | 
            +
                                            gr.Markdown("🎤Test voice")
         | 
| 195 | 
            +
                                            with gr.Row():
         | 
| 196 | 
            +
                                                test_voice_btn = gr.Button(
         | 
| 197 | 
            +
                                                    "Test Voice", variant="secondary"
         | 
| 198 | 
            +
                                                )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                                                with gr.Column(scale=4):
         | 
| 201 | 
            +
                                                    test_text = gr.Textbox(
         | 
| 202 | 
            +
                                                        label="Test Text",
         | 
| 203 | 
            +
                                                        placeholder="Please input test text",
         | 
| 204 | 
            +
                                                        value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
         | 
| 205 | 
            +
                                                    )
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                                                    output_audio = gr.Audio(label="Output Audio")
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                            with gr.Column(scale=1):
         | 
| 210 | 
            +
                                with gr.Group():
         | 
| 211 | 
            +
                                    gr.Markdown("🗃️Save to file")
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                                    speaker_name = gr.Textbox(
         | 
| 214 | 
            +
                                        label="Name", value="forge_speaker_merged"
         | 
| 215 | 
            +
                                    )
         | 
| 216 | 
            +
                                    speaker_gender = gr.Textbox(label="Gender", value="*")
         | 
| 217 | 
            +
                                    speaker_desc = gr.Textbox(
         | 
| 218 | 
            +
                                        label="Description", value="merged speaker"
         | 
| 219 | 
            +
                                    )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                                    save_btn = gr.Button("Save Speaker", variant="primary")
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                                    merged_spker = gr.File(
         | 
| 224 | 
            +
                                        label="Merged Speaker", interactive=False, type="binary"
         | 
| 225 | 
            +
                                    )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        test_voice_btn.click(
         | 
| 228 | 
            +
                            merge_and_test_spk_voice,
         | 
| 229 | 
            +
                            inputs=[
         | 
| 230 | 
            +
                                spk_a,
         | 
| 231 | 
            +
                                spk_a_w,
         | 
| 232 | 
            +
                                spk_b,
         | 
| 233 | 
            +
                                spk_b_w,
         | 
| 234 | 
            +
                                spk_c,
         | 
| 235 | 
            +
                                spk_c_w,
         | 
| 236 | 
            +
                                spk_d,
         | 
| 237 | 
            +
                                spk_d_w,
         | 
| 238 | 
            +
                                test_text,
         | 
| 239 | 
            +
                            ],
         | 
| 240 | 
            +
                            outputs=[output_audio],
         | 
| 241 | 
            +
                        )
         | 
| 242 |  | 
| 243 | 
            +
                        save_btn.click(
         | 
| 244 | 
            +
                            merge_spk_to_file,
         | 
| 245 | 
            +
                            inputs=[
         | 
| 246 | 
            +
                                spk_a,
         | 
| 247 | 
            +
                                spk_a_w,
         | 
| 248 | 
            +
                                spk_b,
         | 
| 249 | 
            +
                                spk_b_w,
         | 
| 250 | 
            +
                                spk_c,
         | 
| 251 | 
            +
                                spk_c_w,
         | 
| 252 | 
            +
                                spk_d,
         | 
| 253 | 
            +
                                spk_d_w,
         | 
| 254 | 
            +
                                speaker_name,
         | 
| 255 | 
            +
                                speaker_gender,
         | 
| 256 | 
            +
                                speaker_desc,
         | 
| 257 | 
            +
                            ],
         | 
| 258 | 
            +
                            outputs=[merged_spker],
         | 
| 259 | 
            +
                        )
         | 
    	
        modules/webui/spliter_tab.py
    CHANGED
    
    | @@ -9,6 +9,7 @@ from modules.webui.webui_utils import ( | |
| 9 | 
             
            from modules.hf import spaces
         | 
| 10 |  | 
| 11 |  | 
|  | |
| 12 | 
             
            @torch.inference_mode()
         | 
| 13 | 
             
            @spaces.GPU
         | 
| 14 | 
             
            def merge_dataframe_to_ssml(dataframe, spk, style, seed):
         | 
| @@ -31,7 +32,7 @@ def merge_dataframe_to_ssml(dataframe, spk, style, seed): | |
| 31 | 
             
                    if seed:
         | 
| 32 | 
             
                        ssml += f' seed="{seed}"'
         | 
| 33 | 
             
                    ssml += ">\n"
         | 
| 34 | 
            -
                    ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
         | 
| 35 | 
             
                    ssml += f"{indent}</voice>\n"
         | 
| 36 | 
             
                return f"<speak version='0.1'>\n{ssml}</speak>"
         | 
| 37 |  | 
|  | |
| 9 | 
             
            from modules.hf import spaces
         | 
| 10 |  | 
| 11 |  | 
| 12 | 
            +
            # NOTE: 因为 text_normalize 需要使用 tokenizer
         | 
| 13 | 
             
            @torch.inference_mode()
         | 
| 14 | 
             
            @spaces.GPU
         | 
| 15 | 
             
            def merge_dataframe_to_ssml(dataframe, spk, style, seed):
         | 
|  | |
| 32 | 
             
                    if seed:
         | 
| 33 | 
             
                        ssml += f' seed="{seed}"'
         | 
| 34 | 
             
                    ssml += ">\n"
         | 
| 35 | 
            +
                    ssml += f"{indent}{indent}{text_normalize(row.iloc[1])}\n"
         | 
| 36 | 
             
                    ssml += f"{indent}</voice>\n"
         | 
| 37 | 
             
                return f"<speak version='0.1'>\n{ssml}</speak>"
         | 
| 38 |  | 
    	
        modules/webui/system_tab.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from modules.webui import webui_config
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def create_system_tab():
         | 
| 6 | 
            +
                with gr.Row():
         | 
| 7 | 
            +
                    with gr.Column(scale=1):
         | 
| 8 | 
            +
                        gr.Markdown(f"info")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    with gr.Column(scale=5):
         | 
| 11 | 
            +
                        toggle_experimental = gr.Checkbox(
         | 
| 12 | 
            +
                            label="Enable Experimental Features",
         | 
| 13 | 
            +
                            value=webui_config.experimental,
         | 
| 14 | 
            +
                            interactive=False,
         | 
| 15 | 
            +
                        )
         | 
    	
        modules/webui/tts_tab.py
    CHANGED
    
    | @@ -3,6 +3,7 @@ import torch | |
| 3 | 
             
            from modules.webui.webui_utils import (
         | 
| 4 | 
             
                get_speakers,
         | 
| 5 | 
             
                get_styles,
         | 
|  | |
| 6 | 
             
                refine_text,
         | 
| 7 | 
             
                tts_generate,
         | 
| 8 | 
             
            )
         | 
| @@ -10,6 +11,13 @@ from modules.webui import webui_config | |
| 10 | 
             
            from modules.webui.examples import example_texts
         | 
| 11 | 
             
            from modules import config
         | 
| 12 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 13 |  | 
| 14 | 
             
            def create_tts_interface():
         | 
| 15 | 
             
                speakers = get_speakers()
         | 
| @@ -90,15 +98,18 @@ def create_tts_interface(): | |
| 90 | 
             
                                            outputs=[spk_input_text],
         | 
| 91 | 
             
                                        )
         | 
| 92 |  | 
| 93 | 
            -
                                     | 
| 94 | 
            -
                                         | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
                                             | 
| 101 | 
            -
                                             | 
|  | |
|  | |
|  | |
| 102 | 
             
                        with gr.Group():
         | 
| 103 | 
             
                            gr.Markdown("💃Inference Seed")
         | 
| 104 | 
             
                            infer_seed_input = gr.Number(
         | 
| @@ -122,85 +133,62 @@ def create_tts_interface(): | |
| 122 | 
             
                            prompt2_input = gr.Textbox(label="Prompt 2")
         | 
| 123 | 
             
                            prefix_input = gr.Textbox(label="Prefix")
         | 
| 124 |  | 
| 125 | 
            -
                             | 
| 126 | 
            -
                                 | 
|  | |
| 127 |  | 
| 128 | 
             
                        infer_seed_rand_button.click(
         | 
| 129 | 
             
                            lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
         | 
| 130 | 
             
                            inputs=[infer_seed_input],
         | 
| 131 | 
             
                            outputs=[infer_seed_input],
         | 
| 132 | 
             
                        )
         | 
| 133 | 
            -
                    with gr.Column(scale= | 
| 134 | 
            -
                        with gr. | 
| 135 | 
            -
                             | 
| 136 | 
            -
                                 | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
             | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
                                     | 
| 169 | 
            -
                                     | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
| 177 | 
            -
                                         | 
| 178 | 
            -
             | 
| 179 | 
            -
             | 
| 180 | 
            -
                                                lambda text, tk=tk: text + " " + tk,
         | 
| 181 | 
            -
                                                inputs=[text_input],
         | 
| 182 | 
            -
                                                outputs=[text_input],
         | 
| 183 | 
            -
                                            )
         | 
| 184 | 
            -
                            with gr.Column(scale=1):
         | 
| 185 | 
            -
                                with gr.Group():
         | 
| 186 | 
            -
                                    gr.Markdown("🎶Refiner")
         | 
| 187 | 
            -
                                    refine_prompt_input = gr.Textbox(
         | 
| 188 | 
            -
                                        label="Refine Prompt",
         | 
| 189 | 
            -
                                        value="[oral_2][laugh_0][break_6]",
         | 
| 190 | 
            -
                                    )
         | 
| 191 | 
            -
                                    refine_button = gr.Button("✍️Refine Text")
         | 
| 192 | 
            -
                                    # TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
         | 
| 193 | 
            -
                                    # send_button = gr.Button("📩Split and send to SSML")
         | 
| 194 | 
            -
             | 
| 195 | 
            -
                                with gr.Group():
         | 
| 196 | 
            -
                                    gr.Markdown("🔊Generate")
         | 
| 197 | 
            -
                                    disable_normalize_input = gr.Checkbox(
         | 
| 198 | 
            -
                                        value=False, label="Disable Normalize"
         | 
| 199 | 
            -
                                    )
         | 
| 200 | 
            -
                                    tts_button = gr.Button(
         | 
| 201 | 
            -
                                        "🔊Generate Audio",
         | 
| 202 | 
            -
                                        variant="primary",
         | 
| 203 | 
            -
                                        elem_classes="big-button",
         | 
| 204 | 
             
                                    )
         | 
| 205 |  | 
| 206 | 
             
                        with gr.Group():
         | 
| @@ -220,6 +208,31 @@ def create_tts_interface(): | |
| 220 | 
             
                        with gr.Group():
         | 
| 221 | 
             
                            gr.Markdown("🎨Output")
         | 
| 222 | 
             
                            tts_output = gr.Audio(label="Generated Audio")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 223 |  | 
| 224 | 
             
                refine_button.click(
         | 
| 225 | 
             
                    refine_text,
         | 
| @@ -243,6 +256,9 @@ def create_tts_interface(): | |
| 243 | 
             
                        style_input_dropdown,
         | 
| 244 | 
             
                        disable_normalize_input,
         | 
| 245 | 
             
                        batch_size_input,
         | 
|  | |
|  | |
|  | |
| 246 | 
             
                    ],
         | 
| 247 | 
             
                    outputs=tts_output,
         | 
| 248 | 
             
                )
         | 
|  | |
| 3 | 
             
            from modules.webui.webui_utils import (
         | 
| 4 | 
             
                get_speakers,
         | 
| 5 | 
             
                get_styles,
         | 
| 6 | 
            +
                load_spk_info,
         | 
| 7 | 
             
                refine_text,
         | 
| 8 | 
             
                tts_generate,
         | 
| 9 | 
             
            )
         | 
|  | |
| 11 | 
             
            from modules.webui.examples import example_texts
         | 
| 12 | 
             
            from modules import config
         | 
| 13 |  | 
| 14 | 
            +
            default_text_content = """
         | 
| 15 | 
            +
            chat T T S 是一款强大的对话式文本转语音模型。它有中英混读和多说话人的能力。
         | 
| 16 | 
            +
            chat T T S 不仅能够生成自然流畅的语音,还能控制[laugh]笑声啊[laugh],
         | 
| 17 | 
            +
            停顿啊[uv_break]语气词啊等副语言现象[uv_break]。这个韵律超越了许多开源模型[uv_break]。
         | 
| 18 | 
            +
            请注意,chat T T S 的使用应遵守法律和伦理准则,避免滥用的安全风险。[uv_break]
         | 
| 19 | 
            +
            """
         | 
| 20 | 
            +
             | 
| 21 |  | 
| 22 | 
             
            def create_tts_interface():
         | 
| 23 | 
             
                speakers = get_speakers()
         | 
|  | |
| 98 | 
             
                                            outputs=[spk_input_text],
         | 
| 99 | 
             
                                        )
         | 
| 100 |  | 
| 101 | 
            +
                                    with gr.Tab(label="Upload"):
         | 
| 102 | 
            +
                                        spk_file_upload = gr.File(label="Speaker (Upload)")
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                                        gr.Markdown("📝Speaker info")
         | 
| 105 | 
            +
                                        infos = gr.Markdown("empty")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                                        spk_file_upload.change(
         | 
| 108 | 
            +
                                            fn=load_spk_info,
         | 
| 109 | 
            +
                                            inputs=[spk_file_upload],
         | 
| 110 | 
            +
                                            outputs=[infos],
         | 
| 111 | 
            +
                                        ),
         | 
| 112 | 
            +
             | 
| 113 | 
             
                        with gr.Group():
         | 
| 114 | 
             
                            gr.Markdown("💃Inference Seed")
         | 
| 115 | 
             
                            infer_seed_input = gr.Number(
         | 
|  | |
| 133 | 
             
                            prompt2_input = gr.Textbox(label="Prompt 2")
         | 
| 134 | 
             
                            prefix_input = gr.Textbox(label="Prefix")
         | 
| 135 |  | 
| 136 | 
            +
                            prompt_audio = gr.File(
         | 
| 137 | 
            +
                                label="prompt_audio", visible=webui_config.experimental
         | 
| 138 | 
            +
                            )
         | 
| 139 |  | 
| 140 | 
             
                        infer_seed_rand_button.click(
         | 
| 141 | 
             
                            lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
         | 
| 142 | 
             
                            inputs=[infer_seed_input],
         | 
| 143 | 
             
                            outputs=[infer_seed_input],
         | 
| 144 | 
             
                        )
         | 
| 145 | 
            +
                    with gr.Column(scale=4):
         | 
| 146 | 
            +
                        with gr.Group():
         | 
| 147 | 
            +
                            input_title = gr.Markdown(
         | 
| 148 | 
            +
                                "📝Text Input",
         | 
| 149 | 
            +
                                elem_id="input-title",
         | 
| 150 | 
            +
                            )
         | 
| 151 | 
            +
                            gr.Markdown(f"- 字数限制{webui_config.tts_max:,}字,超过部分截断")
         | 
| 152 | 
            +
                            gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
         | 
| 153 | 
            +
                            gr.Markdown(
         | 
| 154 | 
            +
                                "- If the input text is all in English, it is recommended to check disable_normalize"
         | 
| 155 | 
            +
                            )
         | 
| 156 | 
            +
                            text_input = gr.Textbox(
         | 
| 157 | 
            +
                                show_label=False,
         | 
| 158 | 
            +
                                label="Text to Speech",
         | 
| 159 | 
            +
                                lines=10,
         | 
| 160 | 
            +
                                placeholder="输入文本或选择示例",
         | 
| 161 | 
            +
                                elem_id="text-input",
         | 
| 162 | 
            +
                                value=default_text_content,
         | 
| 163 | 
            +
                            )
         | 
| 164 | 
            +
                            # TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
         | 
| 165 | 
            +
                            # text_input.change(
         | 
| 166 | 
            +
                            #     fn=lambda x: (
         | 
| 167 | 
            +
                            #         f"📝Text Input ({len(x)} char)"
         | 
| 168 | 
            +
                            #         if x
         | 
| 169 | 
            +
                            #         else (
         | 
| 170 | 
            +
                            #             "📝Text Input (0 char)"
         | 
| 171 | 
            +
                            #             if not x
         | 
| 172 | 
            +
                            #             else "📝Text Input (0 char)"
         | 
| 173 | 
            +
                            #         )
         | 
| 174 | 
            +
                            #     ),
         | 
| 175 | 
            +
                            #     inputs=[text_input],
         | 
| 176 | 
            +
                            #     outputs=[input_title],
         | 
| 177 | 
            +
                            # )
         | 
| 178 | 
            +
                            with gr.Row():
         | 
| 179 | 
            +
                                contorl_tokens = [
         | 
| 180 | 
            +
                                    "[laugh]",
         | 
| 181 | 
            +
                                    "[uv_break]",
         | 
| 182 | 
            +
                                    "[v_break]",
         | 
| 183 | 
            +
                                    "[lbreak]",
         | 
| 184 | 
            +
                                ]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                                for tk in contorl_tokens:
         | 
| 187 | 
            +
                                    t_btn = gr.Button(tk)
         | 
| 188 | 
            +
                                    t_btn.click(
         | 
| 189 | 
            +
                                        lambda text, tk=tk: text + " " + tk,
         | 
| 190 | 
            +
                                        inputs=[text_input],
         | 
| 191 | 
            +
                                        outputs=[text_input],
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 192 | 
             
                                    )
         | 
| 193 |  | 
| 194 | 
             
                        with gr.Group():
         | 
|  | |
| 208 | 
             
                        with gr.Group():
         | 
| 209 | 
             
                            gr.Markdown("🎨Output")
         | 
| 210 | 
             
                            tts_output = gr.Audio(label="Generated Audio")
         | 
| 211 | 
            +
                    with gr.Column(scale=1):
         | 
| 212 | 
            +
                        with gr.Group():
         | 
| 213 | 
            +
                            gr.Markdown("🎶Refiner")
         | 
| 214 | 
            +
                            refine_prompt_input = gr.Textbox(
         | 
| 215 | 
            +
                                label="Refine Prompt",
         | 
| 216 | 
            +
                                value="[oral_2][laugh_0][break_6]",
         | 
| 217 | 
            +
                            )
         | 
| 218 | 
            +
                            refine_button = gr.Button("✍️Refine Text")
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                        with gr.Group():
         | 
| 221 | 
            +
                            gr.Markdown("🔊Generate")
         | 
| 222 | 
            +
                            disable_normalize_input = gr.Checkbox(
         | 
| 223 | 
            +
                                value=False, label="Disable Normalize"
         | 
| 224 | 
            +
                            )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                            # FIXME: 不知道为啥,就是非常慢,单独调脚本是很快的
         | 
| 227 | 
            +
                            with gr.Group(visible=webui_config.experimental):
         | 
| 228 | 
            +
                                gr.Markdown("💪🏼Enhance")
         | 
| 229 | 
            +
                                enable_enhance = gr.Checkbox(value=False, label="Enable Enhance")
         | 
| 230 | 
            +
                                enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
         | 
| 231 | 
            +
                            tts_button = gr.Button(
         | 
| 232 | 
            +
                                "🔊Generate Audio",
         | 
| 233 | 
            +
                                variant="primary",
         | 
| 234 | 
            +
                                elem_classes="big-button",
         | 
| 235 | 
            +
                            )
         | 
| 236 |  | 
| 237 | 
             
                refine_button.click(
         | 
| 238 | 
             
                    refine_text,
         | 
|  | |
| 256 | 
             
                        style_input_dropdown,
         | 
| 257 | 
             
                        disable_normalize_input,
         | 
| 258 | 
             
                        batch_size_input,
         | 
| 259 | 
            +
                        enable_enhance,
         | 
| 260 | 
            +
                        enable_de_noise,
         | 
| 261 | 
            +
                        spk_file_upload,
         | 
| 262 | 
             
                    ],
         | 
| 263 | 
             
                    outputs=tts_output,
         | 
| 264 | 
             
                )
         | 
    	
        modules/webui/webui_config.py
    CHANGED
    
    | @@ -1,4 +1,8 @@ | |
|  | |
|  | |
|  | |
| 1 | 
             
            tts_max = 1000
         | 
| 2 | 
             
            ssml_max = 1000
         | 
| 3 | 
             
            spliter_threshold = 100
         | 
| 4 | 
             
            max_batch_size = 8
         | 
|  | 
|  | |
| 1 | 
            +
            from typing import Literal
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
             
            tts_max = 1000
         | 
| 5 | 
             
            ssml_max = 1000
         | 
| 6 | 
             
            spliter_threshold = 100
         | 
| 7 | 
             
            max_batch_size = 8
         | 
| 8 | 
            +
            experimental = False
         | 
    	
        modules/webui/webui_utils.py
    CHANGED
    
    | @@ -1,37 +1,26 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
            -
            import logging
         | 
| 3 | 
            -
            import sys
         | 
| 4 | 
            -
             | 
| 5 | 
             
            import numpy as np
         | 
| 6 |  | 
|  | |
| 7 | 
             
            from modules.devices import devices
         | 
| 8 | 
             
            from modules.synthesize_audio import synthesize_audio
         | 
| 9 | 
             
            from modules.hf import spaces
         | 
| 10 | 
             
            from modules.webui import webui_config
         | 
| 11 |  | 
| 12 | 
            -
            logging.basicConfig(
         | 
| 13 | 
            -
                level=os.getenv("LOG_LEVEL", "INFO"),
         | 
| 14 | 
            -
                format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
         | 
| 15 | 
            -
            )
         | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
            import gradio as gr
         | 
| 19 | 
            -
             | 
| 20 | 
             
            import torch
         | 
| 21 |  | 
| 22 | 
            -
            from modules. | 
| 23 | 
             
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         | 
| 24 |  | 
| 25 | 
            -
            from modules.speaker import speaker_mgr
         | 
| 26 | 
             
            from modules.data import styles_mgr
         | 
| 27 |  | 
| 28 | 
             
            from modules.api.utils import calc_spk_style
         | 
| 29 | 
            -
            import modules.generate_audio as generate
         | 
| 30 |  | 
| 31 | 
             
            from modules.normalization import text_normalize
         | 
| 32 | 
            -
            from modules import refiner | 
| 33 |  | 
| 34 | 
            -
            from modules.utils import  | 
| 35 | 
             
            from modules.SentenceSplitter import SentenceSplitter
         | 
| 36 |  | 
| 37 |  | 
| @@ -43,11 +32,30 @@ def get_styles(): | |
| 43 | 
             
                return styles_mgr.list_items()
         | 
| 44 |  | 
| 45 |  | 
| 46 | 
            -
            def  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                ret_segments = []
         | 
| 48 | 
             
                total_len = 0
         | 
| 49 | 
             
                for seg in segments:
         | 
| 50 | 
            -
                    if  | 
|  | |
| 51 | 
             
                        continue
         | 
| 52 | 
             
                    total_len += len(seg["text"])
         | 
| 53 | 
             
                    if total_len > total_max:
         | 
| @@ -56,6 +64,28 @@ def segments_length_limit(segments, total_max: int): | |
| 56 | 
             
                return ret_segments
         | 
| 57 |  | 
| 58 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 59 | 
             
            @torch.inference_mode()
         | 
| 60 | 
             
            @spaces.GPU
         | 
| 61 | 
             
            def synthesize_ssml(ssml: str, batch_size=4):
         | 
| @@ -69,7 +99,8 @@ def synthesize_ssml(ssml: str, batch_size=4): | |
| 69 | 
             
                if ssml == "":
         | 
| 70 | 
             
                    return None
         | 
| 71 |  | 
| 72 | 
            -
                 | 
|  | |
| 73 | 
             
                max_len = webui_config.ssml_max
         | 
| 74 | 
             
                segments = segments_length_limit(segments, max_len)
         | 
| 75 |  | 
| @@ -87,18 +118,21 @@ def synthesize_ssml(ssml: str, batch_size=4): | |
| 87 | 
             
            @spaces.GPU
         | 
| 88 | 
             
            def tts_generate(
         | 
| 89 | 
             
                text,
         | 
| 90 | 
            -
                temperature,
         | 
| 91 | 
            -
                top_p,
         | 
| 92 | 
            -
                top_k,
         | 
| 93 | 
            -
                spk,
         | 
| 94 | 
            -
                infer_seed,
         | 
| 95 | 
            -
                use_decoder,
         | 
| 96 | 
            -
                prompt1,
         | 
| 97 | 
            -
                prompt2,
         | 
| 98 | 
            -
                prefix,
         | 
| 99 | 
            -
                style,
         | 
| 100 | 
             
                disable_normalize=False,
         | 
| 101 | 
             
                batch_size=4,
         | 
|  | |
|  | |
|  | |
| 102 | 
             
            ):
         | 
| 103 | 
             
                try:
         | 
| 104 | 
             
                    batch_size = int(batch_size)
         | 
| @@ -126,12 +160,15 @@ def tts_generate( | |
| 126 | 
             
                prompt1 = prompt1 or params.get("prompt1", "")
         | 
| 127 | 
             
                prompt2 = prompt2 or params.get("prompt2", "")
         | 
| 128 |  | 
| 129 | 
            -
                infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np. | 
| 130 | 
             
                infer_seed = int(infer_seed)
         | 
| 131 |  | 
| 132 | 
             
                if not disable_normalize:
         | 
| 133 | 
             
                    text = text_normalize(text)
         | 
| 134 |  | 
|  | |
|  | |
|  | |
| 135 | 
             
                sample_rate, audio_data = synthesize_audio(
         | 
| 136 | 
             
                    text=text,
         | 
| 137 | 
             
                    temperature=temperature,
         | 
| @@ -146,6 +183,10 @@ def tts_generate( | |
| 146 | 
             
                    batch_size=batch_size,
         | 
| 147 | 
             
                )
         | 
| 148 |  | 
|  | |
|  | |
|  | |
|  | |
| 149 | 
             
                audio_data = audio.audio_to_int16(audio_data)
         | 
| 150 | 
             
                return sample_rate, audio_data
         | 
| 151 |  | 
|  | |
| 1 | 
            +
            from typing import Union
         | 
|  | |
|  | |
|  | |
| 2 | 
             
            import numpy as np
         | 
| 3 |  | 
| 4 | 
            +
            from modules.Enhancer.ResembleEnhance import load_enhancer
         | 
| 5 | 
             
            from modules.devices import devices
         | 
| 6 | 
             
            from modules.synthesize_audio import synthesize_audio
         | 
| 7 | 
             
            from modules.hf import spaces
         | 
| 8 | 
             
            from modules.webui import webui_config
         | 
| 9 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 10 | 
             
            import torch
         | 
| 11 |  | 
| 12 | 
            +
            from modules.ssml_parser.SSMLParser import create_ssml_parser, SSMLBreak, SSMLSegment
         | 
| 13 | 
             
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         | 
| 14 |  | 
| 15 | 
            +
            from modules.speaker import speaker_mgr, Speaker
         | 
| 16 | 
             
            from modules.data import styles_mgr
         | 
| 17 |  | 
| 18 | 
             
            from modules.api.utils import calc_spk_style
         | 
|  | |
| 19 |  | 
| 20 | 
             
            from modules.normalization import text_normalize
         | 
| 21 | 
            +
            from modules import refiner
         | 
| 22 |  | 
| 23 | 
            +
            from modules.utils import audio
         | 
| 24 | 
             
            from modules.SentenceSplitter import SentenceSplitter
         | 
| 25 |  | 
| 26 |  | 
|  | |
| 32 | 
             
                return styles_mgr.list_items()
         | 
| 33 |  | 
| 34 |  | 
| 35 | 
            +
            def load_spk_info(file):
         | 
| 36 | 
            +
                if file is None:
         | 
| 37 | 
            +
                    return "empty"
         | 
| 38 | 
            +
                try:
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    spk: Speaker = Speaker.from_file(file)
         | 
| 41 | 
            +
                    infos = spk.to_json()
         | 
| 42 | 
            +
                    return f"""
         | 
| 43 | 
            +
            - name: {infos.name}
         | 
| 44 | 
            +
            - gender: {infos.gender}
         | 
| 45 | 
            +
            - describe: {infos.describe}
         | 
| 46 | 
            +
                """.strip()
         | 
| 47 | 
            +
                except:
         | 
| 48 | 
            +
                    return "load failed"
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def segments_length_limit(
         | 
| 52 | 
            +
                segments: list[Union[SSMLBreak, SSMLSegment]], total_max: int
         | 
| 53 | 
            +
            ) -> list[Union[SSMLBreak, SSMLSegment]]:
         | 
| 54 | 
             
                ret_segments = []
         | 
| 55 | 
             
                total_len = 0
         | 
| 56 | 
             
                for seg in segments:
         | 
| 57 | 
            +
                    if isinstance(seg, SSMLBreak):
         | 
| 58 | 
            +
                        ret_segments.append(seg)
         | 
| 59 | 
             
                        continue
         | 
| 60 | 
             
                    total_len += len(seg["text"])
         | 
| 61 | 
             
                    if total_len > total_max:
         | 
|  | |
| 64 | 
             
                return ret_segments
         | 
| 65 |  | 
| 66 |  | 
| 67 | 
            +
            @torch.inference_mode()
         | 
| 68 | 
            +
            @spaces.GPU
         | 
| 69 | 
            +
            def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
         | 
| 70 | 
            +
                audio_data = torch.from_numpy(audio_data).float().squeeze().cpu()
         | 
| 71 | 
            +
                if enable_denoise or enable_enhance:
         | 
| 72 | 
            +
                    enhancer = load_enhancer(devices.device)
         | 
| 73 | 
            +
                    if enable_denoise:
         | 
| 74 | 
            +
                        audio_data, sr = enhancer.denoise(audio_data, sr, devices.device)
         | 
| 75 | 
            +
                    if enable_enhance:
         | 
| 76 | 
            +
                        audio_data, sr = enhancer.enhance(
         | 
| 77 | 
            +
                            audio_data,
         | 
| 78 | 
            +
                            sr,
         | 
| 79 | 
            +
                            devices.device,
         | 
| 80 | 
            +
                            tau=0.9,
         | 
| 81 | 
            +
                            nfe=64,
         | 
| 82 | 
            +
                            solver="euler",
         | 
| 83 | 
            +
                            lambd=0.5,
         | 
| 84 | 
            +
                        )
         | 
| 85 | 
            +
                audio_data = audio_data.cpu().numpy()
         | 
| 86 | 
            +
                return audio_data, int(sr)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
             
            @torch.inference_mode()
         | 
| 90 | 
             
            @spaces.GPU
         | 
| 91 | 
             
            def synthesize_ssml(ssml: str, batch_size=4):
         | 
|  | |
| 99 | 
             
                if ssml == "":
         | 
| 100 | 
             
                    return None
         | 
| 101 |  | 
| 102 | 
            +
                parser = create_ssml_parser()
         | 
| 103 | 
            +
                segments = parser.parse(ssml)
         | 
| 104 | 
             
                max_len = webui_config.ssml_max
         | 
| 105 | 
             
                segments = segments_length_limit(segments, max_len)
         | 
| 106 |  | 
|  | |
| 118 | 
             
            @spaces.GPU
         | 
| 119 | 
             
            def tts_generate(
         | 
| 120 | 
             
                text,
         | 
| 121 | 
            +
                temperature=0.3,
         | 
| 122 | 
            +
                top_p=0.7,
         | 
| 123 | 
            +
                top_k=20,
         | 
| 124 | 
            +
                spk=-1,
         | 
| 125 | 
            +
                infer_seed=-1,
         | 
| 126 | 
            +
                use_decoder=True,
         | 
| 127 | 
            +
                prompt1="",
         | 
| 128 | 
            +
                prompt2="",
         | 
| 129 | 
            +
                prefix="",
         | 
| 130 | 
            +
                style="",
         | 
| 131 | 
             
                disable_normalize=False,
         | 
| 132 | 
             
                batch_size=4,
         | 
| 133 | 
            +
                enable_enhance=False,
         | 
| 134 | 
            +
                enable_denoise=False,
         | 
| 135 | 
            +
                spk_file=None,
         | 
| 136 | 
             
            ):
         | 
| 137 | 
             
                try:
         | 
| 138 | 
             
                    batch_size = int(batch_size)
         | 
|  | |
| 160 | 
             
                prompt1 = prompt1 or params.get("prompt1", "")
         | 
| 161 | 
             
                prompt2 = prompt2 or params.get("prompt2", "")
         | 
| 162 |  | 
| 163 | 
            +
                infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
         | 
| 164 | 
             
                infer_seed = int(infer_seed)
         | 
| 165 |  | 
| 166 | 
             
                if not disable_normalize:
         | 
| 167 | 
             
                    text = text_normalize(text)
         | 
| 168 |  | 
| 169 | 
            +
                if spk_file:
         | 
| 170 | 
            +
                    spk = Speaker.from_file(spk_file)
         | 
| 171 | 
            +
             | 
| 172 | 
             
                sample_rate, audio_data = synthesize_audio(
         | 
| 173 | 
             
                    text=text,
         | 
| 174 | 
             
                    temperature=temperature,
         | 
|  | |
| 183 | 
             
                    batch_size=batch_size,
         | 
| 184 | 
             
                )
         | 
| 185 |  | 
| 186 | 
            +
                audio_data, sample_rate = apply_audio_enhance(
         | 
| 187 | 
            +
                    audio_data, sample_rate, enable_denoise, enable_enhance
         | 
| 188 | 
            +
                )
         | 
| 189 | 
            +
             | 
| 190 | 
             
                audio_data = audio.audio_to_int16(audio_data)
         | 
| 191 | 
             
                return sample_rate, audio_data
         | 
| 192 |  | 
    	
        webui.py
    CHANGED
    
    | @@ -93,8 +93,10 @@ if __name__ == "__main__": | |
| 93 | 
             
                device_id = get_and_update_env(args, "device_id", None, str)
         | 
| 94 | 
             
                use_cpu = get_and_update_env(args, "use_cpu", [], list)
         | 
| 95 | 
             
                compile = get_and_update_env(args, "compile", False, bool)
         | 
| 96 | 
            -
                webui_experimental = get_and_update_env(args, "webui_experimental", False, bool)
         | 
| 97 |  | 
|  | |
|  | |
|  | |
| 98 | 
             
                webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
         | 
| 99 | 
             
                webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
         | 
| 100 | 
             
                webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
         | 
|  | |
| 93 | 
             
                device_id = get_and_update_env(args, "device_id", None, str)
         | 
| 94 | 
             
                use_cpu = get_and_update_env(args, "use_cpu", [], list)
         | 
| 95 | 
             
                compile = get_and_update_env(args, "compile", False, bool)
         | 
|  | |
| 96 |  | 
| 97 | 
            +
                webui_config.experimental = get_and_update_env(
         | 
| 98 | 
            +
                    args, "webui_experimental", False, bool
         | 
| 99 | 
            +
                )
         | 
| 100 | 
             
                webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
         | 
| 101 | 
             
                webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
         | 
| 102 | 
             
                webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
         | 
 
			
