Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		zhzluke96
		
	commited on
		
		
					Commit 
							
							·
						
						02e90e4
	
1
								Parent(s):
							
							d6fe286
								
update
Browse files- modules/ChatTTS/ChatTTS/core.py +21 -5
- modules/ChatTTS/ChatTTS/model/dvae.py +3 -3
- modules/ChatTTS/ChatTTS/model/gpt.py +2 -3
- modules/ChatTTS/ChatTTS/utils/gpu_utils.py +3 -1
- modules/ChatTTS/ChatTTS/utils/infer_utils.py +5 -5
- modules/SynthesizeSegments.py +2 -2
- modules/api/Api.py +12 -1
- modules/api/impl/google_api.py +16 -3
- modules/api/impl/models_api.py +11 -0
- modules/api/impl/openai_api.py +19 -7
- modules/api/impl/ping_api.py +8 -0
- modules/api/utils.py +6 -2
- modules/config.py +2 -8
- modules/devices/__init__.py +0 -0
- modules/devices/devices.py +160 -0
- modules/devices/mac_devices.py +42 -0
- modules/generate_audio.py +33 -5
- modules/models.py +24 -20
- modules/normalization.py +47 -2
- modules/refiner.py +1 -1
- modules/speaker.py +11 -2
- modules/synthesize_audio.py +2 -1
- modules/utils/JsonObject.py +113 -0
- modules/utils/cache.py +92 -0
- modules/utils/zh_normalization/text_normlization.py +3 -3
- webui.py +49 -22
    	
        modules/ChatTTS/ChatTTS/core.py
    CHANGED
    
    | @@ -101,13 +101,27 @@ class Chat: | |
| 101 | 
             
                    tokenizer_path: str = None,
         | 
| 102 | 
             
                    device: str = None,
         | 
| 103 | 
             
                    compile: bool = True,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 104 | 
             
                ):
         | 
| 105 | 
             
                    if not device:
         | 
| 106 | 
             
                        device = select_device(4096)
         | 
| 107 | 
             
                        self.logger.log(logging.INFO, f"use {device}")
         | 
| 108 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 109 | 
             
                    if vocos_config_path:
         | 
| 110 | 
            -
                        vocos =  | 
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
                        assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
         | 
| 112 | 
             
                        vocos.load_state_dict(torch.load(vocos_ckpt_path))
         | 
| 113 | 
             
                        self.pretrain_models["vocos"] = vocos
         | 
| @@ -115,7 +129,7 @@ class Chat: | |
| 115 |  | 
| 116 | 
             
                    if dvae_config_path:
         | 
| 117 | 
             
                        cfg = OmegaConf.load(dvae_config_path)
         | 
| 118 | 
            -
                        dvae = DVAE(**cfg).to(device).eval()
         | 
| 119 | 
             
                        assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
         | 
| 120 | 
             
                        dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
         | 
| 121 | 
             
                        self.pretrain_models["dvae"] = dvae
         | 
| @@ -123,7 +137,7 @@ class Chat: | |
| 123 |  | 
| 124 | 
             
                    if gpt_config_path:
         | 
| 125 | 
             
                        cfg = OmegaConf.load(gpt_config_path)
         | 
| 126 | 
            -
                        gpt = GPT_warpper(**cfg).to(device).eval()
         | 
| 127 | 
             
                        assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
         | 
| 128 | 
             
                        gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
         | 
| 129 | 
             
                        if compile and "cuda" in str(device):
         | 
| @@ -136,12 +150,14 @@ class Chat: | |
| 136 | 
             
                        assert os.path.exists(
         | 
| 137 | 
             
                            spk_stat_path
         | 
| 138 | 
             
                        ), f"Missing spk_stat.pt: {spk_stat_path}"
         | 
| 139 | 
            -
                        self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to( | 
|  | |
|  | |
| 140 | 
             
                        self.logger.log(logging.INFO, "gpt loaded.")
         | 
| 141 |  | 
| 142 | 
             
                    if decoder_config_path:
         | 
| 143 | 
             
                        cfg = OmegaConf.load(decoder_config_path)
         | 
| 144 | 
            -
                        decoder = DVAE(**cfg).to(device).eval()
         | 
| 145 | 
             
                        assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
         | 
| 146 | 
             
                        decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
         | 
| 147 | 
             
                        self.pretrain_models["decoder"] = decoder
         | 
|  | |
| 101 | 
             
                    tokenizer_path: str = None,
         | 
| 102 | 
             
                    device: str = None,
         | 
| 103 | 
             
                    compile: bool = True,
         | 
| 104 | 
            +
                    dtype: torch.dtype = torch.float32,
         | 
| 105 | 
            +
                    dtype_vocos: torch.dtype = None,
         | 
| 106 | 
            +
                    dtype_dvae: torch.dtype = None,
         | 
| 107 | 
            +
                    dtype_gpt: torch.dtype = None,
         | 
| 108 | 
            +
                    dtype_decoder: torch.dtype = None,
         | 
| 109 | 
             
                ):
         | 
| 110 | 
             
                    if not device:
         | 
| 111 | 
             
                        device = select_device(4096)
         | 
| 112 | 
             
                        self.logger.log(logging.INFO, f"use {device}")
         | 
| 113 |  | 
| 114 | 
            +
                    dtype_vocos = dtype_vocos or dtype
         | 
| 115 | 
            +
                    dtype_dvae = dtype_dvae or dtype
         | 
| 116 | 
            +
                    dtype_gpt = dtype_gpt or dtype
         | 
| 117 | 
            +
                    dtype_decoder = dtype_decoder or dtype
         | 
| 118 | 
            +
             | 
| 119 | 
             
                    if vocos_config_path:
         | 
| 120 | 
            +
                        vocos = (
         | 
| 121 | 
            +
                            Vocos.from_hparams(vocos_config_path)
         | 
| 122 | 
            +
                            .to(device=device, dtype=dtype_vocos)
         | 
| 123 | 
            +
                            .eval()
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
             
                        assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
         | 
| 126 | 
             
                        vocos.load_state_dict(torch.load(vocos_ckpt_path))
         | 
| 127 | 
             
                        self.pretrain_models["vocos"] = vocos
         | 
|  | |
| 129 |  | 
| 130 | 
             
                    if dvae_config_path:
         | 
| 131 | 
             
                        cfg = OmegaConf.load(dvae_config_path)
         | 
| 132 | 
            +
                        dvae = DVAE(**cfg).to(device=device, dtype=dtype_dvae).eval()
         | 
| 133 | 
             
                        assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
         | 
| 134 | 
             
                        dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
         | 
| 135 | 
             
                        self.pretrain_models["dvae"] = dvae
         | 
|  | |
| 137 |  | 
| 138 | 
             
                    if gpt_config_path:
         | 
| 139 | 
             
                        cfg = OmegaConf.load(gpt_config_path)
         | 
| 140 | 
            +
                        gpt = GPT_warpper(**cfg).to(device=device, dtype=dtype_gpt).eval()
         | 
| 141 | 
             
                        assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
         | 
| 142 | 
             
                        gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
         | 
| 143 | 
             
                        if compile and "cuda" in str(device):
         | 
|  | |
| 150 | 
             
                        assert os.path.exists(
         | 
| 151 | 
             
                            spk_stat_path
         | 
| 152 | 
             
                        ), f"Missing spk_stat.pt: {spk_stat_path}"
         | 
| 153 | 
            +
                        self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to(
         | 
| 154 | 
            +
                            device=device, dtype=dtype
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
             
                        self.logger.log(logging.INFO, "gpt loaded.")
         | 
| 157 |  | 
| 158 | 
             
                    if decoder_config_path:
         | 
| 159 | 
             
                        cfg = OmegaConf.load(decoder_config_path)
         | 
| 160 | 
            +
                        decoder = DVAE(**cfg).to(device=device, dtype=dtype_decoder).eval()
         | 
| 161 | 
             
                        assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
         | 
| 162 | 
             
                        decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
         | 
| 163 | 
             
                        self.pretrain_models["decoder"] = decoder
         | 
    	
        modules/ChatTTS/ChatTTS/model/dvae.py
    CHANGED
    
    | @@ -143,9 +143,9 @@ class DVAE(nn.Module): | |
| 143 | 
             
                    else:
         | 
| 144 | 
             
                        vq_feats = inp.detach().clone()
         | 
| 145 |  | 
| 146 | 
            -
                     | 
| 147 | 
            -
             | 
| 148 | 
            -
                     | 
| 149 |  | 
| 150 | 
             
                    vq_feats = vq_feats.transpose(1, 2)
         | 
| 151 | 
             
                    dec_out = self.decoder(input=vq_feats)
         | 
|  | |
| 143 | 
             
                    else:
         | 
| 144 | 
             
                        vq_feats = inp.detach().clone()
         | 
| 145 |  | 
| 146 | 
            +
                    vq_feats = vq_feats.view(
         | 
| 147 | 
            +
                        (vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
         | 
| 148 | 
            +
                    ).permute(0, 2, 3, 1).flatten(2)
         | 
| 149 |  | 
| 150 | 
             
                    vq_feats = vq_feats.transpose(1, 2)
         | 
| 151 | 
             
                    dec_out = self.decoder(input=vq_feats)
         | 
    	
        modules/ChatTTS/ChatTTS/model/gpt.py
    CHANGED
    
    | @@ -190,6 +190,8 @@ class GPT_warpper(nn.Module): | |
| 190 | 
             
                            attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
         | 
| 191 |  | 
| 192 | 
             
                        for i in tqdm(range(max_new_token)):
         | 
|  | |
|  | |
| 193 |  | 
| 194 | 
             
                            model_input = self.prepare_inputs_for_generation(inputs_ids, 
         | 
| 195 | 
             
                                outputs.past_key_values if i!=0 else None, 
         | 
| @@ -250,9 +252,6 @@ class GPT_warpper(nn.Module): | |
| 250 |  | 
| 251 | 
             
                            end_idx = end_idx + (~finish).int()
         | 
| 252 |  | 
| 253 | 
            -
                            if finish.all():
         | 
| 254 | 
            -
                                break
         | 
| 255 | 
            -
                        
         | 
| 256 | 
             
                        inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
         | 
| 257 | 
             
                        inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
         | 
| 258 |  | 
|  | |
| 190 | 
             
                            attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
         | 
| 191 |  | 
| 192 | 
             
                        for i in tqdm(range(max_new_token)):
         | 
| 193 | 
            +
                            if finish.all():
         | 
| 194 | 
            +
                                continue
         | 
| 195 |  | 
| 196 | 
             
                            model_input = self.prepare_inputs_for_generation(inputs_ids, 
         | 
| 197 | 
             
                                outputs.past_key_values if i!=0 else None, 
         | 
|  | |
| 252 |  | 
| 253 | 
             
                            end_idx = end_idx + (~finish).int()
         | 
| 254 |  | 
|  | |
|  | |
|  | |
| 255 | 
             
                        inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
         | 
| 256 | 
             
                        inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
         | 
| 257 |  | 
    	
        modules/ChatTTS/ChatTTS/utils/gpu_utils.py
    CHANGED
    
    | @@ -16,8 +16,10 @@ def select_device(min_memory = 2048): | |
| 16 | 
             
                    if free_memory_mb < min_memory:
         | 
| 17 | 
             
                        logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
         | 
| 18 | 
             
                        device = torch.device('cpu')
         | 
|  | |
|  | |
| 19 | 
             
                else:
         | 
| 20 | 
             
                    logger.log(logging.WARNING, f'No GPU found, use CPU instead')
         | 
| 21 | 
             
                    device = torch.device('cpu')
         | 
| 22 |  | 
| 23 | 
            -
                return device
         | 
|  | |
| 16 | 
             
                    if free_memory_mb < min_memory:
         | 
| 17 | 
             
                        logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
         | 
| 18 | 
             
                        device = torch.device('cpu')
         | 
| 19 | 
            +
                elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
         | 
| 20 | 
            +
                    device = torch.device('mps')
         | 
| 21 | 
             
                else:
         | 
| 22 | 
             
                    logger.log(logging.WARNING, f'No GPU found, use CPU instead')
         | 
| 23 | 
             
                    device = torch.device('cpu')
         | 
| 24 |  | 
| 25 | 
            +
                return device
         | 
    	
        modules/ChatTTS/ChatTTS/utils/infer_utils.py
    CHANGED
    
    | @@ -101,8 +101,8 @@ character_map = { | |
| 101 | 
             
                "!": ".",
         | 
| 102 | 
             
                "(": ",",
         | 
| 103 | 
             
                ")": ",",
         | 
| 104 | 
            -
                 | 
| 105 | 
            -
                 | 
| 106 | 
             
                ">": ",",
         | 
| 107 | 
             
                "<": ",",
         | 
| 108 | 
             
                "-": ",",
         | 
| @@ -131,11 +131,11 @@ halfwidth_2_fullwidth_map = { | |
| 131 | 
             
                ">": ">",
         | 
| 132 | 
             
                "?": "?",
         | 
| 133 | 
             
                "@": "@",
         | 
| 134 | 
            -
                 | 
| 135 | 
             
                "\\": "\",
         | 
| 136 | 
            -
                 | 
| 137 | 
             
                "^": "^",
         | 
| 138 | 
            -
                 | 
| 139 | 
             
                "`": "`",
         | 
| 140 | 
             
                "{": "{",
         | 
| 141 | 
             
                "|": "|",
         | 
|  | |
| 101 | 
             
                "!": ".",
         | 
| 102 | 
             
                "(": ",",
         | 
| 103 | 
             
                ")": ",",
         | 
| 104 | 
            +
                "[": ",",
         | 
| 105 | 
            +
                "]": ",",
         | 
| 106 | 
             
                ">": ",",
         | 
| 107 | 
             
                "<": ",",
         | 
| 108 | 
             
                "-": ",",
         | 
|  | |
| 131 | 
             
                ">": ">",
         | 
| 132 | 
             
                "?": "?",
         | 
| 133 | 
             
                "@": "@",
         | 
| 134 | 
            +
                "[": "[",
         | 
| 135 | 
             
                "\\": "\",
         | 
| 136 | 
            +
                "]": "]",
         | 
| 137 | 
             
                "^": "^",
         | 
| 138 | 
            +
                "_": "_",
         | 
| 139 | 
             
                "`": "`",
         | 
| 140 | 
             
                "{": "{",
         | 
| 141 | 
             
                "|": "|",
         | 
    	
        modules/SynthesizeSegments.py
    CHANGED
    
    | @@ -1,6 +1,6 @@ | |
| 1 | 
             
            import numpy as np
         | 
| 2 | 
             
            from pydub import AudioSegment
         | 
| 3 | 
            -
            from typing import Any, List, Dict
         | 
| 4 | 
             
            from scipy.io.wavfile import write
         | 
| 5 | 
             
            import io
         | 
| 6 | 
             
            from modules.utils.audio import time_stretch, pitch_shift
         | 
| @@ -211,7 +211,7 @@ def generate_audio_segment( | |
| 211 | 
             
                return AudioSegment.from_file(byte_io, format="wav")
         | 
| 212 |  | 
| 213 |  | 
| 214 | 
            -
            def synthesize_segment(segment: Dict[str, Any]) -> AudioSegment  | 
| 215 | 
             
                if "break" in segment:
         | 
| 216 | 
             
                    pause_segment = AudioSegment.silent(duration=segment["break"])
         | 
| 217 | 
             
                    return pause_segment
         | 
|  | |
| 1 | 
             
            import numpy as np
         | 
| 2 | 
             
            from pydub import AudioSegment
         | 
| 3 | 
            +
            from typing import Any, List, Dict, Union
         | 
| 4 | 
             
            from scipy.io.wavfile import write
         | 
| 5 | 
             
            import io
         | 
| 6 | 
             
            from modules.utils.audio import time_stretch, pitch_shift
         | 
|  | |
| 211 | 
             
                return AudioSegment.from_file(byte_io, format="wav")
         | 
| 212 |  | 
| 213 |  | 
| 214 | 
            +
            def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
         | 
| 215 | 
             
                if "break" in segment:
         | 
| 216 | 
             
                    pause_segment = AudioSegment.silent(duration=segment["break"])
         | 
| 217 | 
             
                    return pause_segment
         | 
    	
        modules/api/Api.py
    CHANGED
    
    | @@ -27,7 +27,18 @@ class APIManager: | |
| 27 | 
             
                def __init__(self, no_docs=False, exclude_patterns=[]):
         | 
| 28 | 
             
                    self.app = FastAPI(
         | 
| 29 | 
             
                        title="ChatTTS Forge API",
         | 
| 30 | 
            -
                        description=" | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
                        version="0.1.0",
         | 
| 32 | 
             
                        redoc_url=None if no_docs else "/redoc",
         | 
| 33 | 
             
                        docs_url=None if no_docs else "/docs",
         | 
|  | |
| 27 | 
             
                def __init__(self, no_docs=False, exclude_patterns=[]):
         | 
| 28 | 
             
                    self.app = FastAPI(
         | 
| 29 | 
             
                        title="ChatTTS Forge API",
         | 
| 30 | 
            +
                        description="""
         | 
| 31 | 
            +
            ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
         | 
| 32 | 
            +
            ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            > 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
         | 
| 37 | 
            +
            > All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            > 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
         | 
| 40 | 
            +
            > [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
         | 
| 41 | 
            +
                        """,
         | 
| 42 | 
             
                        version="0.1.0",
         | 
| 43 | 
             
                        redoc_url=None if no_docs else "/redoc",
         | 
| 44 | 
             
                        docs_url=None if no_docs else "/docs",
         | 
    	
        modules/api/impl/google_api.py
    CHANGED
    
    | @@ -30,6 +30,7 @@ class SynthesisInput(BaseModel): | |
| 30 |  | 
| 31 | 
             
            class VoiceSelectionParams(BaseModel):
         | 
| 32 | 
             
                languageCode: str = "ZH-CN"
         | 
|  | |
| 33 | 
             
                name: str = "female2"
         | 
| 34 | 
             
                style: str = ""
         | 
| 35 | 
             
                temperature: float = 0.3
         | 
| @@ -160,6 +161,18 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): | |
| 160 |  | 
| 161 |  | 
| 162 | 
             
            def setup(app: APIManager):
         | 
| 163 | 
            -
                app.post( | 
| 164 | 
            -
                     | 
| 165 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 30 |  | 
| 31 | 
             
            class VoiceSelectionParams(BaseModel):
         | 
| 32 | 
             
                languageCode: str = "ZH-CN"
         | 
| 33 | 
            +
             | 
| 34 | 
             
                name: str = "female2"
         | 
| 35 | 
             
                style: str = ""
         | 
| 36 | 
             
                temperature: float = 0.3
         | 
|  | |
| 161 |  | 
| 162 |  | 
| 163 | 
             
            def setup(app: APIManager):
         | 
| 164 | 
            +
                app.post(
         | 
| 165 | 
            +
                    "/v1/text:synthesize",
         | 
| 166 | 
            +
                    response_model=GoogleTextSynthesizeResponse,
         | 
| 167 | 
            +
                    description="""
         | 
| 168 | 
            +
            google api document: <br/>
         | 
| 169 | 
            +
            [https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            - 多个属性在本系统中无用仅仅是为了兼容google api
         | 
| 172 | 
            +
            - voice 中的 topP, topK, temperature 为本系统中的参数
         | 
| 173 | 
            +
            - voice.name 即 speaker name (或者speaker seed)
         | 
| 174 | 
            +
            - voice.seed 为 infer seed (可在webui中测试具体作用)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            - 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
         | 
| 177 | 
            +
                    """,
         | 
| 178 | 
            +
                )(google_text_synthesize)
         | 
    	
        modules/api/impl/models_api.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from modules.api import utils as api_utils
         | 
| 2 | 
            +
            from modules.api.Api import APIManager
         | 
| 3 | 
            +
            from modules.models import reload_chat_tts
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def setup(app: APIManager):
         | 
| 7 | 
            +
                @app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
         | 
| 8 | 
            +
                async def reload_models():
         | 
| 9 | 
            +
                    # Reload models
         | 
| 10 | 
            +
                    reload_chat_tts()
         | 
| 11 | 
            +
                    return api_utils.success_response("Models reloaded")
         | 
    	
        modules/api/impl/openai_api.py
    CHANGED
    
    | @@ -28,11 +28,11 @@ class AudioSpeechRequest(BaseModel): | |
| 28 | 
             
                model: str = "chattts-4w"
         | 
| 29 | 
             
                voice: str = "female2"
         | 
| 30 | 
             
                response_format: Literal["mp3", "wav"] = "mp3"
         | 
| 31 | 
            -
                speed:  | 
| 32 | 
             
                style: str = ""
         | 
| 33 | 
             
                # 是否开启batch合成,小于等于1表示不适用batch
         | 
| 34 | 
             
                # 开启batch合成会自动分割句子
         | 
| 35 | 
            -
                batch_size: int = Field(1, ge=1, le= | 
| 36 | 
             
                spliter_threshold: float = Field(
         | 
| 37 | 
             
                    100, ge=10, le=1024, description="Threshold for sentence spliter"
         | 
| 38 | 
             
                )
         | 
| @@ -64,8 +64,8 @@ async def openai_speech_api( | |
| 64 | 
             
                    params = api_utils.calc_spk_style(spk=voice, style=style)
         | 
| 65 |  | 
| 66 | 
             
                    spk = params.get("spk", -1)
         | 
| 67 | 
            -
                    seed = params.get("seed", 42)
         | 
| 68 | 
            -
                    temperature = params.get("temperature", 0.3)
         | 
| 69 | 
             
                    prompt1 = params.get("prompt1", "")
         | 
| 70 | 
             
                    prompt2 = params.get("prompt2", "")
         | 
| 71 | 
             
                    prefix = params.get("prefix", "")
         | 
| @@ -107,6 +107,18 @@ async def openai_speech_api( | |
| 107 |  | 
| 108 |  | 
| 109 | 
             
            def setup(api_manager: APIManager):
         | 
| 110 | 
            -
                api_manager.post( | 
| 111 | 
            -
                     | 
| 112 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 28 | 
             
                model: str = "chattts-4w"
         | 
| 29 | 
             
                voice: str = "female2"
         | 
| 30 | 
             
                response_format: Literal["mp3", "wav"] = "mp3"
         | 
| 31 | 
            +
                speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
         | 
| 32 | 
             
                style: str = ""
         | 
| 33 | 
             
                # 是否开启batch合成,小于等于1表示不适用batch
         | 
| 34 | 
             
                # 开启batch合成会自动分割句子
         | 
| 35 | 
            +
                batch_size: int = Field(1, ge=1, le=20, description="Batch size")
         | 
| 36 | 
             
                spliter_threshold: float = Field(
         | 
| 37 | 
             
                    100, ge=10, le=1024, description="Threshold for sentence spliter"
         | 
| 38 | 
             
                )
         | 
|  | |
| 64 | 
             
                    params = api_utils.calc_spk_style(spk=voice, style=style)
         | 
| 65 |  | 
| 66 | 
             
                    spk = params.get("spk", -1)
         | 
| 67 | 
            +
                    seed = params.get("seed", request.seed or 42)
         | 
| 68 | 
            +
                    temperature = params.get("temperature", request.temperature or 0.3)
         | 
| 69 | 
             
                    prompt1 = params.get("prompt1", "")
         | 
| 70 | 
             
                    prompt2 = params.get("prompt2", "")
         | 
| 71 | 
             
                    prefix = params.get("prefix", "")
         | 
|  | |
| 107 |  | 
| 108 |  | 
| 109 | 
             
            def setup(api_manager: APIManager):
         | 
| 110 | 
            +
                api_manager.post(
         | 
| 111 | 
            +
                    "/v1/audio/speech",
         | 
| 112 | 
            +
                    response_class=FileResponse,
         | 
| 113 | 
            +
                    description="""
         | 
| 114 | 
            +
            openai api document: 
         | 
| 115 | 
            +
            [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            以下属性为本系统自定义属性,不在openai文档中:
         | 
| 118 | 
            +
            - batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐)
         | 
| 119 | 
            +
            - spliter_threshold: 开启batch合成时,句子分割的阈值
         | 
| 120 | 
            +
            - style: 风格
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            > model 可填任意值
         | 
| 123 | 
            +
                    """,
         | 
| 124 | 
            +
                )(openai_speech_api)
         | 
    	
        modules/api/impl/ping_api.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from modules.api import utils as api_utils
         | 
| 2 | 
            +
            from modules.api.Api import APIManager
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def setup(app: APIManager):
         | 
| 6 | 
            +
                @app.get("/v1/ping", response_model=api_utils.BaseResponse)
         | 
| 7 | 
            +
                async def ping():
         | 
| 8 | 
            +
                    return {"message": "ok", "data": "pong"}
         | 
    	
        modules/api/utils.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            from pydantic import BaseModel
         | 
| 2 | 
            -
            from typing import Any
         | 
| 3 |  | 
| 4 | 
             
            import torch
         | 
| 5 |  | 
| @@ -36,6 +36,10 @@ class BaseResponse(BaseModel): | |
| 36 | 
             
                    }
         | 
| 37 |  | 
| 38 |  | 
|  | |
|  | |
|  | |
|  | |
| 39 | 
             
            def wav_to_mp3(wav_data, bitrate="48k"):
         | 
| 40 | 
             
                audio = AudioSegment.from_wav(
         | 
| 41 | 
             
                    wav_data,
         | 
| @@ -51,7 +55,7 @@ def to_number(value, t, default=0): | |
| 51 | 
             
                    return default
         | 
| 52 |  | 
| 53 |  | 
| 54 | 
            -
            def calc_spk_style(spk: str  | 
| 55 | 
             
                voice_attrs = {
         | 
| 56 | 
             
                    "spk": None,
         | 
| 57 | 
             
                    "seed": None,
         | 
|  | |
| 1 | 
             
            from pydantic import BaseModel
         | 
| 2 | 
            +
            from typing import Any, Union
         | 
| 3 |  | 
| 4 | 
             
            import torch
         | 
| 5 |  | 
|  | |
| 36 | 
             
                    }
         | 
| 37 |  | 
| 38 |  | 
| 39 | 
            +
            def success_response(data: Any, message: str = "Success") -> BaseResponse:
         | 
| 40 | 
            +
                return BaseResponse(message=message, data=data)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
             
            def wav_to_mp3(wav_data, bitrate="48k"):
         | 
| 44 | 
             
                audio = AudioSegment.from_wav(
         | 
| 45 | 
             
                    wav_data,
         | 
|  | |
| 55 | 
             
                    return default
         | 
| 56 |  | 
| 57 |  | 
| 58 | 
            +
            def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
         | 
| 59 | 
             
                voice_attrs = {
         | 
| 60 | 
             
                    "spk": None,
         | 
| 61 | 
             
                    "seed": None,
         | 
    	
        modules/config.py
    CHANGED
    
    | @@ -1,11 +1,5 @@ | |
| 1 | 
            -
             | 
| 2 |  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            args = {}
         | 
| 6 |  | 
| 7 | 
             
            api = None
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            model_config = {"half": False}
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            disable_tqdm = False
         | 
|  | |
| 1 | 
            +
            from modules.utils.JsonObject import JsonObject
         | 
| 2 |  | 
| 3 | 
            +
            runtime_env_vars = JsonObject({})
         | 
|  | |
|  | |
| 4 |  | 
| 5 | 
             
            api = None
         | 
|  | |
|  | |
|  | |
|  | 
    	
        modules/devices/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        modules/devices/devices.py
    ADDED
    
    | @@ -0,0 +1,160 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from functools import lru_cache
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from modules import config
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            if sys.platform == "darwin":
         | 
| 11 | 
            +
                from modules.devices import mac_devices
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def has_mps() -> bool:
         | 
| 15 | 
            +
                if sys.platform != "darwin":
         | 
| 16 | 
            +
                    return False
         | 
| 17 | 
            +
                else:
         | 
| 18 | 
            +
                    return mac_devices.has_mps
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def get_cuda_device_id():
         | 
| 22 | 
            +
                return (
         | 
| 23 | 
            +
                    int(config.runtime_env_vars.device_id)
         | 
| 24 | 
            +
                    if config.runtime_env_vars.device_id is not None
         | 
| 25 | 
            +
                    and config.runtime_env_vars.device_id.isdigit()
         | 
| 26 | 
            +
                    else 0
         | 
| 27 | 
            +
                ) or torch.cuda.current_device()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def get_cuda_device_string():
         | 
| 31 | 
            +
                if config.runtime_env_vars.device_id is not None:
         | 
| 32 | 
            +
                    return f"cuda:{config.runtime_env_vars.device_id}"
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                return "cuda"
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def get_available_gpus() -> list[tuple[int, int]]:
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Get the list of available GPUs and their free memory.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                :return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                available_gpus = []
         | 
| 44 | 
            +
                for i in range(torch.cuda.device_count()):
         | 
| 45 | 
            +
                    props = torch.cuda.get_device_properties(i)
         | 
| 46 | 
            +
                    free_memory = props.total_memory - torch.cuda.memory_reserved(i)
         | 
| 47 | 
            +
                    available_gpus.append((i, free_memory))
         | 
| 48 | 
            +
                return available_gpus
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_memory_available_gpus(min_memory=2048):
         | 
| 52 | 
            +
                available_gpus = get_available_gpus()
         | 
| 53 | 
            +
                memory_available_gpus = [
         | 
| 54 | 
            +
                    gpu for gpu, free_memory in available_gpus if free_memory > min_memory
         | 
| 55 | 
            +
                ]
         | 
| 56 | 
            +
                return memory_available_gpus
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def get_target_device_id_or_memory_available_gpu():
         | 
| 60 | 
            +
                memory_available_gpus = get_memory_available_gpus()
         | 
| 61 | 
            +
                device_id = get_cuda_device_id()
         | 
| 62 | 
            +
                if device_id not in memory_available_gpus:
         | 
| 63 | 
            +
                    if len(memory_available_gpus) != 0:
         | 
| 64 | 
            +
                        logger.warning(
         | 
| 65 | 
            +
                            f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
         | 
| 66 | 
            +
                        )
         | 
| 67 | 
            +
                        config.runtime_env_vars.device_id = str(memory_available_gpus[0])
         | 
| 68 | 
            +
                    else:
         | 
| 69 | 
            +
                        logger.warning(
         | 
| 70 | 
            +
                            f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
         | 
| 71 | 
            +
                        )
         | 
| 72 | 
            +
                        return "cpu"
         | 
| 73 | 
            +
                return get_cuda_device_string()
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def get_optimal_device_name():
         | 
| 77 | 
            +
                if config.runtime_env_vars.use_cpu:
         | 
| 78 | 
            +
                    return "cpu"
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                if torch.cuda.is_available():
         | 
| 81 | 
            +
                    return get_target_device_id_or_memory_available_gpu()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                if has_mps():
         | 
| 84 | 
            +
                    return "mps"
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                return "cpu"
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def get_optimal_device():
         | 
| 90 | 
            +
                return torch.device(get_optimal_device_name())
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def get_device_for(task):
         | 
| 94 | 
            +
                if task in config.cmd_opts.use_cpu or "all" in config.cmd_opts.use_cpu:
         | 
| 95 | 
            +
                    return cpu
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return get_optimal_device()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def torch_gc():
         | 
| 101 | 
            +
                try:
         | 
| 102 | 
            +
                    if torch.cuda.is_available():
         | 
| 103 | 
            +
                        with torch.cuda.device(get_cuda_device_string()):
         | 
| 104 | 
            +
                            torch.cuda.empty_cache()
         | 
| 105 | 
            +
                            torch.cuda.ipc_collect()
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    if has_mps():
         | 
| 108 | 
            +
                        mac_devices.torch_mps_gc()
         | 
| 109 | 
            +
                except Exception as e:
         | 
| 110 | 
            +
                    logger.error(f"Error in torch_gc", exc_info=True)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            cpu: torch.device = torch.device("cpu")
         | 
| 114 | 
            +
            device: torch.device = get_optimal_device()
         | 
| 115 | 
            +
            dtype: torch.dtype = torch.float32
         | 
| 116 | 
            +
            dtype_dvae: torch.dtype = torch.float32
         | 
| 117 | 
            +
            dtype_vocos: torch.dtype = torch.float32
         | 
| 118 | 
            +
            dtype_gpt: torch.dtype = torch.float32
         | 
| 119 | 
            +
            dtype_decoder: torch.dtype = torch.float32
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def reset_device():
         | 
| 123 | 
            +
                if config.runtime_env_vars.half:
         | 
| 124 | 
            +
                    global dtype
         | 
| 125 | 
            +
                    global dtype_dvae
         | 
| 126 | 
            +
                    global dtype_vocos
         | 
| 127 | 
            +
                    global dtype_gpt
         | 
| 128 | 
            +
                    global dtype_decoder
         | 
| 129 | 
            +
                    dtype = torch.float16
         | 
| 130 | 
            +
                    dtype_dvae = torch.float16
         | 
| 131 | 
            +
                    dtype_vocos = torch.float16
         | 
| 132 | 
            +
                    dtype_gpt = torch.float16
         | 
| 133 | 
            +
                    dtype_decoder = torch.float16
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    logger.info("Using half precision: torch.float16")
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                if (
         | 
| 138 | 
            +
                    config.runtime_env_vars.device_id is not None
         | 
| 139 | 
            +
                    or config.runtime_env_vars.use_cpu is not None
         | 
| 140 | 
            +
                ):
         | 
| 141 | 
            +
                    global device
         | 
| 142 | 
            +
                    device = get_optimal_device()
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    logger.info(f"Using device: {device}")
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            @lru_cache
         | 
| 148 | 
            +
            def first_time_calculation():
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
         | 
| 151 | 
            +
                spends about 2.7 seconds doing that, at least wih NVidia.
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                x = torch.zeros((1, 1)).to(device, dtype)
         | 
| 155 | 
            +
                linear = torch.nn.Linear(1, 1).to(device, dtype)
         | 
| 156 | 
            +
                linear(x)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
         | 
| 159 | 
            +
                conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
         | 
| 160 | 
            +
                conv2d(x)
         | 
    	
        modules/devices/mac_devices.py
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            from packaging import version
         | 
| 4 | 
            +
            import torch.backends
         | 
| 5 | 
            +
            import torch.backends.mps
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def check_for_mps() -> bool:
         | 
| 11 | 
            +
                if version.parse(torch.__version__) <= version.parse("2.0.1"):
         | 
| 12 | 
            +
                    if not getattr(torch, "has_mps", False):
         | 
| 13 | 
            +
                        return False
         | 
| 14 | 
            +
                    try:
         | 
| 15 | 
            +
                        torch.zeros(1).to(torch.device("mps"))
         | 
| 16 | 
            +
                        return True
         | 
| 17 | 
            +
                    except Exception:
         | 
| 18 | 
            +
                        return False
         | 
| 19 | 
            +
                else:
         | 
| 20 | 
            +
                    try:
         | 
| 21 | 
            +
                        return torch.backends.mps.is_available() and torch.backends.mps.is_built()
         | 
| 22 | 
            +
                    except:
         | 
| 23 | 
            +
                        logger.warning("MPS garbage collection failed", exc_info=True)
         | 
| 24 | 
            +
                        return False
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            has_mps = check_for_mps()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def torch_mps_gc() -> None:
         | 
| 31 | 
            +
                try:
         | 
| 32 | 
            +
                    from torch.mps import empty_cache
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    empty_cache()
         | 
| 35 | 
            +
                except Exception:
         | 
| 36 | 
            +
                    logger.warning("MPS garbage collection failed", exc_info=True)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            if __name__ == "__main__":
         | 
| 40 | 
            +
                print(torch.__version__)
         | 
| 41 | 
            +
                print(has_mps)
         | 
| 42 | 
            +
                torch_mps_gc()
         | 
    	
        modules/generate_audio.py
    CHANGED
    
    | @@ -8,18 +8,20 @@ from modules import models, config | |
| 8 |  | 
| 9 | 
             
            import logging
         | 
| 10 |  | 
| 11 | 
            -
            from modules import devices
         | 
|  | |
|  | |
|  | |
| 12 |  | 
| 13 | 
             
            logger = logging.getLogger(__name__)
         | 
| 14 |  | 
| 15 |  | 
| 16 | 
            -
            @torch.inference_mode()
         | 
| 17 | 
             
            def generate_audio(
         | 
| 18 | 
             
                text: str,
         | 
| 19 | 
             
                temperature: float = 0.3,
         | 
| 20 | 
             
                top_P: float = 0.7,
         | 
| 21 | 
             
                top_K: float = 20,
         | 
| 22 | 
            -
                spk: int  | 
| 23 | 
             
                infer_seed: int = -1,
         | 
| 24 | 
             
                use_decoder: bool = True,
         | 
| 25 | 
             
                prompt1: str = "",
         | 
| @@ -48,7 +50,7 @@ def generate_audio_batch( | |
| 48 | 
             
                temperature: float = 0.3,
         | 
| 49 | 
             
                top_P: float = 0.7,
         | 
| 50 | 
             
                top_K: float = 20,
         | 
| 51 | 
            -
                spk: int  | 
| 52 | 
             
                infer_seed: int = -1,
         | 
| 53 | 
             
                use_decoder: bool = True,
         | 
| 54 | 
             
                prompt1: str = "",
         | 
| @@ -65,7 +67,7 @@ def generate_audio_batch( | |
| 65 | 
             
                    "prompt2": prompt2 or "",
         | 
| 66 | 
             
                    "prefix": prefix or "",
         | 
| 67 | 
             
                    "repetition_penalty": 1.0,
         | 
| 68 | 
            -
                    "disable_tqdm": config. | 
| 69 | 
             
                }
         | 
| 70 |  | 
| 71 | 
             
                if isinstance(spk, int):
         | 
| @@ -103,6 +105,32 @@ def generate_audio_batch( | |
| 103 | 
             
                return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
         | 
| 104 |  | 
| 105 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 106 | 
             
            if __name__ == "__main__":
         | 
| 107 | 
             
                import soundfile as sf
         | 
| 108 |  | 
|  | |
| 8 |  | 
| 9 | 
             
            import logging
         | 
| 10 |  | 
| 11 | 
            +
            from modules.devices import devices
         | 
| 12 | 
            +
            from typing import Union
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from modules.utils.cache import conditional_cache
         | 
| 15 |  | 
| 16 | 
             
            logger = logging.getLogger(__name__)
         | 
| 17 |  | 
| 18 |  | 
|  | |
| 19 | 
             
            def generate_audio(
         | 
| 20 | 
             
                text: str,
         | 
| 21 | 
             
                temperature: float = 0.3,
         | 
| 22 | 
             
                top_P: float = 0.7,
         | 
| 23 | 
             
                top_K: float = 20,
         | 
| 24 | 
            +
                spk: Union[int, Speaker] = -1,
         | 
| 25 | 
             
                infer_seed: int = -1,
         | 
| 26 | 
             
                use_decoder: bool = True,
         | 
| 27 | 
             
                prompt1: str = "",
         | 
|  | |
| 50 | 
             
                temperature: float = 0.3,
         | 
| 51 | 
             
                top_P: float = 0.7,
         | 
| 52 | 
             
                top_K: float = 20,
         | 
| 53 | 
            +
                spk: Union[int, Speaker] = -1,
         | 
| 54 | 
             
                infer_seed: int = -1,
         | 
| 55 | 
             
                use_decoder: bool = True,
         | 
| 56 | 
             
                prompt1: str = "",
         | 
|  | |
| 67 | 
             
                    "prompt2": prompt2 or "",
         | 
| 68 | 
             
                    "prefix": prefix or "",
         | 
| 69 | 
             
                    "repetition_penalty": 1.0,
         | 
| 70 | 
            +
                    "disable_tqdm": config.runtime_env_vars.off_tqdm,
         | 
| 71 | 
             
                }
         | 
| 72 |  | 
| 73 | 
             
                if isinstance(spk, int):
         | 
|  | |
| 105 | 
             
                return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
         | 
| 106 |  | 
| 107 |  | 
| 108 | 
            +
            lru_cache_enabled = False
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def setup_lru_cache():
         | 
| 112 | 
            +
                global generate_audio_batch
         | 
| 113 | 
            +
                global lru_cache_enabled
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                if lru_cache_enabled:
         | 
| 116 | 
            +
                    return
         | 
| 117 | 
            +
                lru_cache_enabled = True
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def should_cache(*args, **kwargs):
         | 
| 120 | 
            +
                    spk_seed = kwargs.get("spk", -1)
         | 
| 121 | 
            +
                    infer_seed = kwargs.get("infer_seed", -1)
         | 
| 122 | 
            +
                    return spk_seed != -1 and infer_seed != -1
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                lru_size = config.runtime_env_vars.lru_size
         | 
| 125 | 
            +
                if isinstance(lru_size, int):
         | 
| 126 | 
            +
                    generate_audio_batch = conditional_cache(lru_size, should_cache)(
         | 
| 127 | 
            +
                        generate_audio_batch
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
                    logger.info(f"LRU cache enabled with size {lru_size}")
         | 
| 130 | 
            +
                else:
         | 
| 131 | 
            +
                    logger.debug(f"LRU cache failed to enable, invalid size {lru_size}")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
             
            if __name__ == "__main__":
         | 
| 135 | 
             
                import soundfile as sf
         | 
| 136 |  | 
    	
        modules/models.py
    CHANGED
    
    | @@ -1,15 +1,11 @@ | |
| 1 | 
            -
            from modules.ChatTTS import ChatTTS
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
            -
             | 
| 4 | 
             
            from modules import config
         | 
|  | |
| 5 |  | 
| 6 | 
             
            import logging
         | 
| 7 |  | 
| 8 | 
             
            logger = logging.getLogger(__name__)
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         | 
| 11 | 
            -
            print(f"device use {device}")
         | 
| 12 | 
            -
             | 
| 13 | 
             
            chat_tts = None
         | 
| 14 |  | 
| 15 |  | 
| @@ -17,25 +13,33 @@ def load_chat_tts(): | |
| 17 | 
             
                global chat_tts
         | 
| 18 | 
             
                if chat_tts:
         | 
| 19 | 
             
                    return chat_tts
         | 
|  | |
| 20 | 
             
                chat_tts = ChatTTS.Chat()
         | 
| 21 | 
             
                chat_tts.load_models(
         | 
| 22 | 
            -
                    compile=config. | 
| 23 | 
             
                    source="local",
         | 
| 24 | 
             
                    local_path="./models/ChatTTS",
         | 
| 25 | 
            -
                    device=device,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 26 | 
             
                )
         | 
| 27 |  | 
| 28 | 
            -
                 | 
| 29 | 
            -
                    logging.info("half precision enabled")
         | 
| 30 | 
            -
                    for model_name, model in chat_tts.pretrain_models.items():
         | 
| 31 | 
            -
                        if isinstance(model, torch.nn.Module):
         | 
| 32 | 
            -
                            model.cpu()
         | 
| 33 | 
            -
                            if torch.cuda.is_available():
         | 
| 34 | 
            -
                                torch.cuda.empty_cache()
         | 
| 35 | 
            -
                            model.half()
         | 
| 36 | 
            -
                            if torch.cuda.is_available():
         | 
| 37 | 
            -
                                model.cuda()
         | 
| 38 | 
            -
                            model.eval()
         | 
| 39 | 
            -
                            logger.log(logging.INFO, f"{model_name} converted to half precision.")
         | 
| 40 |  | 
| 41 | 
             
                return chat_tts
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
            +
            from modules.ChatTTS import ChatTTS
         | 
| 3 | 
             
            from modules import config
         | 
| 4 | 
            +
            from modules.devices import devices
         | 
| 5 |  | 
| 6 | 
             
            import logging
         | 
| 7 |  | 
| 8 | 
             
            logger = logging.getLogger(__name__)
         | 
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
            chat_tts = None
         | 
| 10 |  | 
| 11 |  | 
|  | |
| 13 | 
             
                global chat_tts
         | 
| 14 | 
             
                if chat_tts:
         | 
| 15 | 
             
                    return chat_tts
         | 
| 16 | 
            +
             | 
| 17 | 
             
                chat_tts = ChatTTS.Chat()
         | 
| 18 | 
             
                chat_tts.load_models(
         | 
| 19 | 
            +
                    compile=config.runtime_env_vars.compile,
         | 
| 20 | 
             
                    source="local",
         | 
| 21 | 
             
                    local_path="./models/ChatTTS",
         | 
| 22 | 
            +
                    device=devices.device,
         | 
| 23 | 
            +
                    dtype=devices.dtype,
         | 
| 24 | 
            +
                    dtype_vocos=devices.dtype_vocos,
         | 
| 25 | 
            +
                    dtype_dvae=devices.dtype_dvae,
         | 
| 26 | 
            +
                    dtype_gpt=devices.dtype_gpt,
         | 
| 27 | 
            +
                    dtype_decoder=devices.dtype_decoder,
         | 
| 28 | 
             
                )
         | 
| 29 |  | 
| 30 | 
            +
                devices.torch_gc()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 |  | 
| 32 | 
             
                return chat_tts
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def reload_chat_tts():
         | 
| 36 | 
            +
                logging.info("Reloading ChatTTS models")
         | 
| 37 | 
            +
                global chat_tts
         | 
| 38 | 
            +
                if chat_tts:
         | 
| 39 | 
            +
                    if torch.cuda.is_available():
         | 
| 40 | 
            +
                        for model_name, model in chat_tts.pretrain_models.items():
         | 
| 41 | 
            +
                            if isinstance(model, torch.nn.Module):
         | 
| 42 | 
            +
                                model.cpu()
         | 
| 43 | 
            +
                        torch.cuda.empty_cache()
         | 
| 44 | 
            +
                chat_tts = None
         | 
| 45 | 
            +
                return load_chat_tts()
         | 
    	
        modules/normalization.py
    CHANGED
    
    | @@ -1,6 +1,15 @@ | |
| 1 | 
             
            from modules.utils.zh_normalization.text_normlization import *
         | 
| 2 | 
             
            import emojiswitch
         | 
| 3 | 
             
            from modules.utils.markdown import markdown_to_text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 |  | 
| 5 | 
             
            post_normalize_pipeline = []
         | 
| 6 | 
             
            pre_normalize_pipeline = []
         | 
| @@ -87,12 +96,17 @@ character_map = { | |
| 87 | 
             
                ">": ",",
         | 
| 88 | 
             
                "<": ",",
         | 
| 89 | 
             
                "-": ",",
         | 
|  | |
|  | |
|  | |
| 90 | 
             
            }
         | 
| 91 |  | 
| 92 | 
             
            character_to_word = {
         | 
| 93 | 
             
                " & ": " and ",
         | 
| 94 | 
             
            }
         | 
| 95 |  | 
|  | |
|  | |
| 96 |  | 
| 97 | 
             
            @post_normalize()
         | 
| 98 | 
             
            def apply_character_to_word(text):
         | 
| @@ -109,7 +123,8 @@ def apply_character_map(text): | |
| 109 |  | 
| 110 | 
             
            @post_normalize()
         | 
| 111 | 
             
            def apply_emoji_map(text):
         | 
| 112 | 
            -
                 | 
|  | |
| 113 |  | 
| 114 |  | 
| 115 | 
             
            @post_normalize()
         | 
| @@ -122,6 +137,26 @@ def insert_spaces_between_uppercase(s): | |
| 122 | 
             
                )
         | 
| 123 |  | 
| 124 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 125 | 
             
            @pre_normalize()
         | 
| 126 | 
             
            def apply_markdown_to_text(text):
         | 
| 127 | 
             
                if is_markdown(text):
         | 
| @@ -186,7 +221,7 @@ def sentence_normalize(sentence_text: str): | |
| 186 | 
             
                pattern = re.compile(r"(\[.+?\])|([^[]+)")
         | 
| 187 |  | 
| 188 | 
             
                def normalize_part(part):
         | 
| 189 | 
            -
                    sentences = tx.normalize(part)
         | 
| 190 | 
             
                    dest_text = ""
         | 
| 191 | 
             
                    for sentence in sentences:
         | 
| 192 | 
             
                        sentence = apply_post_normalize(sentence)
         | 
| @@ -244,6 +279,16 @@ console.log('1') | |
| 244 | 
             
            “我们是玫瑰花。”花儿们说道。
         | 
| 245 | 
             
            “啊!”小王子说……。
         | 
| 246 | 
             
                    """,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 247 | 
             
                ]
         | 
| 248 |  | 
| 249 | 
             
                for i, test_case in enumerate(test_cases):
         | 
|  | |
| 1 | 
             
            from modules.utils.zh_normalization.text_normlization import *
         | 
| 2 | 
             
            import emojiswitch
         | 
| 3 | 
             
            from modules.utils.markdown import markdown_to_text
         | 
| 4 | 
            +
            from modules import models
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def is_chinese(text):
         | 
| 9 | 
            +
                # 中文字符的 Unicode 范围是 \u4e00-\u9fff
         | 
| 10 | 
            +
                chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
         | 
| 11 | 
            +
                return bool(chinese_pattern.search(text))
         | 
| 12 | 
            +
             | 
| 13 |  | 
| 14 | 
             
            post_normalize_pipeline = []
         | 
| 15 | 
             
            pre_normalize_pipeline = []
         | 
|  | |
| 96 | 
             
                ">": ",",
         | 
| 97 | 
             
                "<": ",",
         | 
| 98 | 
             
                "-": ",",
         | 
| 99 | 
            +
                "~": " ",
         | 
| 100 | 
            +
                "~": " ",
         | 
| 101 | 
            +
                "/": " ",
         | 
| 102 | 
             
            }
         | 
| 103 |  | 
| 104 | 
             
            character_to_word = {
         | 
| 105 | 
             
                " & ": " and ",
         | 
| 106 | 
             
            }
         | 
| 107 |  | 
| 108 | 
            +
            ## ---------- post normalize ----------
         | 
| 109 | 
            +
             | 
| 110 |  | 
| 111 | 
             
            @post_normalize()
         | 
| 112 | 
             
            def apply_character_to_word(text):
         | 
|  | |
| 123 |  | 
| 124 | 
             
            @post_normalize()
         | 
| 125 | 
             
            def apply_emoji_map(text):
         | 
| 126 | 
            +
                lang = "zh" if is_chinese(text) else "en"
         | 
| 127 | 
            +
                return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
         | 
| 128 |  | 
| 129 |  | 
| 130 | 
             
            @post_normalize()
         | 
|  | |
| 137 | 
             
                )
         | 
| 138 |  | 
| 139 |  | 
| 140 | 
            +
            @post_normalize()
         | 
| 141 | 
            +
            def replace_unk_tokens(text):
         | 
| 142 | 
            +
                """
         | 
| 143 | 
            +
                把不在字典里的字符替换为 " , "
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                chat_tts = models.load_chat_tts()
         | 
| 146 | 
            +
                tokenizer = chat_tts.pretrain_models["tokenizer"]
         | 
| 147 | 
            +
                vocab = tokenizer.get_vocab()
         | 
| 148 | 
            +
                vocab_set = set(vocab.keys())
         | 
| 149 | 
            +
                # 添加所有英语字符
         | 
| 150 | 
            +
                vocab_set.update(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"))
         | 
| 151 | 
            +
                vocab_set.update(set(" \n\r\t"))
         | 
| 152 | 
            +
                replaced_chars = [char if char in vocab_set else " , " for char in text]
         | 
| 153 | 
            +
                output_text = "".join(replaced_chars)
         | 
| 154 | 
            +
                return output_text
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            ## ---------- pre normalize ----------
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
             
            @pre_normalize()
         | 
| 161 | 
             
            def apply_markdown_to_text(text):
         | 
| 162 | 
             
                if is_markdown(text):
         | 
|  | |
| 221 | 
             
                pattern = re.compile(r"(\[.+?\])|([^[]+)")
         | 
| 222 |  | 
| 223 | 
             
                def normalize_part(part):
         | 
| 224 | 
            +
                    sentences = tx.normalize(part) if is_chinese(part) else [part]
         | 
| 225 | 
             
                    dest_text = ""
         | 
| 226 | 
             
                    for sentence in sentences:
         | 
| 227 | 
             
                        sentence = apply_post_normalize(sentence)
         | 
|  | |
| 279 | 
             
            “我们是玫瑰花。”花儿们说道。
         | 
| 280 | 
             
            “啊!”小王子说……。
         | 
| 281 | 
             
                    """,
         | 
| 282 | 
            +
                    """
         | 
| 283 | 
            +
            State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
         | 
| 284 | 
            +
             | 
| 285 | 
            +
            🤗 Transformers provides APIs and tools to easily download and train state-of-the-art pretrained models. Using pretrained models can reduce your compute costs, carbon footprint, and save you the time and resources required to train a model from scratch. These models support common tasks in different modalities, such as:
         | 
| 286 | 
            +
             | 
| 287 | 
            +
            📝 Natural Language Processing: text classification, named entity recognition, question answering, language modeling, summarization, translation, multiple choice, and text generation.
         | 
| 288 | 
            +
            🖼️ Computer Vision: image classification, object detection, and segmentation.
         | 
| 289 | 
            +
            🗣️ Audio: automatic speech recognition and audio classification.
         | 
| 290 | 
            +
            🐙 Multimodal: table question answering, optical character recognition, information extraction from scanned documents, video classification, and visual question answering.
         | 
| 291 | 
            +
                    """,
         | 
| 292 | 
             
                ]
         | 
| 293 |  | 
| 294 | 
             
                for i, test_case in enumerate(test_cases):
         | 
    	
        modules/refiner.py
    CHANGED
    
    | @@ -29,7 +29,7 @@ def refine_text( | |
| 29 | 
             
                            "temperature": temperature,
         | 
| 30 | 
             
                            "repetition_penalty": repetition_penalty,
         | 
| 31 | 
             
                            "max_new_token": max_new_token,
         | 
| 32 | 
            -
                            "disable_tqdm": config. | 
| 33 | 
             
                        },
         | 
| 34 | 
             
                        do_text_normalization=False,
         | 
| 35 | 
             
                    )
         | 
|  | |
| 29 | 
             
                            "temperature": temperature,
         | 
| 30 | 
             
                            "repetition_penalty": repetition_penalty,
         | 
| 31 | 
             
                            "max_new_token": max_new_token,
         | 
| 32 | 
            +
                            "disable_tqdm": config.runtime_env_vars.off_tqdm,
         | 
| 33 | 
             
                        },
         | 
| 34 | 
             
                        do_text_normalization=False,
         | 
| 35 | 
             
                    )
         | 
    	
        modules/speaker.py
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 | 
             
            import os
         | 
|  | |
| 2 | 
             
            import torch
         | 
| 3 |  | 
| 4 | 
             
            from modules import models
         | 
| @@ -53,6 +54,14 @@ class Speaker: | |
| 53 |  | 
| 54 | 
             
                    return is_update
         | 
| 55 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 56 |  | 
| 57 | 
             
            # 每个speaker就是一个 emb 文件 .pt
         | 
| 58 | 
             
            # 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
         | 
| @@ -105,13 +114,13 @@ class SpeakerManager: | |
| 105 | 
             
                    self.refresh_speakers()
         | 
| 106 | 
             
                    return speaker
         | 
| 107 |  | 
| 108 | 
            -
                def get_speaker(self, name) -> Speaker  | 
| 109 | 
             
                    for speaker in self.speakers.values():
         | 
| 110 | 
             
                        if speaker.name == name:
         | 
| 111 | 
             
                            return speaker
         | 
| 112 | 
             
                    return None
         | 
| 113 |  | 
| 114 | 
            -
                def get_speaker_by_id(self, id) -> Speaker  | 
| 115 | 
             
                    for speaker in self.speakers.values():
         | 
| 116 | 
             
                        if str(speaker.id) == str(id):
         | 
| 117 | 
             
                            return speaker
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
            from typing import Union
         | 
| 3 | 
             
            import torch
         | 
| 4 |  | 
| 5 | 
             
            from modules import models
         | 
|  | |
| 54 |  | 
| 55 | 
             
                    return is_update
         | 
| 56 |  | 
| 57 | 
            +
                def __hash__(self):
         | 
| 58 | 
            +
                    return hash(str(self.id))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __eq__(self, other):
         | 
| 61 | 
            +
                    if not isinstance(other, Speaker):
         | 
| 62 | 
            +
                        return False
         | 
| 63 | 
            +
                    return str(self.id) == str(other.id)
         | 
| 64 | 
            +
             | 
| 65 |  | 
| 66 | 
             
            # 每个speaker就是一个 emb 文件 .pt
         | 
| 67 | 
             
            # 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
         | 
|  | |
| 114 | 
             
                    self.refresh_speakers()
         | 
| 115 | 
             
                    return speaker
         | 
| 116 |  | 
| 117 | 
            +
                def get_speaker(self, name) -> Union[Speaker, None]:
         | 
| 118 | 
             
                    for speaker in self.speakers.values():
         | 
| 119 | 
             
                        if speaker.name == name:
         | 
| 120 | 
             
                            return speaker
         | 
| 121 | 
             
                    return None
         | 
| 122 |  | 
| 123 | 
            +
                def get_speaker_by_id(self, id) -> Union[Speaker, None]:
         | 
| 124 | 
             
                    for speaker in self.speakers.values():
         | 
| 125 | 
             
                        if str(speaker.id) == str(id):
         | 
| 126 | 
             
                            return speaker
         | 
    	
        modules/synthesize_audio.py
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 | 
             
            import io
         | 
|  | |
| 2 | 
             
            from modules.SentenceSplitter import SentenceSplitter
         | 
| 3 | 
             
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         | 
| 4 |  | 
| @@ -14,7 +15,7 @@ def synthesize_audio( | |
| 14 | 
             
                temperature: float = 0.3,
         | 
| 15 | 
             
                top_P: float = 0.7,
         | 
| 16 | 
             
                top_K: float = 20,
         | 
| 17 | 
            -
                spk: int  | 
| 18 | 
             
                infer_seed: int = -1,
         | 
| 19 | 
             
                use_decoder: bool = True,
         | 
| 20 | 
             
                prompt1: str = "",
         | 
|  | |
| 1 | 
             
            import io
         | 
| 2 | 
            +
            from typing import Union
         | 
| 3 | 
             
            from modules.SentenceSplitter import SentenceSplitter
         | 
| 4 | 
             
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         | 
| 5 |  | 
|  | |
| 15 | 
             
                temperature: float = 0.3,
         | 
| 16 | 
             
                top_P: float = 0.7,
         | 
| 17 | 
             
                top_K: float = 20,
         | 
| 18 | 
            +
                spk: Union[int, Speaker] = -1,
         | 
| 19 | 
             
                infer_seed: int = -1,
         | 
| 20 | 
             
                use_decoder: bool = True,
         | 
| 21 | 
             
                prompt1: str = "",
         | 
    	
        modules/utils/JsonObject.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            class JsonObject:
         | 
| 2 | 
            +
                def __init__(self, initial_dict=None):
         | 
| 3 | 
            +
                    """
         | 
| 4 | 
            +
                    Initialize the JsonObject with an optional initial dictionary.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
                    :param initial_dict: A dictionary to initialize the JsonObject.
         | 
| 7 | 
            +
                    """
         | 
| 8 | 
            +
                    # If no initial dictionary is provided, use an empty dictionary
         | 
| 9 | 
            +
                    self._dict_obj = initial_dict if initial_dict is not None else {}
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def __getattr__(self, name):
         | 
| 12 | 
            +
                    """
         | 
| 13 | 
            +
                    Get an attribute value. If the attribute does not exist,
         | 
| 14 | 
            +
                    look it up in the internal dictionary.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    :param name: The name of the attribute.
         | 
| 17 | 
            +
                    :return: The value of the attribute.
         | 
| 18 | 
            +
                    :raises AttributeError: If the attribute is not found in the dictionary.
         | 
| 19 | 
            +
                    """
         | 
| 20 | 
            +
                    try:
         | 
| 21 | 
            +
                        return self._dict_obj[name]
         | 
| 22 | 
            +
                    except KeyError:
         | 
| 23 | 
            +
                        return None
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __setattr__(self, name, value):
         | 
| 26 | 
            +
                    """
         | 
| 27 | 
            +
                    Set an attribute value. If the attribute name is '_dict_obj',
         | 
| 28 | 
            +
                    set it directly as an instance attribute. Otherwise,
         | 
| 29 | 
            +
                    store it in the internal dictionary.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    :param name: The name of the attribute.
         | 
| 32 | 
            +
                    :param value: The value to set.
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    if name == "_dict_obj":
         | 
| 35 | 
            +
                        super().__setattr__(name, value)
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        self._dict_obj[name] = value
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __delattr__(self, name):
         | 
| 40 | 
            +
                    """
         | 
| 41 | 
            +
                    Delete an attribute. If the attribute does not exist,
         | 
| 42 | 
            +
                    look it up in the internal dictionary and remove it.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    :param name: The name of the attribute.
         | 
| 45 | 
            +
                    :raises AttributeError: If the attribute is not found in the dictionary.
         | 
| 46 | 
            +
                    """
         | 
| 47 | 
            +
                    try:
         | 
| 48 | 
            +
                        del self._dict_obj[name]
         | 
| 49 | 
            +
                    except KeyError:
         | 
| 50 | 
            +
                        return
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def __getitem__(self, key):
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    Get an item value from the internal dictionary.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    :param key: The key of the item.
         | 
| 57 | 
            +
                    :return: The value of the item.
         | 
| 58 | 
            +
                    :raises KeyError: If the key is not found in the dictionary.
         | 
| 59 | 
            +
                    """
         | 
| 60 | 
            +
                    if key not in self._dict_obj:
         | 
| 61 | 
            +
                        return None
         | 
| 62 | 
            +
                    return self._dict_obj[key]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __setitem__(self, key, value):
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    Set an item value in the internal dictionary.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    :param key: The key of the item.
         | 
| 69 | 
            +
                    :param value: The value to set.
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    self._dict_obj[key] = value
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __delitem__(self, key):
         | 
| 74 | 
            +
                    """
         | 
| 75 | 
            +
                    Delete an item from the internal dictionary.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    :param key: The key of the item.
         | 
| 78 | 
            +
                    :raises KeyError: If the key is not found in the dictionary.
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    del self._dict_obj[key]
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def to_dict(self):
         | 
| 83 | 
            +
                    """
         | 
| 84 | 
            +
                    Convert the JsonObject back to a regular dictionary.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    :return: The internal dictionary.
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    return self._dict_obj
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def has_key(self, key):
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    Check if the key exists in the internal dictionary.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    :param key: The key to check.
         | 
| 95 | 
            +
                    :return: True if the key exists, False otherwise.
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    return key in self._dict_obj
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def keys(self):
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    Get a list of keys in the internal dictionary.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    :return: A list of keys.
         | 
| 104 | 
            +
                    """
         | 
| 105 | 
            +
                    return self._dict_obj.keys()
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def values(self):
         | 
| 108 | 
            +
                    """
         | 
| 109 | 
            +
                    Get a list of values in the internal dictionary.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    :return: A list of values.
         | 
| 112 | 
            +
                    """
         | 
| 113 | 
            +
                    return self._dict_obj.values()
         | 
    	
        modules/utils/cache.py
    ADDED
    
    | @@ -0,0 +1,92 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Callable, TypeVar, Any
         | 
| 2 | 
            +
            from typing_extensions import ParamSpec
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from functools import lru_cache, _CacheInfo
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def conditional_cache(maxsize: int, condition: Callable):
         | 
| 8 | 
            +
                def decorator(func):
         | 
| 9 | 
            +
                    @lru_cache_ext(maxsize=maxsize)
         | 
| 10 | 
            +
                    def cached_func(*args, **kwargs):
         | 
| 11 | 
            +
                        return func(*args, **kwargs)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                    def wrapper(*args, **kwargs):
         | 
| 14 | 
            +
                        if condition(*args, **kwargs):
         | 
| 15 | 
            +
                            return cached_func(*args, **kwargs)
         | 
| 16 | 
            +
                        else:
         | 
| 17 | 
            +
                            return func(*args, **kwargs)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    return wrapper
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                return decorator
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def hash_list(l: list) -> int:
         | 
| 25 | 
            +
                __hash = 0
         | 
| 26 | 
            +
                for i, e in enumerate(l):
         | 
| 27 | 
            +
                    __hash = hash((__hash, i, hash_item(e)))
         | 
| 28 | 
            +
                return __hash
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def hash_dict(d: dict) -> int:
         | 
| 32 | 
            +
                __hash = 0
         | 
| 33 | 
            +
                for k, v in d.items():
         | 
| 34 | 
            +
                    __hash = hash((__hash, k, hash_item(v)))
         | 
| 35 | 
            +
                return __hash
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def hash_item(e) -> int:
         | 
| 39 | 
            +
                if hasattr(e, "__hash__") and callable(e.__hash__):
         | 
| 40 | 
            +
                    try:
         | 
| 41 | 
            +
                        return hash(e)
         | 
| 42 | 
            +
                    except TypeError:
         | 
| 43 | 
            +
                        pass
         | 
| 44 | 
            +
                if isinstance(e, (list, set, tuple)):
         | 
| 45 | 
            +
                    return hash_list(list(e))
         | 
| 46 | 
            +
                elif isinstance(e, (dict)):
         | 
| 47 | 
            +
                    return hash_dict(e)
         | 
| 48 | 
            +
                else:
         | 
| 49 | 
            +
                    raise TypeError(f"unhashable type: {e.__class__}")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            PT = ParamSpec("PT")
         | 
| 53 | 
            +
            RT = TypeVar("RT")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def lru_cache_ext(
         | 
| 57 | 
            +
                *opts, hashfunc: Callable[..., int] = hash_item, **kwopts
         | 
| 58 | 
            +
            ) -> Callable[[Callable[PT, RT]], Callable[PT, RT]]:
         | 
| 59 | 
            +
                def decorator(func: Callable[PT, RT]) -> Callable[PT, RT]:
         | 
| 60 | 
            +
                    class _lru_cache_ext_wrapper:
         | 
| 61 | 
            +
                        args: tuple
         | 
| 62 | 
            +
                        kwargs: dict[str, Any]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                        def cache_info(self) -> _CacheInfo: ...
         | 
| 65 | 
            +
                        def cache_clear(self) -> None: ...
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                        @classmethod
         | 
| 68 | 
            +
                        @lru_cache(*opts, **kwopts)
         | 
| 69 | 
            +
                        def cached_func(cls, args_hash: int) -> RT:
         | 
| 70 | 
            +
                            return func(*cls.args, **cls.kwargs)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        @classmethod
         | 
| 73 | 
            +
                        def __call__(cls, *args: PT.args, **kwargs: PT.kwargs) -> RT:
         | 
| 74 | 
            +
                            __hash = hashfunc(
         | 
| 75 | 
            +
                                (
         | 
| 76 | 
            +
                                    id(func),
         | 
| 77 | 
            +
                                    *[hashfunc(a) for a in args],
         | 
| 78 | 
            +
                                    *[(hashfunc(k), hashfunc(v)) for k, v in kwargs.items()],
         | 
| 79 | 
            +
                                )
         | 
| 80 | 
            +
                            )
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                            cls.args = args
         | 
| 83 | 
            +
                            cls.kwargs = kwargs
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                            cls.cache_info = cls.cached_func.cache_info
         | 
| 86 | 
            +
                            cls.cache_clear = cls.cached_func.cache_clear
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                            return cls.cached_func(__hash)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    return _lru_cache_ext_wrapper()
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                return decorator
         | 
    	
        modules/utils/zh_normalization/text_normlization.py
    CHANGED
    
    | @@ -72,9 +72,9 @@ class TextNormalizer(): | |
| 72 | 
             
                    return sentences
         | 
| 73 |  | 
| 74 | 
             
                def _post_replace(self, sentence: str) -> str:
         | 
| 75 | 
            -
                    sentence = sentence.replace('/', '每')
         | 
| 76 | 
            -
                    sentence = sentence.replace('~', '至')
         | 
| 77 | 
            -
                    sentence = sentence.replace('~', '至')
         | 
| 78 | 
             
                    sentence = sentence.replace('①', '一')
         | 
| 79 | 
             
                    sentence = sentence.replace('②', '二')
         | 
| 80 | 
             
                    sentence = sentence.replace('③', '三')
         | 
|  | |
| 72 | 
             
                    return sentences
         | 
| 73 |  | 
| 74 | 
             
                def _post_replace(self, sentence: str) -> str:
         | 
| 75 | 
            +
                    # sentence = sentence.replace('/', '每')
         | 
| 76 | 
            +
                    # sentence = sentence.replace('~', '至')
         | 
| 77 | 
            +
                    # sentence = sentence.replace('~', '至')
         | 
| 78 | 
             
                    sentence = sentence.replace('①', '一')
         | 
| 79 | 
             
                    sentence = sentence.replace('②', '二')
         | 
| 80 | 
             
                    sentence = sentence.replace('③', '三')
         | 
    	
        webui.py
    CHANGED
    
    | @@ -14,9 +14,11 @@ except: | |
| 14 | 
             
            import os
         | 
| 15 | 
             
            import logging
         | 
| 16 |  | 
| 17 | 
            -
             | 
| 18 |  | 
|  | |
| 19 | 
             
            from modules.synthesize_audio import synthesize_audio
         | 
|  | |
| 20 |  | 
| 21 | 
             
            logging.basicConfig(
         | 
| 22 | 
             
                level=os.getenv("LOG_LEVEL", "INFO"),
         | 
| @@ -25,20 +27,17 @@ logging.basicConfig( | |
| 25 |  | 
| 26 |  | 
| 27 | 
             
            import gradio as gr
         | 
| 28 | 
            -
            import io
         | 
| 29 | 
            -
            import re
         | 
| 30 | 
            -
            import numpy as np
         | 
| 31 |  | 
| 32 | 
             
            import torch
         | 
| 33 |  | 
| 34 | 
             
            from modules.ssml import parse_ssml
         | 
| 35 | 
             
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         | 
| 36 | 
            -
            from modules.generate_audio import generate_audio, generate_audio_batch
         | 
| 37 |  | 
| 38 | 
             
            from modules.speaker import speaker_mgr
         | 
| 39 | 
             
            from modules.data import styles_mgr
         | 
| 40 |  | 
| 41 | 
             
            from modules.api.utils import calc_spk_style
         | 
|  | |
| 42 |  | 
| 43 | 
             
            from modules.normalization import text_normalize
         | 
| 44 | 
             
            from modules import refiner, config
         | 
| @@ -147,7 +146,7 @@ def tts_generate( | |
| 147 | 
             
                prompt1 = prompt1 or params.get("prompt1", "")
         | 
| 148 | 
             
                prompt2 = prompt2 or params.get("prompt2", "")
         | 
| 149 |  | 
| 150 | 
            -
                infer_seed = clip(infer_seed, -1, 2**32 - 1)
         | 
| 151 | 
             
                infer_seed = int(infer_seed)
         | 
| 152 |  | 
| 153 | 
             
                if not disable_normalize:
         | 
| @@ -869,31 +868,59 @@ if __name__ == "__main__": | |
| 869 | 
             
                    type=int,
         | 
| 870 | 
             
                    help="Max batch size for TTS",
         | 
| 871 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 872 |  | 
| 873 | 
             
                args = parser.parse_args()
         | 
| 874 |  | 
| 875 | 
            -
                 | 
| 876 | 
            -
             | 
| 877 | 
            -
             | 
| 878 | 
            -
             | 
| 879 | 
            -
             | 
| 880 | 
            -
             | 
| 881 | 
            -
                 | 
| 882 | 
            -
             | 
| 883 | 
            -
                 | 
| 884 | 
            -
                 | 
| 885 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 886 |  | 
| 887 | 
             
                demo = create_interface()
         | 
| 888 |  | 
| 889 | 
             
                if auth:
         | 
| 890 | 
             
                    auth = tuple(auth.split(":"))
         | 
| 891 |  | 
| 892 | 
            -
                 | 
| 893 | 
            -
             | 
| 894 | 
            -
             | 
| 895 | 
            -
                if off_tqdm:
         | 
| 896 | 
            -
                    config.disable_tqdm = True
         | 
| 897 |  | 
| 898 | 
             
                demo.queue().launch(
         | 
| 899 | 
             
                    server_name=server_name,
         | 
|  | |
| 14 | 
             
            import os
         | 
| 15 | 
             
            import logging
         | 
| 16 |  | 
| 17 | 
            +
            import numpy as np
         | 
| 18 |  | 
| 19 | 
            +
            from modules.devices import devices
         | 
| 20 | 
             
            from modules.synthesize_audio import synthesize_audio
         | 
| 21 | 
            +
            from modules.utils.cache import conditional_cache
         | 
| 22 |  | 
| 23 | 
             
            logging.basicConfig(
         | 
| 24 | 
             
                level=os.getenv("LOG_LEVEL", "INFO"),
         | 
|  | |
| 27 |  | 
| 28 |  | 
| 29 | 
             
            import gradio as gr
         | 
|  | |
|  | |
|  | |
| 30 |  | 
| 31 | 
             
            import torch
         | 
| 32 |  | 
| 33 | 
             
            from modules.ssml import parse_ssml
         | 
| 34 | 
             
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         | 
|  | |
| 35 |  | 
| 36 | 
             
            from modules.speaker import speaker_mgr
         | 
| 37 | 
             
            from modules.data import styles_mgr
         | 
| 38 |  | 
| 39 | 
             
            from modules.api.utils import calc_spk_style
         | 
| 40 | 
            +
            import modules.generate_audio as generate
         | 
| 41 |  | 
| 42 | 
             
            from modules.normalization import text_normalize
         | 
| 43 | 
             
            from modules import refiner, config
         | 
|  | |
| 146 | 
             
                prompt1 = prompt1 or params.get("prompt1", "")
         | 
| 147 | 
             
                prompt2 = prompt2 or params.get("prompt2", "")
         | 
| 148 |  | 
| 149 | 
            +
                infer_seed = np.clip(infer_seed, -1, 2**32 - 1)
         | 
| 150 | 
             
                infer_seed = int(infer_seed)
         | 
| 151 |  | 
| 152 | 
             
                if not disable_normalize:
         | 
|  | |
| 868 | 
             
                    type=int,
         | 
| 869 | 
             
                    help="Max batch size for TTS",
         | 
| 870 | 
             
                )
         | 
| 871 | 
            +
                parser.add_argument(
         | 
| 872 | 
            +
                    "--lru_size",
         | 
| 873 | 
            +
                    type=int,
         | 
| 874 | 
            +
                    default=64,
         | 
| 875 | 
            +
                    help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
         | 
| 876 | 
            +
                )
         | 
| 877 | 
            +
                parser.add_argument(
         | 
| 878 | 
            +
                    "--device_id",
         | 
| 879 | 
            +
                    type=str,
         | 
| 880 | 
            +
                    help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
         | 
| 881 | 
            +
                    default=None,
         | 
| 882 | 
            +
                )
         | 
| 883 | 
            +
                parser.add_argument(
         | 
| 884 | 
            +
                    "--use_cpu",
         | 
| 885 | 
            +
                    nargs="+",
         | 
| 886 | 
            +
                    help="use CPU as torch device for specified modules",
         | 
| 887 | 
            +
                    default=[],
         | 
| 888 | 
            +
                    type=str.lower,
         | 
| 889 | 
            +
                )
         | 
| 890 | 
            +
                parser.add_argument("--compile", action="store_true", help="Enable model compile")
         | 
| 891 |  | 
| 892 | 
             
                args = parser.parse_args()
         | 
| 893 |  | 
| 894 | 
            +
                def get_and_update_env(*args):
         | 
| 895 | 
            +
                    val = env.get_env_or_arg(*args)
         | 
| 896 | 
            +
                    key = args[1]
         | 
| 897 | 
            +
                    config.runtime_env_vars[key] = val
         | 
| 898 | 
            +
                    return val
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
         | 
| 901 | 
            +
                server_port = get_and_update_env(args, "server_port", 7860, int)
         | 
| 902 | 
            +
                share = get_and_update_env(args, "share", False, bool)
         | 
| 903 | 
            +
                debug = get_and_update_env(args, "debug", False, bool)
         | 
| 904 | 
            +
                auth = get_and_update_env(args, "auth", None, str)
         | 
| 905 | 
            +
                half = get_and_update_env(args, "half", False, bool)
         | 
| 906 | 
            +
                off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
         | 
| 907 | 
            +
                lru_size = get_and_update_env(args, "lru_size", 64, int)
         | 
| 908 | 
            +
                device_id = get_and_update_env(args, "device_id", None, str)
         | 
| 909 | 
            +
                use_cpu = get_and_update_env(args, "use_cpu", [], list)
         | 
| 910 | 
            +
                compile = get_and_update_env(args, "compile", False, bool)
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
         | 
| 913 | 
            +
                webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
         | 
| 914 | 
            +
                webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int)
         | 
| 915 |  | 
| 916 | 
             
                demo = create_interface()
         | 
| 917 |  | 
| 918 | 
             
                if auth:
         | 
| 919 | 
             
                    auth = tuple(auth.split(":"))
         | 
| 920 |  | 
| 921 | 
            +
                generate.setup_lru_cache()
         | 
| 922 | 
            +
                devices.reset_device()
         | 
| 923 | 
            +
                devices.first_time_calculation()
         | 
|  | |
|  | |
| 924 |  | 
| 925 | 
             
                demo.queue().launch(
         | 
| 926 | 
             
                    server_name=server_name,
         | 
 
			
